# Lecture 29: Training a LeNet for MNIST Classification

In [None]:
%matplotlib inline
import copy
import time
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms,datasets

print(torch.__version__) # This code has been updated for PyTorch 1.0.0

## Load data:

In [None]:
apply_transform = transforms.Compose([transforms.Resize(32),transforms.ToTensor()])
BatchSize = 100

trainset = datasets.MNIST(root='./MNIST', train=True, download=True, transform=apply_transform)
trainLoader = torch.utils.data.DataLoader(trainset, batch_size=BatchSize,
                                          shuffle=True, num_workers=4) # Creating dataloader

testset = datasets.MNIST(root='./MNIST', train=False, download=True, transform=apply_transform)
testLoader = torch.utils.data.DataLoader(testset, batch_size=BatchSize,
                                         shuffle=False, num_workers=4) # Creating dataloader

In [None]:
# Size of train and test datasets
print('No. of samples in train set: '+str(len(trainLoader.dataset)))
print('No. of samples in test set: '+str(len(testLoader.dataset)))

## Define network architecture

In [None]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=2,stride=2)        
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 400)
        x = F.relu(self.fc1(x)) 
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x,dim=1)


In [None]:
net = LeNet()
print(net)

In [None]:
# Copying initial weights for visualization
init_weightConv1 = copy.deepcopy(net.conv1.weight.data)
init_weightConv2 = copy.deepcopy(net.conv2.weight.data)

In [None]:
# Check availability of GPU

use_gpu = torch.cuda.is_available()
if use_gpu:
    print('GPU is available!')
    device = "cuda"
else:
    print('GPU is not available!')
    device = "cpu"

net = net.to(device)

## Define loss function and optimizer

In [None]:
criterion = nn.NLLLoss() # Negative Log-likelihood
optimizer = optim.Adam(net.parameters(), lr=1e-4) # Adam

## Train the network

In [None]:
iterations = 10
trainLoss = []
testAcc = []
start = time.time()
for epoch in range(iterations):
    epochStart = time.time()
    runningLoss = 0    
    net.train() # For training
    for i,data in enumerate(trainLoader): # i -> batch number
        inputs,labels = data
        inputs, labels = inputs.to(device), labels.to(device)         
        # Initialize gradients to zero
        optimizer.zero_grad()
        # Feed-forward input data through the network        
        outputs = net(inputs)        
        # Compute loss/error
        loss = criterion(outputs, labels) # loss is averaged in each batch
        # Backpropagate loss and compute gradients
        loss.backward()
        # Update the network parameters
        optimizer.step()
        # Accumulate loss per batch
        runningLoss += loss.item()  
    avgTrainLoss = runningLoss/(i+1) 
    trainLoss.append(avgTrainLoss)
    
    # Evaluating performance on test set for each epoch
    net.eval() # For testing [Affects batch-norm and dropout layers (if any)]
    running_correct = 0
    with torch.no_grad():
        for data in testLoader:
            inputs,labels = data
            inputs = inputs.to(device)
            outputs = net(inputs)
            _, predicted = torch.max(outputs.data, 1)
            if use_gpu:                
                predicted = predicted.cpu()            
            running_correct += (predicted == labels).sum()
    avgTestAcc = float(running_correct)*100/10000.0
    testAcc.append(avgTestAcc)
        
    # Plotting training loss vs Epochs
    fig1 = plt.figure(1)        
    plt.plot(range(epoch+1),trainLoss,'r-',label='train')        
    if epoch==0:
        plt.legend(loc='upper left')
        plt.xlabel('Epochs')
        plt.ylabel('Training loss')   
    # Plotting testing accuracy vs Epochs
    fig2 = plt.figure(2)        
    plt.plot(range(epoch+1),testAcc,'g-',label='test')        
    if epoch==0:
        plt.legend(loc='upper left')
        plt.xlabel('Epochs')
        plt.ylabel('Testing accuracy')    
    epochEnd = time.time()-epochStart
    print('Iteration: {:.0f} /{:.0f}  ;  Training Loss: {:.6f} ; Testing Acc: {:.3f} ; Time consumed: {:.0f}m {:.0f}s '\
          .format(epoch + 1,iterations,avgTrainLoss,avgTestAcc,epochEnd//60,epochEnd%60))
end = time.time()-start
print('Training completed in {:.0f}m {:.0f}s'.format(end//60,end%60))


In [None]:
# Copying trained weights for visualization
trained_weightConv1 = copy.deepcopy(net.conv1.weight.data)
trained_weightConv2 = copy.deepcopy(net.conv2.weight.data)

## Visualization of weights

In [None]:
# functions to show an image
def imshow(img, strlabel):
    npimg = img.numpy()
    npimg = np.abs(npimg)
    fig_size = plt.rcParams["figure.figsize"]
    fig_size[0] = 10
    fig_size[1] = 10
    plt.rcParams["figure.figsize"] = fig_size
    plt.figure()
    plt.title(strlabel)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

In [None]:
if use_gpu:
    trained_weightConv1 = trained_weightConv1.cpu()
    trained_weightConv2 = trained_weightConv2.cpu()
    
imshow(torchvision.utils.make_grid(init_weightConv1,nrow=6,normalize=True),'Initial Weights: conv1')
imshow(torchvision.utils.make_grid(trained_weightConv1,nrow=6,normalize=True),'Trained Weights: conv1')
imshow(torchvision.utils.make_grid(init_weightConv1-trained_weightConv1,nrow=6,normalize=True),'Difference of weights: conv1')

imshow(torchvision.utils.make_grid(init_weightConv2[0].unsqueeze(1),nrow=6,normalize=True),'Initial Weights: conv2(1)')
imshow(torchvision.utils.make_grid(trained_weightConv2[0].unsqueeze(1),nrow=6,normalize=True),'Trained Weights: conv2(1)')
imshow(torchvision.utils.make_grid(init_weightConv2[0].unsqueeze(1)-trained_weightConv2[0].unsqueeze(1),nrow=6,normalize=True),'Difference of weights: conv2(1)')