In [1]:
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing
from torch.utils.data import Dataset, DataLoader

from adversarialbox.attacks import FGSMAttack, LinfPGDAttack
from adversarialbox.train import adv_train, FGSM_train_rnd
from adversarialbox.utils import to_var, pred_batch, test

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
# Load data. Need MNIST dataset in directory
transform = torchvision.transforms.Compose([torchvision.transforms.Resize(32),torchvision.transforms.ToTensor()])
mnist_train = torchvision.datasets.MNIST('../data/', download=True, train=True, transform=transform)                                 
mnist_test = torchvision.datasets.MNIST('../data/', download=True, train=False, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=256, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=1000, shuffle=True)

In [3]:
images = iter(mnist_test)

In [4]:
im = next(images)

In [5]:
torch.min(im[0])

tensor(0.)

In [6]:
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = torch.nn.Dropout2d()
        self.fc1 = torch.nn.Linear(20*5*5, 50)
        self.fc2 = torch.nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 20*5*5)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

In [7]:
model = CNN()
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.5)
loss_fn = torch.nn.CrossEntropyLoss()
adversary = LinfPGDAttack(model=model, loss_fn=loss_fn)



for epoch in range(30):
    model.train()
    batch_loss = 0
    for idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.long().to(device)
        optimizer.zero_grad()
        yhat = model(x)
        loss = loss_fn(yhat, y)
        
        # adversarial training
        if epoch+1 > 1:
            # use predicted label to prevent label leaking
            y_pred = pred_batch(x, model)
            x_adv = adv_train(x.cpu(), y_pred, model, adversary)
            loss_adv = loss_fn(model(x_adv.to(device)), y)
            loss = (loss + loss_adv) / 2
        batch_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
        test_loss /= len(test_loader.dataset)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))
        
    print(epoch, batch_loss/len(train_loader))




Test set: Avg. loss: -0.0464, Accuracy: 1335/10000 (13%)

0 2.300164095898892

Test set: Avg. loss: -0.3353, Accuracy: 6847/10000 (68%)

1 2.258717658671927

Test set: Avg. loss: -1.9778, Accuracy: 8401/10000 (84%)

2 1.9187473150009804

Test set: Avg. loss: -2.8204, Accuracy: 8786/10000 (88%)

3 1.665797394894539


KeyboardInterrupt: 