<a href="https://colab.research.google.com/github/ecutolo/notebooks/blob/master/DLGD_PL_resnet_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision sklearn 



# Adjusting ResNet for MNIST in PyTorch

For detailed description, go to: https://zablo.net/blog/post/using-resnet-for-mnist-in-pytorch-tutorial

In [None]:
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.datasets import MNIST
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time
from torch import nn, optim
import torch
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from torch.utils.data import DataLoader

In [None]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10)
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)


In [None]:
def get_data_loaders(train_batch_size, val_batch_size):
    mnist = MNIST(download=True, train=True, root=".").train_data.float()
    
    data_transform = Compose([ Resize((224, 224)),ToTensor(), Normalize((mnist.mean()/255,), (mnist.std()/255,))])
    train_ds = MNIST(download=True, root=".", transform=data_transform, train=True)

    val_ds = MNIST(download=True, root=".", transform=data_transform, train=False)

    train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=val_batch_size, shuffle=False)
    return train_loader, val_loader

In [None]:
def calculate_metric(metric_fn, true_y, pred_y):
    if "average" in inspect.getfullargspec(metric_fn).args:
        return metric_fn(true_y, pred_y, average="macro")
    else:
        return metric_fn(true_y, pred_y)
    
def print_scores(p, r, f1, a, batch_size):
    for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
        print(f"\t{name.rjust(14, ' ')}: {sum(scores)/batch_size:.4f}")

In [None]:
start_ts = time.time()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 1

model = MnistResNet().to(device)
train_loader, val_loader = get_data_loaders(256, 256)

losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adadelta(model.parameters())

batches = len(train_loader)
val_batches = len(val_loader)

# training loop + eval loop
for epoch in range(epochs):
    total_loss = 0
    progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
    model.train()
    
    for i, data in progress:
        X, y = data[0].to(device), data[1].to(device)
        
        model.zero_grad()
        outputs = model(X)
        loss = loss_function(outputs, y)

        loss.backward()
        optimizer.step()
        current_loss = loss.item()
        total_loss += current_loss
        progress.set_description("Loss: {:.4f}".format(total_loss/(i+1)))
        
    torch.cuda.empty_cache()
    
    val_losses = 0
    precision, recall, f1, accuracy = [], [], [], []
    
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_loader):
            X, y = data[0].to(device), data[1].to(device)
            outputs = model(X)
            val_losses += loss_function(outputs, y)

            predicted_classes = torch.max(outputs, 1)[1]
            
            for acc, metric in zip((precision, recall, f1, accuracy), 
                                   (precision_score, recall_score, f1_score, accuracy_score)):
                acc.append(
                    calculate_metric(metric, y.cpu(), predicted_classes.cpu())
                )
        
        
        
    print(f"Epoch {epoch+1}/{epochs}, training loss: {total_loss/batches}, validation loss: {val_losses/val_batches}")
    print_scores(precision, recall, f1, accuracy, val_batches)
    losses.append(total_loss/batches)
print(losses)
print(f"Training time: {time.time()-start_ts}s")

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw





Loss:   0%|          | 0/235 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
!pip install pytorch-lightning


Collecting pytorch-lightning
  Downloading pytorch_lightning-1.5.1-py3-none-any.whl (1.0 MB)
[K     |████████████████████████████████| 1.0 MB 4.2 MB/s 
[?25hCollecting PyYAML>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 42.9 MB/s 
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 34.9 MB/s 
Collecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.6.0-py3-none-any.whl (329 kB)
[K     |████████████████████████████████| 329 kB 44.3 MB/s 
[?25hCollecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2021.11.0-py3-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 46.6 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylin

### TODO 1: implement lightning module

In [None]:
import pytorch_lightning as pl


class MnistLitModel(pl.LightningModule):
  # Code goes here
  ...

trainer = pl.Trainer(max_epochs=1, gpus=1)

# training
model = MnistLitModel()
trainer.fit(model=model, train_dataloaders=train_loader)

# testing
trainer.test(model=model, test_dataloaders=val_loader)


In [None]:
import pytorch_lightning as pl


class MnistLitModel(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.mod = MnistResNet()
    self.loss_function = nn.CrossEntropyLoss()

  def training_step(self, batch, batch_idx):
    X, y = batch
    outputs = self.mod(X)
    return self.loss_function(outputs, y)

  def configure_optimizers(self):
    return optim.Adadelta(model.parameters())

  def test_step(self, batch, batch_idx):
    X, y = batch
    outputs = self.mod(X)
    loss = self.loss_function(outputs, y)
    y, outputs = y.cpu(), outputs.cpu()
    predicted_classes = torch.max(outputs, 1)[1]
    self.log_dict({
        "test_precision": precision_score(y, predicted_classes, average='macro'),
        "test_recall": recall_score(y, predicted_classes, average='macro'),
        "test_F1": f1_score(y, predicted_classes, average='macro'),
        "test_acc": accuracy_score(y, predicted_classes),
    }, on_epoch=True)
    return loss

trainer = pl.Trainer(max_epochs=1, gpus=1)

# training
model = MnistLitModel()
trainer.fit(model=model, train_dataloaders=train_loader)

# testing
trainer.test(model=model, test_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | mod           | MnistResNet      | 11.2 M
1 | loss_function | CrossEntropyLoss | 0     
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.701    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

### TODO 2: implement lightning data module

In [None]:
class MnistDataModule(pl.LightningDataModule):
  ...
  # Code goes here


# training
model = MnistLitModel()
dm = MnistDataModule()
trainer.fit(model=model, datamodule=dm)