### In this Notebook I develop a model on MNIST dataset using PyTorch Lightning and a CNN

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from torchmetrics import Accuracy

import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms

import matplotlib.pyplot as plt

import pytorch_lightning as pl

from pytorch_lightning.loggers import WandbLogger

In [2]:
# globals
BATCH_SIZE = 64

# wandb_logger = WandbLogger()

In [3]:
# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())

mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=BATCH_SIZE, num_workers=4)
val_loader = DataLoader(mnist_val, batch_size=BATCH_SIZE, num_workers=4)

In [4]:
# have a look at inputs

for i, batch in enumerate(train_loader):
    inputs, targets = batch
    
    print(inputs.shape)
    
    if i == 0:
        break

torch.Size([64, 1, 28, 28])


In [5]:
# now let's define a simple CNN network to train on MNIST

N_INPUT = 28*28*1
N_CLASSES = 10

class CNNNet(pl.LightningModule):
    def __init__(self):
        super(CNNNet, self).__init__()
        self.features = nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2))
        
        self.avgpool = nn.AdaptiveAvgPool2d((3,3))
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*3*3, 128),
            nn.Linear(128, N_CLASSES))
        
        # memorizzo anche la loss function
        self.loss_fn = torch.nn.CrossEntropyLoss()
        
        self.metric = Accuracy()
    
    # this one is only used for inference
    # better to keep it independent from training_step
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        
        return x
    
    def single_batch(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        
        return x
    
    def training_step(self, train_batch, batch_idx):
        inputs, targets = train_batch
    
        outputs = self.single_batch(inputs)
        
        loss = self.loss_fn(outputs, targets)
        
        acc = self.metric(F.softmax(outputs, dim=1), targets)
        
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        
        return {"loss":loss, "acc":acc}
    
    def validation_step(self, val_batch, batch_idx):
        inputs, targets = val_batch
        
        outputs = self.single_batch(inputs)
        
        loss = self.loss_fn(outputs, targets)
        
        val_acc = self.metric(F.softmax(outputs, dim=1), targets)
        
        self.log('val_loss', loss)
        
        # in questo modo loggo sulla progress bar
        self.log('val_acc', val_acc, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        
        return optimizer

In [6]:
model = CNNNet()

model

CNNNet(
  (features): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(3, 3))
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=576, out_features=128, bias=True)
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
  (loss_fn): CrossEntropyLoss()
  (metric): Accuracy()
)

In [7]:
for i, batch in enumerate(train_loader):
    inputs, targets = batch
    
    print('Input shape:', inputs.shape)
    outputs = model(inputs)
    print('Output shape:', outputs.shape)
    
    if i == 0:
        break

Input shape: torch.Size([64, 1, 28, 28])
Output shape: torch.Size([64, 10])


In [8]:
trainer = pl.Trainer(max_epochs=10, auto_lr_find=True) # , logger=wandb_logger)

trainer.fit(model, train_loader, val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name       | Type              | Params
-------------------------------------------------
0 | features   | Sequential        | 18.8 K
1 | avgpool    | AdaptiveAvgPool2d | 0     
2 | classifier | Sequential        | 75.1 K
3 | loss_fn    | CrossEntropyLoss  | 0     
4 | metric     | Accuracy          | 0     
-------------------------------------------------
94.0 K    Trainable params
0         Non-trainable params
94.0 K    Total params
0.376     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



In [None]:
trainer.validate(val_dataloaders=val_loader)