In [1]:
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import transforms
from session_6 import Model3, train, create_mnist_data_loaders


In [2]:
_ = torch.manual_seed(1)

In [3]:
model = Model3()
model.summarize()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 28, 28]              80
              ReLU-2            [-1, 8, 28, 28]               0
       BatchNorm2d-3            [-1, 8, 28, 28]              16
            Conv2d-4            [-1, 8, 28, 28]             584
              ReLU-5            [-1, 8, 28, 28]               0
       BatchNorm2d-6            [-1, 8, 28, 28]              16
         MaxPool2d-7            [-1, 8, 14, 14]               0
           Dropout-8            [-1, 8, 14, 14]               0
            Conv2d-9           [-1, 12, 14, 14]             876
             ReLU-10           [-1, 12, 14, 14]               0
      BatchNorm2d-11           [-1, 12, 14, 14]              24
           Conv2d-12           [-1, 12, 14, 14]           1,308
             ReLU-13           [-1, 12, 14, 14]               0
      BatchNorm2d-14           [-1, 12,

In [4]:
epochs = 15
batch_size=128
loss_fn = F.nll_loss
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
data_path = "../data"
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [5]:
train_loader, test_loader = create_mnist_data_loaders(batch_size, data_path, train_transform=transform, test_transform=transform)

In [6]:
train(model, epochs, train_loader, test_loader, loss_fn, optimizer, scheduler=scheduler)

epoch=01 loss=0.2709 batch_id=0468 accuracy=90.57%: 100%|██████████| 469/469 [00:05<00:00, 86.87it/s]


Test set: Average loss: 0.0682, Accuracy: 9814/10000 (98.14%)



epoch=02 loss=0.1380 batch_id=0468 accuracy=94.18%: 100%|██████████| 469/469 [00:04<00:00, 96.93it/s] 


Test set: Average loss: 0.0492, Accuracy: 9853/10000 (98.53%)



epoch=03 loss=0.1142 batch_id=0468 accuracy=94.63%: 100%|██████████| 469/469 [00:05<00:00, 91.29it/s]


Test set: Average loss: 0.0402, Accuracy: 9882/10000 (98.82%)



epoch=04 loss=0.1273 batch_id=0468 accuracy=94.82%: 100%|██████████| 469/469 [00:04<00:00, 97.37it/s] 


Test set: Average loss: 0.0358, Accuracy: 9899/10000 (98.99%)



epoch=05 loss=0.1636 batch_id=0468 accuracy=94.83%: 100%|██████████| 469/469 [00:04<00:00, 95.04it/s] 


Test set: Average loss: 0.0285, Accuracy: 9913/10000 (99.13%)



epoch=06 loss=0.0826 batch_id=0468 accuracy=95.24%: 100%|██████████| 469/469 [00:05<00:00, 93.10it/s] 


Test set: Average loss: 0.0299, Accuracy: 9894/10000 (98.94%)



epoch=07 loss=0.0976 batch_id=0468 accuracy=95.25%: 100%|██████████| 469/469 [00:04<00:00, 95.84it/s] 


Test set: Average loss: 0.0242, Accuracy: 9922/10000 (99.22%)



epoch=08 loss=0.0523 batch_id=0468 accuracy=95.25%: 100%|██████████| 469/469 [00:04<00:00, 94.89it/s]


Test set: Average loss: 0.0319, Accuracy: 9905/10000 (99.05%)



epoch=09 loss=0.0249 batch_id=0468 accuracy=95.16%: 100%|██████████| 469/469 [00:04<00:00, 94.00it/s]


Test set: Average loss: 0.0217, Accuracy: 9928/10000 (99.28%)



epoch=10 loss=0.2502 batch_id=0468 accuracy=95.31%: 100%|██████████| 469/469 [00:04<00:00, 95.41it/s] 


Test set: Average loss: 0.0245, Accuracy: 9924/10000 (99.24%)



epoch=11 loss=0.0588 batch_id=0468 accuracy=95.62%: 100%|██████████| 469/469 [00:05<00:00, 91.83it/s]


Test set: Average loss: 0.0180, Accuracy: 9942/10000 (99.42%)



epoch=12 loss=0.1221 batch_id=0468 accuracy=95.77%: 100%|██████████| 469/469 [00:05<00:00, 92.45it/s]


Test set: Average loss: 0.0181, Accuracy: 9942/10000 (99.42%)



epoch=13 loss=0.1471 batch_id=0468 accuracy=95.76%: 100%|██████████| 469/469 [00:04<00:00, 94.23it/s] 


Test set: Average loss: 0.0182, Accuracy: 9941/10000 (99.41%)



epoch=14 loss=0.1047 batch_id=0468 accuracy=95.83%: 100%|██████████| 469/469 [00:04<00:00, 94.71it/s] 


Test set: Average loss: 0.0180, Accuracy: 9942/10000 (99.42%)



epoch=15 loss=0.0804 batch_id=0468 accuracy=95.79%: 100%|██████████| 469/469 [00:04<00:00, 94.41it/s] 


Test set: Average loss: 0.0193, Accuracy: 9939/10000 (99.39%)



**Target**

Target was to take the previous model, and improve and stablize the accuracies using lr schedulers.

**Result**

The base model was taken from [`Model2`](../src/session_6/model_2.py). No other changes were made to the model. During training, StepLR was applied with step size as 10 and gamma as 0.1

Parameter Count: 6174

Train Accuracy: 95.79%

Test Accuracy: 99.39%

**Analysis**

StepLR helped! Although test accuracy of epoch 15 is 99.39%, test accuracies of epoch 11 to 14 were always above 99.4%. The model can do better if pushed more. Since the data is MNIST, some random rotation transform can be applied. Only a little should be applied, too much would cause the model to underfit again.