### Imports

In [None]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import TensorDataset, DataLoader
from pathlib import Path
import os
import matplotlib.pyplot as plt


from art.attacks.evasion import ElasticNet
from art.attacks.evasion import ProjectedGradientDescentPyTorch as PGDAttack
from art.estimators.classification import PyTorchClassifier
from art.utils import load_mnist

from adv_utils import *
from adv_attacks import EAD_L1, PGD


%matplotlib inline
%config InlineBackend.figure_format='retina'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_cuda = True

In [None]:
(x_train, y_train), (x_test, y_test), min_, max_ = load_mnist()

x_train = np.swapaxes(x_train, 1, 3).astype(np.float32)
x_test = np.swapaxes(x_test, 1, 3).astype(np.float32)

train_dataset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))
train_dataloader = DataLoader(train_dataset, batch_size=100, num_workers=6)

test_dataset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))
test_dataloader = DataLoader(test_dataset, batch_size=1000, num_workers=6)

### Classifier

In [None]:
from CNN import CNN_mnist
net = CNN_mnist().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)

eps_l2, a_l2 = 2, 0.1
eps_linf, a_linf = 0.3, 0.01

### L2 Adversarial Training

In [None]:
for e in range(1, 11):
    print('\nEpoch: %d' % int(e))
#     Reduce lr by 10 at epoch 5
    optimizer = adjust_learning_rate(net, optimizer, e, dataset='mnist')
#     Optimization that works well on MNIST: do a first epoch with a lower epsilon
    epsilon = eps_l2 / 3.0 if e == 1 else eps_l2
    
    adv_train(net, optimizer, norm=2, eps=epsilon, a=a_l2)
    acc = test(net)
    
#     Save checkpoints
    torch.save(net.state_dict(), 'adv_2_mnist.pth')

### Linf Adversarial Training

In [None]:
for e in range(1, 11):
    print('\nEpoch: %d' % int(e))
#     Reduce lr by 10 at epoch 5
    optimizer = adjust_learning_rate(net, optimizer, e, dataset='mnist')
#     Optimization that works well on MNIST: do a first epoch with a lower epsilon
    epsilon = eps_linf / 3.0 if e == 1 else eps_linf
    
    adv_train(net, optimizer, norm=np.inf, eps=epsilon, a=a_linf)
    acc = test(net)
    
#     Save checkpoints
    torch.save(net.state_dict(), 'adv_inf_mnist.pth')