In [1]:
import pytorch_lightning as pl
import torch 
import torch.nn as nn

from torchmetrics import Accuracy

torch.set_float32_matmul_precision('medium')


In [2]:
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms


class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_path='./MNIST/') -> None:
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def prepare_data(self):
        MNIST(self.data_path, download=True)

    def setup(self, stage: str):

        mnist_all = MNIST(
            root=self.data_path,
            train=True,
            transform=self.transform,
            download=False
        )

        self.train, self.val = random_split(
            mnist_all, [55000, 5000], 
            generator=torch.Generator().manual_seed(1)
        )

        self.test = MNIST(
            root=self.data_path,
            train=False,
            transform=self.transform,
            download=False
        )
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64, num_workers=4)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size=64, num_workers=4)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size=64, num_workers=4)


In [3]:
from typing import Any


class CNNNetwork(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.valid_acc   = Accuracy(task="multiclass", num_classes=10)
        self.test_acc  = Accuracy(task="multiclass", num_classes=10)

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32,
                                kernel_size=5, padding=2)
        self.a1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64,
                                kernel_size=5, padding=2)
        self.a2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.fl = nn.Flatten()
        self.fc3 = nn.Linear(3136, 1024)
        self.a3 = nn.ReLU()
        self.dp3 = nn.Dropout(p=0.5)
        self.fc4 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.a1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.a2(x)
        x = self.pool2(x)
        x = self.fl(x)
        x = self.fc3(x)
        x = self.a3(x)
        x = self.dp3(x)
        x = self.fc4(x)
        return x
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)
        self.log("train loss: ",loss, prog_bar=True)
        return loss
    
    def on_train_epoch_end(self):
        self.log("train acc :", self.train_acc.compute())

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        self.log("valid_loss", loss, prog_bar=True)
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(self(x), y)
        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 0.001)
        return optimizer



        

In [4]:
torch.manual_seed(1)
mnist_dm = MnistDataModule()
mnistClassifier = CNNNetwork()

# trainer = pl.Trainer(max_epochs=10, accelerator="auto")

trainer = pl.Trainer(max_epochs=20, accelerator="auto", enable_checkpointing=True)

trainer.fit(model=mnistClassifier, datamodule=mnist_dm)
            # ,ckpt_path='/home/tej/Documents/Courses/Learning/ML_With_PyTorch_Scikit_Practice/Chapter13/lightning_logs/version_0/checkpoints/epoch=49-step=43000.ckpt')

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/tej/Documents/Courses/Learning/ML_With_PyTorch_Scikit_Practice/Chapter14/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name      | Type               | Params
--------------------------------------------------
0  | train_acc | MulticlassAccuracy | 0     
1  | valid_acc | MulticlassAccuracy | 0     
2  | test_acc  | MulticlassAccuracy | 0     
3  | conv1     | Conv2d             | 832   
4  | a1        | ReLU               | 0     
5  | pool1     | MaxPool2d          | 0     
6  | conv2     | Conv2d             | 51.3 K
7  | a2        | ReLU               | 0     
8  | pool2     | MaxPool2d          | 0     
9  | fl        | Flatten            | 0     
10 | fc3       | Linear             | 3.2 M 
11 | a3        | ReLU               | 0     
12 | dp3       | Dropout            | 0     
13 | fc

Epoch 19: 100%|██████████| 860/860 [00:03<00:00, 254.68it/s, v_num=0, train loss: =0.0677, valid_loss=0.0497, valid_acc=0.989]  

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 860/860 [00:03<00:00, 250.39it/s, v_num=0, train loss: =0.0677, valid_loss=0.0497, valid_acc=0.989]


In [5]:
# %load_ext tensorboard
# %tensorboard --logdir ./Chapter13/lightning_logs/

In [6]:
mnistClassifier

CNNNetwork(
  (train_acc): MulticlassAccuracy()
  (valid_acc): MulticlassAccuracy()
  (test_acc): MulticlassAccuracy()
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (a1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (a2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fl): Flatten(start_dim=1, end_dim=-1)
  (fc3): Linear(in_features=3136, out_features=1024, bias=True)
  (a3): ReLU()
  (dp3): Dropout(p=0.5, inplace=False)
  (fc4): Linear(in_features=1024, out_features=10, bias=True)
)