In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import torch.nn as nn
from models import ConvNet
from helper_functions import train_convnet, evaluate_classifier, rotation_collate

2024-11-29 16:17:32.962586: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-29 16:17:32.972961: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732893452.986145 2719517 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732893452.989918 2719517 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-29 16:17:33.004132: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
NUM_BLOCKS=4

## Create Datasets

In [3]:
dataset = datasets.MNIST(root='./data', 
                         train=True, 
                         download=True, 
                         transform=transforms.ToTensor())

targets = np.array(dataset.targets)

# Select 10 indices for each class
indices = []
for digit in range(10):
    digit_indices = np.where(targets == digit)[0][:10]  # Take first 10 samples for each digit
    indices.extend(digit_indices)


supervised_dataset = Subset(dataset, indices)

In [4]:
test_dataset = datasets.MNIST(root='./data', 
                              train=False, 
                              download=True, transform=transforms.ToTensor())

## Train

In [5]:
supervised_train_loader = DataLoader(supervised_dataset, 
                                     batch_size=8,
                                     shuffle=True,
                                     num_workers=4,
                                     persistent_workers=True)
supervised_val_loader = DataLoader(test_dataset, 
                                   batch_size=128, 
                                   shuffle=False,
                                   num_workers=4,
                                   persistent_workers=True)

baseline_model=ConvNet(num_classes=10).cuda()
criterion=nn.CrossEntropyLoss()

optimizer=torch.optim.Adam(baseline_model.parameters(),lr=0.0001,weight_decay=0.001)
learning_rate_scheduler=torch.optim.lr_scheduler.StepLR(optimizer,step_size=500,gamma=0.5)


In [6]:
MNIST_model=ConvNet(num_classes=10,num_blocks=NUM_BLOCKS).cuda()

In [7]:
supervised_train_loader = DataLoader(supervised_dataset, batch_size=8, shuffle=True,num_workers=4,persistent_workers=True)
supervised_val_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,num_workers=4,persistent_workers=True)
optimizer=torch.optim.Adam(MNIST_model.parameters(),lr=0.001,weight_decay=0.001)

train_convnet(MNIST_model,
             supervised_train_loader,
             supervised_val_loader,
             criterion,optimizer,
             learning_rate_scheduler,
             num_epochs=10,
             filename=f'mnist_model_{NUM_BLOCKS}.pth')

Epoch 1/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 30.75it/s, loss=1.1419]
Epoch 1/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 652.95it/s]


Epoch [1/10] - Train Loss: 1.9692, Train Accuracy: 35.00%, Validation Loss: 2.4703, Validation Accuracy: 11.35%


Epoch 2/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 178.89it/s, loss=0.7990]
Epoch 2/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 687.93it/s]


Epoch [2/10] - Train Loss: 0.8115, Train Accuracy: 86.00%, Validation Loss: 3.0947, Validation Accuracy: 11.35%


Epoch 3/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 221.33it/s, loss=0.3390]
Epoch 3/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 690.20it/s]


Epoch [3/10] - Train Loss: 0.2658, Train Accuracy: 97.00%, Validation Loss: 1.4041, Validation Accuracy: 50.26%


Epoch 4/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 224.53it/s, loss=0.0727]
Epoch 4/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 663.68it/s]


Epoch [4/10] - Train Loss: 0.0839, Train Accuracy: 99.00%, Validation Loss: 0.7519, Validation Accuracy: 75.99%


Epoch 5/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 213.83it/s, loss=0.0450]
Epoch 5/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 713.84it/s]


Epoch [5/10] - Train Loss: 0.0301, Train Accuracy: 100.00%, Validation Loss: 0.6175, Validation Accuracy: 80.13%


Epoch 6/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 207.53it/s, loss=0.0399]
Epoch 6/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 713.44it/s]


Epoch [6/10] - Train Loss: 0.0159, Train Accuracy: 100.00%, Validation Loss: 0.5036, Validation Accuracy: 84.29%


Epoch 7/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 91.44it/s, loss=0.0076]
Epoch 7/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 718.24it/s]


Epoch [7/10] - Train Loss: 0.0091, Train Accuracy: 100.00%, Validation Loss: 0.4927, Validation Accuracy: 84.07%


Epoch 8/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 214.90it/s, loss=0.0100]
Epoch 8/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 719.78it/s]


Epoch [8/10] - Train Loss: 0.0073, Train Accuracy: 100.00%, Validation Loss: 0.4889, Validation Accuracy: 84.24%


Epoch 9/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 196.63it/s, loss=0.0045]
Epoch 9/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 685.91it/s]


Epoch [9/10] - Train Loss: 0.0062, Train Accuracy: 100.00%, Validation Loss: 0.4957, Validation Accuracy: 83.73%


Epoch 10/10 [Training]: 100%|██████████| 13/13 [00:00<00:00, 221.15it/s, loss=0.0072]
Epoch 10/10 [Validation]: 100%|██████████| 1250/1250 [00:01<00:00, 709.39it/s]

Epoch [10/10] - Train Loss: 0.0044, Train Accuracy: 100.00%, Validation Loss: 0.4782, Validation Accuracy: 84.16%





## Evaluate

In [8]:
evaluate_classifier(MNIST_model,supervised_val_loader)

Test Accuracy: 84.16%
