# How Do the Weights of a Simple Neural Network Change During Training?

I'll use Bokeh for the live visulizations, and PyTorch to implement the network. To get interpretable results, I'll use a network without a hidden layer.

In [1]:
import os
import time
from pathlib import Path
from itertools import islice
import gzip
import pickle

import numpy as np
import matplotlib.pyplot as plt

from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure, ColumnDataSource
from bokeh.transform import linear_cmap
from bokeh.palettes import Plasma11, Viridis11
from bokeh.util.hex import hexbin
from bokeh.layouts import gridplot
output_notebook()

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms

In [2]:
dir_path = Path().absolute()
dataset_path = dir_path.parent / "data/mnist.pkl.gz"
if not dataset_path.exists():
    print('Downloading dataset with curl ...')
    if not dataset_path.parent.exists():
        os.mkdir(dataset_path.parent)
    url = 'http://ericjmichaud.com/downloads/mnist.pkl.gz'
    os.system('curl -L {} -o {}'.format(url, dataset_path))
print('Download failed') if not dataset_path.exists() else print('Dataset acquired')
f = gzip.open(dataset_path, 'rb')
mnist = pickle.load(f)
f.close()
print('Loaded data to variable `mnist`')

Dataset acquired
Loaded data to variable `mnist`


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# dtype = torch.cuda.float if torch.cuda.is_available() else torch.float
dtype = torch.float32
torch.set_default_dtype(dtype)

In [4]:
class MNISTDataset(Dataset):
    """MNIST Digits Dataset."""
    def __init__(self, data, transform=None):
        self.mnist = data
        self.transform = transform
        
    def __len__(self):
        return len(self.mnist)

    def __getitem__(self, idx):
        sample = self.mnist[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

class ToTensor(object):
    """Convert samples (np.ndarray, np.ndarray) to (torch.tensor, torch.tensor)"""
    def __call__(self, sample):
        image, one_hot_label = sample
        image = torch.from_numpy(image).to(dtype)
        label = torch.tensor(np.argmax(one_hot_label)).to(torch.long)
        return (image, label)


In [5]:
class SoftmaxRegression(nn.Module):
    """Single-layer softmax network."""
    def __init__(self, n_in, n_out):
        super(SoftmaxRegression, self).__init__()
        self.linear = nn.Linear(n_in, n_out, bias=False)
    
    def forward(self, x):
        return F.log_softmax(self.linear(x), dim=1)


In [6]:
training = MNISTDataset(mnist[:60000], transform=ToTensor())
test = MNISTDataset(mnist[60000:], transform=ToTensor())
training_loader = torch.utils.data.DataLoader(training, batch_size=20, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=20, shuffle=True)

In [7]:
model = SoftmaxRegression(28*28, 10).to(device)
loss_fn = nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters())

In [8]:
# --- DEFINE OUTPUT ---

weight_image = figure(plot_width=5*150, plot_height=2*150, title='Output Weights')
weight_image.xgrid.visible = False
weight_image.ygrid.visible = False
weight_image.image([np.zeros((784, 10)).reshape(112, -1)], 0, 0, 1, 1, palette=Plasma11)

display = show(weight_image, notebook_handle=True)

def update_output():
    weight = model.linear.weight.cpu().detach().numpy()
    for i in range(len(weight)):
        x = i % 5
        y = i // 5
        weight_image.image([weight[i].reshape(28, 28)], x*29, -y*29, 28, 28, palette=Viridis11)
    push_notebook(handle=display)


# --- TRAIN ---

num_batches = 0
ac = 0
ei = None
for epoch in range(1):
    for sample, target in training_loader:
        optimizer.zero_grad()
        loss = loss_fn(model(sample.to(device)), target.to(device))
        loss.backward()
        optimizer.step()
        num_batches += 1
        if num_batches % 5 == 0:
            update_output()
        if num_batches == 100:
            break