# MNIST Classification Using PyTorch Lightning

In [1]:
# Built-in library
from pathlib import Path
import re
import json
from typing import Any, Optional, Sequence, TypeAlias, Union
import logging
import warnings

# Standard imports
import numpy as np
import numpy.typing as npt
from pprint import pprint
import pandas as pd
import polars as pl
from rich.console import Console
from rich.theme import Theme

custom_theme = Theme(
    {
        "info": "#76FF7B",
        "warning": "#FBDDFE",
        "error": "#FF0000",
    }
)
console = Console(theme=custom_theme)

# Visualization
import matplotlib.pyplot as plt


# Pandas settings
pd.options.display.max_rows = 1_000
pd.options.display.max_columns = 1_000
pd.options.display.max_colwidth = 600

warnings.filterwarnings("ignore")


# Black code formatter (Optional)
%load_ext lab_black

# auto reload imports
%load_ext autoreload
%autoreload 2

In [2]:
def go_up_from_current_directory(go_up: int = 1) -> None:
    """This is used to up a number of directories.

    Args:
    -----
    go_up: int, default=1
        This indicates the number of times to go back up from the current directory.

    Returns:
    --------
    None
    """
    import os
    import sys

    CONST: str = "../"
    NUM: str = CONST * go_up

    # Goto the previous directory
    prev_directory = os.path.join(os.path.dirname(__name__), NUM)
    # Get the 'absolute path' of the previous directory
    abs_path_prev_directory = os.path.abspath(prev_directory)

    # Add the path to the System paths
    sys.path.insert(0, abs_path_prev_directory)
    print(abs_path_prev_directory)

In [3]:
from watermark import watermark
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm

print(watermark(packages="polars,scikit-learn,torch,lightning", python=True))
print("Torch CUDA available?: ", torch.cuda.is_available())

Python implementation: CPython
Python version       : 3.11.8
IPython version      : 8.22.2

polars      : 0.20.18
scikit-learn: 1.4.1.post1
torch       : 2.2.2
lightning   : 2.2.1

Torch CUDA available?:  False


## Setup DataLoader

In [4]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms
import lightning as L


def get_dataset_loaders(
    fp: str = "../../data/mnist",
) -> tuple[DataLoader, DataLoader, DataLoader]:
    """This function returns train, validation and test data loaders."""

    seed: int = 42
    batch_size: int = 64

    train_dataset = datasets.MNIST(
        root=fp, train=True, transform=transforms.ToTensor(), download=True
    )

    test_dataset = datasets.MNIST(root=fp, train=False, transform=transforms.ToTensor())

    train_dataset, val_dataset = random_split(
        train_dataset,
        lengths=[55000, 5000],
        generator=torch.Generator().manual_seed(seed),
    )

    train_loader = DataLoader(
        dataset=train_dataset,
        num_workers=0,
        batch_size=batch_size,
        shuffle=True,
    )

    val_loader = DataLoader(
        dataset=val_dataset,
        num_workers=0,
        batch_size=batch_size,
        shuffle=False,
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        num_workers=0,
        batch_size=batch_size,
        shuffle=False,
    )

    return train_loader, val_loader, test_loader

In [5]:
class PyTorchMLP(nn.Module):
    def __init__(self, num_features: int, num_classes: int) -> None:
        """Multi-layer perceptron (MLP) with two hidden layers."""
        super().__init__()

        self.all_layers = nn.Sequential(
            # 1st hidden layer
            nn.Linear(num_features, 50),
            nn.ReLU(),
            # 2nd hidden layer
            nn.Linear(50, 25),
            nn.ReLU(),
            # Output layer
            nn.Linear(25, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.flatten(x, start_dim=1)
        logits: torch.Tensor = self.all_layers(x)
        return logits

In [6]:
Model: TypeAlias = nn.Module


class LightningModel(L.LightningModule):
    """
    A PyTorch Lightning module that wraps a PyTorch model, providing training and validation steps, as well
    as configuring the optimizer.

    The `LightningModel` class takes a PyTorch model and a learning rate as input, and provides the
    following functionality:

        - The `forward` method simply passes the input through the PyTorch model.
        - The `training_step` method computes the loss using cross-entropy loss and logs the training loss.
        - The `validation_step` method computes the loss using cross-entropy loss and logs the validation loss.
        - The `configure_optimizers` method sets up an Adam optimizer with the provided learning rate.

    This class can be used as a building block for more complex PyTorch Lightning models, allowing you to easily integrate your PyTorch models into a PyTorch Lightning pipeline.
    """

    def __init__(self, model: Model, learning_rate: float = 0.001) -> None:
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def training_step(self, batch: int, batch_idx: int) -> torch.Tensor:
        # Fetch the data
        features, true_labels = batch
        # Forward prop
        logits = self(features)
        # Compute loss
        loss = F.cross_entropy(logits, true_labels)
        self.log("train_loss", loss)
        return loss  # this is passed to the optimizer for training

    def validation_step(self, batch: int, batch_idx: int) -> None:
        # Fetch the data
        features, true_labels = batch
        # Forward prop
        logits = self(features)
        # Compute loss
        loss = F.cross_entropy(logits, true_labels)
        self.log("val_loss", loss, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [7]:
go_up_from_current_directory(go_up=2)


from src.utilities import compute_accuracy

/Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch


In [8]:
train_loader, val_loader, test_loader = get_dataset_loaders()

pytorch_model: Model = PyTorchMLP(num_features=784, num_classes=10)
lightning_model = LightningModel(model=pytorch_model, learning_rate=0.05)

trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",  # set to "auto" or "gpu" to use GPUs if available
    devices="auto",  # Uses all available GPUs if applicable
)

trainer.fit(
    model=lightning_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

train_acc: float = compute_accuracy(pytorch_model, train_loader)
val_acc: float = compute_accuracy(pytorch_model, val_loader)
test_acc: float = compute_accuracy(pytorch_model, test_loader)
print(
    f"Train Acc {train_acc*100:.2f}%"
    f" | Val Acc {val_acc*100:.2f}%"
    f" | Test Acc {test_acc*100:.2f}%"
)


PATH: str = "./lightning.pt"
torch.save(pytorch_model.state_dict(), PATH)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | PyTorchMLP | 40.8 K
-------------------------------------
40.8 K    Trainable params
0         Non-trainable params
40.8 K    Total params
0.163     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


Train Acc 88.63% | Val Acc 87.58% | Test Acc 87.62%


## Add Metrics using Torch Metrics

In [13]:
import torchmetrics


class LightningModel(L.LightningModule):
    """
    A PyTorch Lightning module that wraps a PyTorch model, providing training and validation steps, as well
    as configuring the optimizer.

    The `LightningModel` class takes a PyTorch model and a learning rate as input, and provides the
    following functionality:

        - The `forward` method simply passes the input through the PyTorch model.
        - The `training_step` method computes the loss using cross-entropy loss and logs the training loss.
        - The `validation_step` method computes the loss using cross-entropy loss and logs the validation loss.
        - The `configure_optimizers` method sets up an Adam optimizer with the provided learning rate.

    This class can be used as a building block for more complex PyTorch Lightning models, allowing you to easily integrate your PyTorch models into a PyTorch Lightning pipeline.
    """

    def __init__(self, model: Model, learning_rate: float = 0.001) -> None:
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model

        # === NEW! ===
        # Set up attributes for computing accuracy
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def training_step(self, batch: int, batch_idx: int) -> torch.Tensor:
        # Fetch the data
        features, true_labels = batch
        # Forward prop
        logits = self(features)
        # Compute loss
        loss = F.cross_entropy(logits, true_labels)
        self.log("train_loss", loss)

        # === NEW! ===
        # Compute training accuracy batch by batch and log
        # it after each epoch. This is much faster!
        predicted_labels = torch.argmax(logits, dim=1)
        self.train_acc(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch: int, batch_idx: int) -> None:
        # Fetch the data
        features, true_labels = batch
        # Forward prop
        logits = self(features)
        # Compute loss
        loss = F.cross_entropy(logits, true_labels)
        self.log("val_loss", loss, prog_bar=True)

        # === NEW! ===
        # Compute validation accuracy batch by batch and log
        # it after each epoch. This is much faster!
        predicted_labels = torch.argmax(logits, dim=1)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [10]:
train_loader, val_loader, test_loader = get_dataset_loaders()

pytorch_model: Model = PyTorchMLP(num_features=784, num_classes=10)
lightning_model = LightningModel(model=pytorch_model, learning_rate=0.05)

trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",  # set to "auto" or "gpu" to use GPUs if available
    devices="auto",  # Uses all available GPUs if applicable
)

trainer.fit(
    model=lightning_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

PATH: str = "./lightning.pt"
torch.save(pytorch_model.state_dict(), PATH)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | PyTorchMLP         | 40.8 K
1 | train_acc | MulticlassAccuracy | 0     
2 | val_acc   | MulticlassAccuracy | 0     
-------------------------------------------------
40.8 K    Trainable params
0         Non-trainable params
40.8 K    Total params
0.163     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


### Add A TestSet

In [15]:
class LightningModel(L.LightningModule):
    """
    A PyTorch Lightning module that wraps a PyTorch model, providing training and validation steps, as well
    as configuring the optimizer.

    The `LightningModel` class takes a PyTorch model and a learning rate as input, and provides the
    following functionality:

        - The `forward` method simply passes the input through the PyTorch model.
        - The `training_step` method computes the loss using cross-entropy loss and logs the training loss.
        - The `validation_step` method computes the loss using cross-entropy loss and logs the validation loss.
        - The `configure_optimizers` method sets up an Adam optimizer with the provided learning rate.

    This class can be used as a building block for more complex PyTorch Lightning models, allowing you to easily integrate your PyTorch models into a PyTorch Lightning pipeline.
    """

    def __init__(self, model: Model, learning_rate: float = 0.001) -> None:
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model

        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)

        # === NEW! ===
        # Set up attributes for computing test accuracy
        self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def _shared_step(self, batch: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Fetch the data
        features, true_labels = batch
        # Forward prop
        logits = self(features)
        # Compute loss
        loss = F.cross_entropy(logits, true_labels)
        predicted_labels = torch.argmax(logits, dim=1)
        return loss, true_labels, predicted_labels

    def training_step(self, batch: int, batch_idx: int) -> torch.Tensor:
        # === NEW! ===
        # Use the shared step function
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss)
        self.train_acc(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch: int, batch_idx: int) -> None:
        # === NEW! ===
        # Use the shared step function
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("val_loss", loss, prog_bar=True)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar=True)

    # === NEW! ===
    def test_step(self, batch: int, batch_idx: int) -> None:
        _, true_labels, predicted_labels = self._shared_step(batch)

        self.test_acc(predicted_labels, true_labels)
        self.log("accuracy", self.test_acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [12]:
train_loader, val_loader, test_loader = get_dataset_loaders()

pytorch_model: Model = PyTorchMLP(num_features=784, num_classes=10)
lightning_model = LightningModel(model=pytorch_model, learning_rate=0.05)

trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",  # set to "auto" or "gpu" to use GPUs if available
    devices="auto",  # Uses all available GPUs if applicable
)

trainer.fit(
    model=lightning_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

# === NEW ===
# Evaluate the model using the test step
train_acc = trainer.test(dataloaders=train_loader, verbose=True)[0]["accuracy"]
val_acc = trainer.test(dataloaders=val_loader, verbose=True)[0]["accuracy"]
test_acc = trainer.test(dataloaders=test_loader, verbose=True)[0]["accuracy"]

# === NEW ===
print(
    f"Train Acc {train_acc*100:.2f}%"
    f" | Val Acc {val_acc*100:.2f}%"
    f" | Test Acc {test_acc*100:.2f}%"
)


PATH: str = "./lightning.pt"
torch.save(pytorch_model.state_dict(), PATH)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | PyTorchMLP         | 40.8 K
1 | train_acc | MulticlassAccuracy | 0     
2 | val_acc   | MulticlassAccuracy | 0     
3 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
40.8 K    Trainable params
0         Non-trainable params
40.8 K    Total params
0.163     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_7/checkpoints/epoch=9-step=8600.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_7/checkpoints/epoch=9-step=8600.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_7/checkpoints/epoch=9-step=8600.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_7/checkpoints/epoch=9-step=8600.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_7/checkpoints/epoch=9-step=8600.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_7/checkpoints/epoch=9-step=8600.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Train Acc 83.90% | Val Acc 83.46% | Test Acc 83.48%


<br>

## Load Pytorch Model

```py
# To load model:
model = PyTorchMLP(num_features=784, num_classes=10)
model.load_state_dict(torch.load(PATH))
model.eval()
```

### Add Deterministic Behaviour To Lightning

In [16]:
train_loader, val_loader, test_loader = get_dataset_loaders()

# === NEW ===
torch.manual_seed(42)
pytorch_model: Model = PyTorchMLP(num_features=784, num_classes=10)
lightning_model = LightningModel(model=pytorch_model, learning_rate=0.05)

trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",  # set to "auto" or "gpu" to use GPUs if available
    devices="auto",  # Uses all available GPUs if applicable
    deterministic=True,  # === NEW ===
)

trainer.fit(
    model=lightning_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)


train_acc = trainer.test(dataloaders=train_loader, verbose=True)[0]["accuracy"]
val_acc = trainer.test(dataloaders=val_loader, verbose=True)[0]["accuracy"]
test_acc = trainer.test(dataloaders=test_loader, verbose=True)[0]["accuracy"]


print(
    f"Train Acc {train_acc*100:.2f}%"
    f" | Val Acc {val_acc*100:.2f}%"
    f" | Test Acc {test_acc*100:.2f}%"
)


PATH: str = "./lightning.pt"
torch.save(pytorch_model.state_dict(), PATH)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | PyTorchMLP         | 40.8 K
1 | train_acc | MulticlassAccuracy | 0     
2 | val_acc   | MulticlassAccuracy | 0     
3 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
40.8 K    Trainable params
0         Non-trainable params
40.8 K    Total params
0.163     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_9/checkpoints/epoch=9-step=8600.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_9/checkpoints/epoch=9-step=8600.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_9/checkpoints/epoch=9-step=8600.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_9/checkpoints/epoch=9-step=8600.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_9/checkpoints/epoch=9-step=8600.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_9/checkpoints/epoch=9-step=8600.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Train Acc 88.33% | Val Acc 87.24% | Test Acc 88.11%


<hr>

## Organizing Code With DataLoaders

In [19]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(
        self, data_dir: str = "../../data/mnist", batch_size: int = 64, seed: int = 42
    ) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.seed = seed

    def prepare_data(self):
        """Prepare the dataset by downloading the training and test sets from the internet."""
        datasets.MNIST(self.data_dir, train=True, download=True)
        datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str) -> None:
        """Define the setup method which is responsible for loading and splitting the dataset."""
        self.mnist_test = datasets.MNIST(
            self.data_dir, transform=transforms.ToTensor(), train=False
        )
        self.mnist_predict = datasets.MNIST(
            self.data_dir, transform=transforms.ToTensor(), train=False
        )
        mnist_full = datasets.MNIST(
            self.data_dir, transform=transforms.ToTensor(), train=True
        )
        self.mnist_train, self.mnist_val = random_split(
            mnist_full,
            [55_000, 5_000],
            generator=torch.Generator().manual_seed(self.seed),
        )

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train, batch_size=self.batch_size, shuffle=True, drop_last=True
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, shuffle=False)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size, shuffle=False)

In [20]:
torch.manual_seed(42)

# === NEW ===
data_module = MNISTDataModule()

pytorch_model: Model = PyTorchMLP(num_features=784, num_classes=10)
lightning_model = LightningModel(model=pytorch_model, learning_rate=0.05)

trainer = L.Trainer(
    max_epochs=10,
    accelerator="auto",  # set to "auto" or "gpu" to use GPUs if available
    devices="auto",  # Uses all available GPUs if applicable
    deterministic=True,  # === NEW ===
)

# === NEW ===
trainer.fit(model=lightning_model, datamodule=data_module)


train_acc = trainer.test(dataloaders=train_loader, verbose=True)[0]["accuracy"]
val_acc = trainer.test(dataloaders=val_loader, verbose=True)[0]["accuracy"]
test_acc = trainer.test(dataloaders=test_loader, verbose=True)[0]["accuracy"]


print(
    f"Train Acc {train_acc*100:.2f}%"
    f" | Val Acc {val_acc*100:.2f}%"
    f" | Test Acc {test_acc*100:.2f}%"
)


PATH: str = "./lightning.pt"
torch.save(pytorch_model.state_dict(), PATH)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | PyTorchMLP         | 40.8 K
1 | train_acc | MulticlassAccuracy | 0     
2 | val_acc   | MulticlassAccuracy | 0     
3 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
40.8 K    Trainable params
0         Non-trainable params
40.8 K    Total params
0.163     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_11/checkpoints/epoch=9-step=8590.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_11/checkpoints/epoch=9-step=8590.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_11/checkpoints/epoch=9-step=8590.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_11/checkpoints/epoch=9-step=8590.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Restoring states from the checkpoint path at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_11/checkpoints/epoch=9-step=8590.ckpt
Loaded model weights from the checkpoint at /Users/neidu/Desktop/Projects/Personal/My_Projects/Deep-Learning-With-Pytorch/notebook/PyTorchLightning/lightning_logs/version_11/checkpoints/epoch=9-step=8590.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

Train Acc 91.05% | Val Acc 90.68% | Test Acc 91.01%
