<div class='heading'>
    <div style='float:left;'><h1>CPSC 4300/6300: Applied Data Science</h1></div>
     <img style="float: right; padding-right: 10px" width="100" src="https://raw.githubusercontent.com/bsethwalker/clemson-cs4300/main/images/clemson_paw.png"> </div>
     </div>

**Clemson University**<br>
**Instructor(s):** Aaron Masino <br>

## Lab 7: Introduction to Neural Networks with PyTorch & PyTorch Lightning

In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
import torchvision.transforms.functional as F 

import lightning as L
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

import torchmetrics as TM

from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

In [None]:
def create_data_directory(path: str):
    if not os.path.exists(path):
        os.makedirs(path)

dir_dataroot = os.path.join("..", "data")
create_data_directory(dir_dataroot)

dir_lightning = os.path.join("..", "lightning")
create_data_directory(dir_lightning)

SEED = 123456

# Learning Goals

This lab will introduce you to the creation, training, and evaluation of deep learning neural network models using the [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/pytorch-lightning) libraries. PyTorch contains core capabilities related to the development of deep learning models. PyTorch Lightning provides functionally that abstracts much of the process of training and evaluating deep learning models created with PyTorch.

By the end of this lab, you should be able to:
- Create PyTorch Tensor objects using numpy arrays and built-in PyTorch functions
- Apply element-wise operations on PyTorch Tensors
- Apply matrix operations on PyTorch Tensors
- Explain the use of CPU and GPU processing on Tensors
- Create PyTorch DataSet and DataLoader objects
- Create PyTorch Lightning DataModule objects
- Create PyTorch modules in the context of neural networks
- Create deep learning models by composing PyTorch Modules
- Apply the PyTorch LightningModule class to encapsulate deep learning model training and prediction functionality
- Apply the PyTorch Lightning Trainer class to train a deep learning model 
- Apply the TorchMetrics library to collect deep learning model performance data
- Evaluate deep learning model performance using visualization and classification metrics

# Part 1 Tensors
The primary data structure used in the PyTorch is the [Tensor](https://pytorch.org/docs/stable/tensors.html). `Tensor` objects are analagous to numpy arrays. Indeed, they have many of the same attributes as shown in the code cell below. One attribute that distinguishes PyTorch Tensors from numpy arrays is the `device` attribute. This attribute tells us if the memory where our Tensor is held is on the CPU RAM or on a GPU. Typically, our Python code will be executed on our computer's CPU (central processing unit). However, much of deep learning involves matrix multiplications (in fact many such multiplications). It turns out that GPUs (graphical processing units) are orders of magnitude faster than CPUs for matrix multiplications. Hence, it will be advantagous to move our Tensors to the GPU during model training and inference. PyTorch will make this process very easy.

In [None]:
# make a 3 dimensional tensor
x = [[[1, 2], [3, 4]], 
     [[5, 6], [7, 8]],
     [[9, 10], [11, 12]]]
t1 = torch.tensor(x)

# tensor attributes
print("t.size() = ", t1.size())
print("t.shape = ", t1.shape)
print("t.dim() = ", t1.dim())
print("t.dtype = ", t1.dtype)
print("t.device = ", t1.device)


## From numpy and back again

PyTorch Tensors can be created from numpy arrays. The `Tensor` method, `numpy` will create a numpy array from the `Tensor` values.

In [None]:
x = np.random.rand(4,3)
t1 = torch.tensor(x)
print("t.dtype = ", t1.dtype)
print("t.shape = ", t1.shape)    

y = t1.numpy()
print("y.dtype = ", y.dtype)
print("y.shape = ", y.shape)

## Direclty from PyTorch
PyTorch includes methods for creating arbitrary dimension and size Tensors similar to numpy. These include `ones`, `zeros`, and several random sampling based methods, such as `rand` which samples values from a uniform distrubtion on $\left[0,1\right)$ and `randn` which samples from the normal distribution with zero mean and unit standard deviation. For more details, see [PyTorch random sampling](https://pytorch.org/docs/main/torch.html#random-sampling).

In [None]:
t1 = torch.ones(2,1)
print(t1)

t1 = torch.zeros(2,2)
print(t1)

# uniform random numbers from [0,1)
t1 = torch.rand(3,2)
print(t1)

## Tensor indexing, slicing, and values

Tensors are index in the same manner as numpy arrays. Portions of tensors can be selected with index slicing that works the same way as numpy arrays.

In [None]:
# draw from a normal distribution with mean=0 and std=1
t1 = torch.randn(3,2)
print(t1)

# print th 0,0 element
print(t1[0,0])

# print the first row
print(t1[0])

# print the first column
print(t1[:,0])

# print the last two columns of the last two rows
print(t1[-2:,-2:])

Sometimes it is usefuly to work with the values held in a Tensor using the core Python types (i.e., float, int, boolean) rather than the Tensor object. The Tensor `item` method enables access to a single value. Use of the `item` method requires that the Tensor (or slice of the Tensor) has only one element. Otherwise, the Tensor `tolist` method can be used to obtain the values from indices of a Tensor or the entire Tensor.

In [None]:
t1 = torch.rand(2,2)

# get the value of item at 0,0 
print(t1[0,0].item())

# get the value of item at 0,1  
print(t1[0,1].item())

# get the values of the first row
print(t1[0].tolist())

# get all of the values
print(t1.tolist())

## Tensor operations

### Element wise operation
Just as with numpy arrays, the standard Python arithmetic operators are overriden for PyTorch Tensors to allow for element wise operation. For example, the binary operator `+` applied to Tensor objects, $t_1$ and $t_2$, each with $k$ elements will result in a single Tensor whose $k^{th}$ element is $t_1^{\{k\}} + t_2^{\{k\}}$.

In [None]:
t1 = torch.ones(2,2)
t2 = torch.Tensor([[1,2], [10,20]])
print(t1)
print(t2)

# elementwise addition
t3 = t1 + t2
print('t1 + t2 = \n', t3)

# elementwise subtraction
t3 = t1 - t2
print('t1 - t2 = \n', t3)

# elementwise multiplication
t3 = t1 * t2
print('t1 * t2 = \n', t3)

# elementwise division
t3 = t1 / t2
print('t1 / t2 = \n', t3)

### Matrix multiplication & dot products
Matrix multiplication and vector dot products are integral operations to the computation in NNets for both model training and inference. PyTorch provides methods for both operations that are computationally efficient and are supported on CPU and GPU.

The `dot` function can be called directly from a Tensor object or from the `torch` module. The dot product is a symmetric function, so the order of the arguments is arbitary.

In [None]:
t1 = torch.randn(10)
t2 = torch.randn(10)

# dot product
dp = torch.dot(t1, t2)
print('dot product = ', dp.item())
print(dp.shape)

# or .dot product
dp = t1.dot(t2)
print('dot product = ', dp.item())
print(dp.shape)

# or change the order of the arguments
dp = t2.dot(t1)
print('dot product = ', dp.item())
print(dp.shape)

There are a few equivalent methods for matrix multiplication with PyTorch Tensors including the `mm` and `matmul` methods accessible from either a Tensor object or the `torch` module. Additionally, the overridden `@` operator also yields the result of multiplying two matrices. Be careful, matrix multiplication is an asymmetric operation (i.e., the order of the arguments can produce different results). Additionally to perform matrix multiplication, the first matrix must have the same number of _columns_ as the second matrix has _rows_. 

In [None]:
t1 = torch.randn(2,3)
t2 = torch.randn(3,4)
print(t1)
print(t2)

# matrix multiplication - all of the following are equivalent
t3 = torch.mm(t1,t2)
t3 = t1.mm(t2)
t3 = torch.matmul(t1,t2)
t3 = t1 @ t2
print(t3)

# the following will raise an error
try:
    t3 = t2 @ t1
except Exception as e:
    print(e)

### Combining Tensors
Tensors can be combined using the `cat` (concatenate), `vstack` (vertical stack), and `hstack` (horizontal stack) methods. The `cat` method includes the `dim` method which indicates on which Tensor dimension the arguments should be combined. The `vstack` method is equivalent to calling `cat` with `dim=0` and the `hstack` method is equivalent to calling `cat` with `dim=1`.

In [None]:
t1 = torch.rand(10)
t2 = torch.rand(2)
print(t1.shape)
print(t2.shape)

# concatenate one dimensional tensors
ts = torch.cat((t1,t2))
print(ts)

In [None]:
t1 = torch.rand(2,3)
t2 = torch.rand(4,3)
print(t1)
print(t2)

# concatenate with vertical stacking
# these are equivalent
ts = torch.cat((t1,t2), dim=0)
ts = torch.vstack((t1,t2))
print(ts)

In [None]:
t1 = torch.rand(2,3)
t2 = torch.rand(2,2)
print(t1)
print(t2)

# concatenate with horizontal stacking
# these are equivalent
ts = torch.cat((t1,t2), dim=1)
ts = torch.hstack((t1,t2))
print(ts)

### Know your dtypes
The Tensor `dtypes` indicates the object type of the values held in the Tensor. Most Tensor operations with arity greater than one (i.e., functions that operate on two or more inputs) require that the Tensor arguments have the same `dtypes`. This is because Tensor operations are _vectorized_ which doesn't easily allow for mixed types. The values of a Tensor can be cast to different types using the `torch.Tensor.to` method (see, [here](https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch.Tensor.to)) using the appropriate `dtype` argument (see the complete list [here](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)). Alternatively, the Tensor class also has methods such as the `float` to directly cast Tensors to specified types.

In [None]:
t1 = torch.tensor([[1,2,3], [4,5,6]])
print(t1.dtype)

try:
    tr = torch.rand_like(t1) #this won't work because t1 is int and rand_like will create a float tensor
except Exception as e:
    print('Exception:')
    print(e)

# cast t to float
tr = torch.rand_like(t1.float())
print(t1.float() @ tr.T)

# equivalent to
tr = torch.rand_like(t1.to(dtype=torch.float32))
print(t1.float() @ tr.T)

### Know your devices
As noted above, PyTorch seamlessly manages transfer of Tensors between CPU and GPU memory. It is important to be able to identify where a given Tensor is stored. Of course, not all computing environments will support GPU usage, so one must first check if there is a PyTorch supported device. The code cell below shows how to check if there is a NVIDIA GPU with CUDA support available. If you are using a Mac, you can instead check if there is MPS support (see [here](https://pytorch.org/docs/main/notes/mps.html)). More recently, PyTorch is also providing support for AMD GPUs with ROCM support (see [here](https://pytorch.org/blog/pytorch-for-amd-rocm-platform-now-available-as-python-package/)). For more information on PyTorch acceleration options, see [here](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)

To specifically check if a CUDA GPU is available, we can use the `torch.cuda.is_available` method. Once we have confired that GPU (or other device) support is available, we can check the location of a given Tensor `t` with the `t.device` attribute. We can move a the tensor to the CUDA GPU using the `t.to('cuda')` method.

In [None]:
t1 = torch.randn(3,4)

# which device is this on?
print('t.device = ', t1.device)

# check if cuda is available
print('torch.cuda.is_available() = ', torch.cuda.is_available())

# move to gpu
t1 = t1.to('cuda')
print('t.device = ', t1.device)

t2 = torch.randn(4,2)

try:
    print(t1 @ t2)
except Exception as e:
    print('Exception:')
    print(e)

# move to gpu
t2 = t2.to('cuda')
print('t.device = ', t2.device) 

# or use the device agnostic way
t1 = torch.randn(3,4)
t2 = torch.randn(4,2)
t1 = t1.to(t2.device)
print('t.device = ', t1.device)
print(t1 @ t2)

# Part 2 Data Handling with PyTorch Lightning 

DataSets, DataLoaders, and DataModules oh my

When developing NNet models, especially deep-learning models, we are usualy working with something other than tabular data such as images or text files. In many cases, there will be thousands to millions of such files. It is typically not feasible to load all of the samples into computer memory at the same time. Additionally, as we will discuss in lecture, we don't typically update NNet models using all of the data at once (as is done in standard _Batch Gradient Descent_). Instead, the training data is divided into equal sized subsets, called __mini-batches__. For each _mini-batch_, the model predictions are generated and used to compute the loss function value. The loss function derivatives (computed via backpropagation) are then used to update model parameters. This process is known as _mini-batch gradient descent_. The process continues until all _mini-batches_ have been used, which is referred to as an __epoch__. Typically, NNets are trained form multiple `epochs`, meaning the entire training dataset is used multiple times to adjust the model learning paramaters.

To enable this functionality in an efficient manner, we will use three data hanlding clsses. The first two, `DataSet` and `DataLoader` are part of the PyTorch library. The third, `DataModule` is from the PyTorch Lightning library.
- DataSets are responsible for holding information about the dataset and methods to manipulate individual samples.
- DataLoaders use the DataSets and other methods to load data for use in model training, evaluation, and inference. It is typical to have training, validation, and evaluation instances of DataLoaders for a given DataSet
- DataModules encapsulate multiple DataLoaders and other methods for working with a DataSet for modle training and evaluation including batching data samples.

We introduce examples of these below, and we will see more in future labs.

## DataSets

The `DataSet` class is the _base_ class that can be extended for custom `DataSet` classes. Classes that extend `DataSet` must implement the following methods:
1.  `__len__` returns the number of items available in the DataSet (e.g., the number of images)
2.  `__getitem__` returns the item associated with the input argument (e.g., and index)

For more general information on DataSet see [here](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). For DataSet extensions specifically for image data, see [here](https://pytorch.org/vision/main/datasets.html#base-classes-for-custom-datasets) and specifically [ImageFolder](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder). 

In [None]:
class MyDataSet(torch.utils.data.Dataset):
    def __init__(self, x, t):
        self.x = x
        self.t = t

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        # used modulo to make the dataset circular
        # this will allow us to get batches that are larger than the dataset or repeat the dataset
        idx = idx % len(self.x)
        return self.x[idx], self.t[idx]
    
x = np.random.rand(100,3)
t = np.random.rand(100,1)
ds = MyDataSet(x,t)

print('len(ds) = ', len(ds))
predictors, targets = ds[0]
print('predictors = ', predictors)
print('targets = ', targets)

# what happens if the idx is larger than the dataset
# this works because we used modulo in the __getitem__ method
predictors, targets = ds[x.shape[0]+10]


## DataLoader
Typically, we can use the PyTorch `DataLoader` class without modification to load a DataSet. This is going to be useful when we train NNets using mini-batch gradient descent. The `DataLoader` is constructed with a DataSet input and a `batch_size` (which is the size of the mini-batch). The `DataLoader` can be treated as an iterator where each yielded value is a mini-batch of the data.

In [None]:
# create numpy array of 0 and 1 values
ds = MyDataSet(np.random.rand(100,3), np.random.randint(0,2,(100,1)))
dl = DataLoader(ds, batch_size=3, shuffle=True)

# print the first k batches
k = 3
for idx, (x,t) in enumerate(dl):
    if idx == k:
        break
    print('batch ', idx)
    print('predictors = ', x)
    print('target = ', t)

## DataModules
Typically, we will need to split our data into training, validation, and test sets. We will need a `DataLoader` for each of these. Additionally, there is often some setup work that is necessary to preprocess data samples before the data can be used for model training or inference. It will be convenient to have a single object that handles all of this functionality. This is where the PyTorch Lightning `LightningDataModule` class comes in, see [here](https://lightning.ai/docs/pytorch/stable/data/datamodule.html). It will hold the necessary DataLoader objects, prepare data as needed with in a custom `setup` method and, if needed, prepare `mini-batch` samples before sending them to the model (not shown here, we'll see this when working with text data).

In [None]:
class MyDataModule(L.LightningDataModule):
    def __init__(self, x, t, batch_size, val_split=0.2, test_split=0.1):
        super().__init__()
        self.x = x
        self.t = t
        self.batch_size = batch_size
        self.val_split = val_split
        self.test_split = test_split
        self.setup()

    def setup(self):
        # create a train, validation and test dataset

        # first split the data into training and test
        n = len(self.x)
        n_val = int(n * self.val_split)
        n_test = int(n * self.test_split)
        n_train = n - n_val - n_test
        
        self.x_train = self.x[:n_train]
        self.t_train = self.t[:n_train]
        self.x_val = self.x[n_train:n_train+n_val]
        self.t_val = self.t[n_train:n_train+n_val]
        self.x_test = self.x[n_train+n_val:]
        self.t_test = self.t[n_train+n_val:]

    def train_dataloader(self):
        ds = MyDataSet(self.x_train, self.t_train)
        return DataLoader(ds, batch_size=self.batch_size, shuffle=True)
    
    def val_dataloader(self):
        ds = MyDataSet(self.x_val, self.t_val)
        return DataLoader(ds, batch_size=self.batch_size, shuffle=False)
    
    def test_dataloader(self):
        ds = MyDataSet(self.x_test, self.t_test)
        return DataLoader(ds, batch_size=self.batch_size, shuffle=False)
    
x = np.random.rand(1000,3)
t = np.random.rand(1000,1)
dm = MyDataModule(x, t, batch_size=32)
# get the train dataloader
dl_train = dm.train_dataloader()

# print the first batch from the train dataloader
x,t = next(iter(dl_train))
print('predictors = ', x)
print('target = ', t)



# Part 3 Building and Training a Model on MNIST with PyTorch

## MNIST data
We are going to build a fully connected feed forward NNet to classify handwritten digits into one of ten classes, i.e., the model will indicate which number in $\left[0,9\right]$ is represented by hand written digit image. We will use the classic MNIST dataset. Before building the model, let's take a look at the MNIST data which is organized in a custom DataSet as part of the PyTorch torchvision library.

In [None]:
# load the mnist dataset
# on the first run this will download the dataset
# we need to a location to store the dataset
# setting train=True will download the training set as well
mnist_train_all = MNIST(dir_dataroot, train=True, download=True, transform=ToTensor())

# to get the test set we can set train=False
mnist_test = MNIST(dir_dataroot, train=False, download=True, transform=ToTensor())

# how many samples are in the training set
print('len(mnist_train) = ', len(mnist_train_all))

# how many samples are in the test set
print('len(mnist_test) = ', len(mnist_test))

Let's look at some of the images

In [None]:
def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

batch_size = 16
# we can create a DataLoader using the 
imgs = next(iter(DataLoader(mnist_train_all, batch_size=batch_size, shuffle=True)))[0]
print('imgs.shape = ', imgs.shape)
grid = make_grid(imgs, nrow=4)
show(grid)


## DataLoaders
We won't need a custom DataSet or DataModules for the MNIST data as the required functionality is provided in the MNIST class. We will need to split the training data into training and validation sets and create DataLoaders for the training, validation, and test sets.

In [None]:
# let's seed everything for reproducibility
seed_everything(SEED)
batch_size = 16

# split the training set into training and validation
val_split = 0.2
val_size = int(len(mnist_train_all) * val_split)
train_size = len(mnist_train_all) - val_size
mnist_train, mnist_val = torch.utils.data.random_split(mnist_train_all, [train_size, val_size])

In [None]:
print("Training samples:",mnist_train_all.data[mnist_train.indices].shape)
print("Unique classes in training:",mnist_train_all.targets[mnist_train.indices].unique())

print("\nValidation samples:", mnist_train_all.data[mnist_val.indices].shape)
print("Unique classes in validation:", mnist_train_all.targets[mnist_val.indices].unique())

print("\nTest samples:", mnist_test.data.shape)
print(mnist_test.targets.unique())


Let's create the DataLoaders

In [None]:

# create the dataloaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(mnist_val, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)

## Building the MNIST NNet Classification Model
We are now ready to build a NNet model using PyTorch. As we consider more complex NNet architectures, it will be useful to split the models up into smaller components. This will enable us to resuse the components in other models and it will make it easier to understand the functionality of individual pieces of the model. 

### Modules
We will use the `nn.Module` class to build NNet model components. Modules are used to encapsulate segments of an overal NNet model. We will see later that our NNets may be composed of multiple modules. For now, we will consider a single module representing a fully connected feed forward hidden layer. Note, this does NOT include the output layer. That will be an additional module. 

The `MNISTEncoder` module extends the `nn.Module` class. It includes a custom implementation of the `__init__` method and the `__forward__` method. 

In the `__init__` method, a `nn.Linear` object is created. This is the PyTorch implementation of a fully connected feedforward layer. It expects a 1-D input Tensor of size `input_size`. It contains `hidden_size` computation nodes. Thus the there are `input_size` X `hidden_size` weights (aka edges aka learning parameters) plus bias terms (included by default) for this layer. 

The `__forward__` method will take an input Tensor, `x` and __flatten__ it (meaning it will transform it from a higher order dimension to a 1-D Tensor). If `x` is $D$ dimensional with dimension sizes $\{d_1, \ldots, d_D\}$ then it is required that the product of the dimensions is equal to `input_size`. For example, if `input_size` is 100 and `x` is 2-D with dimension size $d_1$ and $d_2$, then it is reaquired that $d_1 d_2 = 100$. The flattened representation of `x` is then passed through the `linear` layer which outputs the affine transformations (weighted sums) of produced by each node. Finally, those values are passed through the `ReLU` activation function. 

We can think of this module as __encoding__ (or __embedding__) the input image `x` into a latent space.

In [None]:
class MNISTEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear(x)
        x = self.relu(x)
        return x

num_hidden = 64
n_classes = 10
input_dim = 28 
encoder = MNISTEncoder(input_dim*input_dim, num_hidden, n_classes)
input_image = torch.randn(1, 1, input_dim, input_dim)
embedding = encoder(input_image)
print('embedding.shape = ', embedding.shape)

Next we want to use the _encoded_ representation of the input image, `x`, from the `MNISTEncoder` module to classify the image. To do this, we will construct another module representing the output layer, which we will call `MNISTOutput`. This module will require only a `nn.Linear` module as it will process the Tensor that is output from the `MNISTEncoder.forward` method which is already 1-D. We do not use an activation function here, because in the output layer we want only the _logits_. Note that the `linear` layer accepts a Tensor of size `hidden_size` which corresponds to the number of computation units in the `linear` layer in the `MNISTEncoder`. It outputs a Tensor of size `num_classes` which for MNIST is 10. 

You may be wondering why we haven't included the softmax conversion of the logits in the `forward` method. It turns out, that when training the model, it will be more computationally efficient to perform this operation as part of the loss function. In the example below, we can use the `nn.functional.softmax` function to compute the probability estimates from the logits - but be careful, we haven't trained the model yet, so these are just random numbers.

In [None]:
class MNISTOutput(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super().__init__()
        self.linear = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        return self.linear(x)
    
output = MNISTOutput(num_hidden, n_classes)
logits = output(embedding)
print('logits.shape = ', logits.shape)
# These are the logits for each class
print('logits = ', logits)
# To get the probabilities we can use the softmax function, but remember, we haven't trained the model yet
probs = nn.functional.softmax(logits, dim=1)
print('probs = ', probs)

## Create the MNIST Classifer 
Now that we have defined the individual components of our model, we need to assemble them. Just as with our data handlers, it will prove convenient to hold our model components in a single object that also provides methods that support model training and inference. It turns out that for deep learning models, there are repeatable coding patterns that have emerged. These are abstracted in the PyTorch Lightning `LightningModule` class, see [here](https://lightning.ai/docs/pytorch/LTS/common/lightning_module.html). 

To build our custom `LightningModule` class, we will need to define the following methods:

1. `__init__` : define any attributes or setup procedures. This will include our model component modules.
2. `__forward__` : this will pass an input sample (or mini-batch) through our component modules
3. `training_step` : defines what steps to perform for each training step (i.e. mini-batch). Typically, passing a mini-batch through the network and computing the loss. 
4. `validation_step` : defines what steps to perform for each validation step. Typically, passing a mini-batch through the network and computing the loss
5. `test_step` : defines what steps to perform for each test step. Typically, passing a mini-batch through the network and computing the loss
6. `on_validation_epoch_end` : defines what steps to perform an the end of a validation epoch (i.e., after all validation mini-batches have been evaluated by the model). This often includes a test for halting model training.
7. `on_test_epoch_end` : defines what steps to perform an the end of a test epoch (i.e., after all test mini-batches have been evaluated by the model). This often includes storing metric values.
8. `configure_optimizers` : define the optimization class for use in gradient descent. There are many optimizers available (we'll discuss this more in lecture). Here we use the popular Adam optimizer.

The `LigtningModule` class will use the `training_step` method and the `optimizer` provided by the `configure_optimizer` method to handle updating the model parameters. We do NOT need to implement any code for that part.

Note that in our implementation below, we are using the `TorchMetrics`, see [here](https://lightning.ai/docs/torchmetrics/stable/) library. the `MetricTracker` and associated metrics classes will do all the work for us in calculating and storing model performance metrics on the validation and test mini-batches. We will be able to use these to assess model performance.

In [None]:
class MNISTClassifier(L.LightningModule):
    def __init__(self, encoder, output, num_classes):
        super().__init__()
        self.encoder = encoder
        self.output = output
        
       # validation metrics - we will use these to compute the metrics at the end of the validation epoch
        self.val_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes)]), maximize=True)
        self.validation_step_outputs = []
        self.validation_step_targets = []

        # test metrics - we will use these to compute the metrics at the end of the test epoch
        self.test_roc = TM.ROC(task="multiclass", num_classes=num_classes, thresholds=list(np.linspace(0.0, 1.0, 20))) # roc and cm have methods we want to call so store them in a variable
        self.test_cm = TM.ConfusionMatrix(task='multiclass', num_classes=num_classes)
        self.test_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes), 
                                                            self.test_roc, self.test_cm]), maximize=True) 
        
        # test outputs and targets - we will store the outputs and targets for the test step
        self.test_step_outputs = []
        self.test_step_targets = []

    # the forward method applies the encoder and output to the input
    def forward(self, x):
        x = self.encoder(x)
        x = self.output(x)
        return x

    # the training step. pass the batch through the model and compute the loss
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        # where is our softmax? We don't need it here because we are using cross_entropy which includes the softmax for efficiency
        loss = nn.functional.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss
    
    # the validation step. pass the batch through the model and compute the loss. Store the outputs and targets for the epoch end step and log the loss
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True)
        
        # store the outputs and targets for the epoch end step
        self.validation_step_outputs.append(logits)
        self.validation_step_targets.append(y)
        return loss
    
    # the test step. pass the batch through the model and compute the loss. Store the outputs and targets for the epoch end step and log the loss
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('test_loss', loss, on_step=True, on_epoch=True)
        self.test_step_outputs.append(logits)
        self.test_step_targets.append(y)
        return loss
    
    # at the end of the epoch compute the metrics
    def on_validation_epoch_end(self):
        # stack all the outputs and targets into a single tensor
        all_preds = torch.vstack(self.validation_step_outputs)
        all_targets = torch.hstack(self.validation_step_targets)
        
        # compute the metrics
        loss = nn.functional.cross_entropy(all_preds, all_targets)
        self.val_metrics_tracker.increment()
        self.val_metrics_tracker.update(all_preds, all_targets)
        self.log('val_loss_epoch_end', loss)
        
        # clear the validation step outputs
        self.validation_step_outputs.clear()
        self.validation_step_targets.clear()
    
    def on_test_epoch_end(self):
        all_preds = torch.vstack(self.test_step_outputs)
        all_targets = torch.hstack(self.test_step_targets)
        
        self.test_metrics_tracker.increment()
        self.test_metrics_tracker.update(all_preds, all_targets)
        # clear the test step outputs
        self.test_step_outputs.clear()
        self.test_step_targets.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Train the MNIST model
Now that we have constructed our model Class, we just need to instantiate an instance of this class and train it. Fortunately, because we are using a LightningModule to represent our model, we can use the PyTorch Lightning `Trainer` class to execute model training. It will do everything for us! All we need to do is instantiate an instance of the `Trainer` class with our model and call the `fit` method. 

Here, we have also included a __callback__ argument in the `Trainer`. Specificall, we've added the `EarlyStopping` callback which we've told to monitor the `val_loss_epoch_end` value. This callback will check this value every training epoch to determine if model training should be halted.

In [None]:
seed_everything(SEED)
encoder = MNISTEncoder(input_dim*input_dim, num_hidden, n_classes)
output = MNISTOutput(num_hidden, n_classes)
model = MNISTClassifier(encoder, output, num_classes=n_classes)

max_epochs = 1 # we're using a small number here for demonstration. Try a larger number like 10 for better performance.
trainer = L.Trainer(default_root_dir=dir_lightning, 
                    max_epochs=max_epochs,
                    callbacks=[EarlyStopping(monitor="val_loss_epoch_end", mode="min")])
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

## Validation Set Accuracy
Now we can use the metrics trackers from our model to examine performance. First let's look at the validation set accuracy.

In [None]:
mca = model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

## Test Set Performance
We can also evaluate the model on the test set. First, we use the trainer to execute the model on the test set.

In [None]:
trainer.test(model=model, dataloaders=test_loader)

Now, we can get the performance metric results from the test metrics tracker

In [None]:
rslt = model.test_metrics_tracker.compute()

Let's look at some of the test set results. First we'll examine the confusion matrix.

In [None]:
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted Label')
cmp.set_xticklabels(mnist_train_all.classes, rotation=90)
cmp.set_yticklabels(mnist_train_all.classes, rotation=0)
cmp.set_ylabel('Actual Label');

Now let's look at the intra-calss ROC results.

In [None]:
fpr, tpr, thresholds = rslt['MulticlassROC']
for i in range(10):
    plt.plot(fpr[i], tpr[i], label=mnist_train_all.classes[i])
plt.xlabel('False Positive Rate')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

We will ultimately want to generate predictions from the model without using the `Trainer` object. Below, we do this to on the test data to generate a classification report using scikit-learn.

In [None]:
device = torch.device("cpu")   #"cuda:0"
# put the model in evaluation mode so that the parameters are fixed and we don't compute gradients
model.eval()
y_true=[]
y_pred=[]
# use torch.no_grad() to disable gradient computation
with torch.no_grad():
    # iterate over the test loader minibatches
    for test_data in test_loader:
        # get the images and labels from the test loader and move them to the cpu. this will make it easier to use them with sklearn
        test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
        pred = model(test_images).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true,y_pred,target_names=mnist_train_all.classes,digits=4))