In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
%matplotlib inline

import torch.cuda as cuda
import os

# Local imports
from models import AdversarialNet

In [None]:

def cudafy(use_gpu, seq, device=None):
    """ If use_gpu is True, returns cuda version of everything in tuple seq"""
    if use_gpu is False:
        return tuple(_.cpu() for _ in seq)
    else:
        if device != None:
            return tuple(_.to(device) for _ in seq)
        else:
            return tuple(_.cuda() for _ in seq)
USE_GPU = torch.cuda.is_available()

In [None]:
# Training parameters
n_epochs       = 10 # Total number of epochs
adversary_pace = 2  # Once every k epochs
batch_size     = 200

# Dataset parameters
num_classes = 10

In [None]:
# Data preprocessing
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# Loader objects
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# TODO: Previous code had 'batch_size = 4' here, was that intentional?
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)

# Currently unused
classes = (str(i) for i in range(1, num_classes+1))

In [None]:
# Optimization target
criterion = nn.CrossEntropyLoss()

# Adversarial network
advNet = AdversarialNet()

if USE_GPU: # if we can use the GPU, move the parameters onto the GPU
    advNet.cuda()

# Primary optimizer
optimizer_min = optim.SGD(advNet.primary_weights, lr=0.01, momentum=0., nesterov=False)
# Adversary optimizer
# TODO: May need to give this more power (?)
# Doesn't seem to do anything in its round currently
optimizer_max = optim.SGD(advNet.adversary_weights, lr=0.01, momentum=0., nesterov=False)

In [None]:
# Optimization history logging
history = []

# Epochs
for epoch in range(n_epochs):
    # Averaged loss
    running_loss = 0.
    
    # Primary optimization
    for i, data in enumerate(trainloader, 0):
        # Get data
        inputs, labels = cudafy(USE_GPU, data) # If using GPU, move data onto GPU
        
        # Zero gradients
        optimizer_min.zero_grad()
                
        # Forward pass
        outputs = advNet(inputs)
        loss    = criterion(outputs, labels)
        # Backpropagate
        loss.backward()
        # Cancel out gradients for the adversary
        # TODO: Is this really needed if the optimizer only works with a subset of weights?
        # Yes, otherwise gradient will accumulate?
        # But we're gonna zero it when it's the adversary's turn anyways?
        optimizer_max.zero_grad()
        # Descend
        optimizer_min.step()
        
        # Print statistics before adversary turn
        # TODO: This currently spams the console, it's fine
        print('[Primary, Epoch %d, batch %d] Crossentropy: %.3f' %
              (epoch+1, i+1, loss.item()))
        # Save history
        # TODO: Save separate histories for primary, adversary, epoch, etc.
        history.append(loss.item())
        
    # Adversary optimization - once every adversary_pace epochs
    if epoch % adversary_pace == 0:
        # Adversary optimization
        for i, data in enumerate(trainloader, 0):
            # Get data
            inputs, labels = cudafy(USE_GPU, data) # if using GPU, move data onto GPU 
            
            # Zero gradients
            optimizer_max.zero_grad()
                    
            # Forward pass
            outputs = advNet(inputs)
            loss    = -criterion(outputs, labels)
            # Backpropagate
            loss.backward()
            # Cancel out gradients for the primary
            # TODO: Same question as before
            optimizer_min.zero_grad()
            # Descend
            optimizer_max.step()
            
            # Print statistics
            print('[Adversary, Epoch %d, batch %d] Crossentropy: %.3f' %
                  (epoch+1, i+1, -loss.item()))
            # Save history
            history.append(-loss.item())           

In [None]:
# Plot raw global history
plt.figure(); plt.plot(history)