In [1]:
import torchvision
import torch
import tqdm
import numpy as np
from torchvision.datasets import MNIST
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from lets_plot import *

import sklearn
import sklearn.metrics

LetsPlot.setup_html()

In [2]:
import importlib

import mnist
import adversarial

importlib.reload(mnist)
importlib.reload(adversarial)

from mnist import *
from adversarial import *


## Load the dataset

In [3]:

mnist_train = load_data(train=True)
mnist_test = load_data(train=False)

train_dataset = MnistDataset(mnist_train)


In [4]:
train_X = torch.stack([t[0] for t in mnist_train])
train_y = torch.stack([t[1] for t in mnist_train])

In [5]:
test_X = torch.stack([t[0] for t in mnist_test])
test_y = torch.stack([t[1] for t in mnist_test])

## Train the model

In [6]:
model = MnistClassifier()


In [7]:
loss_fct = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.01)

epochs = 1

train_loader = DataLoader(train_dataset, batch_size=64)

for e in (range(epochs)):
    model.train()
    losses = []
    
    pbar = tqdm.tqdm(train_loader)
    for batch in pbar:
        optimizer.zero_grad()
        out = model(batch['data'])
        loss = loss_fct(out, batch['label'])
        loss.backward()
        
        losses.append(loss.item())
        pbar.set_description(f'Loss: {loss.item():.04f}')
        
        optimizer.step()
        

Loss: 1.4740: 100%|██████████| 938/938 [00:22<00:00, 41.58it/s]


In [8]:
ggplot({'X': list(range(len(losses))), 'Y': losses}) + geom_line(aes(x='X', y='Y'))

In [9]:
model.eval()
train_pred = model(train_X)
print(sklearn.metrics.classification_report(train_pred.argmax(dim=1).detach().numpy(), train_y))

              precision    recall  f1-score   support

           0       0.99      0.98      0.98      5993
           1       0.99      0.87      0.93      7705
           2       0.92      0.98      0.95      5640
           3       0.94      0.99      0.96      5786
           4       0.99      0.84      0.91      6886
           5       0.96      0.97      0.97      5387
           6       0.97      0.98      0.97      5809
           7       0.97      0.96      0.97      6340
           8       0.91      0.95      0.93      5566
           9       0.81      0.98      0.89      4888

    accuracy                           0.95     60000
   macro avg       0.94      0.95      0.95     60000
weighted avg       0.95      0.95      0.95     60000



In [10]:
model.eval()
test_pred = model(test_X)
print(sklearn.metrics.classification_report(test_pred.argmax(dim=1).detach().numpy(), test_y))

              precision    recall  f1-score   support

           0       0.99      0.97      0.98      1003
           1       1.00      0.89      0.94      1264
           2       0.94      0.98      0.96       986
           3       0.95      0.99      0.97       965
           4       0.98      0.86      0.92      1128
           5       0.97      0.97      0.97       889
           6       0.96      0.99      0.97       928
           7       0.96      0.95      0.96      1040
           8       0.93      0.96      0.94       948
           9       0.84      0.99      0.91       849

    accuracy                           0.95     10000
   macro avg       0.95      0.95      0.95     10000
weighted avg       0.95      0.95      0.95     10000



## Sample and train adversarial examples

In [11]:
adv_examples_count = 100
adv_X, adv_y = train_X[:adv_examples_count], train_y[:adv_examples_count]
adv_examples = train_adv_bim(model, CrossEntropyLoss(), adv_X.clone(), adv_y.clone())


In [12]:
model.eval()
adv_pred = model(adv_examples).argmax(dim=1).detach()
adv_pred, adv_y

(tensor([3, 0, 9, 1, 4, 8, 4, 3, 4, 1, 3, 3, 3, 8, 7, 7, 7, 1, 8, 1, 9, 0, 4, 1,
         2, 3, 9, 3, 7, 8, 7, 1, 1, 8, 1, 0, 1, 0, 4, 1, 1, 1, 1, 4, 1, 7, 1, 5,
         8, 5, 5, 7, 7, 1, 7, 1, 0, 4, 9, 1, 9, 1, 8, 8, 4, 3, 6, 7, 5, 8, 2, 7,
         1, 1, 7, 7, 1, 1, 1, 1, 8, 0, 0, 1, 7, 1, 7, 4, 0, 4, 6, 7, 4, 8, 9, 4,
         1, 8, 5, 7]),
 tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1,
         1, 2, 4, 3, 2, 7, 3, 8, 6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, 5,
         9, 3, 3, 0, 7, 4, 9, 8, 0, 9, 4, 1, 4, 4, 6, 0, 4, 5, 6, 1, 0, 0, 1, 7,
         1, 6, 3, 0, 2, 1, 1, 7, 9, 0, 2, 6, 7, 8, 3, 9, 0, 4, 6, 7, 4, 6, 8, 0,
         7, 8, 3, 1]))

## Evaluate model on the real and adversarial datasets

In [13]:
real_pred = model(adv_X).argmax(dim=1).detach().numpy()
conf = sklearn.metrics.confusion_matrix(real_pred, adv_y).astype(np.float)
ggplot() + geom_image(conf)

In [14]:
print(sklearn.metrics.classification_report(adv_pred.numpy(), adv_y))
conf = sklearn.metrics.confusion_matrix(adv_pred.numpy(), adv_y).astype(np.float)
ggplot() + geom_image(conf)

              precision    recall  f1-score   support

           0       0.46      0.75      0.57         8
           1       0.50      0.25      0.33        28
           2       0.00      0.00      0.00         2
           3       0.36      0.50      0.42         8
           4       0.27      0.25      0.26        12
           5       0.20      0.20      0.20         5
           6       0.18      1.00      0.31         2
           7       0.50      0.29      0.37        17
           8       0.12      0.08      0.10        12
           9       0.00      0.00      0.00         6

    accuracy                           0.29       100
   macro avg       0.26      0.33      0.26       100
weighted avg       0.35      0.29      0.30       100



In [15]:
ggplot() + geom_image(adv_X[1].numpy().reshape(28, 28))

In [16]:
ggplot() + geom_image(adv_examples[1].numpy().reshape(28, 28))