# Training a vanilla 784-200-200-10 MNIST classifier in $\Theta$-space


Sam Greydanus. 3 June 2017. MIT License.

Just your regular old MNIST classifier. Use this as a baseline to compare to the models trained in $\omega$ space (parameter spaces with vastly reduced dimensionality). **Should produce 98-99% accuracy on test set.**

In [1]:
% matplotlib inline
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from torchvision import datasets, models, transforms, utils
import numpy as np
import matplotlib.pyplot as plt
import os

reseed = lambda: np.random.seed(seed=123) ; ms = torch.manual_seed(123) # for reproducibility
reseed()

## Hyperparameters

In [2]:
# book-keeping
fig_dir = 'figures/'
os.makedirs(fig_dir) if not os.path.exists(fig_dir) else None

# model settings
D_side = 28
D_img = D_side**2
D_hidden = 200
batch_size = 256

# train settings
lr = 1e-3
global_step = 0
eval_every = 50
epochs = 10
total_steps = int(60000*epochs/batch_size) #max(checkpoint_steps + [10000])
print("using {} epochs".format(epochs))

using 10 epochs


## Dataloader

In [3]:
modes = ['train', 'test']
trans = transforms.Compose([transforms.ToTensor(),]) # transforms.Normalize((0.1307,), (0.3081,))
dsets = {k: datasets.MNIST('./data', train=k=='train', download=True, transform=trans) for k in modes}
loaders = {k: torch.utils.data.DataLoader(dsets[k], batch_size=batch_size, shuffle=True) for k in modes}

def mnist(mode='train'):
    X, y = next(iter(loaders[mode]))
    return Variable(X).resize(batch_size, D_img), Variable(y)

## Build the model

In [4]:
class SimpleNN(torch.nn.Module):
    def __init__(self, batch_size, input_dim, h_dim, output_dim):
        super(SimpleNN, self).__init__()
        # param_meta maps each param to (dim1, dim2, initial_stdev)
        self.batch_size = batch_size
        self.param_meta = {'W1': (input_dim, h_dim, 0.001), 'W2': (h_dim, h_dim, 0.001),
                      'W3': (h_dim, output_dim, 0.01), \
                      'b1': (1, h_dim, 0.0), 'b2': (1, h_dim, 0.0), 'b3': (1, output_dim, 0.0) }
        self.names = [k for k in self.param_meta.keys()]
        self.counts = [self.param_meta[n][0]*self.param_meta[n][1] for n in self.names]
        self.slices = np.cumsum([0] + self.counts)
        self.theta_dim = int(self.slices[-1])
        
        flat_params = [np.random.randn(self.counts[i],1)*self.param_meta[n][2] for i, n in enumerate(self.names)]
        theta_init = torch.Tensor(np.concatenate(flat_params, axis=0))
        self.flat_theta = nn.Parameter(theta_init, requires_grad=True)
        
        thetas = {n: self.flat_theta[self.slices[i]:self.slices[i+1]] for i, n in enumerate(self.names)}
        self.thetas = {k : v.resize(self.param_meta[k][0], self.param_meta[k][1]) for k, v in thetas.items()}
        print('\tthis model\'s theta space has {} parameters'.format(self.theta_dim))

    def forward(self, X):
        h1 = F.relu(X.mm(self.thetas['W1']) + self.thetas['b1'].repeat(self.batch_size, 1))
        h2 = F.relu(h1.mm(self.thetas['W2']) + self.thetas['b2'].repeat(self.batch_size, 1))
        h3 = F.log_softmax(h2.mm(self.thetas['W3']) + self.thetas['b3'].repeat(self.batch_size, 1))
        return h3

## Training utilities
Takes 10-20 min to train 10 epochs with [$\omega$]=100 on my MacBook Air

In [5]:
def accuracy(model, loaders, mode='test'):
    assert mode in list(loaders.keys()), 'incorrect mode supplied'
    model.eval()
    loss = 0
    correct = 0
    for X, y in loaders[mode]:
        X, y = mnist(mode)
        y_hat = model(X)
        loss += F.nll_loss(y_hat, y).data[0]
        pred = y_hat.data.max(1)[1]
        correct += pred.eq(y.data).cpu().sum()

    loss /= len(loaders[mode]) # loss function already averages over batch size
    total = len(loaders[mode].dataset)
    model.train()
    return loss, correct, total

In [6]:
def train(model, optimizer, global_step=0, verbose=False):
    running_loss = None
    acc_msg = '...' ; print('\ttraining...')
    loss_hist = []
    acc_hist = []

    # generic train loop
    for global_step in range(global_step, total_steps+global_step+1):

        # ======== DISCRIMINATOR STEP ======== #
        # forward
        X, y = mnist(mode='train')
        y_hat = model(X)

        # backward
        loss = F.nll_loss(y_hat, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        np_loss = loss.data.numpy()[0]
        running_loss = np_loss if running_loss is None else .99*running_loss + (1-.99)*np_loss
        loss_hist.append((global_step, running_loss))

        # ======== DISPLAY PROGRESS ======== #
        print('\tstep {} / {} / loss: {:.4f}'.format(global_step, acc_msg, running_loss), end="\r")
        if global_step % eval_every == 0:
            l, c, t = accuracy(model, loaders, mode='test')
            acc_msg = 'accuracy: {:.4f}% ({}/{})'.format(100*c/t, c, t)
            acc_hist.append((global_step, 100*c/t))

    l, c, t = accuracy(model, loaders, mode='test')
    acc_hist.append((global_step, 100*c/t))
    acc_msg = 'accuracy: {:.4f}% ({}/{})'.format(100*c/t, c, t)
    print('\tstep {} / {} / loss: {:.4f}'.format(global_step, acc_msg, running_loss))
    return loss_hist, acc_hist

## Train models in several different $\omega$ spaces

In [7]:
reseed()
model = SimpleNN(batch_size, D_img, D_hidden, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=0)
loss_hist, acc_hist = train(model, optimizer, global_step, verbose=False)

	this model's theta space has 199210 parameters
	training...
	step 250 / accuracy: 97.7600% (9776/10000) / loss: 0.3509

KeyboardInterrupt: 

## Plot loss

In [None]:
f3 = plt.figure(figsize=[8,5])

xy = np.stack(loss_hist)
plt.plot(xy[:,0], xy[:,1], linewidth=3.0, label='[$\omega$]={}'.format(model.theta_dim))

title = "2-layer NN loss on MNIST ([$\Theta$]={})".format(model.theta_dim)
plt.title(title, fontsize=16)
plt.xlabel('train step', fontsize=14) ; plt.setp(plt.gca().axes.get_xticklabels(), fontsize=14)
plt.ylabel('loss', fontsize=14) ; plt.setp(plt.gca().axes.get_yticklabels(), fontsize=14)
plt.ylim([0,2.5])
plt.legend()

plt.show() ; f3.savefig('./figures/vanilla-loss.png', bbox_inches='tight')

## Plot accuracy

In [None]:
f3 = plt.figure(figsize=[8,5])

xy = np.stack(acc_hist)
plt.plot(xy[:,0], xy[:,1], linewidth=3.0, label='[$\omega$]={}'.format(model.theta_dim))

title = "2-layer NN accuracy on MNIST ([$\Theta$]={})".format(model.theta_dim)
plt.title(title, fontsize=16)
plt.xlabel('train step', fontsize=14) ; plt.setp(plt.gca().axes.get_xticklabels(), fontsize=14)
plt.ylabel('accuracy (%)', fontsize=14) ; plt.setp(plt.gca().axes.get_yticklabels(), fontsize=14)

results_msg = 'epochs: {}\nlearning rate : {}\nbatch size: {}\nmax test accuracy: {:.2f}%'\
    .format(epochs, lr, batch_size, acc_hist[-1][-1])
f3.text(0.92, .50, results_msg, ha='left', va='center', fontsize=12)
plt.ylim([0,100])
plt.legend()

plt.show() ; f3.savefig('./figures/vanilla-acc.png', bbox_inches='tight')