### In this Notebook I develop a Notebook on MNIST dataset using PyTorch Lightning

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from torchmetrics import Accuracy

import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms

import matplotlib.pyplot as plt

import pytorch_lightning as pl

from pytorch_lightning.loggers import WandbLogger

In [2]:
# globals
BATCH_SIZE = 64

wandb_logger = WandbLogger()

In [3]:
# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())

mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=BATCH_SIZE, num_workers=4)
val_loader = DataLoader(mnist_val, batch_size=BATCH_SIZE, num_workers=4)

In [4]:
# have a look at inputs

for i, batch in enumerate(train_loader):
    inputs, targets = batch
    
    print(inputs.shape)
    
    if i == 0:
        break

torch.Size([64, 1, 28, 28])


In [5]:
# now let's define a simple network to train on MNIST
N_INPUT = 28*28*1
N_CLASSES = 10

class SimpleNet(pl.LightningModule):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(N_INPUT, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50,N_CLASSES)
        
        # memorizzo anche la loss function
        self.loss_fn = torch.nn.CrossEntropyLoss()
        
        self.metric = Accuracy()
    
    # this one is only used for inference
    # better to keep it independent from training_step
    def forward(self, x):
        x = x.view(-1, N_INPUT)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # network produces Logits
        x = self.fc3(x)
        
        return x
    
    def single_batch(self, x):
        x = x.view(-1, N_INPUT)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # network produces Logits
        x = self.fc3(x)
        
        return x
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        inputs, targets = train_batch
    
        outputs = self.single_batch(inputs)
        
        loss = self.loss_fn(outputs, targets)
        
        acc = self.metric(F.softmax(outputs, dim=1), targets)
        
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        
        return {"loss":loss, "acc":acc}
    
    def validation_step(self, val_batch, batch_idx):
        inputs, targets = val_batch
        
        outputs = self.single_batch(inputs)
        
        loss = self.loss_fn(outputs, targets)
        
        val_acc = self.metric(F.softmax(outputs, dim=1), targets)
        
        self.log('val_loss', loss)
        self.log('val_acc', val_acc)

In [6]:
model = SimpleNet()

In [7]:
trainer = pl.Trainer(max_epochs=10, auto_lr_find=True, logger=wandb_logger)

trainer.fit(model, train_loader, val_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
[34m[1mwandb[0m: Currently logged in as: [33mlsaetta[0m (use `wandb login --relogin` to force relogin)



  | Name    | Type             | Params
---------------------------------------------
0 | fc1     | Linear           | 65.9 K
1 | fc2     | Linear           | 4.2 K 
2 | fc3     | Linear           | 510   
3 | loss_fn | CrossEntropyLoss | 0     
4 | metric  | Accuracy         | 0     
---------------------------------------------
70.7 K    Trainable params
0         Non-trainable params
70.7 K    Total params
0.283     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [8]:
trainer.validate(val_dataloaders=val_loader)

Validating: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.9733999967575073, 'val_loss': 0.10329815745353699}
--------------------------------------------------------------------------------


[{'val_loss': 0.10329815745353699, 'val_acc': 0.9733999967575073}]