In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
%matplotlib inline

import torch.cuda as cuda
import os

# Local imports
from models import AdversarialNet

In [2]:
# GPU functions
def use_gpu():
    """ The shortcut to retrieve the environment variable 'MY_GPU'"""
    try:
        str_val = os.environ['MY_GPU']
    except:
        set_global_gpu()
        str_val = os.environ['MY_GPU']
    assert str_val in ['True', 'False']
    return str_val == 'True'

def set_global_gpu(manual=None):
    """ Sets the environment variable 'MY_GPU'. Defaults to using gpu
        if cuda is available
    ARGS:
        manual : bool - we set the 'MY_GPU' environment var to the string
                 of whatever this is
    RETURNS
        None
    """
    if manual is None:
        val = cuda.is_available()
    else:
        val = manual
    os.environ['MY_GPU'] = str(val)

# Use GPU
use_gpu()

True

In [3]:
# Training parameters
n_epochs       = 10 # Total number of epochs
adversary_pace = 2  # Once every k epochs
batch_size     = 200

# Dataset parameters
num_classes = 10

In [4]:
# Data preprocessing
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# Loader objects
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# TODO: Previous code had 'batch_size = 4' here, was that intentional?
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)

# Currently unused
classes = (str(i) for i in range(1, num_classes+1))

In [5]:
# Optimization target
criterion = nn.CrossEntropyLoss()

# Adversarial network
advNet = AdversarialNet()

# Primary optimizer
optimizer_min = optim.SGD(advNet.primary_weights, lr=0.01, momentum=0., nesterov=False)
# Adversary optimizer
# TODO: May need to give this more power (?)
# Doesn't seem to do anything in its round currently
optimizer_max = optim.SGD(advNet.adversary_weights, lr=0.01, momentum=0., nesterov=False)

In [None]:
# Optimization history logging
history = []

# Epochs
for epoch in range(n_epochs):
    # Averaged loss
    running_loss = 0.
    
    # Primary optimization
    for i, data in enumerate(trainloader, 0):
        # Get data
        inputs, labels = data
        
        # Zero gradients
        optimizer_min.zero_grad()
                
        # Forward pass
        outputs = advNet(inputs)
        loss    = criterion(outputs, labels)
        # Backpropagate
        loss.backward()
        # Cancel out gradients for the adversary
        # TODO: Is this really needed if the optimizer only works with a subset of weights?
        # Yes, otherwise gradient will accumulate?
        # But we're gonna zero it when it's the adversary's turn anyways?
        optimizer_max.zero_grad()
        # Descend
        optimizer_min.step()
        
        # Print statistics before adversary turn
        # TODO: This currently spams the console, it's fine
        print('[Primary, Epoch %d, batch %d] Crossentropy: %.3f' %
              (epoch+1, i+1, loss.item()))
        # Save history
        # TODO: Save separate histories for primary, adversary, epoch, etc.
        history.append(loss.item())
        
    # Adversary optimization - once every adversary_pace epochs
    if epoch % adversary_pace == 0:
        # Adversary optimization
        for i, data in enumerate(trainloader, 0):
            # Get data
            inputs, labels = data
            
            # Zero gradients
            optimizer_max.zero_grad()
                    
            # Forward pass
            outputs = advNet(inputs)
            loss    = -criterion(outputs, labels)
            # Backpropagate
            loss.backward()
            # Cancel out gradients for the primary
            # TODO: Same question as before
            optimizer_min.zero_grad()
            # Descend
            optimizer_max.step()
            
            # Print statistics
            print('[Adversary, Epoch %d, batch %d] Crossentropy: %.3f' %
                  (epoch+1, i+1, -loss.item()))
            # Save history
            history.append(-loss.item())           



[Primary, Epoch 1, batch 1] Crossentropy: 2.318
[Primary, Epoch 1, batch 2] Crossentropy: 2.322
[Primary, Epoch 1, batch 3] Crossentropy: 2.272
[Primary, Epoch 1, batch 4] Crossentropy: 2.309
[Primary, Epoch 1, batch 5] Crossentropy: 2.299
[Primary, Epoch 1, batch 6] Crossentropy: 2.268
[Primary, Epoch 1, batch 7] Crossentropy: 2.277
[Primary, Epoch 1, batch 8] Crossentropy: 2.303
[Primary, Epoch 1, batch 9] Crossentropy: 2.271
[Primary, Epoch 1, batch 10] Crossentropy: 2.265
[Primary, Epoch 1, batch 11] Crossentropy: 2.268
[Primary, Epoch 1, batch 12] Crossentropy: 2.269
[Primary, Epoch 1, batch 13] Crossentropy: 2.289
[Primary, Epoch 1, batch 14] Crossentropy: 2.271
[Primary, Epoch 1, batch 15] Crossentropy: 2.274
[Primary, Epoch 1, batch 16] Crossentropy: 2.270
[Primary, Epoch 1, batch 17] Crossentropy: 2.270
[Primary, Epoch 1, batch 18] Crossentropy: 2.270
[Primary, Epoch 1, batch 19] Crossentropy: 2.265
[Primary, Epoch 1, batch 20] Crossentropy: 2.265
[Primary, Epoch 1, batch 21] 

[Primary, Epoch 1, batch 167] Crossentropy: 2.071
[Primary, Epoch 1, batch 168] Crossentropy: 2.059
[Primary, Epoch 1, batch 169] Crossentropy: 2.088
[Primary, Epoch 1, batch 170] Crossentropy: 2.039
[Primary, Epoch 1, batch 171] Crossentropy: 2.054
[Primary, Epoch 1, batch 172] Crossentropy: 2.054
[Primary, Epoch 1, batch 173] Crossentropy: 2.031
[Primary, Epoch 1, batch 174] Crossentropy: 2.048
[Primary, Epoch 1, batch 175] Crossentropy: 2.055
[Primary, Epoch 1, batch 176] Crossentropy: 2.052
[Primary, Epoch 1, batch 177] Crossentropy: 2.087
[Primary, Epoch 1, batch 178] Crossentropy: 2.071
[Primary, Epoch 1, batch 179] Crossentropy: 1.987
[Primary, Epoch 1, batch 180] Crossentropy: 2.043
[Primary, Epoch 1, batch 181] Crossentropy: 2.042
[Primary, Epoch 1, batch 182] Crossentropy: 2.024
[Primary, Epoch 1, batch 183] Crossentropy: 2.087
[Primary, Epoch 1, batch 184] Crossentropy: 2.032
[Primary, Epoch 1, batch 185] Crossentropy: 2.039
[Primary, Epoch 1, batch 186] Crossentropy: 2.025


[Adversary, Epoch 1, batch 30] Crossentropy: 1.892
[Adversary, Epoch 1, batch 31] Crossentropy: 1.881
[Adversary, Epoch 1, batch 32] Crossentropy: 1.917
[Adversary, Epoch 1, batch 33] Crossentropy: 1.906
[Adversary, Epoch 1, batch 34] Crossentropy: 1.928
[Adversary, Epoch 1, batch 35] Crossentropy: 1.909
[Adversary, Epoch 1, batch 36] Crossentropy: 1.905
[Adversary, Epoch 1, batch 37] Crossentropy: 1.872
[Adversary, Epoch 1, batch 38] Crossentropy: 1.866
[Adversary, Epoch 1, batch 39] Crossentropy: 1.940
[Adversary, Epoch 1, batch 40] Crossentropy: 1.909
[Adversary, Epoch 1, batch 41] Crossentropy: 1.889
[Adversary, Epoch 1, batch 42] Crossentropy: 1.857
[Adversary, Epoch 1, batch 43] Crossentropy: 1.902
[Adversary, Epoch 1, batch 44] Crossentropy: 1.914
[Adversary, Epoch 1, batch 45] Crossentropy: 1.895
[Adversary, Epoch 1, batch 46] Crossentropy: 1.905
[Adversary, Epoch 1, batch 47] Crossentropy: 1.902
[Adversary, Epoch 1, batch 48] Crossentropy: 1.949
[Adversary, Epoch 1, batch 49] 

In [None]:
# Plot raw global history
plt.figure(); plt.plot(history)