# Practicum 06 - Image Classificaiton with CNNs

In this practicum, we will use the [Alzheimer MRI Preprocessed Dataset](https://www.kaggle.com/datasets/sachinkumar413/alzheimer-mri-dataset) to build a classificaiton model that estimates dimentia severity for individuals with Alzheimer's disease. The models will use only the MRI image as input. We will use [PyTorch](https://pytorch.org/) and [Pytorch Lightning](https://lightning.ai/docs/pytorch/stable/) to build the models. 

We will illulstrate many of the primary concepts of both PyTorch (PT) and PyTorch Lightning (PTL) (though certainly not all) that are needed to build deep learning models including
1. PyTorch
    1. Tensors
    2. nn.Modules
2. PyTorch Lightning
    1. Data modules and Data loaders
    2. Ligtning modules
    3. Callbacks
    4. Early stopping
    5. Metrics
We will also use the [TorchMetrics](https://lightning.ai/docs/torchmetrics/stable/) library to simplify performance metric calcuation.

For more information on PyTorch and PyTorch Lightning, see these tutorials:
1. [PyTorch Tutorials](https://pytorch.org/tutorials/)
2. [PyTorch Lightning Tutorials](https://lightning.ai/docs/pytorch/stable/levels/core_skills.html)

In [None]:
import os
import torch
from torch import optim, nn, utils, Tensor
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torchvision import transforms, datasets, models
import torchvision.transforms.functional as F
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import seed_everything
import torchmetrics as TM
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

import sys
sys.path.append('../../src') # alternatively add to path using: pip install -e /path/to/src
from torchvision_utils import show

# PyTorch Lightning Demonstration on FashionMNIST

We will begin by creating a PTL model to classify images from the [FashionMNIST](https://github.com/zalandoresearch/fashion-mnist?tab=readme-ov-file) data set, one of the many datasets available in the [PyTorch torchvision](https://pytorch.org/vision/stable/datasets.html) library. The FashionMNIST dataset contains 60,000 training images and 10,000 test images of clothing divided into 10 classes. In this example, we will build a deep learing model with 2 convolutional layers followed by a feed forward layer to predict class membership for an input image.

First, we need to create a directory for the FashionMNIST data. We will create the directory `../../data` which is ignored by git. We also need a direcory to save trained models. By default, PTL will save versions of the model during training called __checkpoints__ as discussed below. Rather than saving these in the current directory, we will a create directory `../../lightning` which is also ignored by git.  

In [None]:
def create_data_directory(path: str):
    if not os.path.exists(path):
        os.makedirs(path)

dir_dataroot = os.path.join("..", "..", "data")
create_data_directory(dir_dataroot)

dir_lightning = os.path.join("..", "..", "lightning")
create_data_directory(dir_lightning)

rs = 123456 # random seed for everything

## Datasets, data loaders, data modules

Now we can pull the data using the `torchvision.datasets.FashionMNIST` which is an instance of PT [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset). When we first instantiate the `FashionMNIST` object, it will automatically download the data to the specificed directory. The download will include training and test data, however our object will only reference the training data. We will split this data into a validation and training set and use those to create PT [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)s. We will then create a second instance of the FashionMNIST dataset setting the argument `train=False` to obtain the test data. We will create a third `DataLoader` with the test data. We will use these data loaders to train and test the model.

### Why all these data handlers?
Unlike the models we've seen previously that used tabular data that we loaded entirely into memory, most deep learning models (such as the ones developed here) use unstructured data that is stored in many files and usually is too large to load into memory at once. Similarly, the models cannot be trained with all of the data in a given iteration. Instead we select small subsets of the data, called __batches__, that we use to update the model during learning (so-called __batch training__). Model training is divided into __epochs__. Assuming there are $n$ samples in the training data and the batch size has $k<<n$ samples, the model will updated $n/k$ (or $n/k +1$) times during each epoch.

In [None]:
# set the random seed for everything
seed_everything(rs)
batch_size = 16
ds_fashion_train_all = FashionMNIST(dir_dataroot, train=True, download=True, transform=ToTensor())

# split into train and validation
train_size = int(0.8 * len(ds_fashion_train_all))
val_size = len(ds_fashion_train_all) - train_size
ds_fashion_train, ds_fashion_val = utils.data.random_split(ds_fashion_train_all, [train_size, val_size])
train_loader = DataLoader(ds_fashion_train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(ds_fashion_val, batch_size=batch_size, shuffle=False)

# get test data
ds_fashion_test = FashionMNIST(dir_dataroot, train=False, download=True, transform=ToTensor())
test_loader = DataLoader(ds_fashion_test, batch_size=batch_size, shuffle=False)

### Torch Tensors
[PyTorch Tensors](https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html) are the primary data structures used in PyTorch models. PT Tensors have many of the same properties as numpy ndarrays. In fact a PT Tensor, `x` can be converted to a numpy array by `na = x.numpy()` and a PT Tensor can be created from an numpy array, `na`, with `x = torch.from_numpy(na)`.

Here we examine the shape of the FashionMNIST tensors. Here, the first dimesnion indicates the number of samples, and the second and third dimensions are the size of each image.

In [None]:
print("Training samples:",ds_fashion_train_all.data[ds_fashion_train.indices].shape)
print("Unique classes in training:",ds_fashion_train_all.targets[ds_fashion_train.indices].unique())

print("\nValidation samples:", ds_fashion_train_all.data[ds_fashion_val.indices].shape)
print("Unique classes in validation:", ds_fashion_train_all.targets[ds_fashion_val.indices].unique())

print("\nTest samples:", ds_fashion_test.data.shape)
print(ds_fashion_test.targets.unique())

We can access each sample as a 2D tensor by either indexing on the dataset, `ds_fashion_train`, or wrapping the train_loader in an `iter`. Here we display a batch of images from the training data loader.

In [None]:
# imgs = [ds_fashion_train[i][0] for i in range(batch_size)]
imgs = next(iter(train_loader))[0]
print(imgs.shape)
grid = make_grid(imgs)
show(grid)

## Modules & Models
Let's start building our FashionMNIST classification model. Conceptually, we will think of this model as a pipeline with two modules:
1. Enocder - this portion of the model will contain our convolutional layers and will output a _representation_ of the input image in the form of many small tensors resulting from the convolution process.
2. Classifier - this portion will contain a fully connected layer that will transorm the encoder representation to class label `logits`. Note, we will not need to explicitly include a softmax function as this can be handled external to the model.

We can create our modules by extending the [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module) class or one of its [many subclases](https://pytorch.org/docs/stable/nn.html). Most often we will only need to concer ourselves with the `__init__` method, where we will specify the subparts of the module and the `forward` method where we will specify how a given input should be processed.

Below, we create an `ImageEncoder` module that contains 2 convolutional layers each followed by a `ReLU` activation function and a max pooling operation. 

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # input has 1 channel
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),  # conv2d layer with 1 input channel, 16 filters, kernel size 3x3
            nn.ReLU(),                                             # ReLU activation function
            nn.MaxPool2d(kernel_size=2, stride=2),                 # max pooling layer with kernel size 2x2
            # input has 16 channels
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), # conv2d layer with 16 input channels, 32 filters, kernel size 3x3
            nn.ReLU(),                                             # ReLU activation function
            nn.MaxPool2d(kernel_size=2, stride=2)                  # max pooling layer with kernel size 2x2
        )

    def forward(self, x):
        x = self.encoder(x)
        return x

# Test the ImageEncoder with random input data
encoder = ImageEncoder()
input_image = torch.randn(1, 1, 28, 28)  # batch_size x channels x height x width
output_features = encoder(input_image)
print("Output shape:", output_features.shape)

Next, we create our classifier module that contains a fully connected feed forward network. It also used a `Flatten` module to combine the incoming input (which will be the output tensors of the convolutional encoder) into a single 1D tensor that is passed to the feed forward layer in the `forward` method.

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()  # Flatten the input tensor
        self.linear = nn.Linear(32 * 7 * 7, 10)  # Linear layer with 10 output units

    def forward(self, x):
        x = self.flatten(x)  # Flatten the input tensor
        x = self.linear(x)   # Pass through the linear layer
        return x

# Test the module with random input data
classifier = Classifier()
input_tensor = torch.randn(5, 32, 7, 7)  # Batch size of 5, input tensor shape [5, 32, 7, 7]
output = classifier(input_tensor)
print("Output shape:", output.shape)  # Expected output shape: [5, 10]

Now that we have our modules, we can compose them into the final model. Here, we use the PTL [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html) class to build the model. The PTL Lightning module is a an extension of the PT nn.Module class that, when used with the PTL [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) will greatly reduce the amount of code needed to managing the model training process (more below.) In the PTL LigthningModlue class, there are several methods we usually will want to overload:
1. `__init__` : similar to nn.Module - define the model components and setup behavior
2. `forward` : same as nn.Module
3. `training_step` : define the loss computation for a single training batch 
4. `validation_step` : define the loss computation for a single validation batch (by default, this is done at the end of each epoch). Can be used for things like early stopping.
5. `test_step` : define the loss computation for a single test batch 
6. `configure_optimizers` : used to specify the optimization object that will update the model parameters given a loss value

We may also want to use the PTL LightningModule __call back hooks__ such as `on_validation_epoch_end` and `on_test_epoch_end` that allow us to define actions that occur between epochs.

We will also make use of the [TorchMetrics](https://lightning.ai/docs/torchmetrics/stable/) library

Below we build an ImageClassifier model that takes as input an encoder and a classifier (both of which are expected to be nn.Modules) and a scalar, `num_classes`. 

In [None]:
class ImageClassifierModel(L.LightningModule):
    def __init__(self, encoder, classifier, num_classes):
        super().__init__()
        # model layers
        self.encoder = encoder
        self.classifier = classifier
        
        # validation metrics
        self.val_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes)]))
        self.validation_step_outputs = []
        self.validation_step_targets = []

        # test metrics
        self.test_roc = TM.ROC(task="multiclass", num_classes=num_classes, thresholds=list(np.linspace(0.0, 1.0, 20))) # roc and cm have methods we want to call so store them in a variable
        self.test_cm = TM.ConfusionMatrix(task='multiclass', num_classes=num_classes)
        self.test_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes), 
                                                            self.test_roc, self.test_cm]))
        self.test_step_outputs = []
        self.test_step_targets = []

    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True)
        
        # store the outputs and targets for the epoch end step
        self.validation_step_outputs.append(logits)
        self.validation_step_targets.append(y)
        return loss
    
    def on_validation_epoch_end(self):
        # stack all the outputs and targets into a single tensor
        all_preds = torch.vstack(self.validation_step_outputs)
        all_targets = torch.hstack(self.validation_step_targets)
        
        # compute the metrics
        loss = nn.functional.cross_entropy(all_preds, all_targets)
        self.val_metrics_tracker.increment()
        self.val_metrics_tracker.update(all_preds, all_targets)
        self.log('val_loss_epoch_end', loss)
        
        # clear the validation step outputs
        self.validation_step_outputs.clear()
        self.validation_step_targets.clear()
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('test_loss', loss, on_step=True, on_epoch=True)
        self.test_step_outputs.append(logits)
        self.test_step_targets.append(y)
        return loss
    
    def on_test_epoch_end(self):
        all_preds = torch.vstack(self.test_step_outputs)
        all_targets = torch.hstack(self.test_step_targets)
        
        self.test_metrics_tracker.increment()
        self.test_metrics_tracker.update(all_preds, all_targets)
        # clear the test step outputs
        self.test_step_outputs.clear()
        self.test_step_targets.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

## Model Training
Now that we have specified the model architecture we can train the model. Here we will use the PTL [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) class. The `Trainer` will handle all of the training process: sampling batches from the training data, calling the model train step, calling the optimizer, updating the model parameters, storing training metrics, etc. 

Importantly, the `trainer` will also automatically check if a __GPU__ is available and if so, handle transferring the data to and from the GPU as needed. 

In [None]:
seed_everything(rs)
encoder = ImageEncoder()
classifier = Classifier()
fashion_model = ImageClassifierModel(encoder, classifier, num_classes=10)

trainer = L.Trainer(default_root_dir=dir_lightning, 
                    max_epochs=5,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)])
trainer.fit(model=fashion_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

### Validation Set Accuracy
The validation step performance metric calculations are stored in the variable `fashion_model.val_metrics_tracker`. Calling the `compute_all` method on this variable returns a dictionary like object containing the any metrics, in this case __multi class accuracy_.

In [None]:
mca = fashion_model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

### Test Set Performance
Now let's take a look at the test set performance. We can again use the `trainer` to run the model on the test data by calling the `trainer.test` method and passing it the `test_loader`.

In [None]:
trainer.test(model=fashion_model, dataloaders=test_loader)

Here, our test step performance metric calculations are stored in the variable `fashion_model.test_metrics_tracker`. Calling the `compute` method on this variable returns a dictionary like object containing the any metrics, in this case _MulticlassConfusionMatrix__ and __MulticlassROC__.

In [None]:
rslt = fashion_model.test_metrics_tracker.compute()

In [None]:
# Plot the confusion matrix
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted Label')
cmp.set_xticklabels(ds_fashion_train_all.classes, rotation=90)
cmp.set_yticklabels(ds_fashion_train_all.classes, rotation=0)
cmp.set_ylabel('Actual Label');

In [None]:
# Plot the 1 vs all ROC curves
fpr, tpr, thresholds = rslt['MulticlassROC']
for i in range(10):
    plt.plot(fpr[i], tpr[i], label=ds_fashion_train_all.classes[i])
plt.xlabel('False Positive Rate')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

In [None]:
# Print the classification report
device = torch.device("cpu")   #"cuda:0"
fashion_model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_data in test_loader:
        test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
        pred = fashion_model(test_images).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true,y_pred,target_names=ds_fashion_train_all.classes,digits=4))

## Loading saved models

You may have noticed that during training, PTL saved `.ckpt` files to the `../../lightning` directory. These are trained versions of the model that can be loaded directly into a new instance of the model. In turn, these can be used for further training or inference on new test samples.

In [None]:
# load model from checkpoint
# YOU MAY NEED TO UPDATE THE PATH TO THE CHECKPOINT FILE
enc = ImageEncoder()
clf = Classifier()
mp = os.path.join(dir_lightning, 'lightning_logs/version_0/checkpoints/epoch=7-step=24000.ckpt')
model = ImageClassifierModel.load_from_checkpoint(mp, encoder=enc, classifier=clf, num_classes=10)

In [None]:
# Run the model on the test data
trainer.test(model=model, dataloaders=test_loader)

# Alzheimer MRI Classification

Now, we will build on our knowledge of PT and PTL to create a model that classifies input MRI images of a patient's brain to infer the severity of dementia. This is a multi-class problem where the goal is to differentiate between:
1. Non Demented
2. Very Mild Dementia
3. Mild Dementia
4. Moderate Dementia

## Dataset and Dataloader
This dataset is not provided in the PT library. As such, we will need to develop custom DataLoaders to handle the data. We will extend the PT [LightningDataModule](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html#lightning.pytorch.core.LightningDataModule) which will handle much of the work for us. Also, because our data is organized in a directory where the subdirectories are arranged by outcome label, we will use the [torchvision ImageFolder](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder) to manage the dataset.

In [None]:
class AlzheimerDataModule(L.LightningDataModule):
    def __init__(self, root_dir, transform=None, batch_size=16):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        if transform is None:
            self.transform=transforms.Compose([
                      transforms.RandomRotation(10),      # rotate +/- 10 degrees
                      transforms.RandomHorizontalFlip(),  # reverse 50% of images
                      transforms.ToTensor()])
        else:
            self.transform = transform
        self.classes = sorted(os.listdir(self.root_dir))

    def setup(self, stage=None):
        dataset = datasets.ImageFolder(root=self.root_dir, transform=self.transform)
        n_data = len(dataset)
        n_train = int(0.8 * n_data)
        n_test = n_data - n_train

        train_dataset, test_dataset = torch.utils.data.random_split(dataset, [n_train, n_test])

        train_size = int(0.9 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        train_dataset, val_dataset = utils.data.random_split(train_dataset, [train_size, val_size])

        self._train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        self._val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size)
        self._test_dataloader = DataLoader(test_dataset, batch_size=self.batch_size)

    def train_dataloader(self):
        return self._train_dataloader

    def test_dataloader(self):
        return self._test_dataloader
    
    def val_dataloader(self):
        return self._val_dataloader

Now we can create variables for our dataloaders for model training and testing.

In [None]:
seed_everything(rs)
alzheimers_dm = AlzheimerDataModule(root_dir=os.path.join(dir_dataroot, 'alzheimer-mri', 'Dataset'), batch_size=32)
alzheimers_dm.setup()
alzheimers_train_dataloader=alzheimers_dm.train_dataloader()
alzheimers_val_dataloader=alzheimers_dm.val_dataloader()
alzheimers_test_dataloader=alzheimers_dm.test_dataloader()

First, let's visualize a sample of the data.

In [None]:
batch = next(iter(alzheimers_train_dataloader))
imgs = batch[0]
labels = batch[1]
grid = make_grid(imgs)
print("Batch Dimensions:", imgs.shape)
print("Batch Targets:",labels)
show(grid)

## Alzheimers Classification Model

# Problem 1 (3 points)
Define new Image Encoder and Classifier modules for the MRI data. Specifically, in the code cells below adapt the `__init__` methods from `ImageEncoder` and `Classifier` classes we created for the FashionMNIST data by changing the number of kernels in the two conv2d layers to 32 and 64 respectively and adjust for the number of input channels for the first conv2d layer which is now 3 for the mri images. Hint, you can use the output shape from the `MRIImageEncoder` (first cell) to determine the appropriate input size for the linear layer in the `MRIClassifier`.

In [None]:
class MRIImageEncoder(nn.Module):
    def __init__(self):
        ##### START YOUR CODE HERE #####
        pass
        ##### END YOUR CODE HERE #####

    def forward(self, x):
        x = self.encoder(x)
        return x

# Test the ImageEncoder with random input data
encoder = MRIImageEncoder()
input_image = torch.randn(1, 3, 128, 128)  # batch_size x channels x height x width
output_features = encoder(input_image)
print("Output shape:", output_features.shape)

In [None]:
class MRIClassifier(nn.Module):
    def __init__(self):
        ##### START YOUR CODE HERE #####
        pass
        ##### END YOUR CODE HERE #####

    def forward(self, x):
        x = self.flatten(x)  # Flatten the input tensor
        x = self.linear(x)   # Pass through the linear layer
        return x

# Test the module with random input data
module = MRIClassifier()
input_tensor = torch.randn(5, 64, 31, 31)  # Batch size of 5, input tensor shape [5, 32, 7, 7]
output = module(input_tensor)
print("Output shape:", output.shape)  # Expected output shape: [5, 10]

# Problem (2 points)
We always want to make our code as resusable as possible. What change could you make to the `init` method in `ImageEncoder` and the `Classifier` classes to allow them to be resued without creating new classes for the MRI data?

# Problem 3 (1 point)
In the code cell below, train the MRI model using a PTL Trainer (as shown with the FashionMNIST model above). However, when constructing the `Trainer` object, set `max_epochs=20`. Be sure to pass the correct training and validation loaders (alzheimers_val_dataloader and alzheimers_test_dataloader).

In [None]:
seed_everything(rs)
encoder = MRIImageEncoder()
classifier = MRIClassifier()
mri_model = ImageClassifierModel(encoder, classifier, 4)

##### START YOUR CODE HERE #####
trainer = None
##### END YOUR CODE HERE #####

## Validation Set Accuracy

In [None]:
mca = mri_model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

## Test Set Performance

In [None]:
trainer.test(model=mri_model, dataloaders=alzheimers_dm.test_dataloader())

In [None]:
rslt = mri_model.test_metrics_tracker.compute()

In [None]:
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted label')
cmp.set_xticklabels(alzheimers_dm.classes, rotation=45)
cmp.set_yticklabels(alzheimers_dm.classes, rotation=0)
cmp.set_ylabel('Actual label');

In [None]:
fpr, tpr, thresholds = rslt['MulticlassROC']
for i in range(4):
    plt.plot(fpr[i], tpr[i], label=alzheimers_dm.classes[i])
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

# Transfer Learing
In many cases, especially when we have limted data, we can use transfer learning where a more complex model, trained on a very large dataset from another domain is used fine tuned to a new problem. Here, we demonstrate transfer learning on the MRI problem. Specifically, we use the [SqueezeNet](https://pytorch.org/vision/stable/models/generated/torchvision.models.squeezenet1_1.html#torchvision.models.SqueezeNet1_1_Weights) model which was trained on ImageNet data as our encoder. The SqueezeNet model is deep model with many CNN layers that was trained to classify images from ImageNet. We will only use the the "encoder" portion of the SqueezeNet model (contained in the `features` variable of the SqueezeNet model). We will add additional layers to serve as our classifier. When we train the model, we will only update the classifier parameters, though we could also update the encoder parameters if we had sufficient data and computation time.

First let's recreate our data module as we will utilize additional transforms on the dataset.

In [None]:
seed_everything(rs)
transform=transforms.Compose([
        transforms.RandomRotation(10),      # rotate +/- 10 degrees
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.Resize(224),             # resize shortest side to 224 pixels
        transforms.CenterCrop(224),         # crop longest side to 224 pixels at center
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
        
])

alzheimers_dm = AlzheimerDataModule(root_dir=os.path.join(dir_dataroot, 'alzheimer-mri', 'Dataset'), batch_size=32, transform=transform)
alzheimers_dm.setup()
alzheimers_train_dataloader=alzheimers_dm.train_dataloader()
alzheimers_val_dataloader=alzheimers_dm.val_dataloader()
alzheimers_test_dataloader=alzheimers_dm.test_dataloader()

Now we can create the `SqueezeNetMRIImageEncoder` using the `SqueezeNet` features backbone.

In [None]:
class SqueezeNetMRIImageEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        backbone = models.squeezenet1_1(weights=models.SqueezeNet1_1_Weights.DEFAULT)
        self.encoder = backbone.features

    def forward(self, x):
        x = self.encoder(x)
        return x

# Test the ImageEncoder with random input data
encoder = SqueezeNetMRIImageEncoder()
input_image = torch.randn(1, 3, 224, 224)  # batch_size x channels x height x width
output_features = encoder(input_image)
print("Output shape:", output_features.shape)

Next, we create the classifier which is slightly more complicate than the one we created previously as we want to add a final convolutional layer as was done with the original SqueezeNet model.

In [None]:
class SqueezeNetMRIClassifierSquezeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_classes = 4
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.3), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
        )

    def forward(self, x):
        x = self.classifier(x)   # Pass through the linear layer
        return x.squeeze()

# Test the module with random input data
module = SqueezeNetMRIClassifierSquezeNet()
input_tensor = torch.randn(5, 512, 13, 13)  # Batch size of 5, input tensor shape [5, 32, 7, 7]
output = module(input_tensor)
print("Output shape:", output.shape)  # Expected output shape: [5, 10]

We are now ready to create our overall model.

In [None]:
class ImageClassifierSqueezeNet(ImageClassifierModel):
    def __init__(self, encoder, classifier, num_classes, freeze_encoder=True):
        super().__init__(encoder, classifier, num_classes)
        if freeze_encoder:
            self.encoder.freeze()
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.classifier.parameters(), lr=1e-3)
        return optimizer

Finally, we can train and evaluate our model.

In [None]:
seed_everything(rs)
encoder = SqueezeNetMRIImageEncoder()
classifier = SqueezeNetMRIClassifierSquezeNet()
mri_squeeze_model = ImageClassifierSqueezeNet(encoder, classifier, 4)

trainer = L.Trainer(default_root_dir=dir_lightning, 
                    max_epochs=50,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5)],
                    accelerator='gpu')
trainer.fit(model=mri_squeeze_model, train_dataloaders=alzheimers_dm.train_dataloader(), val_dataloaders=alzheimers_dm.val_dataloader())

### Validation Set 

In [None]:
mca = mri_squeeze_model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

### Test Set

In [None]:
trainer.test(model=mri_squeeze_model, dataloaders=alzheimers_dm.test_dataloader())

In [None]:
rslt = mri_squeeze_model.test_metrics_tracker.compute()

In [None]:
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted label')
cmp.set_ylabel('Actual label');

In [None]:
fpr, tpr, thresholds = rslt['MulticlassROC']
class_map = {0: 'Mild', 1: 'Moderate', 2: 'Healthy', 3: 'Very Mild'}
for i in range(4):
    plt.plot(fpr[i], tpr[i], label=class_map[i])
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

In [None]:
device = torch.device("cpu")   #"cuda:0"
class_names=sorted(os.listdir(alzheimers_dm.root_dir))
mri_squeeze_model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_data in alzheimers_dm.test_dataloader():
        test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
        pred = mri_squeeze_model(test_images).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true,y_pred,target_names=class_names,digits=4))