Pytorch Lightning
* It forces a standar structure on your code - easier to reason about your code (and someone else's)
* It gets rid of a lot of boilerplate
* It abstracts away a lot of extra stuff besides the core code - logging and parallelization on GPU or even TPU. Paralleization can be done by just adding some flags to your Trainer rather than refactoring your code

Dataset
Build a model
Define loss_func and optimizer
Define trainer
Define test
Run trainer and test

In [9]:
import torch
from torch import nn

from torch.nn import functional as F
from torch import optim

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import pytorch_lightning as pl

In [10]:
train_ds = MNIST(root='data', train = True, download=True, transform=ToTensor())
valid_ds = MNIST(root='data', train = False, download=True, transform=ToTensor())

bs = 64
train_dl = DataLoader(train_ds,batch_size=bs,shuffle=True,num_workers=15)
valid_dl = DataLoader(valid_ds,batch_size=bs,shuffle=False,num_workers=15)


In [11]:


#build a model
from typing import Any
import torchmetrics

from pytorch_lightning.utilities.types import STEP_OUTPUT, OptimizerLRScheduler


class MNISTModel(pl.LightningModule):
    def __init__(self, lr = 0.5):
        super().__init__()
        self.lr = lr
        self.lin = nn.Linear(784,10)
        # self.lin2 = nn.Linear(64,32)
        # self.lin3 = nn.Linear(32,10)
        #metrics
        self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
        self.valid_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
        


    def forward (self, xb):
        xb = xb.flatten(1,-1) # (bs,1,28,28) -> bs, 784
        return self.lin(xb)
    

    def shared_step(self, batch, train):
        xb, yb = batch
        pred = self.forward(xb)
        loss = F.cross_entropy(pred, yb)

        #logginh
        if (train):
            self.train_accuracy(pred.softmax(dim=-1),yb)
            self.log('train_accuracy',self.train_accuracy,on_step=True,on_epoch=False, prog_bar=True)
        else:
            self.valid_accuracy(pred.softmax(dim=-1),yb)
            self.log('valid_accuracy',self.valid_accuracy,on_step=True,on_epoch=False, prog_bar=True)

        return loss
    
    def training_step(self, batch,batch_idx):
        return  self.shared_step(batch,train=True)

    def validation_step(self, batch, batch_idx):
        return  self.shared_step(batch,train=False)
    
    def configure_optimizers(self) :
        return optim.SGD(self.parameters(), lr=self.lr)

In [12]:
from pytorch_lightning.loggers import TensorBoardLogger

# >tensorboard --logdir tb_logs

tb_logger = TensorBoardLogger('tb_logs', name="my_model")
# init model
mnist_model = MNISTModel()

#init trainer
trainer = pl.Trainer(max_epochs=2,logger=tb_logger)

#training the model
trainer.fit(mnist_model,train_dl, valid_dl)

#optinally - run test
#trainer.test()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name           | Type               | Params
------------------------------------------------------
0 | lin            | Linear             | 7.9 K 
1 | train_accuracy | MulticlassAccuracy | 0     
2 | valid_accuracy | MulticlassAccuracy | 0     
------------------------------------------------------
7.9 K     Trainable params
0         Non-trainable params
7.9 K     Total params
0.031     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\dac\miniconda3\envs\deep_learning\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


                                                                            

c:\Users\dac\miniconda3\envs\deep_learning\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 1: 100%|██████████| 938/938 [00:23<00:00, 39.34it/s, v_num=2, train_accuracy=0.938, valid_accuracy=1.000]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 938/938 [00:23<00:00, 39.33it/s, v_num=2, train_accuracy=0.938, valid_accuracy=1.000]
