In [228]:
import torch
from torch import nn,optim
from torchvision import datasets as dsets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import matplotlib.pyplot as plt
import numpy as np
from statistics import mean

In [229]:
batch_size = 20

train = dsets.MNIST(root = './data', train = True, transform=transforms.ToTensor(), download=True)
valid = dsets.MNIST(root = './data', train = False, transform=transforms.ToTensor(), download=True)

sample_set_size = 60000
indices = torch.randperm(len(train))[:sample_set_size]
train_loader = DataLoader(dataset = train, batch_size = batch_size, sampler=SubsetRandomSampler                           (indices))
val_loader = DataLoader(dataset = valid, batch_size = 10000, shuffle = True)

In [230]:
class CNN(nn.Module):

    def __init__(self, out1 = 16, out2 = 32):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=out1, kernel_size=2, padding=3, stride=1)
        self.mp1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = nn.Conv2d(in_channels=out1, out_channels=out2, kernel_size=2, padding=3, stride=1)
        self.mp2 = nn.MaxPool2d(kernel_size=2)

        self.linear1 = nn.Linear(10 * 10 * out2, 100)
        self.linear2 = nn.Linear(100,10)

    def forward(self,x):
        x = torch.relu(self.conv1(x))
        x = self.mp1(x)
        x = torch.relu(self.conv2(x))
        x = self.mp2(x)

        x = x.view(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
        x = torch.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [231]:
def train_model(model, epochs, learning_rate, final_lr, train_loader, val_loader, momentum):

    to_plot = []
    accuracy = []
    criterion = nn.CrossEntropyLoss()
    num_batches = len(list(train_loader))
    print(f'{num_batches} batches per epoch.')

    lr_decay = (final_lr - learning_rate) / int((3/4) * epochs)

    for epoch in range(epochs):
        model.train()
        opt = optim.SGD(model.parameters(), lr=learning_rate, momentum = momentum)

        print('EPOCH : ', epoch + 1)
        print('Batch : ', end = '')
        counter = 0
        mean_loss_list = []
        for x,y in train_loader :

            opt.zero_grad()
            yhat = model.forward(x)

            loss = criterion(yhat,y)
            mean_loss_list.append(loss.item())

            loss.backward()
            opt.step()

            if counter % (num_batches/10) == 0 : 
                print(str(counter) + ' --> ', end = '', flush = True)
            counter += 1

        mean_loss = mean(mean_loss_list)
        to_plot.append(mean_loss)

        correct = 0
        incorrect = 0

        for x,y in val_loader :
            model.eval()

            yhat = model.forward(x)

            for i in range(yhat.shape[0]):
                pred = yhat[i].argmax().item()
                actual = y[i].item()

                if pred == actual :
                    correct += 1
                else :
                    incorrect += 1
        
        acc = (correct / (correct + incorrect)) * 100
        accuracy.append(acc)

        print()
        print('Mean loss : ', mean_loss)
        print(f'Accuracy on test set : {acc} %')
        print('-' * 50)

        learning_rate = learning_rate - lr_decay

    return to_plot, accuracy

In [232]:
model = CNN()
epochs = 5
lr = 0.01
momentum = 0.85
loss, acc = train_model(model = model, epochs = epochs, learning_rate = lr, final_lr = lr/10, train_loader =                             train_loader, val_loader = val_loader, momentum = momentum)

3000 batches per epoch.
EPOCH :  1
Batch : 0 --> 300 --> 600 --> 900 --> 1200 --> 1500 --> 1800 --> 2100 --> 2400 --> 2700 --> 
Mean loss :  0.3003675595802876
Accuracy on test set : 97.23 %
--------------------------------------------------
EPOCH :  2
Batch : 0 --> 300 --> 600 --> 900 --> 1200 --> 1500 --> 1800 --> 2100 --> 2400 --> 2700 --> 
Mean loss :  0.0766290470463476
Accuracy on test set : 98.33 %
--------------------------------------------------
EPOCH :  3
Batch : 0 --> 300 --> 600 --> 900 --> 1200 --> 1500 --> 1800 --> 2100 --> 2400 --> 2700 --> 
Mean loss :  0.05506903864178215
Accuracy on test set : 98.11999999999999 %
--------------------------------------------------
EPOCH :  4
Batch : 0 --> 300 --> 600 --> 900 --> 1200 --> 1500 --> 1800 --> 2100 --> 2400 --> 2700 --> 
Mean loss :  0.04294508469010907
Accuracy on test set : 98.57000000000001 %
--------------------------------------------------
EPOCH :  5
Batch : 0 --> 300 --> 600 --> 900 --> 1200 --> 1500 --> 1800 --> 21