In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter 
import torchvision
from torchvision import datasets, transforms

import numpy as np

from collections import OrderedDict,namedtuple
from itertools import product

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cuda


In [3]:
# Download data for FashionMNIST

# Training
train_dataset = datasets.MNIST(
    "./data",
    download = True,
    train = True,
    transform = transforms.Compose([transforms.ToTensor()])
)

# Sanity Testing
# Check the distribution of labels in training set
print(train_dataset.targets.bincount())

# Each batch has [images, labels]. the count of images/labels is the same as batch size
train_batch_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size = 50,
    shuffle=True
)


# Test
test_dataset = datasets.MNIST(
    "./data",
    download=True,
    train=False,
    transform = transforms.Compose([transforms.ToTensor()])
)

# Sanity Testing
# Check the distribution of labels in test set
print(test_dataset.targets.bincount())

test_batch_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = 50
)


# Training Dataset size
print("Training Dataset size", len(train_batch_dataloader.dataset))

# Testing Dataset size
print("Testing Dataset size", len(test_batch_dataloader.dataset))


tensor([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949])
tensor([ 980, 1135, 1032, 1010,  982,  892,  958, 1028,  974, 1009])
Training Dataset size 60000
Testing Dataset size 10000


In [8]:
print(len(train_dataset))
train_batch = next(iter(train_batch_dataloader))
print(train_batch[0].shape)

print(len(train_batch_dataloader.dataset), len(test_batch_dataloader.dataset))

60000
torch.Size([50, 1, 28, 28])
60000 10000


In [5]:
class MNISTCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 6, kernel_size = 5)
        self.conv2 = nn.Conv2d(in_channels = 6, out_channels = 12, kernel_size = 5)
        self.linear_size = None
        X = torch.rand(28,28).reshape(-1, 1, 28,28)
        self.convs(X)
        
        self.fc1 = nn.Linear(in_features = self.linear_size, out_features = 120)
        self.fc2 = nn.Linear(in_features = 120, out_features = 60)
        self.out = nn.Linear(in_features = 60, out_features = 10)
        
    def convs(self, X):
        X = F.relu(self.conv1(X))
        X = F.max_pool2d(X,  kernel_size = 2, stride = 2)
        
        X = F.relu(self.conv2(X))
        X = F.max_pool2d(X,  kernel_size = 2, stride = 2)
        
        X =  X.flatten(1,-1)
        
        if self.linear_size is None:
            self.linear_size = X.shape[1]
        
        
        return X
    
    def forward(self,X):
        X = self.convs(X)
        
        X = F.relu(self.fc1(X))
        X = F.relu(self.fc2(X))
        
        X = self.out(X)
        
        return X.to(device)

network = MNISTCNN()
network = network.to(device)
print(network)
print(network.linear_size)
        
        
        
        

MNISTCNN(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=192, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=60, bias=True)
  (out): Linear(in_features=60, out_features=10, bias=True)
)
192


In [6]:
class RunBuilder():
    
    @staticmethod
    def get_runs(params):
        # params is an ordered dict
        keys = params.keys()
        
        values =params.values()
        
        Run = namedtuple('Run', keys)
        
        Runs = []
        
        for param_value_combination_tuple in product(*values):
            run = Run(*param_value_combination_tuple)
            Runs.append(run)
        
        return Runs


In [9]:
# Run Manager to manage Run/Epoch level operations and logging to TensorBoard
class RunManager():
    def __init__(self):
        self.run =  None
        self.network = None
        self.tb = None
        
        self.train_batch_dataloader = None
        self.test_batch_dataloader = None
        
        self.epoch_id = 0
        self.correct_predictions = 0
        self.total_loss=0.0
        
    def run_start(self, run, network, train_batch_dataloader, test_batch_dataloader):
        self.run = run
        self.tb = SummaryWriter(comment=f"MNIST-{run}")
        self.network = network
        
        self.train_batch_dataloader = train_batch_dataloader
        self.test_batch_dataloader = test_batch_dataloader
        
        self.epoch_id = 0
    
    def run_end(self):
        self.tb.close()
    
    def epoch_start(self):
        self.correct_predictions = 0
        self.total_loss=0.0
        
    
    def epoch_end(self):
        accuracy = 100*(self.correct_predictions/len(self.train_batch_dataloader.dataset))
        
        self.tb.add_scalar("Training Accuracy", accuracy,self.epoch_id)
        self.tb.add_scalar("Training Loss", self.total_loss, self.epoch_id)
        
        for name, param in self.network.named_parameters():
            self.tb.add_histogram(f"{name}.grad", param.grad, self.epoch_id)
        
        self.total_loss=0.0
        self.epoch_id+=1
        
    # This will be called per batch
    def track_correct_predictions(self, predictions, actual):
        correct_predictions = self.get_correct_predictions(predictions, actual)
        self.correct_predictions+= correct_predictions
    
    # This will be called per batch
    def track_total_loss(self,loss):
        self.total_loss+=loss
    
    def get_correct_predictions(self, predictions, actual):
        predictions = predictions.argmax(dim=1)
        # actual = actual.argmax(dim=1)
        
        matches = predictions.eq(actual).sum().item()
        
        return matches
    
    def record_validation_stats(self, accuracy, loss, epoch):
        self.tb.add_scalar("Validation Accuracy: ", accuracy, epoch)
        self.tb.add_scalar("Validation Loss: ", loss, epoch)
        
rm = RunManager()


In [10]:
params = OrderedDict(
    lr = [0.001,0.01],
    batch_size = [100, 1000]
)

runs = RunBuilder.get_runs(params)
rm = RunManager()
EPOCHS = 3

for run in runs:
    optimizer = optim.Adam(network.parameters(), lr =run.lr)
    train_batch_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size = run.batch_size
    )
    test_batch_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size = run.batch_size
    )    
    rm.run_start(run, network, train_batch_loader, test_batch_loader)
    for epoch in range(EPOCHS):
        rm.epoch_start()
        
        for train_batch in train_batch_loader:
            train_images, train_labels = train_batch
            
            train_images = train_images.reshape(run.batch_size, 1, 28,28)
            train_images = train_images.to(device)
            
            train_labels = train_labels.to(device)
            
            network.zero_grad()
            
            predicted = network(train_images)
            loss = F.cross_entropy(predicted, train_labels)
            
            rm.track_total_loss(loss.item())
            rm.track_correct_predictions(predicted, train_labels)
            
            loss.backward()
            optimizer.step()
        
        rm.epoch_end()
    
    with torch.no_grad():
        correct_predictions = 0
        total_loss = 0.0
        for test_batch in test_batch_loader:
            test_images, test_labels = test_batch
            test_images = test_images.to(device)
            test_labels = test_labels.to(device)
            
            predicted = network(test_images)
            
            correct_predictions += rm.get_correct_predictions(predicted, test_labels)
            
            loss = F.cross_entropy(predicted, test_labels).item()
            total_loss += loss
            
        test_accuracy = 100*(correct_predictions/len(test_batch_loader.dataset))
        print("Test Accuracy ", test_accuracy)
        rm.record_validation_stats(test_accuracy, total_loss, epoch)
    rm.run_end()
            

Test Accuracy  97.71
Test Accuracy  98.53
Test Accuracy  97.61
Test Accuracy  98.63
