## Wrapper: test randomized smoothing

In [None]:
%load_ext autoreload
%autoreload 2

import torch 
from torch import nn

import torchvision
import torchvision.transforms as transforms

import numpy as np
from matplotlib import pyplot as plt

from models.rs import RSClassifier
from models.base_models.base_classifier import ConvBase

# correct device handling?
device = torch.device('cuda:0')

In [None]:
base_classifier = ConvBase().to(device)
number_classes = 10
sigma = 1
batch_size = 128
epochs = 1

k = 20
r = 0.01
m = 5

In [None]:
# --------------------

### Load and prepare data
(see https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
train_flattened_loader = torch.utils.data.DataLoader(train_set, batch_size=len(train_set), shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=2)
# choose batch_size 1 for certification

In [None]:
for data in train_flattened_loader:
    inputs, _ = data
    train_flattened = inputs.reshape(inputs.shape[0], -1)

In [None]:
# --------------------

### Train the smoothed classifier
(see https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
# init the smoothed classifier
classifier_smoothed = RSClassifier(base_classifier=base_classifier,
                                   num_classes=number_classes,
                                   sigma=sigma,
                                   device=device).to(device)

In [None]:
# choose a loss function and optimizer
loss_func = nn.CrossEntropyLoss()

# optimizer = torch.optim.SGD(classifier_smoothed.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.Adam(classifier_smoothed.parameters(), lr=0.001)

In [None]:
# set model to training mode
classifier_smoothed.train()
# base_classifier.train()
losses = []
accuracies = []
for epoch in range(epochs):
    print(f'Starting epoch {epoch}')
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + compute accuracy
        outputs = classifier_smoothed(inputs)
        _, predicted = torch.max(outputs, 1)
        accuracy = torch.mean((predicted == labels) * 1.0).item()
        accuracies.append(accuracy)
        
        # compute loss + backward
        loss = loss_func(outputs, labels)
        loss.backward()
        losses.append(loss.item())
        
        # optimization
        optimizer.step()

print('Finished training')

In [None]:
PATH = './cifar_net.pth'
torch.save(classifier_smoothed.state_dict(), PATH)

In [None]:
plt.plot(losses)
plt.title("Loss development over iterations") 
plt.xlabel("Iteration") 
plt.ylabel("Loss") 
plt.grid(True)
plt.show()

In [None]:
plt.plot(accuracies)
plt.title("Accuracy development over iterations") 
plt.xlabel("Iteration") 
plt.ylabel("Accuracy") 
plt.grid(True)
plt.show()

In [None]:
# --------------------

### Evaluate smoothed classifier
(see https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
number_samples = 100
number_samples0 = 1000
alpha = 0.1
batch_size = batch_size

In [None]:
# load model
model = RSClassifier(
    base_classifier=base_classifier,
    num_classes=number_classes,
    sigma=sigma,
    device=device).to(device)
model.load_state_dict(torch.load(PATH))

In [None]:
correct_robust_radii = []
false_robust_radii = []
abstain = 0

for (image, label) in test_loader:
    image, label = image.to(device), label.to(device)
    prediction, robust_radius = model.certify(
        image, number_samples, number_samples0, alpha, batch_size)
    
    if prediction > -1:
        if prediction == label:
            correct_robust_radii.append(robust_radius)
        else:
            false_robust_radii.append(robust_radius)
    else:
        abstain += 1
        
correct = len(correct_robust_radii)
false = len(false_robust_radii)

In [None]:
# average robust radius
if len(correct_robust_radii) > 0:
    torch.mean(correct_robust_radii)