In [7]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from model_robustness.model_test.networks import MLP
from advertorch.attacks import GradientSignAttack

import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(0)
np.random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [41]:
def epoch(mode, device, net, dataloader, optimizer, criterion):
    loss_avg, acc_avg, num_exp = 0,0,0
    if mode == "train":
        net.train()

    else:
        net.eval()

    for i, data in enumerate(dataloader):
        imgs, labels = data
        imgs = imgs.to(device)
        labels = labels.to(device)

        n_b = labels.shape[0]

        outputs = net(imgs)
        loss = criterion(outputs, labels)

        acc = np.sum(np.equal(np.argmax(outputs.cpu().data.numpy(), axis=-1), labels.cpu().data.numpy()))

        loss_avg += loss.item()
        acc_avg += acc
        num_exp += n_b

        if mode == "train":
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    loss_avg /= num_exp
    acc_avg /= num_exp

    return loss_avg, acc_avg

In [42]:
dataset = torch.load("../model_robustness/data/MNIST/dataset.pt")

trainset = dataset["trainset"]
testset = dataset["testset"]

trainloader = DataLoader(
    dataset=trainset,
    batch_size=64,
    shuffle=True
)

testloader = DataLoader(
    dataset=testset,
    batch_size=64,
    shuffle=False
)

In [43]:
aux_loader = DataLoader(
    dataset=testset,
    batch_size=len(testset),
    shuffle=False
)

for cln_data, true_label in aux_loader:
    break

cln_data, true_label = cln_data.to(device), true_label.to(device)

In [44]:
model = MLP()
model.to(device)

MLP(
  (fc_1): Linear(in_features=784, out_features=64, bias=True)
  (fc_2): Linear(in_features=64, out_features=128, bias=True)
  (fc_3): Linear(in_features=128, out_features=10, bias=True)
)

In [46]:
adversary = GradientSignAttack(
    model,
    loss_fn=nn.CrossEntropyLoss(reduction="sum"),
    eps=0.1,
    targeted=False
)

# adv_untargeted = adversary.perturb(cln_data, true_label)
target = torch.ones_like(true_label) * 3
adversary.targeted = True
adv_untargeted = adversary.perturb(cln_data, target)

In [47]:
adv_data = torch.utils.data.TensorDataset(adv_untargeted, true_label)

adv_loader = DataLoader(
    dataset=adv_data,
    batch_size=64,
    shuffle=False
)

In [48]:
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

In [49]:
for e in range(10):
    train_loss, train_acc = epoch("train", device, model, trainloader, optimizer, criterion)
    print(f"[{e + 1}] TRAINING \n loss: {train_loss:.3f}, accuracy: {train_acc:.3f}")

    test_loss, test_acc = epoch("test", device, model, testloader, optimizer, criterion)
    print(f"[{e + 1}] TESTING \n loss: {test_loss:.3f}, accuracy: {test_acc:.3f}")

    test_loss, test_acc = epoch("test", device, model, adv_loader, optimizer, criterion)
    print(f"[{e + 1}] PERTURBATION \n loss: {test_loss:.3f}, accuracy: {test_acc:.3f}")

[1] TRAINING 
 loss: 0.005, accuracy: 0.907
[1] TESTING 
 loss: 0.003, accuracy: 0.940
[1] PERTURBATION 
 loss: 0.008, accuracy: 0.878
[2] TRAINING 
 loss: 0.003, accuracy: 0.941
[2] TESTING 
 loss: 0.004, accuracy: 0.933
[2] PERTURBATION 
 loss: 0.008, accuracy: 0.906
[3] TRAINING 
 loss: 0.003, accuracy: 0.948
[3] TESTING 
 loss: 0.003, accuracy: 0.954
[3] PERTURBATION 
 loss: 0.006, accuracy: 0.920
[4] TRAINING 
 loss: 0.003, accuracy: 0.953
[4] TESTING 
 loss: 0.003, accuracy: 0.949
[4] PERTURBATION 
 loss: 0.007, accuracy: 0.897
[5] TRAINING 
 loss: 0.003, accuracy: 0.955
[5] TESTING 
 loss: 0.003, accuracy: 0.953
[5] PERTURBATION 
 loss: 0.006, accuracy: 0.918
[6] TRAINING 
 loss: 0.002, accuracy: 0.958
[6] TESTING 
 loss: 0.003, accuracy: 0.956
[6] PERTURBATION 
 loss: 0.006, accuracy: 0.921
[7] TRAINING 
 loss: 0.002, accuracy: 0.958
[7] TESTING 
 loss: 0.005, accuracy: 0.940
[7] PERTURBATION 
 loss: 0.007, accuracy: 0.892
[8] TRAINING 
 loss: 0.002, accuracy: 0.958
[8] TESTING