# Adjusting ResNet for MNIST in PyTorch

For detailed description, go to: https://zablo.net/blog/post/using-resnet-for-mnist-in-pytorch-tutorial

In [1]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time
from torch import nn, optim
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader



In [15]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)


In [21]:
def get_data_loaders(train_batch_size, val_batch_size):
    mnist = MNIST(download=False, train=True, root=".").train_data.float()
    
    data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])

    train_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
                              batch_size=train_batch_size, shuffle=True)

    val_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
                            batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

In [22]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [23]:
start_ts = time.time()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 5

model = MnistResNet().to(device)
train_loader, val_loader = get_data_loaders(256, 256)

losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())

batches = len(train_loader)
val_batches = len(val_loader)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
        
        
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
print(losses)
print(f"Training time: {time.time()-start_ts}s")

HBox(children=(IntProgress(value=0, description='Loss: ', max=235, style=ProgressStyle(description_width='init…

Epoch 1/5, training loss: 1.6594452599261669, validation loss: 1.739317536354065
	     precision: 0.8446
	        recall: 0.7342
	            F1: 0.7219
	      accuracy: 0.7344


HBox(children=(IntProgress(value=0, description='Loss: ', max=235, style=ProgressStyle(description_width='init…

Epoch 2/5, training loss: 1.4798276926608795, validation loss: 1.4951493740081787
	     precision: 0.9747
	        recall: 0.9708
	            F1: 0.9709
	      accuracy: 0.9716


HBox(children=(IntProgress(value=0, description='Loss: ', max=235, style=ProgressStyle(description_width='init…

Epoch 3/5, training loss: 1.4725456608102676, validation loss: 1.4825210571289062
	     precision: 0.9819
	        recall: 0.9803
	            F1: 0.9802
	      accuracy: 0.9810


HBox(children=(IntProgress(value=0, description='Loss: ', max=235, style=ProgressStyle(description_width='init…

Epoch 4/5, training loss: 1.4697661815805638, validation loss: 1.5628594160079956
	     precision: 0.9341
	        recall: 0.8986
	            F1: 0.8892
	      accuracy: 0.8989


HBox(children=(IntProgress(value=0, description='Loss: ', max=235, style=ProgressStyle(description_width='init…

Epoch 5/5, training loss: 1.467640884379123, validation loss: 1.4743891954421997
	     precision: 0.9881
	        recall: 0.9883
	            F1: 0.9878
	      accuracy: 0.9881
[1.6594452599261669, 1.4798276926608795, 1.4725456608102676, 1.4697661815805638, 1.467640884379123]
Training time: 2109.7608716487885s
