In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.convLayer1 = nn.Sequential(
            nn.Conv2d(1, 8, 3), 
            # input 28 x 28 x 1 -- output 26 x 26 x 8, Receptive field - 3
            nn.ReLU(),
            nn.BatchNorm2d(8),
            nn.Conv2d(8, 16, 3),
            # input 26 x 26 x 8 -- output 24 x 24 x 16, Receptive field - 5
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 8, 1),
            # -- Ant Man -- #
            # input 24 x 24 x 16 -- output 24 x 24 x 8, Receptive field - 5
            nn.MaxPool2d(2, 2),
            # input 24 x 24 x 8 -- output 12 x 12 x 8, Receptive field - 10
            nn.Dropout(0.3)
        )
        self.convLayer2 = nn.Sequential(
            nn.Conv2d(8, 16, 3),
            # input 12 x 12 x 8 -- output 10 x 10 x 16, Receptive field - 12
            nn.ReLU(),          
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 16, 3),
            # input 10 x 10 x 16 -- output 8 x 8 x 16, Receptive field - 14
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, 2),
            # input 8 x 8 x 16 -- output 4 x 4 x 16, Receptive field - 28
            nn.Dropout(0.3)

        )
       
        self.convLayer3 = nn.Sequential(
            nn.Conv2d(16, 32, 3),
            # input 4 x 4 x 16 -- output 2 x 2 x 32, Receptive field - 30
            nn.Conv2d(32, 10, 1),
            # -- Ant Man -- #
            # input 2 x 2 x 32 -- output 2 x 2 x 10, Receptive field - 30

        )

    def forward(self, x):
        x = self.convLayer1(x)
        x = self.convLayer2(x)
        x = self.convLayer3(x)
        x = F.avg_pool2d(x, 2) # 2 x 2 average pooling
        x = x.view(-1, 10)
        return F.log_softmax(x)

In [3]:
!pip install torchsummary
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = Net().to(device)
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              80
              ReLU-2            [-1, 8, 26, 26]               0
       BatchNorm2d-3            [-1, 8, 26, 26]              16
            Conv2d-4           [-1, 16, 24, 24]           1,168
              ReLU-5           [-1, 16, 24, 24]               0
       BatchNorm2d-6           [-1, 16, 24, 24]              32
            Conv2d-7            [-1, 8, 24, 24]             136
         MaxPool2d-8            [-1, 8, 12, 12]               0
           Dropout-9            [-1, 8, 12, 12]               0
           Conv2d-10           [-1, 16, 10, 10]           1,168
             ReLU-11           [-1, 16, 10, 10]               0
      BatchNorm2d-12           [-1, 16, 10, 10]              32
           Conv2d-13             [-1, 16, 8, 8]           2,320
             ReLU-14             [-1, 1



In [4]:
torch.manual_seed(1)
batch_size = 128

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=batch_size, shuffle=True, **kwargs)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [5]:
from tqdm import tqdm
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    correct = 0
    pbar = tqdm(train_loader, position = 0, leave = True)
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        pbar.set_description(desc= f'loss={loss.item()} batch_id={batch_idx}')
    
    print('\Train set: Accuracy: {}/{} ({:.4f}%)\n'.format(
        correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [6]:
from torch.optim.lr_scheduler import StepLR

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.04, momentum=0.9)
for epoch in range(1, 31):
    print(f'EPOCH : {epoch}')
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)

  0%|          | 0/469 [00:00<?, ?it/s]

EPOCH : 1


loss=0.03161170706152916 batch_id=468: 100%|██████████| 469/469 [00:14<00:00, 32.85it/s]

\Train set: Accuracy: 55988/60000 (93.3133%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0554, Accuracy: 9816/10000 (98.16%)

EPOCH : 2


loss=0.10073136538267136 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.65it/s]

\Train set: Accuracy: 58464/60000 (97.4400%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0378, Accuracy: 9871/10000 (98.71%)

EPOCH : 3


loss=0.029696205630898476 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 34.13it/s]

\Train set: Accuracy: 58750/60000 (97.9167%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0408, Accuracy: 9863/10000 (98.63%)

EPOCH : 4


loss=0.037268538028001785 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.58it/s]

\Train set: Accuracy: 58845/60000 (98.0750%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0274, Accuracy: 9915/10000 (99.15%)

EPOCH : 5


loss=0.1320980042219162 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.93it/s]

\Train set: Accuracy: 58935/60000 (98.2250%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0279, Accuracy: 9915/10000 (99.15%)

EPOCH : 6


loss=0.009748919866979122 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.88it/s]

\Train set: Accuracy: 59018/60000 (98.3633%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0240, Accuracy: 9925/10000 (99.25%)

EPOCH : 7


loss=0.03792427107691765 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.85it/s]

\Train set: Accuracy: 59022/60000 (98.3700%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0257, Accuracy: 9907/10000 (99.07%)

EPOCH : 8


loss=0.0502290241420269 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 34.11it/s]

\Train set: Accuracy: 59125/60000 (98.5417%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0273, Accuracy: 9912/10000 (99.12%)

EPOCH : 9


loss=0.018511682748794556 batch_id=468: 100%|██████████| 469/469 [00:14<00:00, 33.41it/s]

\Train set: Accuracy: 59099/60000 (98.4983%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0218, Accuracy: 9930/10000 (99.30%)

EPOCH : 10


loss=0.1410997360944748 batch_id=468: 100%|██████████| 469/469 [00:14<00:00, 33.39it/s]

\Train set: Accuracy: 59160/60000 (98.6000%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0250, Accuracy: 9912/10000 (99.12%)

EPOCH : 11


loss=0.013561318628489971 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.50it/s]

\Train set: Accuracy: 59156/60000 (98.5933%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0220, Accuracy: 9929/10000 (99.29%)

EPOCH : 12


loss=0.05983744561672211 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.82it/s]

\Train set: Accuracy: 59208/60000 (98.6800%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0249, Accuracy: 9919/10000 (99.19%)

EPOCH : 13


loss=0.06692109256982803 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 34.23it/s]

\Train set: Accuracy: 59211/60000 (98.6850%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0271, Accuracy: 9907/10000 (99.07%)

EPOCH : 14


loss=0.06915511935949326 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.57it/s]

\Train set: Accuracy: 59194/60000 (98.6567%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0217, Accuracy: 9932/10000 (99.32%)

EPOCH : 15


loss=0.03698914870619774 batch_id=468: 100%|██████████| 469/469 [00:14<00:00, 33.37it/s]

\Train set: Accuracy: 59268/60000 (98.7800%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0233, Accuracy: 9931/10000 (99.31%)

EPOCH : 16


loss=0.0626743882894516 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.55it/s]

\Train set: Accuracy: 59319/60000 (98.8650%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0205, Accuracy: 9924/10000 (99.24%)

EPOCH : 17


loss=0.02222670614719391 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.58it/s]

\Train set: Accuracy: 59269/60000 (98.7817%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0198, Accuracy: 9934/10000 (99.34%)

EPOCH : 18


loss=0.02927895076572895 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.70it/s]

\Train set: Accuracy: 59290/60000 (98.8167%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0214, Accuracy: 9921/10000 (99.21%)

EPOCH : 19


loss=0.013907006941735744 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.64it/s]

\Train set: Accuracy: 59284/60000 (98.8067%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0203, Accuracy: 9933/10000 (99.33%)

EPOCH : 20


loss=0.03578072041273117 batch_id=468: 100%|██████████| 469/469 [00:14<00:00, 33.35it/s]

\Train set: Accuracy: 59314/60000 (98.8567%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0213, Accuracy: 9933/10000 (99.33%)

EPOCH : 21


loss=0.060387711971998215 batch_id=468: 100%|██████████| 469/469 [00:14<00:00, 32.96it/s]

\Train set: Accuracy: 59316/60000 (98.8600%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0229, Accuracy: 9937/10000 (99.37%)

EPOCH : 22


loss=0.07935784012079239 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.95it/s]

\Train set: Accuracy: 59333/60000 (98.8883%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0187, Accuracy: 9934/10000 (99.34%)

EPOCH : 23


loss=0.01984652504324913 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.56it/s]

\Train set: Accuracy: 59355/60000 (98.9250%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0216, Accuracy: 9927/10000 (99.27%)

EPOCH : 24


loss=0.049010127782821655 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.76it/s]

\Train set: Accuracy: 59321/60000 (98.8683%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0205, Accuracy: 9938/10000 (99.38%)

EPOCH : 25


loss=0.01714135892689228 batch_id=468: 100%|██████████| 469/469 [00:14<00:00, 33.15it/s]

\Train set: Accuracy: 59326/60000 (98.8767%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0189, Accuracy: 9945/10000 (99.45%)

EPOCH : 26


loss=0.06805456429719925 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.82it/s]

\Train set: Accuracy: 59350/60000 (98.9167%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0236, Accuracy: 9927/10000 (99.27%)

EPOCH : 27


loss=0.05174213647842407 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.73it/s]


\Train set: Accuracy: 59311/60000 (98.8517%)



  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0199, Accuracy: 9933/10000 (99.33%)

EPOCH : 28


loss=0.011350960470736027 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.91it/s]

\Train set: Accuracy: 59334/60000 (98.8900%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0205, Accuracy: 9927/10000 (99.27%)

EPOCH : 29


loss=0.0111185722053051 batch_id=468: 100%|██████████| 469/469 [00:14<00:00, 33.09it/s]

\Train set: Accuracy: 59373/60000 (98.9550%)




  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 0.0175, Accuracy: 9939/10000 (99.39%)

EPOCH : 30


loss=0.02119218371808529 batch_id=468: 100%|██████████| 469/469 [00:13<00:00, 33.52it/s]

\Train set: Accuracy: 59415/60000 (99.0250%)







Test set: Average loss: 0.0181, Accuracy: 9937/10000 (99.37%)



#What I did
Made a model with 9.7k parameters

* Ran the model for 50 epochs, with a learning rate of 0.05, the training started with 87.19% of training accuracy. The model eventually reached the accuracy of > 99.40%, highest being 99.49% (At 44th epoch). This means that the model is capable of achieveing higher accuracy given the learning rate is increased.

* Increased the learning rate to 0.01, added image augumentation of 7% tilt and trained the model for 30 epochs. The model started the training with training accuracy of 89.86%. The model managed to achieve >99.40% accuracy on several epochs, started the trend at 19th epochs. The highest the model got the validation accuracy was 99.50% (30th epoch), with training accuracy of 98.78%.

* Increased the learning rate to 0.02, trained the model for 30 epochs. The model started with 91.6367% as the training accuracy. The model first achieved the > 99% validation accuracy at 4th epoch, and continued the upward trend till 8th epoch (This happened before also in the 0.01 learning rate). The model achieved the highest validation accuracy of 99.45% several times (first at 25th epoch)

* Increased the learning rate to 0.04, the model started with 91% of training accuracy and first achieved the validation accuracy of > 99.4% (99.44%) at the 15th epoch. It then continued to overshoot the accuracy multiple times, and come up with the accuracy of > 99.4% again and again. It managed to get the highest validation accuracy of 99.48%.

* Increased the learning rate to 0.08, the model started with 93.3367% of training accuracy. It constantly over-shot the parameters, and was never able achieve the required > 99.4% accuracy. The model continued the forward trend till the 6th epoch, for which it achieved the validation accuracy of 99.17%, and then continued the pattern of rising-falling, and reached highest validation accuracy of 99.25%