In [1]:
import torch
import torchvision
import os
import numpy
import collections

In [2]:
class Convolutional(torch.nn.Module):
    def __init__(self, K, O): # K = hidden units / O = outputs
        super(Convolutional, self).__init__()
        # Convolutional Layers
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(5,5), stride=1, padding=2)
        self.conv2 = torch.nn.Conv2d(in_channels=self.conv1.out_channels, out_channels=32, kernel_size=(5,5), stride=1, padding=2)
        self.pool = torch.nn.MaxPool2d(kernel_size=(2,2), stride=2)
        # fully connected layers
        self.fc1 = torch.nn.Linear(7*7*32, K, bias=True)
        self.fc2 = torch.nn.Linear(K, O)
        
        self.activation = torch.nn.Sigmoid()
        self.bn = torch.nn.BatchNorm2d(self.conv2.out_channels)
        
    
    def forward(self, x):
        a = self.activation(self.pool(self.conv1(x)))
        # a = self.activation(self.bn(self.pool(self.conv2(a)))) with batch normalization
        a = self.activation(self.pool(self.conv2(a)))
        a = torch.flatten(a, 1)
        return self.fc2(self.activation(self.fc1(a)))

In [3]:
train_adversarial = False
train_with_noise = False

filename = f"Results_{'noise' if train_with_noise else 'adv' if train_adversarial else 'none'}_bn.txt"

torch.manual_seed(42)

transform = torchvision.transforms.ToTensor()
train_set = torchvision.datasets.MNIST(root="data/MNIST",train=True, download=True,transform=transform)
test_set = torchvision.datasets.MNIST(root="data/MNIST",train=False, download=True,transform=transform)

# loaders
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, batch_size=128)
test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, batch_size=100)

network = Convolutional(50,10)
loss = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
    params=network.parameters(),
    lr = 1e-2, momentum=0.9
)

In [4]:
def FGS(x, t, alpha=0.3):
    x.requires_grad_(True) # we need the gradient for the input
    network.zero_grad() # no remainings
    z = network.forward(x)
    J = loss(z, t)
    J.backward()
    grad = x.grad.detach() # get the gradient
    return torch.clamp(x + alpha * torch.sign(grad), 0, 1) # perform gradient ascent [0,1]

def noise(x, alpha=0.3):
    noise = torch.bernoulli(torch.ones(x.shape) * 0.5) * 2 - 1
    return torch.clamp(x + alpha * noise, 0, 1)

In [5]:
# training

for epoch in range(50):
    for x, t in train_loader:
        optimizer.zero_grad()
        z = network(x)
        J = loss(z, t)
        J.backward()
        
        if train_adversarial:
            x_hat = noise(x) if train_with_noise else FGS(x, t) # compute adversarials for batch
            z_hat = network(x_hat) # output for adversarials
            J = loss(z_hat, t)
            J.backward()
        optimizer.step() # perfom combined optimizer step
        
    # evalutation
    correct_clean = 0
    correct_adv = 0
    for x, t in test_loader:
        with torch.no_grad():
            z = network(x)
            # compute classification accuracy
            correct = torch.argmax(z, dim=1) == t
            correct_clean += torch.sum(correct)
            
        # create adversarial samples for correctly classified samples
        x = x[correct]
        t = t[correct]
        x_hat = FGS(x, t)
        
        with torch.no_grad():
            z_hat = network(x_hat)
            # compute classification accuracy
            correct = torch.argmax(z_hat, dim=1) == t
            correct_adv += torch.sum(correct)
            
    print(f"Epoch {epoch+1}: \n Clean acc: {correct_clean}/{len(test_set)} = {correct_clean.item()/len(test_set):.4f} \n Adver acc: {correct_adv}/{correct_clean} = {correct_adv.item()/correct_clean.item():.4f}")

Epoch 1: 
 Clean acc: 1135/10000 = 0.1135 
 Adver acc: 1135/1135 = 1.0000
Epoch 2: 
 Clean acc: 1135/10000 = 0.1135 
 Adver acc: 1135/1135 = 1.0000
Epoch 3: 
 Clean acc: 1490/10000 = 0.1490 
 Adver acc: 218/1490 = 0.1463
Epoch 4: 
 Clean acc: 6144/10000 = 0.6144 
 Adver acc: 255/6144 = 0.0415
Epoch 5: 
 Clean acc: 8441/10000 = 0.8441 
 Adver acc: 20/8441 = 0.0024
Epoch 6: 
 Clean acc: 8938/10000 = 0.8938 
 Adver acc: 17/8938 = 0.0019
Epoch 7: 
 Clean acc: 9196/10000 = 0.9196 
 Adver acc: 12/9196 = 0.0013
Epoch 8: 
 Clean acc: 9349/10000 = 0.9349 
 Adver acc: 18/9349 = 0.0019
Epoch 9: 
 Clean acc: 9458/10000 = 0.9458 
 Adver acc: 8/9458 = 0.0008
Epoch 10: 
 Clean acc: 9517/10000 = 0.9517 
 Adver acc: 18/9517 = 0.0019
Epoch 11: 
 Clean acc: 9577/10000 = 0.9577 
 Adver acc: 20/9577 = 0.0021
Epoch 12: 
 Clean acc: 9593/10000 = 0.9593 
 Adver acc: 23/9593 = 0.0024
Epoch 13: 
 Clean acc: 9652/10000 = 0.9652 
 Adver acc: 19/9652 = 0.0020
Epoch 14: 
 Clean acc: 9688/10000 = 0.9688 
 Adver acc: