# Coursework: Self-supervised learning

In this coursework, you will explore the popular self-supervised contrastive learning approach [SimCLR]((https://arxiv.org/abs/2002.05709)).

You will be asked to implement some of the key components of SimCLR, including a suitable data augmentation strategy (for generating positive pairs), the SimCLR loss function, and the SimCLR training step. Additionally, you will be using transfer learning strategies for evaluating the performance of different pre-trained models for a downstream classification task.

The coursework is divided into three-parts:
- **Part A:** Implementation of a suitable dataset for contrastive model training;
- **Part B:** Implementation of the SimCLR loss and training step;
- **Part C:** Implementation of transfer learning strategies (linear probing and finetuning) for model evaluation.

**Important:** Read the text descriptions carefully and look out for hints and comments indicating a specific 'TASK'. Make sure to add sufficient documentation to your code.

**Submission:** You are asked to submit two versions of your notebook:
1. You should submit the raw notebook in `.ipynb` format with *all outputs cleared*. Please name your file `coursework.ipynb`.
2. Additionally, you will be asked to submit an exported version of your notebook in `.pdf` format, with *all outputs included*. We will primarily use this version for marking, but we will use the raw notebook to check for correct implementations. Please name this file `coursework_export.pdf`.

## Your details

Please add your details below. You can work in groups up to two.

Authors: **Kangle Yuan** & **Jiqiu Hu**

DoC alias: **ky523** & **jh523**

## Setup

In [None]:
# On Google Colab uncomment the following line to install PyTorch Lightning and
#the MedMNIST dataset
! pip install lightning medmnist

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import models
from torchvision import transforms
from pytorch_lightning import LightningModule, LightningDataModule, Trainer,
seed_everything
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from torchmetrics.functional import auroc
from PIL import Image
from medmnist.info import INFO
from medmnist.dataset import MedMNIST

## **Part A:** Implement a dataset suitable for contrastive learning.

We will be using the [MedMNIST Pneumonia](https://medmnist.com/) dataset, which is a medical imaging inspired dataset but with the characteristics of MNIST. This allows efficient experimentation due to the small image size. The dataset contains real chest X-ray images but downsampled to 28 x 28 pixels, with binary labels indicating the presence of [Pneumonia](https://www.nhs.uk/conditions/pneumonia/) (which is an inflammation of the lungs).

### **Task A-1:** Complete the dataset implementation.

You are asked to implement a dataset class `SimCLRPneumoniaMNISTDataset` suitable for training a self-supervised model with a contrastive objective. For each sample, your dataset class should return two 'views' of the corresponding image, forming the positive pairs for contrastive learning. It is up to you to design suitable augmentation pipeline for generating these views. Please provide a short description in plain language of what your data augmentation pipeline is meant to do.

To get you started, we have provided the skeleton of the dataset class in the cell below. Once you have implemented your dataset class, you are asked to run the provided visualisation code to visualise one batch of your training dataloader.

*Note:* You can use the same data augmentation pipeline for training, validation, and testing.

In [None]:
class SimCLRPneumoniaMNISTDataset(MedMNIST):
    def __init__(self, split = 'train'):
        ''' Dataset class for PneumoniaMNIST.
        The provided init function will automatically download the necessary
        files at the first class initialistion.

        :param split: 'train', 'val' or 'test', select subset

        '''
        self.flag = "pneumoniamnist"
        self.size = 28
        self.size_flag = ""
        self.root = './data/coursework/'
        self.info = INFO[self.flag]
        self.download()

        npz_file = np.load(os.path.join(self.root, "pneumoniamnist.npz"))

        self.split = split

        # Load all the images
        assert self.split in ['train','val','test']

        self.imgs = npz_file[f'{self.split}_images']
        self.labels = npz_file[f'{self.split}_labels']

        # TASK: Define here your data augmentation pipeline
        # Add a short description in plain language.

        # Random Rotation: rotate the image by a small angle, to
        # mimic how patient lie unperfectly in real life

        # Random Horizontal Flip: flip the image horizontally with a 50% chance,
        # to mimic the different side of viewing the xray.

        # Random Resized Crop: Crop a part of the image and resize it back to
        # 28x28 pixels. This help the model to focus on different parts
        # of the lung.

        # Color Jitter: Adjust the brightness, saturation, hue and contrast of
        #the image,to mimic the diffrence after processing of x-ray

        # Gaussian Noise: add a small amount of Gaussian noise to mimic
        # the sensor noise in X-ray machines.

        self.augmentation_pipeline = transforms.Compose([
            # convert to PIL image
            transforms.ToPILImage(),
            # randomly rotate image by degree from -3 to 3
            transforms.RandomRotation(3),
            # flip image horizontally with 50% probability
            transforms.RandomHorizontalFlip(),
            # crop random part of the image and then resize back
            transforms.RandomResizedCrop(self.size, scale=(0.77, 0.92)),
            #randomly changes brightness and contrast of the image within range


            transforms.RandomApply([transforms.ColorJitter(brightness=0.8,
                    contrast=0.8,saturation=0.8,hue=0.2)],p=0.8),

            #converts the PIL to a tensor
            transforms.ToTensor(),
            #transforms.RandomGrayscale(p=0.2)
            # #grayscele (1 channel) then duplicate the channel 3 times
            # transforms.Lambda(lambda x: torch.cat([x, x, x], 0) \
            #                   if x.size(0) == 1 else x),
            #applies Gaussian blur with a 3*3 kernel to image with 50% chance
            transforms.RandomApply([transforms.GaussianBlur(3)], p=0.5),
            #normalizes the image tensor with certain mean and std
            transforms.Normalize(mean=[0.485], std=[0.229])
        ])


    def __len__(self):
        return self.imgs.shape[0]

    def __getitem__(self, index):
        # TASK: Fill in the blanks such that you return two tensors
        # of shape [1, 28, 28], img_view1 and img_view2, representing two
        #augmented view of the images.
        ...
        img = self.imgs[index]
        img_view1 = self.augmentation_pipeline(img)
        img_view2 = self.augmentation_pipeline(img)

        return img_view1, img_view2

We use a [LightningDataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) for handling your PneumoniaMNIST dataset. You do not need to make any modifications to the code below.

In [None]:
class SimCLRPneumoniaMNISTDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 8):
        super().__init__()
        self.batch_size = batch_size
        self.train_set = SimCLRPneumoniaMNISTDataset(split='train')
        self.val_set = SimCLRPneumoniaMNISTDataset(split='val')
        self.test_set = SimCLRPneumoniaMNISTDataset(split='test')

    def train_dataloader(self):
        return DataLoader(dataset=self.train_set, batch_size=self.batch_size,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_set, batch_size=self.batch_size,
                          shuffle=False)

    def test_dataloader(self):
        return DataLoader(dataset=self.test_set, batch_size=self.batch_size,
                          shuffle=False)

#### **Check** dataset implementation.

Run the below cell to visualise a batch of your training dataloader.

In [None]:
# DO NOT MODIFY THIS CELL! IT IS FOR CHECKING THE IMPLEMENTATION ONLY.

# Initialise data module
datamodule = SimCLRPneumoniaMNISTDataModule()
# Get train dataloader
train_dataloader = datamodule.train_dataloader()
# Get first batch
batch = next(iter(train_dataloader))
# Visualise the images
view1, view2 = batch
f, ax = plt.subplots(2, 8, figsize=(12,4))
for i in range(8):
  ax[0,i].imshow(view1[i, 0], cmap='gray')
  ax[1,i].imshow(view2[i, 0], cmap='gray')
  ax[0,i].set_title('view 1')
  ax[1,i].set_title('view 2')
  ax[0, i].axis("off")
  ax[1, i].axis("off")

## **Part B:** Implement the SimCLR loss and training step.

In this part, we ask you to:
1. Implement the SimCLR loss function, as per the equation in the lecture notes (and the [original paper](https://arxiv.org/abs/2002.05709)).
2. Once you have implemented the loss, implement the training step function in the provided LightningModule.

### **Task B-1:** SimCLR loss function.

For the implementation of the SimCLR loss, you should follow the 'recipe' from the lecture slides. We provide a code skeleton to get you started. Fill in all the blanks.

*Hint:* In PyTorch, to compute scalar products (also called dot products) between many elements efficiently, note that for two batches of $d$-dimensional feature vectors $v1$ and $v2$ of size $[N, d]$ (with $N$ being the batch size) computing the matrix multiplication `torch.mm(v1, v2.t())` returns a matrix $S$ of size $[N, N]$ where each element $S[i, j]$ is the scalar product of $v1_i$ and $v2_j$.

In [None]:
import torch
import torch.nn.functional as F

def simclr_loss(embedding_view1, embedding_view2, tau=1.0):
    # Step 1: Normalize the embeddings
    embedding_view1 = F.normalize(embedding_view1, dim=1)
    embedding_view2 = F.normalize(embedding_view2, dim=1)

    # Step 2: gather all embeddings into one big vector of size [2*N,feature_dim]
    z_all_views = torch.cat([embedding_view1, embedding_view2], dim=0)

    # Step 3: compute all possible similarities,
             #should be a matrix of size [2 * N, 2 * N]
    # all_similarities[i,j] will be the similarity between z_all_views[i] and
    #z_all_views[j].
    # Use the hint.
    all_similarities = torch.mm(z_all_views, z_all_views.t())

    # Step 4: Here we want to return a mask of size[2 * N, 2* N] for which
    #mask[i,j] = 1 if z_all_views[i] and z_all_views[j] form a positive pair.
    # There should be exactely 2 * N non-zeros elements in this matrix.
    batch_size = embedding_view1.size(0)
    masks_pre = torch.cat([torch.arange(batch_size) for _ in range(2)], dim=0)
    #element-wise comparison
    boolean_pair = masks_pre.unsqueeze(0) == masks_pre.unsqueeze(1)
    positive_pair_mask = boolean_pair.float().to(embedding_view1.device)
    positive_pair_mask.fill_diagonal_(0) # Exclude self-similarities

    # Step 5: self-mask. For computing the denominator term in the loss function,
    # we need to sum over all possible similarities except the self-similarity.
    # Create a mask of shape [2*N, 2*N] that is 1 for all valid pairs and 0 for
    #all self-pairs (i = j).
    mask_exclude_self = 1 - torch.eye(
        2 * batch_size, device=embedding_view1.device)

    # Step 6: Computing all numerators for the loss function.
    # Should be vector of size [2 * N],
    # where element is exp(sim(i, j) / t) for each positive pair (i, j).
    # Re-use the computed quantities above.
    numerators_all = torch.exp(all_similarities / tau) *
    positive_pair_mask.to(embedding_view1.device)

    # Step 7: Computing all denominators for the loss function.
    # Should be a vector of size [2 * N].
    # Where each element should be the sum of exp(sim(i,k)/tau) for all k != i.
    denominators_all = torch.sum(torch.exp(all_similarities / tau) *
              mask_exclude_self, dim=1, keepdim=True).to(embedding_view1.device)


    # Step 8: Return the final loss values, using the previously
    #computing numerators and denominators.
    loss = -torch.log(
        numerators_all[positive_pair_mask.bool()] / (denominators_all + 1e-11))
    loss = loss.mean()

    return loss


#### **Check** SimCLR loss function.

To check your implementation, please run the following tests. Note that we will also use other tests on different inputs to test your code.

In [None]:
# DO NOT MODIFY THIS CELL! IT IS FOR CHECKING THE IMPLEMENTATION ONLY.

seed_everything(33)

expected_results = [torch.tensor(1.7518), torch.tensor(1.6376),
                    torch.tensor(4.194),  torch.tensor(4.1754)]
for i, (N, feature_dim) in enumerate(zip([3, 3, 33, 33], [5, 125, 5, 125])):
  embedding_view1 = torch.rand((N, feature_dim))
  embedding_view2 = torch.rand((N, feature_dim))
  loss = simclr_loss(embedding_view1.clone(), embedding_view2.clone(), tau=0.5)
  print(f"Expected loss: {expected_results[i]}, Computed loss: {loss}")
  assert torch.isclose(loss, expected_results[i], rtol=1e-3)
print("Passed all tests successfully !")

### **Task B-2:** SimCLR training step.

In this next task you are asked to complete the blanks in the provided [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html).

We provide the implementation of an image encoder (the CNN backbone that will act as feature extractor). No changes are needed for this part.

In [None]:
class ImageEncoder(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.net = models.resnet50(weights=None)
        del self.net.fc
        self.net.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2,
                                         padding=3, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.net.conv1(x)
        x = self.net.bn1(x)
        x = self.net.relu(x)
        x0 = self.net.maxpool(x)
        x1 = self.net.layer1(x0)
        x2 = self.net.layer2(x1)
        x3 = self.net.layer3(x2)
        x4 = self.net.layer4(x3)
        x4 = self.net.avgpool(x4)
        x4 = torch.flatten(x4, 1)
        return x4

Next, you will need to complete the implementation of the SimCLR model. In order to make the training step work correctly, you will need to implement the `process_batch` function.

In [None]:
class SimCLRModel(LightningModule):
    def __init__(self, learning_rate: float = 0.001):
        super().__init__()
        self.learning_rate = learning_rate

        self.encoder = ImageEncoder()

        self.projector = torch.nn.Sequential(
            torch.nn.Linear(2048, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 128),
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def process_batch(self, batch):
        # TASK: Implement the process_batch function

        view1,view2 = batch
        h_1 = self.encoder(view1)
        h_2 = self.encoder(view2)
        z_1 = self.projector(h_1)
        z_2 = self.projector(h_2)
        loss = simclr_loss(z_1, z_2)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.process_batch(batch)
        self.log('train_loss', loss, prog_bar=True)
        if batch_idx == 0:
            grid = torchvision.utils.make_grid(torch.cat((batch[0][0:4, ...],
                            batch[1][0:4, ...]), dim=0), nrow=4, normalize=True)
            self.logger.experiment.add_image('train_images', grid,
                                             self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.process_batch(batch)
        self.log('val_loss', loss, prog_bar=True)

#### **Check** SimCLR training step.

Here you can test that your code runs fine by training the model for 5 epochs using the cell below.

Report the training and validation loss at the end of 5 epochs.

In [None]:
# DO NOT MODIFY THIS CELL! IT IS FOR CHECKING THE IMPLEMENTATION ONLY.

seed_everything(33, workers=True)

data = SimCLRPneumoniaMNISTDataModule(batch_size=32)

model = SimCLRModel()

trainer = Trainer(
    max_epochs=5,
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(save_dir='./lightning_logs/coursework/',
                             name='simclr'),
    callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'),
               TQDMProgressBar(refresh_rate=10)],
)
trainer.fit(model=model, datamodule=data)

## **Part C:** Linear probing and model finetuning.

In this part, you are given two different image encoders that were pre-trained with different datasets and training strategies. The objective for this task is to assess the performance of these two encoders in a downstream classification task. This this end, you are asked to implement evaluation routines seen in the lecture: linear probing and model finetuning. The downstream task is the prediction of Pneumonia in the (small) chest X-ray images from the PneumoniaMNIST dataset.

This part can be broken down into the following tasks:
1. Adapt your PneunomiaMNIST dataset for the image classification task.
2. Implement a classification model with a linear layer attached to a pre-trained image encoder.
3. For both pre-trained encoders:
    - a) Train the classifier on top of the frozen encoder (linear probing)
    - b) Finetune the entire model (including the encoder).
4. Evaluate all models on the test set, and provide a brief summary (no more than 300 words) with an analysis of your findings.

### **Task C-1:** Adapt your PneunomiaMNIST dataset for the image classification task.

We can base our implementation largely on the `SimCLRPneumoniaMNISTDataset` and adapt it to make it suitable for image classification. Think about a suitable data augmentation pipeline. Check previous tutorials for inspiration.

In [None]:
class PneumoniaMNISTDataset(MedMNIST):
    def __init__(self, split = 'train', augmentation: bool = False):
        ''' Dataset class for Pneumonia MNST.
        The provided init function will automatically download the necessary
        files at the first class initialistion.

        :param split: 'train', 'val' or 'test', select subset

        '''
        self.flag = "pneumoniamnist"
        self.size = 28
        self.size_flag = ""
        self.root = './data/coursework/'
        self.info = INFO[self.flag]
        self.download()

        npz_file = np.load(os.path.join(self.root, "pneumoniamnist.npz"))

        self.split = split

        # Load all the images
        assert self.split in ['train','val','test']

        self.imgs = npz_file[f'{self.split}_images']
        self.labels = npz_file[f'{self.split}_labels']

        self.do_augment = augmentation

        # TASK: Define here your data augmentation pipeline suitable for
        #classification.
        # Check previous tutorials for inspiration.
        self.transform = transforms.Compose([
                  transforms.ToPILImage(),
                  # crop random part of the image and then resize back
                  transforms.RandomResizedCrop(self.size, scale=(0.77, 0.92)),
                  # flip image horizontally with 50% probability
                  transforms.RandomHorizontalFlip(),
                  transforms.ToTensor(),
                  ])

    def __len__(self):
        return self.imgs.shape[0]

    def __getitem__(self, index):
        # TASK: Implement the __getitem__ function to return the image and its
        #class label.
        img = self.imgs[index]
        label = self.labels[index]

        if self.do_augment:
          img = self.transform(img)
        else:
          transform = transforms.Compose([transforms.ToPILImage(),
                                          transforms.ToTensor()])
          img = transform(img)
        return img, label

Again, we use a [LightningDataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) for handling your PneumoniaMNIST dataset. No changes needed for this part.

In [None]:
class PneumoniaMNISTDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        self.batch_size = batch_size
        self.train_set = PneumoniaMNISTDataset(split='train', augmentation=True)
        self.val_set = PneumoniaMNISTDataset(split='val', augmentation=False)
        self.test_set = PneumoniaMNISTDataset(split='test', augmentation=False)

    def train_dataloader(self):
        return DataLoader(dataset=self.train_set, batch_size=self.batch_size,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_set, batch_size=self.batch_size,
                          shuffle=False)

    def test_dataloader(self):
        return DataLoader(dataset=self.test_set, batch_size=self.batch_size,
                          shuffle=False)

#### **Check** dataset implementation.

Run the below cell to visualise a batch of your training dataloader.

In [None]:
# DO NOT MODIFY THIS CELL! IT IS FOR CHECKING THE IMPLEMENTATION ONLY.

# Initialise data module
datamodule = PneumoniaMNISTDataModule()
# Get train dataloader
train_dataloader = datamodule.train_dataloader()
# Get first batch
batch = next(iter(train_dataloader))
# Visualise the images
images, labels = batch
f, ax = plt.subplots(1, 8, figsize=(12,4))
for i in range(8):
  ax[i].imshow(images[i, 0], cmap='gray')
  ax[i].set_title('label: ' + str(labels[i].item()))
  ax[i].axis("off")

### **Task C-2:** Implement a classification model with a linear layer attached to a pre-trained image encoder.

We first download the weights of the two pre-trained image encoders. One of them has been trained with the self-supervised SimCLR objective on a large publicly available chest X-ray dataset (different from PneunomiaMNIST). The other encoder is a standard ImageNet backbone that has been trained with a supervised classification objective on the ImageNet dataset.

In [None]:
! wget https://www.doc.ic.ac.uk/~bglocker/teaching/mli/coursework.zip
! unzip coursework.zip

We provide the function for loading the encoders. No changes needed here.

In [None]:
def load_encoder_from_checkpoint(checkpoint_path):
  ckpt = torch.load(checkpoint_path, map_location='cpu')
  simclr_module = SimCLRModel()
  print(simclr_module.load_state_dict(state_dict=ckpt))
  return simclr_module.encoder.eval()

imagenet_model = './data/coursework/model_imagenet.ckpt'
chestxray_model = './data/coursework/model_chestxray.ckpt'


Now, implement a classification model as a LightningModule for image classification using a pre-trained image encoder.

The model should have a flag in the init function `freeze_encoder` that if set to true freezes all the weights in the encoder (used for linear probing), and if set to false all weights are trainable (used for model finetuning).

*Hint:* Check out previous tutorials for inspiration on how to implement a classification model as LightningModule. For the coursework, we recommend using the Area Under the Receiver Operating Characteristic Curve (ROC-AUC) performance metric (instead of accuracy). ROC-AUC is measure of the overall discriminative power of a classification model. You can use the readily available implementation in [torchmetrics](https://lightning.ai/docs/torchmetrics/stable/classification/auroc.html#functional-interface). You should log the ROC-AUC similar to how we logged accuracy in previous tutorials.

In [None]:
# TASK: Implement the ImageClassifier class
# Check previous tutorials for insipration how to implement an `ImageClassifier`

class ImageClassifier(LightningModule):
    def __init__(self, pretrained_encoder: torch.nn.Module,
                 freeze_encoder: bool = True, output_dim: int = 2,
                 learning_rate: float = 0.001):
        super().__init__()
        self.encoder = pretrained_encoder
        self.flag = freeze_encoder
        self.output_dim = output_dim
        self.learning_rate = learning_rate
        if self.flag:
          for param in self.encoder.parameters():
            param.requires_grad = False
        self.ln = nn.Linear(2048, self.output_dim)


    def forward(self, x):
      #pass linearised x through the encoder
      encoding = self.encoder(x)
      #pass embedding through the fully connected layers
      pre = self.ln(encoding)

      return pre

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def process_batch(self, batch):
        x, y = batch
        logits = self(x)
        y = y.reshape(y.size(0))
        loss = F.cross_entropy(logits, y)
        roc_auc = auroc(logits, y, task='multiclass', num_classes=self.output_dim)
        return loss, roc_auc

    def training_step(self, batch, batch_idx):
        loss, roc_auc = self.process_batch(batch)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_roc_auc', roc_auc, prog_bar=True)
        if batch_idx == 0:
            grid = torchvision.utils.make_grid(batch[0][0:16, ...],
                                               nrow=4, normalize=True)
            self.logger.experiment.add_image('train_images', grid, self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, roc_auc = self.process_batch(batch)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_roc_auc', roc_auc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, roc_auc = self.process_batch(batch)
        self.log('test_loss', loss)
        self.log('test_roc_auc', roc_auc)


### **Task C-3a:** Implement training and testing for linear probing.

Train two classification models using linear probing, one for each of the two provided image encoders. Evaluate on both the validation and test sets.

*Note:* Training for 25 epochs should be sufficient.

In [None]:
seed_everything(33, workers=True)

data = PneumoniaMNISTDataModule(batch_size=32)

# TASK: Implement the linear probing training and testing routines.

# use imagenet_model
print("------------------Implement imagenet_model------------------")
imagenet_encode = load_encoder_from_checkpoint(imagenet_model)
model = ImageClassifier(imagenet_encode, freeze_encoder=True, output_dim=2,
                        learning_rate=0.001)

trainer = Trainer(
    max_epochs=25,
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(save_dir='./lightning_logs/linear/classification/',
                             name='imagenet_model'),
    callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'),
               TQDMProgressBar(refresh_rate=10)],
)
trainer.fit(model=model, datamodule=data)

#use chestxray_model
print("------------------Implement chestxray_model------------------")
chestxray_encode = load_encoder_from_checkpoint(chestxray_model)
model = ImageClassifier(chestxray_encode, freeze_encoder=True, output_dim=2,
                        learning_rate=0.001)

trainer = Trainer(
    max_epochs=25,
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(save_dir='./lightning_logs/linear/classification/',
                             name='chestxray_model'),
    callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'),
               TQDMProgressBar(refresh_rate=10)],
)
trainer.fit(model=model, datamodule=data)

### **Task C-3b:** Implement training and testing for model finetuning.

Repeat the experiments, but this time using model finetuning instead of linear probing. Evaluate on both the validation and test sets.

In [None]:
seed_everything(33, workers=True)

data = PneumoniaMNISTDataModule(batch_size=32)

# TASK: Implement the model finetuning training and testing routines.
# use imagenet_model
print("------------------Implement imagenet_model------------------")
imagenet_encode = load_encoder_from_checkpoint(imagenet_model)
model = ImageClassifier(imagenet_encode, freeze_encoder=False,
                        output_dim=2, learning_rate=0.001)

trainer = Trainer(
    max_epochs=25,
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(save_dir='./lightning_logs/finetuning/classification/',
                             name='imagenet_model'),
    callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'),
               TQDMProgressBar(refresh_rate=10)],
)
trainer.fit(model=model, datamodule=data)

#use chestxray_model
print("------------------Implement chestxray_model------------------")
chestxray_encode = load_encoder_from_checkpoint(chestxray_model)
model = ImageClassifier(chestxray_encode, freeze_encoder=False,
                        output_dim=2, learning_rate=0.001)

trainer = Trainer(
    max_epochs=25,
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(save_dir='./lightning_logs/finetuning/classification/',
                             name='chestxray_model'),
    callbacks=[ModelCheckpoint(monitor='val_loss', mode='min'),
               TQDMProgressBar(refresh_rate=10)],
)
trainer.fit(model=model, datamodule=data)

### **Task C-4:** Your evaluation report.

Provide a brief summary (no more than 300 words) with an analysis of your findings. Try explaining the observed performance.

### Validation Performance Analysis:
1. **ROC AUC Scores**:
   - The ROC AUC validation scores for ImageNet and Chest X-ray models (above 0.9 in both linear probing and fine-tuning) suggest that they are effective in distinguishing between classes. The ImageNet improves its ROC AUC from 0.971 in linear probing to 0.992 in fine-tuning, while the Chest X-ray has increased scores from 0.932 to 0.991, respectively. These improvements indicate that fine-tuning can faciliate model to capture subtle image pattern, thereby improving model's accuracy on validation dataset.
  
2. **Validation Loss**:
    - The validation loss provides a direct measure of the model's error on unseen data. In the linear probing, the ImageNet and Chest X-ray reported validation losses of 0.194 and 0.315, respectively. In the fine-tuning, where all parameters were trainable, the validation losses increased for the ImageNet (to 0.361) but decreased for the Chest X-ray (to 0.158). The increase in validation loss for the ImageNet could indicate that the model, despite being more flexible, might be starting to overfit the training data, reducing its generalization capability. Conversely, the decrease in validation loss for the Chest X-ray suggests that additional training of the full model could effectively capture the relevant features from the pnuemonia images without compromising its ability to generalize.

### Zero ROC AUC Analysis:
  - Regarding the unexpected zero ROC AUC observed in some training batches for the ImageNet encoder during both linear probing and fine-tuning, it's probable that this encoder struggles to learn certain distributions present in those batches. Consequently, the encoder may classify all data points in a batch into a single class (0 or 1), resulting in a zero ROC AUC and a significant increase in loss. However, the ImageNet encoder performs well on the validation dataset, possibly because it is well-suited to that particular dataset.

## Logging

In [None]:
#%load_ext tensorboard
#%tensorboard --logdir './lightning_logs/coursework/'
#%tensorboard --logdir './lightning_logs/classification'
#%tensorboard --logdir './lightning_logs/finetuning'
#%tensorboard --logdir './lightning_logs/linear'