In [1]:
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
                                ])

trainset = datasets.MNIST('./mnist', train=True, transform=transform, download=True)
evalset = datasets.MNIST('./mnist', train=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True)
evalloader = torch.utils.data.DataLoader(evalset, batch_size=1, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist/MNIST/raw



In [4]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.fc1 = nn.Linear(64 * 12 * 12, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        out = self.fc2(x)
        return out

epochs = 5

model = CNN().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
criterion = nn.CrossEntropyLoss().to(device)

model.train()
for ep in tqdm(range(epochs)):
    train_loss = 0
    for xs, ys in trainloader:
        xs, ys = xs.to(device), ys.to(device)
        optimizer.zero_grad()
        out = model(xs)
        loss = criterion(out, ys)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    scheduler.step()
    train_loss /= len(trainloader)
    sys.stdout.write(f'[{ep+1}/{epochs}] Loss {train_loss:.4f}\n')

model.eval()
eval_loss = 0
correct = 0
with torch.no_grad():
    for x, y in evalloader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = criterion(out, y)
        eval_loss += loss.item()
        correct += (torch.argmax(out, 1) == y).sum()
    eval_loss /= len(evalloader)
    acc = 100 * correct / len(evalset)
    sys.stdout.write(f'\n==> [Eval] Acc {acc:.2f}%, Loss {eval_loss:.4f}')

  0%|          | 0/5 [00:00<?, ?it/s]

[1/5] Loss 0.4729
[2/5] Loss 0.1178
[3/5] Loss 0.0861
[4/5] Loss 0.0714
[5/5] Loss 0.0654

==> [Eval] Acc 98.79%, Loss 0.0373

In [9]:
# Fast Gradient Sign Method (FGSM)
def fgsm(x, data_grad, epsilon):
    perturb = x + epsilon * data_grad.sign() # Add perturbation to pixel values
    perturb = torch.clip(perturb, 0, 1)
    return perturb

In [8]:
epsilons = [0.0, 0.05, 0.1, 0.15, 0.2]

accs = []
examples = []
num_samples = 5

with torch.no_grad():
  for e in tqdm(epsilons):
      correct = 0
      adv_examples = []

      for x, y in evalloader:
          x, y = x.to(device), y.to(device)
          x.requires_grad = True

          out = model(x)
          init_pred = torch.argmax(out, 1)
          if init_pred.item() != y.item():
              continue

          loss = criterion(out, y)

          model.zero_grad()
          loss.backward()

          x_grad = x.grad.data
          perturb_x = fgsm(x, x_grad, e)

          out = model(perturb_x)
          pred = torch.argmax(out, 1)
          if pred.item() == y.item():
              correct += 1
              if (e == 0) and (len(adv_examples) < num_samples):
                  adv = perturb_x.squeeze().detach().cpu().numpy()
                  adv_examples.append((init_pred.item(), pred.item(), adv))
          else:
              if len(adv_examples) < num_samples:
                  adv = perturb_x.squeeze().detach().cpu().numpy()
                  adv_examples.append((init_pred.item(), pred.item(), adv))

      acc = 100 * correct / len(evalset)
      sys.stdout.write(f'Epsilon {e}\tAcc {acc:>.2f}%\n')

      accs.append(acc)
      examples.append(adv_examples)

  0%|          | 0/5 [00:00<?, ?it/s]

RuntimeError: ignored

In [None]:
plt.rcParams['figure.figsize'] = (10, 10)

In [None]:
plt.plot(epsilons, accs, 'o-')
for x, y in zip(epsilons, accs):
    plt.text(x, y, str(y))
plt.yticks(np.arange(0, 110, 10))
plt.xlabel('Epsilon')
plt.ylabel('Accuracy')
plt.show()

In [None]:
cnt = 0
for i, ep in enumerate(epsilons):
    for j, sample in enumerate(examples[i]):
        cnt += 1
        plt.subplot(len(epsilons), num_samples, cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        if j == 0:
            plt.ylabel(f'Epsilon {ep}')
        orig, adv, ex = sample
        plt.imshow(ex)
        plt.title(f'{orig} -> {adv}')
plt.tight_layout()
plt.show()