# PyTorch Lightning Computer Vision

This notebook presents an overview of PyTorch Lightning and its key features.

## <a id="sec_1">1. Introduction</a>


When building deep learning models with PyTorch, it's easy to end up with complex, cluttered code that mixes model logic, training loops, logging, and checkpointing.
This can make it harder to experiment, debug, and scale.
PyTorch Lightning offers a clean, structured way to organize PyTorch code by abstracting away the boilerplate.
It enables researchers and engineers to focus on what really matters—the model and the research logic—while taking care of engineering concerns like training loops, GPU usage, mixed precision, and distributed training.
By using PyTorch Lightning, you get more readable code, faster iteration, and easier scaling without giving up the flexibility of native PyTorch.

In this tutorial we will explore PyTorch Lightning API by reproducing the work done at tutorial `03_pytorch_computer_vision.ipynb`.

### 1.1 PyTorch Lightning

PyTorch Lightning is a high-level framework built on top of PyTorch that simplifies the process of building and training deep learning models.
It introduces a structured, modular design that encourages clean, readable code by separating the science (model architecture and forward pass) from the engineering (training loops, logging, checkpointing, and device management).
With PyTorch Lightning, developers can write less boilerplate and focus more on the core aspects of model development, such as experimentation and tuning.

One of the key advantages of PyTorch Lightning is its scalability and flexibility.
Whether you're training on a single CPU, multiple GPUs, or even using distributed clusters, Lightning handles the complexity under the hood with minimal code changes.
It also integrates seamlessly with popular logging and experiment tracking tools like TensorBoard.



### 1.2 Importing PyTorch Lightning

To import the Lightning module, you must first install it. If you already installed the modules in the `requirements.txt` file, Lightning is already installed.
If not, you may install Lightning using `pip` by running the following command on your terminal:

```bash
pip install lightning
```

We recommend installing all the packages in `requirements.txt`. To do so, you may run the following command on your terminal:

```bash
pip install -r requirements.txt
```

Now we can import the lightning module.

In [1]:
try:
    import lightning as L
except:
    try:
        #Try to install it and import again
        print("[INFO]: Could not import the lightning module. Trying to install it!")
        !pip install lightning
        import lightning as L
    except:
        raise Exception("[ERROR] Couldn't find the lightning module ... \n" +
                        "Please, install it before running the notebook.\n"+
                        "You might want to install the modules listed at requirements.txt\n" +
                        "To do so, run: \"pip install -r requirements.txt\"")

### 1.3 What we're going to cover

In this tutorial, we'll explore how to use Lightning to build and train a model for a computer vision task.
Specifically, we'll reproduce the experiment from the `03_pytorch_computer_vision.ipynb` tutorial using the TinyVGG model.


| **Topic** | **Contents** |
| ----- | ----- |
| [**2. Basic setup**](#sec_2) | Import useful modules (torch, torchvision, and lightning). |
| [**3. Getting the dataset**](#sec_3) | Download the dataset and split the train set into train and validation subsets. |
| [**4. Setting up the DataLoader**](#sec_4) | Create the train, validation, and test data loaders. |
| [**5. Building a PyTorch model**](#sec_5) | Build a TinyVGG model in PyTorch - same as in tutorial `03_pytorch_computer_vision.ipynb`. |
| [**6. Building the Lightning model**](#sec_6) | Build a Lightning model that encapsulates the TinyVGG PyTorch model. |
| [**7. Trainining the Lightning model**](#sec_7) | Train the Lightning model. |
| [**8. Monitoring performance on validation set and testing the model**](#sec_8) | Modify the Lightning model to track validation statistics during training and to report test statistics during test.. |
| [**9. Exercises**](#sec_9) | Suggested Exercises. |


### 1.4 Where can you get help?

In addition to discussing with your colleagues or the course professor, you might also consider researching or posting PyTorch related question on the [PyTorch developer forums](https://discuss.pytorch.org/) and Lightning related question on the [PyTorch Lightning forum](https://lightning.ai/forums/).

And of course, there's the [PyTorch documentation](https://pytorch.org/docs/stable/index.html) and [Lightning documentation](https://lightning.ai/docs/overview/getting-started) sites.

## <a id="sec_2">2. Basic setup</a>

Let's import the basic modules, such as lightning, torch, and other utility modules.

In [2]:
# Import PyTorch
import torch

# Import torchvision
import torchvision

# Import lightning
import lightning as L

# Check versions
# Note: your PyTorch version shouldn't be lower than 1.10.0 and torchvision version shouldn't be lower than 0.11
print(f"PyTorch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
print(f"Lightning version: {L.__version__}")

PyTorch version: 2.6.0
torchvision version: 0.21.0
Lightning version: 2.5.0.post0


This time, let's skip selecting the target device and just focus on defining our set_seed() function. We'll use PyTorch Lightning's `seed_everything()` utility to handle seed setting in a clean and consistent way.

In [3]:
# Auxiliary function to set both the CPU and GPU seeds
def set_seed(seed):
    L.seed_everything(seed)

## <a id='sec_3'>3. Getting the dataset</a>

As in the previous tutorial, we will use `torchvision` to get the FashionMNIST.

In [4]:
# Setup training data
train_dataset = torchvision.datasets.FashionMNIST(
    root="data", # where to download data to?
    train=True, # get training data
    download=True, # download data if it doesn't exist on disk
    transform=torchvision.transforms.ToTensor(), # images come as PIL format, we want to turn into Torch tensors
    target_transform=None # you can transform labels as well
)

# Setup training data
test_dataset = torchvision.datasets.FashionMNIST(
    root="data", # where to download data to?
    train=False, # get training data
    download=True, # download data if it doesn't exist on disk
    transform=torchvision.transforms.ToTensor(), # images come as PIL format, we want to turn into Torch tensors
    target_transform=None # you can transform labels as well
)

class_names = train_dataset.classes

Alright. Now, let's split the train_dataset into train and validation sets.

In [5]:
from torch.utils.data import random_split

set_seed(42)

# Define split sizes (e.g., 80% training, 20% validation)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size

# Split the dataset
train_data, val_data = random_split(train_dataset, [train_size, val_size])

Seed set to 42


## <a id='sec_4'>4. Setting Up the DataLoader</a>

Now that we have the train, validation, and test splits, let's build the data loaders.

In [6]:
DL_BATCH_SIZE = 32
DL_NUM_WORKERS = 15 # Number of workers to help loading the data and
                    # apply the data transformations.
                    # Change this according to the number of cores in your machine.
                    # Mine has 16, hence I'll use 15 to keep one free.

train_dataloader = torch.utils.data.DataLoader(train_data,
                                           batch_size=DL_BATCH_SIZE,
                                           num_workers=DL_NUM_WORKERS,
                                           drop_last=False,
                                           shuffle=True)

val_dataloader = torch.utils.data.DataLoader(val_data,
                                          batch_size=DL_BATCH_SIZE,
                                          num_workers=DL_NUM_WORKERS,
                                          shuffle=False)

test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=DL_BATCH_SIZE,
                                          num_workers=DL_NUM_WORKERS,
                                          shuffle=False)

# Let's check out what we've created
print(f"Dataloaders: {train_dataloader, val_dataloader, test_dataloader}")
print(f"Length of train dataloader: {len(train_dataloader)} batches of {DL_BATCH_SIZE} samples")
print(f"Length of validation dataloader: {len(val_dataloader)} batches of {DL_BATCH_SIZE} samples")
print(f"Length of test dataloader: {len(test_dataloader)} batches of {DL_BATCH_SIZE} samples")

Dataloaders: (<torch.utils.data.dataloader.DataLoader object at 0x303b40090>, <torch.utils.data.dataloader.DataLoader object at 0x305df1490>, <torch.utils.data.dataloader.DataLoader object at 0x1023485d0>)
Length of train dataloader: 1500 batches of 32 samples
Length of validation dataloader: 375 batches of 32 samples
Length of test dataloader: 313 batches of 32 samples




Alright. So far, everything looks pretty similar to previous tutorials—and that’s by design.
PyTorch Lightning works seamlessly with standard PyTorch datasets and dataloaders, building directly on top of these familiar abstractions.
> Note: Lightning also contain the DataModule class, which is designed to organize the dataloaders. In this tutorial we will not use this abstraction.

## <a id='sec_5'>5. Building a PyTorch model</a>

Let's build the same CNN model as in tutorial `03_pytorch_computer_vision.ipynb`.
The CNN model we're going to be using is known as TinyVGG from the [CNN Explainer](https://poloclub.github.io/cnn-explainer/) website.

Again, it follows the typical structure of a convolutional neural network:

`Input layer -> [Convolutional layer -> activation layer -> pooling layer] -> Output layer`

Where the contents of `[Convolutional layer -> activation layer -> pooling layer]` can be upscaled and repeated multiple times, depending on requirements.

In the previous tutorial we named it `FashionMNISTModelV2`, however, here, we will name it `TinyVGG`.

In [7]:
import torch.nn as nn

# Create a convolutional neural network
class TinyVGG(nn.Module):
    """
    Model architecture copying TinyVGG from:
    https://poloclub.github.io/cnn-explainer/
    """
    def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
        super().__init__()
        self.block_1 = nn.Sequential(
            nn.Conv2d(input_shape, hidden_units, 3, 1, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_units, hidden_units, 3, 1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.block_2 = nn.Sequential(
            nn.Conv2d(hidden_units, hidden_units, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_units, hidden_units, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=hidden_units*7*7, out_features=output_shape)
        )

    def forward(self, x: torch.Tensor):
        x = self.block_1(x)
        x = self.block_2(x)
        x = self.classifier(x)
        return x

Done!

Once again, not much different from the previous tutorials.

## <a id="sec_6">6. Building the Lightning model</a>

In a **PyTorch Model**, you typically define just the model architecture in a class that inherits from `nn.Module`.
Training logic, validation, testing, and other utilities are implemented separately, often in a script or training loop, as illustrated in previous tutorials.

The **Lightning Model** includes not only the model architecture (forward), but also the training, validation, and test steps, as well as optimizer configuration—all neatly encapsulated in a single class.
This additional information is used by a Trainer object (more on this later), which takes care of training, validating, and testing the model.

Let’s create a simple Lightning model that encapsulates our TinyVGG architecture, and also defines the optimizer to be used during training, along with the logic for performing each training step.


In [8]:
# The Lightning model extends the L.LightningModule, instead of torch.nn.Module
class TinyVGG_LightningModel_V1(L.LightningModule):
    # You might define the init method as you like
    def __init__(self, input_shape: int, hidden_units: int, num_classes: int, learning_rate: float = 0.1):
        super().__init__()
        # We’re creating the model architecture directly in the __init__ method,
        # but it’s also common practice to receive the model as an argument instead.
        # This approach adds flexibility, allowing you to easily swap in different
        # architectures without modifying the Lightning module itself.
        self.model = TinyVGG(input_shape=input_shape,
                             hidden_units=hidden_units,
                             output_shape=num_classes)
        # Define the loss function
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    # The forward method defines how input data passes through the model.
    # It's the same as in standard torch.nn.Module classes.
    def forward(self, x):
        return self.model(x)

    # This method is called by the training loop. It allows you to define how each
    # training step must be performed. In this case, we will compute the loss, log it
    # using the self.log() method and then return the loss.
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)   # This calls self.forward(x) under the hood
        loss = self.loss_fn(logits, y)
        self.log("train_loss", loss)
        return loss

    # This method will be invoked by the trainer object to configure the optimizer.
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.learning_rate)


As shown in the previous code block, a Lightning model class extends the `lightning.LightningModule class`, rather than the `nn.Module` class used in standard PyTorch.

Like typical Python classes, it includes a constructor method (`__init__`), which is called when the object is instantiated. 
It also implements a `forward()` method, which mirrors the behavior of PyTorch's `forward()` method: it receives a sample input, runs the model on it, and returns the output.

In addition, the model class defines two important methods:

* `training_step()`: This method allows for customization of the loss computation during training, which can vary depending on the model. 
  It is called by the trainer during each step of the training loop and is where you define the logic to calculate and return the loss.

* `configure_optimizers()`: This method specifies which optimizer(s) and learning rate scheduler(s) to use for training. 
  In this example, it returns a PyTorch SGD optimizer instance. 
  The trainer calls this method before the training loop begins.

Now that the model is defined, let's proceed to train it.

## <a id="sec_7">7. Training the Lightning model</a>

To train the model, all we need to do is:
(i) instantiate the model,
(ii) create a `Trainer` object using the `lightning.Trainer` class, and
(iii) call `.fit()` with the training data.

Let’s get started!

In [9]:
# Lets set the seed for determinism
set_seed(42)

# Instantiate the model
model = TinyVGG_LightningModel_V1(1, 10, len(class_names))

# Instantiate the trainer object
trainer = L.Trainer(max_epochs=5)

# Train the model using the training data loader
trainer.fit(model, train_dataloader)

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | TinyVGG          | 7.7 K  | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
7.7 K     Trainable params
0         Non-trainable params
7.7 K     Total params
0.031     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode
/opt/homebrew/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


What?! It automatically detected my GPU and used it to train the model? Nice!!! 😄

The Trainer object called the `configure_optimizers()` method from our Lightning model to set up the optimizer.

It also used the model’s `training_step()` method to compute the loss at each step.
Notice that this method doesn’t include any calls to reset gradients, perform backpropagation, or update the optimizer—that’s all handled automatically by the Trainer under the hood.

Under the hood, Lightning does the following (pseudocode):

```python
# enable gradient calculation
torch.set_grad_enabled(True)

for batch_idx, batch in enumerate(train_dataloader):
    # Invoke Lightning model's training_step() method to compute the loss
    loss = training_step(batch, batch_idx)

    # clear gradients
    optimizer.zero_grad()

    # backward
    loss.backward()

    # update parameters
    optimizer.step()
```

### 7.1 Monitoring training statistics with TensorBoard

As discussed in our previous tutorials, it's often helpful to track training statistics—such as training and validation loss—throughout the training process.

PyTorch Lightning provides several built-in logger classes that can be integrated with the `Trainer` to track training and evaluation metrics.
By default, it uses the `TensorBoardLogger`, which logs data in a format compatible with `TensorBoard`. 
This logger automatically creates a directory named `lightning_logs/` to store the logs.
Each time the `Trainer` is executed, a new subdirectory is created within `lightning_logs/`. 
This subfolder contains all the information logged via the `self.log()` method, as well as any model checkpoints (we’ll discuss checkpoints in more detail later).

After running the previous code block, you should see a `lightning_logs/` folder in your working directory.

A common way to view these logs is by using TensorBoard, a tool that reads the contents of these folders and displays the data in clear, interactive charts.

The following image shows the TensorBoard dashboard with statistics collected by the previous training process.
The second chart shows the training loss registered by the `self.log()` command at each training step.

![Training statistics on TensorBoard](https://raw.githubusercontent.com/eborin/SSL-course/main/images/06-tensorboard-1.png)

If you want to monitor training performance in real time while the `Trainer` is running, you can simply click the reload icon in the `TensorBoard` interface. 
This forces `TensorBoard` to refresh the view by reloading the log files and updating the charts accordingly.

For more details on using TensorBoard, visit the official [TensorBoard website](https://www.tensorflow.org/tensorboard).

## <a id="sec_8">8. Monitoring performance on validation set and testing the model</a>

In order to monitor the performance on the validation set or test the model on a test set, we need to tell the trainer how to perform these steps.
This is done in a similar way we informed the trainer how to compute the loss function -- we implement (overload) methods of the Lightning model.

This API and tutorials on how to perform these changes can be found on the [Lightning API references page](https://lightning.ai/docs/pytorch/stable/api_references.html).

There is also several [how-to guides](https://lightning.ai/docs/pytorch/stable/common/) on how to perform several tasks with PyTorch Lightning.

The  [How to Organize PyTorch Into Lightning](https://lightning.ai/docs/pytorch/stable/starter/converting.html) page provides a nice step-by-step on how to convert your PyTorch model into a Lightning model.
I relied on this page to change our Lightning model to include validation and test steps.

Let's copy and paste the previous `TinyVGG_LightningModel()` and adapt it with code to compute the validation and test loss.

In [10]:
# The Lightning model extends the L.LightningModule, instead of torch.nn.Module
class TinyVGG_LightningModel_V2(L.LightningModule):
    # You might define the init method as you like
    def __init__(self, input_shape: int, hidden_units: int, num_classes: int, learning_rate: float = 0.1):
        super().__init__()
        # We’re creating the model architecture directly in the __init__ method,
        # but it’s also common practice to receive the model as an argument instead.
        # This approach adds flexibility, allowing you to easily swap in different
        # architectures without modifying the Lightning module itself.
        self.model = TinyVGG(input_shape=input_shape,
                             hidden_units=hidden_units,
                             output_shape=num_classes)
        # Define the loss function
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    # The forward method defines how input data passes through the model.
    # It's the same as in standard torch.nn.Module classes.
    def forward(self, x):
        return self.model(x)

    # This method is called by the training loop. It allows you to define how each
    # training step must be performed. In this case, we will compute the loss, log it
    # using the self.log() method and then return the loss.
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)   # This calls self.forward(x) under the hood
        loss = self.loss_fn(logits, y)
        self.log("train_loss", loss)
        return loss

    # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
    # Changes from previous class
    # --------------------------------------------------------------------------------
    # This method is called by the validation loop, which is executed at the end of each
    # training epoch. It allows you to define how each validation step must be performed.
    # In this case, it is very similar to the training  step, the only difference is that
    # we register the loss using the "val_loss" identifier.
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)   # This calls self.forward(x) under the hood
        loss = self.loss_fn(logits, y)
        self.log("val_loss", loss)
        return loss

    # This method is called by the test loop, which is called by the test() method.
    # It allows you to define how each test step must be performed.
    # Again, in this case, it is very similar to the training and validation steps,
    # the only difference is that we register the loss using the "test_loss" identifier.
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)   # This calls self.forward(x) under the hood
        loss = self.loss_fn(logits, y)
        self.log("test_loss", loss)
        return loss

    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    # This method will be invoked by the trainer object to configure the optimizer.
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.learning_rate)

Ok. Now, lets train this model for 5 epochs.

In [11]:
# Lets set the seed for determinism
set_seed(42)

# Instantiate the model again
model_v2 = TinyVGG_LightningModel_V2(1, 10, len(class_names))

# Instantiate the trainer object
trainer = L.Trainer(max_epochs=5)

trainer.fit(model_v2, train_dataloader, val_dataloader)

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | TinyVGG          | 7.7 K  | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
7.7 K     Trainable params
0         Non-trainable params
7.7 K     Total params
0.031     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/homebrew/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


This new training process created a new folder inside `lightning_logs/`.

If you open TensorBoard, you will notice it has logs for two training processes: `version_0` and `version_1`.
You can plot and compare the `train_loss` of both training processes.
Also, you can check the validation losses for `version_1`, as this training process trained a model that contained code to register the validation loss.

The following image shows the TensorBoard dashboard with statistics collected by both training process.

![Training statistics on TensorBoard](https://raw.githubusercontent.com/eborin/SSL-course/main/images/06-tensorboard-2.png)

We can also test the model's performance using the `trainer.test()` method.
To do so, we use the `test_dataloader` as follows:

In [12]:
trainer.test(model_v2, test_dataloader)

/opt/homebrew/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.3259245753288269}]

The `test_step()` method logs the loss (and any other metrics) for each batch.
The `Trainer.test()` method then aggregates these values—typically by computing the epoch-level average across all test batches—and returns the final results as a dictionary.

The output of the loss function isn't typically used as a primary metric for evaluating model performance. 
Instead, metrics like Accuracy, F1-score, Dice coefficient, and others are more commonly used, depending on the task.

Next, let's modify our Lightning model to compute and log additional evaluation metrics.

In [13]:
# The Lightning model extends the L.LightningModule, instead of torch.nn.Module
class TinyVGG_LightningModel_V3(L.LightningModule):
    # You might define the init method as you like
    def __init__(self, input_shape: int, hidden_units: int, num_classes: int, learning_rate: float = 0.1):
        super().__init__()
        # We’re creating the model architecture directly in the __init__ method,
        # but it’s also common practice to receive the model as an argument instead.
        # This approach adds flexibility, allowing you to easily swap in different
        # architectures without modifying the Lightning module itself.
        self.model = TinyVGG(input_shape=input_shape,
                             hidden_units=hidden_units,
                             output_shape=num_classes)
        # Define the loss function
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    # The forward method defines how input data passes through the model.
    # It's the same as in standard torch.nn.Module classes.
    def forward(self, x):
        return self.model(x)

    # This method is called by the training loop. It allows you to define how each
    # training step must be performed. In this case, we will compute the loss, log it
    # using the self.log() method and then return the loss.
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)   # This calls self.forward(x) under the hood
        loss = self.loss_fn(logits, y)
        self.log("train_loss", loss)
        return loss

    # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
    # Changes from previous class
    # --------------------------------------------------------------------------------
    def compute_accuracy(self, y_pred, y_true):
        correct = torch.eq(y_true, y_pred).sum().item()
        acc = (correct / len(y_pred)) * 100
        return acc

    # This method is called by the validation loop, which is executed at the end of each
    # training epoch. It allows you to define how each validation step must be performed.
    # In this case, it is very similar to the training  step, the only difference is that
    # we register the loss using the "val_loss" identifier.
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)   # This calls self.forward(x) under the hood
        loss = self.loss_fn(logits, y)
        self.log("val_loss", loss, prog_bar=True)
        # Compute and log the accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.compute_accuracy(preds, y)
        self.log("val_acc", acc, prog_bar=True)
        return {"val_acc":acc, "val_loss": loss}

    # This method is called by the test loop, which is called by the test() method.
    # It allows you to define how each test step must be performed.
    # Again, in this case, it is very similar to the training and validation steps,
    # the only difference is that we register the loss using the "test_loss" identifier.
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)   # This calls self.forward(x) under the hood
        # Compute and log the loss
        loss = self.loss_fn(logits, y)
        self.log("test_loss", loss, prog_bar=True) # Also show it on the progress bar
        # Compute and log the accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.compute_accuracy(preds, y)
        self.log("test_acc", acc, prog_bar=True)
        return {"test_acc":acc, "test_loss": loss}
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

    # This method will be invoked by the trainer object to configure the optimizer.
    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.learning_rate)



Ok. Let's try again. This time for 20 epochs - like we did in tutorial `03_pytorch_computer_vision.ipynb`.

In [14]:
# Lets set the seed for determinism
set_seed(42)

# Instantiate the model again
model_v3 = TinyVGG_LightningModel_V3(1, 10, len(class_names))

# Instantiate the trainer object
trainer = L.Trainer(max_epochs=20)

# Train the model
trainer.fit(model_v3, train_dataloader, val_dataloader)

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | model   | TinyVGG          | 7.7 K  | train
1 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
7.7 K     Trainable params
0         Non-trainable params
7.7 K     Total params
0.031     Total estimated model params size (MB)
17        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


The validation accuracy looks very similar to what we achieved with the CNN-based model in the `03_pytorch_computer_vision.ipynb` tutorial.

To get a more detailed view, you can use TensorBoard to explore the validation loss and accuracy curves over time.

The image below shows the TensorBoard dashboard on my machine after running all the previous code blocks.
Notice that the third version includes logs for validation accuracy across the epochs.

![Training statistics on TensorBoard - v3](https://raw.githubusercontent.com/eborin/SSL-course/main/images/06-tensorboard-3.png)

Now, let's evaluate the model's performance on the test set.

In [15]:
# Test the model
trainer.test(model_v3, test_dataloader)

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.29931285977363586, 'test_acc': 89.4800033569336}]

Nice—this result is also very similar to what we saw with the CNN-based model in the `03_pytorch_computer_vision.ipynb` tutorial.

## <a id="sec_9">9. Exercises</a>

While there is no designated exercise for this tutorial, the following activities can provide valuable insights into factors that influence model performance:

* **Prediction**: Use the Trainer object to perform predictions with your Lightning model.
To enable this, implement the `predict_step()` method in your LightningModule, following the guidelines in the [LightningModule prediction loop documentation](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html). Use the modified model to predict the class labels of the samples in the test dataset.

* **Logger**: The deault logger for Lightning is the [TensorBoard logger](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.tensorboard.html#module-lightning.pytorch.loggers.tensorboard), however, you can use [different loggers](https://lightning.ai/docs/pytorch/stable/api_references.html#loggers) to register your statistics.
Try to modify the previous code to log your statistics using a different logger (e.g., the [CSV logger](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.csv_logs.html#module-lightning.pytorch.loggers.csv_logs)).

* **Checkpoints**: By default, the training process does not save any checkpoints.
However, Lightning provides a powerful API for model checkpointing.
For example, you can use the [ModelCheckpoint](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#lightning.pytorch.callbacks.ModelCheckpoint) callback to configure the trainer to save the model periodically by monitoring a specific metric.
(Any metric logged using `log()` or `log_dict()` can be used as a monitor key.)
Try modifying the previous code to save a checkpoint whenever the validation loss improves, and load this best checkpoint before evaluating the model on the test set.

* **Early stopping**: Explore the [EarlyStopping callback](https://lightning.ai/docs/pytorch/stable/common/early_stopping.html), and modify the previous code to train for up to 100 epochs, but stop early if the validation accuracy does not improve for 5 consecutive epochs.

* Review the [Level Up documentation](https://lightning.ai/docs/pytorch/stable/expertise_levels.html) and see if you can pick up any new techniques to apply to the code you wrote earlier.
