In [39]:
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import torch.nn as nn
from models import ConvNet
from helper_functions import train_convnet, evaluate_classifier, rotation_collate

In [40]:
NUM_BLOCKS=4

## Create Datasets

In [41]:
dataset = datasets.MNIST(root='./data', 
                         train=True, 
                         download=True, 
                         transform=transforms.ToTensor())

targets = np.array(dataset.targets)

# Select 10 indices for each class
indices = []
for digit in range(10):
    digit_indices = np.where(targets == digit)[0][:10]  # Take first 10 samples for each digit
    indices.extend(digit_indices)


supervised_dataset = Subset(dataset, indices)

In [42]:
test_dataset = datasets.MNIST(root='./data', 
                              train=False, 
                              download=True, transform=transforms.ToTensor())

## Train

In [43]:
supervised_train_loader = DataLoader(supervised_dataset, 
                                     batch_size=8,
                                     shuffle=True,
                                     num_workers=4,
                                     persistent_workers=True)
supervised_val_loader = DataLoader(test_dataset, 
                                   batch_size=128, 
                                   shuffle=False,
                                   num_workers=4,
                                   persistent_workers=True)

In [44]:
MNIST_model=ConvNet(num_classes=10,num_blocks=NUM_BLOCKS).cuda()

In [45]:
MNIST_model

ConvNet(
  (blocks): ModuleList(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), 

In [46]:
supervised_train_loader = DataLoader(supervised_dataset, 
                                     batch_size=8, 
                                     shuffle=True,
                                     num_workers=4,
                                     persistent_workers=True)
supervised_val_loader = DataLoader(test_dataset, 
                                   batch_size=256, 
                                   shuffle=False,
                                   num_workers=4,
                                   persistent_workers=True)
criterion=nn.CrossEntropyLoss()

optimizer=torch.optim.Adam(MNIST_model.parameters(),
                           lr=0.01,weight_decay=0.001)

learning_rate_scheduler=torch.optim.lr_scheduler.StepLR(optimizer,
                                                        step_size=50,gamma=0.5)

train_convnet(MNIST_model,
             supervised_train_loader,
             supervised_val_loader,
             criterion,optimizer,
             learning_rate_scheduler,
             num_epochs=40,
             filename=f'baseline_model_{NUM_BLOCKS}.pth')

0


Epoch 1/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 68.46it/s, loss=2.9440]
Epoch 1/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 111.93it/s]


Epoch [1/40] - 
Train Loss: 3.2733, Train Accuracy: 13.00%, 
Validation Loss: 2.3941, Validation Accuracy: 14.40%



Epoch 2/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 163.45it/s, loss=1.8164]
Epoch 2/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 220.61it/s]


Epoch [2/40] - 
Train Loss: 2.1032, Train Accuracy: 26.00%, 
Validation Loss: 2.3452, Validation Accuracy: 13.10%



Epoch 3/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 244.84it/s, loss=1.5572]
Epoch 3/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 161.08it/s]


Epoch [3/40] - 
Train Loss: 1.6829, Train Accuracy: 42.00%, 
Validation Loss: 2.3647, Validation Accuracy: 16.05%



Epoch 4/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 183.21it/s, loss=1.3101]
Epoch 4/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 188.83it/s]


Epoch [4/40] - 
Train Loss: 1.1209, Train Accuracy: 65.00%, 
Validation Loss: 1.8685, Validation Accuracy: 37.62%



Epoch 5/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 206.60it/s, loss=0.5782]
Epoch 5/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 151.52it/s]


Epoch [5/40] - 
Train Loss: 0.7507, Train Accuracy: 75.00%, 
Validation Loss: 1.4530, Validation Accuracy: 46.17%



Epoch 6/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 208.11it/s, loss=0.6209]
Epoch 6/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 206.42it/s]


Epoch [6/40] - 
Train Loss: 0.4736, Train Accuracy: 89.00%, 
Validation Loss: 1.7331, Validation Accuracy: 48.54%



Epoch 7/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 238.18it/s, loss=0.2055]
Epoch 7/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 195.30it/s]


Epoch [7/40] - 
Train Loss: 0.2367, Train Accuracy: 91.00%, 
Validation Loss: 1.2233, Validation Accuracy: 59.47%



Epoch 8/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 232.89it/s, loss=0.1035]
Epoch 8/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 210.06it/s]


Epoch [8/40] - 
Train Loss: 0.0988, Train Accuracy: 97.00%, 
Validation Loss: 1.1070, Validation Accuracy: 65.18%



Epoch 9/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 238.46it/s, loss=0.9428]
Epoch 9/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 184.91it/s]


Epoch [9/40] - 
Train Loss: 0.1386, Train Accuracy: 96.00%, 
Validation Loss: 0.7240, Validation Accuracy: 75.56%



Epoch 10/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 251.69it/s, loss=0.1070]
Epoch 10/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 177.29it/s]


Epoch [10/40] - 
Train Loss: 0.1270, Train Accuracy: 97.00%, 
Validation Loss: 0.5935, Validation Accuracy: 80.37%



Epoch 11/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 246.37it/s, loss=0.0124]
Epoch 11/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 180.67it/s]


Epoch [11/40] - 
Train Loss: 0.0717, Train Accuracy: 98.00%, 
Validation Loss: 0.5651, Validation Accuracy: 81.48%



Epoch 12/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 232.88it/s, loss=0.0127]
Epoch 12/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 178.55it/s]


Epoch [12/40] - 
Train Loss: 0.0550, Train Accuracy: 99.00%, 
Validation Loss: 0.8658, Validation Accuracy: 72.02%



Epoch 13/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 219.95it/s, loss=0.1893]
Epoch 13/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 169.71it/s]


Epoch [13/40] - 
Train Loss: 0.0414, Train Accuracy: 99.00%, 
Validation Loss: 0.5399, Validation Accuracy: 82.76%



Epoch 14/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 221.25it/s, loss=0.0752]
Epoch 14/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 161.91it/s]


Epoch [14/40] - 
Train Loss: 0.0955, Train Accuracy: 98.00%, 
Validation Loss: 0.7997, Validation Accuracy: 78.86%



Epoch 15/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 200.52it/s, loss=0.0320]
Epoch 15/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 171.57it/s]


Epoch [15/40] - 
Train Loss: 0.0760, Train Accuracy: 99.00%, 
Validation Loss: 1.3192, Validation Accuracy: 62.82%



Epoch 16/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 238.01it/s, loss=0.0029]
Epoch 16/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 186.58it/s]


Epoch [16/40] - 
Train Loss: 0.0887, Train Accuracy: 97.00%, 
Validation Loss: 1.5609, Validation Accuracy: 56.83%



Epoch 17/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 231.17it/s, loss=0.0212]
Epoch 17/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 180.02it/s]


Epoch [17/40] - 
Train Loss: 0.1781, Train Accuracy: 94.00%, 
Validation Loss: 0.6726, Validation Accuracy: 79.66%



Epoch 18/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 162.15it/s, loss=0.0727]
Epoch 18/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 215.10it/s]


Epoch [18/40] - 
Train Loss: 0.0912, Train Accuracy: 99.00%, 
Validation Loss: 1.2884, Validation Accuracy: 70.80%



Epoch 19/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 225.37it/s, loss=0.0226]
Epoch 19/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 191.42it/s]


Epoch [19/40] - 
Train Loss: 0.0418, Train Accuracy: 98.00%, 
Validation Loss: 0.7734, Validation Accuracy: 77.76%



Epoch 20/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 240.38it/s, loss=0.0114]
Epoch 20/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 200.73it/s]


Epoch [20/40] - 
Train Loss: 0.0219, Train Accuracy: 99.00%, 
Validation Loss: 0.9511, Validation Accuracy: 72.47%



Epoch 21/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 237.94it/s, loss=0.0348]
Epoch 21/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 197.81it/s]


Epoch [21/40] - 
Train Loss: 0.0154, Train Accuracy: 100.00%, 
Validation Loss: 0.6105, Validation Accuracy: 80.54%



Epoch 22/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 230.58it/s, loss=0.0007]
Epoch 22/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 175.43it/s]


Epoch [22/40] - 
Train Loss: 0.0036, Train Accuracy: 100.00%, 
Validation Loss: 0.5623, Validation Accuracy: 81.85%



Epoch 23/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 169.94it/s, loss=0.0025]
Epoch 23/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 188.42it/s]


Epoch [23/40] - 
Train Loss: 0.0022, Train Accuracy: 100.00%, 
Validation Loss: 0.5196, Validation Accuracy: 82.99%



Epoch 24/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 246.44it/s, loss=0.0010]
Epoch 24/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 186.72it/s]


Epoch [24/40] - 
Train Loss: 0.0010, Train Accuracy: 100.00%, 
Validation Loss: 0.4753, Validation Accuracy: 84.47%



Epoch 25/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 240.42it/s, loss=0.0027]
Epoch 25/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 194.88it/s]


Epoch [25/40] - 
Train Loss: 0.0020, Train Accuracy: 100.00%, 
Validation Loss: 0.5089, Validation Accuracy: 83.39%



Epoch 26/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 243.32it/s, loss=0.0005]
Epoch 26/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 188.98it/s]


Epoch [26/40] - 
Train Loss: 0.0007, Train Accuracy: 100.00%, 
Validation Loss: 0.4841, Validation Accuracy: 84.16%



Epoch 27/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 243.50it/s, loss=0.0006]
Epoch 27/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 143.83it/s]


Epoch [27/40] - 
Train Loss: 0.0010, Train Accuracy: 100.00%, 
Validation Loss: 0.4407, Validation Accuracy: 85.89%



Epoch 28/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 238.91it/s, loss=0.0025]
Epoch 28/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 168.49it/s]


Epoch [28/40] - 
Train Loss: 0.0008, Train Accuracy: 100.00%, 
Validation Loss: 0.4313, Validation Accuracy: 86.07%



Epoch 29/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 191.29it/s, loss=0.0006]
Epoch 29/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 169.18it/s]


Epoch [29/40] - 
Train Loss: 0.0005, Train Accuracy: 100.00%, 
Validation Loss: 0.4290, Validation Accuracy: 86.08%



Epoch 30/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 145.80it/s, loss=0.0018]
Epoch 30/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 175.59it/s]


Epoch [30/40] - 
Train Loss: 0.0008, Train Accuracy: 100.00%, 
Validation Loss: 0.4378, Validation Accuracy: 85.67%



Epoch 31/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 207.96it/s, loss=0.0013]
Epoch 31/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 200.69it/s]


Epoch [31/40] - 
Train Loss: 0.0005, Train Accuracy: 100.00%, 
Validation Loss: 0.4147, Validation Accuracy: 86.60%



Epoch 32/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 212.27it/s, loss=0.0238]
Epoch 32/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 135.68it/s]


Epoch [32/40] - 
Train Loss: 0.0128, Train Accuracy: 100.00%, 
Validation Loss: 1.8150, Validation Accuracy: 64.93%



Epoch 33/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 160.16it/s, loss=0.7749]
Epoch 33/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 185.29it/s]


Epoch [33/40] - 
Train Loss: 0.7024, Train Accuracy: 79.00%, 
Validation Loss: 18.2262, Validation Accuracy: 25.47%



Epoch 34/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 230.31it/s, loss=0.4342]
Epoch 34/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 174.58it/s]


Epoch [34/40] - 
Train Loss: 0.6792, Train Accuracy: 75.00%, 
Validation Loss: 2.6696, Validation Accuracy: 64.21%



Epoch 35/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 167.38it/s, loss=0.3435]
Epoch 35/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 172.51it/s]


Epoch [35/40] - 
Train Loss: 0.4414, Train Accuracy: 84.00%, 
Validation Loss: 1.5903, Validation Accuracy: 55.62%



Epoch 36/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 194.66it/s, loss=0.1277]
Epoch 36/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 178.36it/s]


Epoch [36/40] - 
Train Loss: 0.2406, Train Accuracy: 91.00%, 
Validation Loss: 0.6257, Validation Accuracy: 80.15%



Epoch 37/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 186.64it/s, loss=0.3997]
Epoch 37/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 215.77it/s]


Epoch [37/40] - 
Train Loss: 0.0746, Train Accuracy: 100.00%, 
Validation Loss: 0.6144, Validation Accuracy: 81.04%



Epoch 38/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 210.98it/s, loss=0.0994]
Epoch 38/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 174.42it/s]


Epoch [38/40] - 
Train Loss: 0.1508, Train Accuracy: 96.00%, 
Validation Loss: 0.8512, Validation Accuracy: 75.05%



Epoch 39/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 237.09it/s, loss=0.0189]
Epoch 39/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 183.26it/s]


Epoch [39/40] - 
Train Loss: 0.1508, Train Accuracy: 98.00%, 
Validation Loss: 0.5629, Validation Accuracy: 82.88%



Epoch 40/40 [Training]: 100%|██████████| 13/13 [00:00<00:00, 144.17it/s, loss=0.0059]
Epoch 40/40 [Validation]: 100%|██████████| 40/40 [00:00<00:00, 196.68it/s]

Epoch [40/40] - 
Train Loss: 0.0243, Train Accuracy: 100.00%, 
Validation Loss: 0.5064, Validation Accuracy: 85.38%






## Evaluate

In [None]:
NUM_BLOCKS=4
MNIST_model=ConvNet(num_classes=10,num_blocks=NUM_BLOCKS).cuda()
MNIST_model.load_state_dict(torch.load(f'./models/baseline_model_{NUM_BLOCKS}.pth'))

  MNIST_model.load_state_dict(torch.load(f'./models/baseline_model_{NUM_BLOCKS}.pth'))


<All keys matched successfully>

In [54]:
evaluate_classifier(MNIST_model,supervised_val_loader)

Test Accuracy: 85.21%
