# MNIST training

In [None]:
import numpy as np
from datetime import datetime 
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision
import matplotlib.pyplot as plt
import ipdb
import time
from mnist_model import ConvNet
from utils import *

In [None]:
# check device
device = 'cuda'

# parameters
RANDOM_SEED = 42
learning_rate = 0.001
batch_size_train = 128
batch_size_test = 10000
num_workers = 10
n_classes = 10
activation_cutoff = 99.9

## data

In [None]:
transform = transforms.Compose([transforms.ToTensor(),])

# download and create datasets
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
valid_dataset = datasets.MNIST(root='./data', train=False,transform=transform)

# define the data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size_test, shuffle=False, num_workers=num_workers)

In [None]:
def training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every=1):
    train_losses = []
    valid_losses = []

    for epoch in range(0, epochs):
        model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device)
        train_losses.append(train_loss)
        with torch.no_grad():
            model, valid_loss = validate(valid_loader, model, criterion, device, activation_cutoff)
            valid_losses.append(valid_loss)
        if epoch % print_every == (print_every - 1):
            train_acc = get_accuracy(model, train_loader, device=device)
            valid_acc = get_accuracy(model, valid_loader, device=device)
            print(f'{datetime.now().time().replace(microsecond=0)} --- '
                  f'Epoch: {epoch}\t'
                  f'Train loss: {train_loss:.4f}\t'
                  f'Valid loss: {valid_loss:.4f}\t'
                  f'Train accuracy: {100 * train_acc:.2f}\t'
                  f'Valid accuracy: {100 * valid_acc:.2f}')
        if 100 * valid_acc > 99.2: break
    plot_losses(train_losses, valid_losses)
    return model, optimizer, (train_losses, valid_losses)

In [None]:
def train(train_loader, model, criterion, optimizer, device):
    model.train()
    running_loss = 0
    for i, (X, y_true) in enumerate(train_loader):
        optimizer.zero_grad()
        X = X.to(device)
        y_true = y_true.to(device)
        # Forward pass
        y_hat, probs, out1, out2, out3 = model(X)
        norm = 2
        beta = 0.0003
        loss = criterion(y_hat, y_true) + 0.1*beta*torch.norm(out1, norm) + 0.5*beta*torch.norm(out2, norm)\
                + 2*beta*torch.norm(out3, norm) + 20*beta*torch.norm(y_hat, 2) 
        # Backward pass
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * X.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    return model, optimizer, epoch_loss

## training

In [None]:
torch.manual_seed(RANDOM_SEED)

model = ConvNet(n_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.0001)
criterion = nn.CrossEntropyLoss()

In [None]:
model, optimizer, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, 40, device)

In [None]:
torch.save(model.state_dict(), "./mnist-convnet.pth") # don't forget to set model.eval() after loading

In [None]:
for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
        print(module.weight.shape)