## MNIST image classification with PyTorchLightning

In [32]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd


import torch

import pytorch_lightning as pl

import torchmetrics

import torchvision

In [33]:
# To install pytorch lightning:
# conda install lightning -c conda-forge

# Define a model

In [34]:
# In the following, we will use the accuracy metric from torchmetrics
# Here is its most basic use
acc = torchmetrics.Accuracy(task="multiclass", num_classes=3)
y = torch.Tensor([0,1,2])
y_hat = torch.Tensor([0,0,0])
print(acc(y, y_hat))
# Here our prediction is of float rather than int. 
# torchmetrics interprets this as logits and automatically apply softmax
acc = torchmetrics.Accuracy(task="multiclass", num_classes=2)
y = torch.Tensor([[1.0, 2.0], [1.0, -1.0], [-2.0, -1.0]])
# According to this y, the prediction should be [1, 0, 1], the accuracy should be 1/3
pred = torch.argmax(y, dim=1)
print(pred)
y_hat = torch.Tensor([0,0,0])
print(acc(y, y_hat))
print(acc(pred, y_hat))

tensor(0.3333)
tensor([1, 0, 1])
tensor(0.3333)
tensor(0.3333)


In [35]:
# Another method to calculate accuracy, which is implemented in the following codes
# The key steps are to use update() and compute()
acc = torchmetrics.Accuracy(task="multiclass", num_classes=3)
y = torch.Tensor([0,1,2])
y_hat = torch.Tensor([0,0,0])
acc.update(y, y_hat)
print(acc.compute())

tensor(0.3333)


In [36]:
# The benefit of using update() and compute() is that 
# we can accumulate the true labels over several iteration
# Imagin that there are two batches, each with 3 examples
# For the first batch, only one example is predicted correctly.
# For the second batch, two examples are predicted correctly.
# So the overall accuracy is (1+2)/(3+3)=0.5
# This can be achieved by using update() for each batch and use compute() at the end
acc = torchmetrics.Accuracy(task="multiclass", num_classes=3)
y = torch.Tensor([0,1,2])
y_hat = torch.Tensor([0,0,0])
acc.update(y, y_hat)
y = torch.Tensor([0,0,2])
y_hat = torch.Tensor([0,0,0])
acc.update(y, y_hat)
print(acc.compute())

tensor(0.5000)


In [37]:
# Another basic python utility used below is the extraction of arguments for a function
def my_func(a,b,c):
    print(a,b,c)
    
x = [1,2,3]
my_func(*x)
# By using *x, the entries in x is extracted to pass to my_func
# The number of elements in x must be the same as the number of arguments of my_func

# A related utility is to put * in the function definition
# In this case, the function accepts any number of inputs
def my_func2(*args):
    print(args)
    
my_func2(1,2,3)
my_func2(1,2)

# We can use the two methods together to achieve arbitrary number of inputs when these inputs are in a list.
x = [1,2,3]
my_func2(*x)

1 2 3
(1, 2, 3)
(1, 2)
(1, 2, 3)


In [38]:
class MultiLayerPerceptron(pl.LightningModule):
    def __init__(self, image_shape=(1,28,28), hidden_units=(32,16)):
        super().__init__()
        
        # Define three accuracies to keep track of the training process
        self.accuracy_train = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.accuracy_valid = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.accuracy_test = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        
        input_size = image_shape[0] * image_shape[1] * image_shape[2]
        
        # Define layers
        # Except the output layer, each linear layer is followed by a ReLU layer
        all_layers = []
        all_layers.append(torch.nn.Flatten())
        for num_unit in hidden_units:
            all_layers.append(torch.nn.Linear(input_size, num_unit))
            all_layers.append(torch.nn.ReLU())
            input_size = num_unit
        all_layers.append(torch.nn.Linear(hidden_units[-1], 10))
        self.model = torch.nn.Sequential(*all_layers)
        # This is essentially the same as Sequential(layer1, layer2, ...)
        
    def forward(self, X):
        # There is a forward method in pl.LightningModule.
        # We use self.model(X) to create an instance and call the forward method in pl.LightningModule.
        y = self.model(X)
        return y
    
    def training_step(self, batch, batch_idx):
        # The most important step in this function is to calculate the loss function
        # This naturally requires calculation of prediction, or at least computing the logits
        # When using a loss function in torch.nn, pay attention to the shapes of its input
        X, y = batch
        # Logits are outputs of the network before being transformed by softmax
        logits = self(X)
        loss = torch.nn.functional.cross_entropy(self(X), y)
        pred = torch.argmax(logits, dim=1)
        self.accuracy_train.update(pred, y)
        # self.log() is a method in PyTorch Lightning that logs a given metric value
        # The self.log() method takes two arguments:
        # name (string): The name of the metric to log.
        # value (tensor or float): The value of the metric to log.
        # prog_bar=True: the metric value will be displayed on the progress bar 
        # during training and validation.
        self.log("train_loss", loss, prog_bar=True) #??
        return loss
    
    def on_train_epoch_end(self):  

        # on_train_epoch_end is a method in PyTorch Lightning 
        # that is called at the end of each training epoch. 
        # It allows you to perform some operations on the model 
        # and the training data after each epoch has completed.
        
        # Here, we calculate the training accuracy accumulated over all batches.
        # Note that we don't specify prog_bar, so it takes it default value (False)
        self.log("train_acc", self.accuracy_train.compute())
    
    def validation_step(self, batch, batch_idx):
        # This is almost the same as the training step
        X, y = batch
        logits = self(X)
        loss = torch.nn.functional.cross_entropy(self(X), y)
        pred = torch.argmax(logits, dim=1)
        self.accuracy_valid.update(pred, y)
        # The validation loss and accuracy are logged here
        self.log("valid_loss", loss, prog_bar=True)
        # There are no batches for the validatio step, so we directly compute the accuracy
        self.log("valid_acc", self.accuracy_valid.compute(), prog_bar=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Similar as validation_step()
        X, y = batch
        logits = self(X)
        loss = torch.nn.functional.cross_entropy(self(X), y)
        pred = torch.argmax(logits, dim=1)
        self.accuracy_test.update(pred, y)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.accuracy_test.compute(), prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        # Define an optimizer
        # Note that the name is 'configure_optimizers', not 'configure_optimizer'!!
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer
        

In [39]:
# Test the model output
model_test = MultiLayerPerceptron()
# Generate data that mimic 2 images, each of size 1-28-28
images = torch.randn(2, 1, 28, 28)
# The output for each image contains 10 logits.
model_test(images).shape

torch.Size([2, 10])

In [40]:
# Test the loss function
y = torch.tensor([1,2])
logits = model_test(images)
loss = torch.nn.functional.cross_entropy(logits, y)
print(loss)

tensor(2.4901, grad_fn=<NllLossBackward0>)


In [41]:
# Test the prediction
pred = torch.argmax(logits, dim=1)
print(pred)

tensor([6, 6])


# Prepare data

In [42]:
class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_path='./data'):
        super().__init__()
        self.data_path = data_path
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
        
    def prepare_data(self):
        # Note that this following result is not assigned to any variable
        # Its function is only to download the dataset.
        torchvision.datasets.MNIST(root=self.data_path)
    
    def setup(self, stage=None):
        # stage = 'fit' or 'validate' or 'test' or 'predict'
        mnist_all = torchvision.datasets.MNIST(
            root=self.data_path,
            train=True,
            transform=self.transform,
            download=False
        )
        
        # mnist_all contains the training and validation datasets
        # So we split it. The two datasets have 55000 and 5000 examples respectively.
        # To randomly split the data, we need a generator
        # We create the generator by using torch.Generator() 
        # and use its manual_seed method to get reproducible result.
        self.train, self.val = torch.utils.data.random_split(
            mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
        )
        
        # Create the test set
        self.test = torchvision.datasets.MNIST(
            root=self.data_path,
            train=False,
            transform=self.transform,
            download=False
        )
        
    def train_dataloader(self): 
        # By using the num_workers=4, 
        # the data loading will be performed in parallel in 4 subprocesses.
        return torch.utils.data.DataLoader(self.train, batch_size=64, num_workers=4)
        
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val, batch_size=64, num_workers=4)
        
    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test, batch_size=64, num_workers=4)

In [43]:
torch.manual_seed(1)
mnist_dm = MnistDataModule()

In [44]:
mnist_classifier = MultiLayerPerceptron()

In [45]:
if torch.cuda.is_available():
    trainer = pl.Trainer(max_epochs=20, accelerator="gpu", devices=1)
else:
    trainer = pl.Trainer(max_epochs=20)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [46]:
trainer.fit(model=mnist_classifier, datamodule=mnist_dm)

You are using a CUDA device ('NVIDIA GeForce RTX 2050') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | accuracy_train | MulticlassAccuracy | 0     
1 | accuracy_valid | MulticlassAccuracy | 0     
2 | accuracy_test  | MulticlassAccuracy | 0     
3 | model          | Sequential         | 25.8 K
------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


## Tensor Board

In [54]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 15148), started 0:24:07 ago. (Use '!kill 15148' to kill it.)

In [58]:
mnist_classifier_continue = MultiLayerPerceptron.load_from_checkpoint('./lightning_logs/version_9/checkpoints/epoch=19-step=17200.ckpt')

In [59]:
if torch.cuda.is_available():
    trainer = pl.Trainer(max_epochs=5, accelerator="gpu", devices=1)
else:
    trainer = pl.Trainer(max_epochs=5)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [60]:
trainer.fit(model=mnist_classifier_continue, datamodule=mnist_dm)

You are using a CUDA device ('NVIDIA GeForce RTX 2050') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | accuracy_train | MulticlassAccuracy | 0     
1 | accuracy_valid | MulticlassAccuracy | 0     
2 | accuracy_test  | MulticlassAccuracy | 0     
3 | model          | Sequential         | 25.8 K
------------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [62]:
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/

Reusing TensorBoard on port 6006 (pid 15148), started 0:35:57 ago. (Use '!kill 15148' to kill it.)

In [63]:
# Model performance on the test set
trainer.test(model=mnist_classifier_continue, datamodule=mnist_dm)

You are using a CUDA device ('NVIDIA GeForce RTX 2050') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9504699110984802
        test_loss           0.17379100620746613
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.17379100620746613, 'test_acc': 0.9504699110984802}]