# Implementing Yann LeCun's LeNet-5 in PyTorch

## Setup

In [None]:
import numpy as np
from datetime import datetime 
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision
import matplotlib.pyplot as plt
import ipdb
import time
from lenet import LeNet5
from utils import *
# check device
DEVICE = 'cuda'

In [None]:
# parameters
RANDOM_SEED = 42
LEARNING_RATE = 0.001
BATCH_SIZE = 128
num_workers = 10

IMG_SIZE = 32
N_CLASSES = 10

## Data

In [None]:
# define transforms
# transforms.ToTensor() automatically scales the images to [0,1] range
transform = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# download and create datasets
train_dataset = datasets.MNIST(root='mnist_data', train=True, transform=transform, download=True)
valid_dataset = datasets.MNIST(root='mnist_data', train=False,transform=transform)

# define the data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True)

In [None]:
def training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every=1):
    best_loss = 1e10
    train_losses = []
    valid_losses = []
    writer = SummaryWriter()#'runs/lenet_experiment_1')
    # Train model
    for epoch in range(0, epochs):
        # training
        model, optimizer, train_loss, out1, out2, out3, out4, out5 = train(train_loader, model, criterion, optimizer, device)
        weights, biases = get_weights_biases(model)
        activations = (out1, out2, out3, out4, out5)
        for i in range(len(weights)):
            writer.add_histogram('Layer' + str(i+1) + '/weights', weights[i], epoch)
            writer.add_histogram('Layer' + str(i+1) + '/biases', biases[i], epoch)
            writer.add_histogram('Layer' + str(i+1) + '/activations', activations[i], epoch)

        train_losses.append(train_loss)
        writer.add_scalar('Loss/train', train_loss, epoch)
        # validation
        with torch.no_grad():
            model, valid_loss = validate(valid_loader, model, criterion, device)
            valid_losses.append(valid_loss)
            writer.add_scalar('Loss/test', train_loss, epoch)
        if epoch % print_every == (print_every - 1):
            train_acc = get_accuracy(model, train_loader, device=device)
            valid_acc = get_accuracy(model, valid_loader, device=device)
            writer.add_scalar('Accuracy/train', train_acc, epoch)
            writer.add_scalar('Accuracy/test', valid_acc, epoch)
            print(f'{datetime.now().time().replace(microsecond=0)} --- '
                  f'Epoch: {epoch}\t'
                  f'Train loss: {train_loss:.4f}\t'
                  f'Valid loss: {valid_loss:.4f}\t'
                  f'Train accuracy: {100 * train_acc:.2f}\t'
                  f'Valid accuracy: {100 * valid_acc:.2f}')
    writer.close()
    plot_losses(train_losses, valid_losses)
    return model, optimizer, (train_losses, valid_losses)

In [None]:
def train(train_loader, model, criterion, optimizer, device):
    '''
    Function for the training step of the training loop
    '''
    model.train()
    running_loss = 0
    l1, l2, l3, l4, l5 = 0, 0, 0, 0, 0
    for i, (X, y_true) in enumerate(train_loader):
        optimizer.zero_grad()
        X = X.to(device)
        y_true = y_true.to(device)
        # Forward pass
        y_hat, probs, out1, out2, out3, out4 = model(X)
        max1, max2, max3, max4, max5 = out1.max(), out2.max(), out3.max(), out4.max(), y_hat.abs().max()
        if max1 > l1: l1 = max1
        if max2 > l2: l2 = max2
        if max3 > l3: l3 = max3
        if max4 > l4: l4 = max4
        if max5 > l5: l5 = max5
        norm = 3
        beta = 0.001
        loss = criterion(y_hat, y_true) + 10*beta*torch.norm(y_hat, 2) + 2*beta*torch.norm(out4, norm)\
                + 1.3*beta*torch.norm(out3, norm) + 0.7*beta*torch.norm(out2, norm) + 0.6*beta*torch.norm(out1, norm)
        # Backward pass
        loss.backward()
        optimizer.step()
        #ipdb.set_trace()
        running_loss += loss.item() * X.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    print("Max activations: {0:.2f}, {1:.2f}, {2:.2f}, {3:.2f}, {4:.2f}".format(l1, l2, l3, l4, l5))
    return model, optimizer, epoch_loss, out1, out2, out3, out4, y_hat

In [None]:
torch.manual_seed(RANDOM_SEED)

model = LeNet5(N_CLASSES).to(DEVICE)
#model = nn.DataParallel(model) #torch.cuda.device_count()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

In [None]:
model, optimizer, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, 17, DEVICE)

In [None]:
biggest_weight(model)

In [None]:
model.eval()
X, y_true = next(iter(test_loader))
X = X.to(DEVICE)
Y = y_true.to(DEVICE)
Y_hat = model(X)[0]
plt.plot(Y_hat.cpu().detach().numpy().flatten())

In [None]:
from matplotlib import pyplot as plt
import numpy as np

weights = list(model.parameters())
weights = [weight.flatten() for weight in weights]
weights = torch.cat(weights)

capture = plt.hist(weights.cpu().detach().numpy(), bins=200)

In [None]:
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

In [None]:
torch.save(model.state_dict(), "./quartz-lenet.pth") # don't forget to set model.eval() after loading

**Plotting the images**

In [None]:
ROW_IMG = 10
N_ROWS = 5

In [None]:
fig = plt.figure()
for index in range(1, ROW_IMG * N_ROWS + 1):
    plt.subplot(N_ROWS, ROW_IMG, index)
    plt.axis('off')
    plt.imshow(train_dataset.data[index])
fig.suptitle('MNIST Dataset - preview');

In [None]:
fig = plt.figure()
for index in range(1, ROW_IMG * N_ROWS + 1):
    plt.subplot(N_ROWS, ROW_IMG, index)
    plt.axis('off')
    plt.imshow(train_dataset.data[index], cmap='gray_r')
fig.suptitle('MNIST Dataset - preview');