In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import scipy.io as sio

from matplotlib import pyplot as plt

# Local imports
from models import AdversarialNet
from models import cudafy, accuracy

USE_GPU = torch.cuda.is_available()

In [None]:
# Training parameters
primary_epochs   = 8
adversary_epochs = 8
meta_epochs = 4
batch_size  = 200

# Architecture parameters
num_layers = 2 # Including final softmax

# Universal adversary strength (multiplicative)
eps = 0.02

# Printing parameters
batch_print = 20

In [None]:
# Data preprocessing
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# Loader objects
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# TODO: Previous code had 'batch_size = 4' here, was that intentional?
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)

# Optimization target
criterion = nn.CrossEntropyLoss()

# Adversarial network
advNet = AdversarialNet(eps)

if USE_GPU: # if we can use the GPU, move the parameters onto the GPU
    advNet.cuda()

# Primary optimizer
optimizer_min = optim.Adam(advNet.primary_weights, lr=0.001, amsgrad=True)
# Adversary optimizer
optimizer_max = optim.Adam(advNet.adversary_weights, lr=0.001, amsgrad=True)

# Optimization history logging
history_loss = []
history_acc  = []

# Preload weights
model_path = 'models/adv_mnist_pretrained.pt'
advNet.load_state_dict(torch.load(model_path))

# Reset adversary weights out to zero
advNet.fc1.cancel_adversary()
advNet.fc2.cancel_adversary()

In [None]:
# Meta-epochs
for meta_epoch in range(meta_epochs):
    # Update primary weights with adversary weights and reset adversary weights to zero
    advNet.fc1.update_primary()
    advNet.fc2.update_primary()
    
    # Primary optimization
    for epoch in range(primary_epochs):
        for i, data in enumerate(trainloader, 0):
            # Get data
            inputs, labels = cudafy(USE_GPU, data) # If using GPU, move data onto GPU
            
            # Zero gradients
            optimizer_min.zero_grad()
                    
            # Forward pass
            outputs = advNet(inputs)
            loss    = criterion(outputs, labels)                
            # Backpropagate
            loss.backward()
            
            # Cancel out gradients for the adversary
            optimizer_max.zero_grad()
            # Descend
            optimizer_min.step()
            
            # Print statistics
            if i % batch_print == 0:
                # Compute validation accuracy
                val_acc = accuracy(advNet, testloader, USE_GPU)
                print('[Primary, Meta-Epoch %d, Epoch %d, Batch %d] Crossentropy: %.4f Val. acc: %.4f' %
                      (meta_epoch, epoch, i+1, loss.item(), val_acc))
                # Save history
                history_loss.append(loss.item())
                history_acc.append(val_acc)
                
    # Reset adversary weights out of zero region
    advNet.fc1.reset_adversary()
    advNet.fc2.reset_adversary()
    
    # Adversary optimization
    for epoch in range(adversary_epochs):
        for i, data in enumerate(trainloader, 0):
            # Get data
            inputs, labels = cudafy(USE_GPU, data) # if using GPU, move data onto GPU 
            
            # Zero gradients
            optimizer_max.zero_grad()
                    
            # Forward pass
            outputs = advNet(inputs)
            loss    = -criterion(outputs, labels)
            
            # Backpropagate
            loss.backward()
            
            # Cancel out gradients for the primary
            optimizer_min.zero_grad()

            # Ascend
            optimizer_max.step()
            
            # Print statistics
            if i % batch_print == 0:
                # Compute validation accuracy
                val_acc = accuracy(advNet, testloader, USE_GPU)
                print('[Adversary, Meta-Epoch %d, Epoch %d, Batch %d] Crossentropy: %.4f Val. acc: %.4f' %
                      (meta_epoch, epoch, i+1, loss.item(), val_acc))
                
                # Save history
                history_loss.append(-loss.item())
                history_acc.append(val_acc)

In [None]:
# Global filename
filename = 'models/adv_mnist_eps' + str(eps) + '_primary' + \
           str(primary_epochs) + '_adv' + str(adversary_epochs) + \
           '_layers' + str(num_layers)

# Save network
torch.save(advNet.state_dict(), filename + '.pt')

# Save figures
plt.figure()
plt.plot(history_acc)
plt.grid()
plt.xlabel('Batch index')
plt.ylabel('Validation accuracy')
plt.savefig(filename + '_acc.png', dpi=600)
plt.close()

plt.figure()
plt.plot(history_loss)
plt.grid()
plt.xlabel('Batch index')
plt.ylabel('Cross-entropy loss')
plt.savefig(filename + '_crossent.png', dpi=600)
plt.close()

# Save history to .mat
sio.savemat(filename + '.mat', {'history_acc': history_acc,
                                'history_loss': history_loss,
                                'eps': eps})