<a href="https://colab.research.google.com/github/gkdivya/EVA/blob/main/4_ArchitecturalBasics/Experiments/MNIST_Exp2_WithTransitionBlock.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Experiment - 2

**Objective** : In previous [experiment](https://github.com/gkdivya/EVA/blob/main/4_ArchitecturalBasics/Experiments/MNIST_Exp1_WithLessParams.ipynb) we reduced the number of parameters to **5490**. In this notebook, we will be refining the architecture with reference to [Kaggle Notebook](https://www.kaggle.com/enwei26/mnist-digits-pytorch-cnn-99)

Idea is to use a transition block - Max pooling followed by 1*1 to observe the accuracy

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 10, 3, padding=1, bias=False),  #Input:28x28 Output:26x26 GRF:3x3
            nn.ReLU(),

            nn.Conv2d(10, 10, 3, padding=1, bias=False), #Input:26x26 Output:24x24 GRF:5x5
            nn.ReLU()    
        )

        self.trans1 = nn.Sequential(
            nn.MaxPool2d(2, 2),  #Input:22x22 Output:11x11 GRF:14x14
            nn.Conv2d(10, 10, 1, bias=False), #Input:22x22 Output:11x11 GRF:14x14
            nn.ReLU()
        )

        self.conv2 =  nn.Sequential(
            nn.Conv2d(10, 10, 3, padding=1, bias=False), #Input:11x11 Output:9x9 GRF:16x16
            nn.ReLU(),

            nn.Conv2d(10, 10, 3, padding=1, bias=False),  #Input:9x9 Output:7x7 GRF:18x18
            nn.ReLU()            
        )

        self.trans2 = nn.Sequential(
            nn.MaxPool2d(2, 2),  #Input:22x22 Output:11x11 GRF:14x14
            nn.Conv2d(10, 10, 1, bias=False), #Input:9x9 Output:7x7 GRF:18x18
            nn.ReLU()
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(10, 10, 3, bias=False), #Input:7x7 Output:5x5 GRF:20x20
            nn.ReLU(),

            nn.Conv2d(10, 10, 3, bias=False), #Input:5x5 Output:3x3 GRF:22x22
            nn.ReLU(),

            nn.Conv2d(10, 10, 3, bias=False) #Input:5x5 Output:3x3 GRF:22x22
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.trans1(x)
        x = self.conv2(x)
        x = self.trans2(x)
        x = self.conv3(x)

        x = x.view(-1,10)
        return F.log_softmax(x,dim=1)

In [6]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = Net().to(device)
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 10, 28, 28]              90
              ReLU-2           [-1, 10, 28, 28]               0
            Conv2d-3           [-1, 10, 28, 28]             900
              ReLU-4           [-1, 10, 28, 28]               0
         MaxPool2d-5           [-1, 10, 14, 14]               0
            Conv2d-6           [-1, 10, 14, 14]             100
              ReLU-7           [-1, 10, 14, 14]               0
            Conv2d-8           [-1, 10, 14, 14]             900
              ReLU-9           [-1, 10, 14, 14]               0
           Conv2d-10           [-1, 10, 14, 14]             900
             ReLU-11           [-1, 10, 14, 14]               0
        MaxPool2d-12             [-1, 10, 7, 7]               0
           Conv2d-13             [-1, 10, 7, 7]             100
             ReLU-14             [-1, 1

In [7]:
torch.manual_seed(1)
batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.RandomRotation(5),
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [8]:
from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    pbar = tqdm(train_loader)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pbar.set_description(desc= f'epoch={epoch} loss={loss.item():.10f} batch_id={batch_idx:05d}')
    #print( f'Epoch {epoch} - \nTrain set : loss={loss.item()} batch_id={batch_idx}')


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [9]:

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.02, momentum=0.7)

for epoch in range(1, 20):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

epoch=1 loss=2.3025302887 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.85it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 2.3025, Accuracy: 1535/10000 (15.35%)



epoch=2 loss=2.3024733067 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 31.07it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 2.3024, Accuracy: 1932/10000 (19.32%)



epoch=3 loss=2.3011991978 batch_id=00468: 100%|██████████| 469/469 [00:14<00:00, 31.30it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 2.3011, Accuracy: 2054/10000 (20.54%)



epoch=4 loss=0.3724213541 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 31.14it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.2411, Accuracy: 9223/10000 (92.23%)



epoch=5 loss=0.0761757568 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 31.05it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.1442, Accuracy: 9546/10000 (95.46%)



epoch=6 loss=0.1145213544 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.65it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0959, Accuracy: 9701/10000 (97.01%)



epoch=7 loss=0.0502433740 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.88it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0793, Accuracy: 9747/10000 (97.47%)



epoch=8 loss=0.0942621008 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.92it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0742, Accuracy: 9778/10000 (97.78%)



epoch=9 loss=0.1641373932 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 31.16it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0548, Accuracy: 9821/10000 (98.21%)



epoch=10 loss=0.0312329028 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.58it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0661, Accuracy: 9797/10000 (97.97%)



epoch=11 loss=0.0440519862 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 31.16it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0472, Accuracy: 9849/10000 (98.49%)



epoch=12 loss=0.0892613009 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.45it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0513, Accuracy: 9840/10000 (98.40%)



epoch=13 loss=0.0982613191 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.71it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0483, Accuracy: 9841/10000 (98.41%)



epoch=14 loss=0.0309470203 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.25it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0442, Accuracy: 9866/10000 (98.66%)



epoch=15 loss=0.0287094563 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.73it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0395, Accuracy: 9877/10000 (98.77%)



epoch=16 loss=0.0033145845 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.74it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0395, Accuracy: 9882/10000 (98.82%)



epoch=17 loss=0.0782022402 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.93it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0407, Accuracy: 9865/10000 (98.65%)



epoch=18 loss=0.0716896206 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.62it/s]
  0%|          | 0/469 [00:00<?, ?it/s]

Test set: Average loss: 0.0393, Accuracy: 9877/10000 (98.77%)



epoch=19 loss=0.0532136746 batch_id=00468: 100%|██████████| 469/469 [00:15<00:00, 30.69it/s]


Test set: Average loss: 0.0391, Accuracy: 9882/10000 (98.82%)



## Summary
Refined the model architecture to include transition block along with Max pooling. Adding 1*1 block increased the number of params by 200 :(

Transition block might not be a big value add for MNIST, but for bigger networks 1*1 will help in reducing the number of channels which in turn will help the networks go deeper without compromising on the feature maps.

With just 5690 params, MNIST model is trained to achieve 98.82% accuracy in 20 epochs

And the difference between train and validation accuracy is ~0.2