# MNIST Classification using  neural networks

### Import modules

In [21]:
import torch
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torch import optim
import torch.nn.functional as F
from torchvision import transforms
from torchinfo import summary

### Download MNIST data and pre process it

In [22]:
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.RandomInvert()])
train_set = datasets.MNIST(root='data',train=True,transform=transform,download=True)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_set = datasets.MNIST(root='data',train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True)

### Define a cnn based model with a dense network at the end

In [23]:
# conv layers of the network
class CNN(nn.Module):
    def __init__(self, channels, nf):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=nf, kernel_size=3, stride=1, padding='same')
        self.conv2 = nn.Conv2d(in_channels=nf, out_channels=2*nf, kernel_size=3, stride=1, padding='same')
        self.conv3 = nn.Conv2d(in_channels=2*nf, out_channels=4*nf, kernel_size=3, stride=1, padding='same')
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)        
        self.bn_conv = nn.BatchNorm2d(nf)

    def forward(self, x):
        x = self.conv1(x)
        #x = self.bn_conv(x)
        #x = self.dropout(x)
        x = self.maxpool(x)
        x = F.relu(x)
        
        x = self.conv2(x)
        #x = self.bn_conv(x)
        #x = self.dropout(x)
        x = self.maxpool(x)
        x = F.relu(x)

        x = self.conv3(x)
        #x = self.bn_conv(x)
        #x = self.dropout(x)
        x = self.maxpool(x)
        x = F.relu(x)

        return x

# dense layers of the network
class FCNN(nn.Module):
    def __init__(self, nf, num_classes):
        super(FCNN, self).__init__()
        # in: nfx4x4, flatten and feed to linear layer
        self.fc1 = nn.Linear(4*nf*3*3, 4*nf)
        self.bn_d1 = nn.BatchNorm1d(4*nf)
        self.fc2 = nn.Linear(4*nf, nf)
        self.bn_d2 = nn.BatchNorm1d(nf)        
        self.fc3 = nn.Linear(nf, num_classes)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn_d1(x)
        x = self.dropout(x)
        x = F.relu(x)
        
        x = self.fc2(x)
        x = self.bn_d2(x)
        x = self.dropout(x)
        x = F.relu(x)

        x = self.fc3(x)
        #x = self.bn_d3(x)
        return x

# combined model
class MODEL(nn.Module):
    def __init__(self, channels, nf, num_classes):
        super(MODEL, self).__init__()
        self.conv = CNN(channels, nf)
        self.dense = FCNN(nf, num_classes)
        self.nf = nf
        self.channels = channels
    
    def forward(self, x):
        x = self.conv(x)
        x = x.flatten(1)
        x = self.dense(x)
        return x

In [24]:
# A util function to print model summary
def print_model_summary(model):
    print(summary(model.conv, input_size=(1,28,28)))
    print(summary(model.dense, input_size=(1, 4*model.nf * 3*3)))

In [26]:
# check if output shape is as expected
test_model = MODEL(1,16,10) #MODEL(3,10)
test_input = torch.randn(size=(64,1,28,28))
output = test_model(test_input)
output.shape
print_model_summary(test_model)


Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      [64, 3, 3]                32
├─Conv2d: 1-1                            [16, 28, 28]              160
├─MaxPool2d: 1-2                         [16, 14, 14]              --
├─Conv2d: 1-3                            [32, 14, 14]              4,640
├─MaxPool2d: 1-4                         [32, 7, 7]                --
├─Conv2d: 1-5                            [64, 7, 7]                18,496
├─MaxPool2d: 1-6                         [64, 3, 3]                --
Total params: 23,328
Trainable params: 23,328
Non-trainable params: 0
Total mult-adds (M): 10.44
Input size (MB): 0.00
Forward/backward pass size (MB): 0.18
Params size (MB): 0.09
Estimated Total Size (MB): 0.27
Layer (type:depth-idx)                   Output Shape              Param #
FCNN                                     [1, 10]                   --
├─Linear: 1-1                            [1, 64]                   36,