In [None]:
import torch 
from torch import nn

import torchvision
import torchvision.transforms as transforms

import numpy as np
from matplotlib import pyplot as plt

import models.rs as rs
from models.base_models.base_classifier import ConvBase

device = torch.device('cuda:0')

from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

## Wrapper: test randomized smoothing

### Init

In [None]:
base_classifier = ConvBase().to(device)
number_classes = 10
sigma = 1
batch_size = 512
epochs = 10

k = 20
r = 0.01
m = 5

In [None]:
# --------------------

### Load and prepare data
(see https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
train_flattened_loader = torch.utils.data.DataLoader(train_set, batch_size=len(train_set), shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=2)
# choose batch_size 1 for certification

In [None]:
for data in train_flattened_loader:
    inputs, _ = data
    train_flattened = inputs.reshape(inputs.shape[0], -1)

In [None]:
# --------------------

### Train the smoothed classifier
(see https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
# init the smoothed classifier
classifier_smoothed = rs.RSClassifier(base_classifier=base_classifier,
                                      num_classes=number_classes,
                                      sigma=sigma,
                                      device=device).to(device)

In [None]:
# choose a loss function and optimizer
loss_func = nn.CrossEntropyLoss()

# optimizer = torch.optim.SGD(classifier_smoothed.parameters(), lr=0.001, momentum=0.9)
optimizer = torch.optim.Adam(classifier_smoothed.parameters(), lr=0.001)

In [None]:
# set model to training mode
classifier_smoothed.train()
# base_classifier.train()
losses = []
accuracies = []
for epoch in range(epochs):
    print(f'Starting epoch {epoch}')
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + compute accuracy
        outputs = classifier_smoothed(inputs)
        _, predicted = torch.max(outputs, 1)
        accuracy = torch.mean((predicted == labels) * 1.0).item()
        accuracies.append(accuracy)
        
        # compute loss + backward
        loss = loss_func(outputs, labels)
        loss.backward()
        losses.append(loss.item())
        
        # optimization
        optimizer.step()

print('Finished training')

In [None]:
PATH = './cifar_net.pth'
torch.save(classifier_smoothed.state_dict(), PATH)

In [None]:
plt.plot(losses)
plt.title("Loss development over iterations") 
plt.xlabel("Iteration") 
plt.ylabel("Loss") 
plt.grid(True)
plt.show()

In [None]:
plt.plot(accuracies)
plt.title("Accuracy development over iterations") 
plt.xlabel("Iteration") 
plt.ylabel("Accuracy") 
plt.grid(True)
plt.show()

In [None]:
# --------------------

### Evaluate smoothed classifier
(see https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
n_sampling = 100
n_bound = 1000
alpha = 0.1
batch_size = batch_size

In [None]:
# load model
model = rs.RSClassifier(
    base_classifier=base_classifier,
    num_classes=number_classes,
    sigma=sigma,
    device=device).to(device)
model.load_state_dict(torch.load(PATH))

In [None]:
from importlib import reload
reload(rs)

In [None]:
correct_robust_radii = []
false_robust_radii = []
n_abstain = 0

for (image, label) in test_loader:
    image, label = image.to(device), label.to(device)
    prediction, robust_radius = model.certify(
        image, n_sampling, n_bound, model.sigma, alpha)
    
    if prediction > -1:
        if prediction == label:
            correct_robust_radii.append(robust_radius)
        else:
            false_robust_radii.append(robust_radius)
    else:
        n_abstain += 1
        
n_correct = len(correct_robust_radii)
n_false = len(false_robust_radii)

In [None]:
# average robust radius
if n_correct > 0:
    print(np.mean(correct_robust_radii))

# 2D example

### Init

In [None]:
import logging

import knn

from helper_functions import get_toy_dataset_2d, plot_network_boundary_2d, train

In [None]:
logging.getLogger().setLevel(logging.DEBUG)
log = logging.getLogger()

### Run

In [None]:
network = torch.nn.Sequential(
    #zoo.Flatten(),
    torch.nn.Linear(2, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 1000),
    torch.nn.ReLU(),
    torch.nn.Linear(1000, 1000),
    torch.nn.ReLU(),
    torch.nn.Linear(1000, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 2)
).to(device)
X, y = get_toy_dataset_2d("random", N=100, r=0.01)

In [None]:
train(network, X.to(device), y.to(device), lr=1e-3, epochs=1000)

In [None]:
plot_network_boundary_2d(network, [0, 0], [1, 1], 100, data=(X, y))

In [None]:
network_rs = rs.RSClassifier(network, 2, sigma=0.05, device=device)

In [None]:
def network_forward_pass_rs(x):
    return torch.stack([
        network_rs.predict(x_.unsqueeze(0), n=10000, sigma=network_rs.sigma, alpha=0.01, return_all_counts=True) for x_ in x
    ])

In [None]:
plot_network_boundary_2d(network_forward_pass_rs, [0, 0], [1, 1], 100, data=(X, y)) # ToDo: speedup by batched prediction?

In [None]:
k = 5
knn_comp = knn.KNNDistComp(torch.utils.data.TensorDataset(X, y), device=device)
def distances_fcn(x):
    return knn_comp.compute_knn_and_dists(x)[1][0]#.cpu().numpy()
def knns_fcn(x):
    return knn_comp.compute_knns(x, k)#.cpu().numpy()
def mean_distances_fcn(x):
    return knn_comp.compute_mean_dist(x, k)#.cpu().numpy()
    

In [None]:
network_idrs = rs.IDRSClassifier(network, 2, sigma=0.05, distances=None, rate=5.0, m=1.0, device=device,
                                mean_distances_fcn=mean_distances_fcn)

In [None]:
def network_forward_pass_idrs(x):
    return torch.stack([
        network_idrs.predict(x_.unsqueeze(0), n=10000, sigma=network_idrs.sigma_fcn(x_.unsqueeze(0)), alpha=0.01, return_all_counts=True) for x_ in x
    ])

In [None]:
plot_network_boundary_2d(network_forward_pass_idrs, [0, 0], [1, 1], 100, data=(X, y),
                        x_base=X, r=[network_idrs.sigma_fcn(x.unsqueeze(0)) for x in X]) # ToDo: speedup by batched prediction?

In [None]:
network_bidrs = rs.BiasedIDRSClassifier(network, 2, sigma=0.05, distances=None, rate=5.0, m=1.0, device=device,
                                  knns_fcn=knns_fcn, distances_fcn=distances_fcn, mean_distances_fcn=mean_distances_fcn,
                                  variance_func="sigma_knn", bias_func="mu_knn_based")

In [None]:
def network_forward_pass_bidrs(x):
    return torch.stack([
        network_bidrs.predict(x_.unsqueeze(0), n=10000, sigma=network_bidrs.sigma_fcn(x_.unsqueeze(0)),
                             alpha=0.01, return_all_counts=True,
                             bias=network_bidrs.bias_fcn(x_.unsqueeze(0))) for x_ in x
    ])

In [None]:
plot_network_boundary_2d(network_forward_pass_bidrs, [0, 0], [1, 1], 100, data=(X, y),
                        x_base=X, r=[network_bidrs.sigma_fcn(x.unsqueeze(0)) for x in X]) # ToDo: speedup by batched prediction?