CPSC 8810 Machine Learning for Biomedical Applications

# Assignment 3 - Classifiers with Image and Text Data
# Molecular Similarity Dataset
In this assignment, you are asked to use the `molecular-similarity` dataset to create classification models that predict whether or not a pair of molecules are similar. The available data include images of the 2D molecular structure and the [Simplified molecular-input line-entry system (SMILES)](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) string representation of the molecular formula. The data are arranged in pairs where each pair contains the image and SMILES formula for two molecules and the similarity score (the fraction of 20 human expert raters that rated the molecules as similar).

The data are contained in the course repository in the _assignments/source_data/molecular-similarity_ folder. The folder contains a _data.csv_ file in which each row contains represents the information for a pair of molecules `(a,b)`. The SMILES formula for both molecules and the corresponding names of the 2D image files are given in the row. The `frac_similar` column is the similarity score for the pair of molecules. Detailed information on the data set can be found [here](https://www.kaggle.com/datasets/tanvirnwu/molecular-similarity-prediction-dataset) and in the related journal article [Molecular Similarity Perception Based on Machine-Learning Models](https://www.mdpi.com/1422-0067/23/11/6114)

__Please read through the notebook and follow the instructions for each of the 6 problems. For extra credit, you may also do Prblem 7__

In [None]:
# Google Colab setup
# mount the google drive - this is necessary to access supporting src
from google.colab import drive
drive.mount("/content/drive")

In [None]:
# install any packages not found in the Colab environment
!pip install lightning
!pip install 'portalocker>=2.0.0'

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torchtext
from torchtext.vocab import build_vocab_from_iterator
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torchvision import transforms, datasets, models
from torchvision.io import read_image
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
from torchtext.functional import to_tensor
import lightning as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch import seed_everything
import lightning.pytorch.trainer as trainer
import torchmetrics as TM
import torchmetrics as TM
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

# local project imports
import sys
sys.path.append("/content/drive/MyDrive/Colab Notebooks/CPSC-8810-ML-BioMed/src")
from torchvision_utils import show

In [None]:
dir_dataroot = "/content/drive/MyDrive/Colab Notebooks/CPSC-8810-ML-BioMed/assignments/source_data"

dir_lightning = "/content/drive/MyDrive/Colab Notebooks/CPSC-8810-ML-BioMed/lightning"

rs = 123456 # random seed for everything

# Data Module
As in the practicums, it is useful to create a PyTorch Lightning data module to handle loading and processing data for batch training with our PyTorch modules. In this assignment, we are working with mulitmodal data that includes both images and text as possible input to the model. The `MolecularSimilarityDataModule` in the code cell below, will handle all data loading and should __NOT__ be modified. However, it is important that you review the code to understand how the data module works. Importantly, the data module takes the following inputs to its constructor (`__init__`):
- `data_dir` - the directory containing the _data.csv_ file and the _images-2D-#x#_ direcotry where _#_ is the image pixel dimension (e.g., 128)
- `image_dim` - the pixel dimension (assumed to be square) of the 2D molecular images
- `embedding_dim` - the token embedding dimension when using the SMILES text input
- `modalities` - determines which data to use for input: 0 - SMILES text only, 1 - 2D images only, 2 - both text and images
- `class_threshold_map` - a dictionary in the form {c1:(0, x1), c2:(x1, x2), ..., cK:(x_{k-1},1.01)} that defines the frac_similarity bins that define the classes [c1, c2, ..., ck]

The data module class creates train, validation, and test dataloaders using the `MolecularSimilarityDataSet` class that extends the PyTorch `Dataset` class. This class derives the sample set size (the number of scored pairs of molecules) from the _data.csv_ file.

In the problems below, we will use this data module class to support creating a text only model, an image only model and (for extra credit) and multimodal model that uses both the text and image inputs.

In [None]:
####################################################################################################
# DO NOT CHANGE THIS CELL
####################################################################################################
class MolecularSimilarityDataSet(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.samples = self.load_data()
        self.len = len(self.samples)

    def load_data(self):
        samples = []
        with open(os.path.join(self.data_dir, 'data.csv'), 'r') as f:
            # skip the header
            next(f)
            for line in f:
                data = line.strip().split(',')
                samples.append(data[1:])
        return samples

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        # each sample is a list: [smilea, smileb, imagea, imageb, frac_similar]
        return self.samples[idx]

class MolecularSimilarityDataModule(L.LightningDataModule):
    def __init__(self, data_dir, class_threshold_map, train_fraction=0.875, val_fraction=0.145, batch_size=10,
                 embedding_dim=8, image_dim = 128, modalities='2', class_name_map=None):
        """
        Args:
            data_dir: str - path to the data directory
            val_fraction: float - fraction of the data to use for validation
            batch_size: int - batch size
            embedding_dim: int - dimension of the embeddings
            modalities: str - 0: text only, 1: image only, 2: text and image
        """
        super().__init__()
        self.data_dir = data_dir
        self.val_fraction = val_fraction
        self.train_fraction = train_fraction
        self.batch_size = batch_size
        self.embedding_dim = embedding_dim
        self.modalities = modalities
        self.vocab = None
        self.dataset = None
        self.max_tokens = None
        self.class_threshold_map = class_threshold_map
        self.class_name_map = class_name_map
        self.image_dim = image_dim

    def tokenize(self, text):
        s = set(text)
        if ' ' in s:
            s.remove(' ')
        return list(s)

    def label_pipeline(self, frac_similar):
        # bin the samples based on the fraction similarity thresholds specified in the class_map
        # {0:(0, 0.25), 1:(0.25, 0.75), 2:(0.75,1.01)}
        for k, v in self.class_threshold_map.items():
            lb = v[0]
            rb = v[1]
            if frac_similar>=lb and frac_similar < rb:
                return k

    def max_tokens_in(self, data_iterable):
        if self.vocab is None:
            self.build_vocab(data_iterable)
        text_to_tokens = lambda x: self.vocab(self.tokenize(x))
        max_tokens = 0
        #each sample is a list: [smilea, smileb, imagea, imageb, frac_similar]
        for sample in data_iterable:
            smilea = sample[0]
            smileb = sample[1]
            l1 = len(text_to_tokens(smilea))
            l2 = len(text_to_tokens(smileb))
            l = max(l1, l2)
            if l > max_tokens:
                max_tokens = l
        return max_tokens+2 # add 2 for start and end tokens

    def max_sample_length(self):
        return self.max_tokens

    def build_vocab(self, data_iter):
        def yield_tokens(data_iter):
            for sample in data_iter:
                smilea = set(sample[0])
                smileb = set(sample[1])
                smileab = smilea.union(smileb)
                if ' ' in smileab:
                    smileab.remove(' ')
                yield smileab

        self.vocab = build_vocab_from_iterator(yield_tokens(data_iter), specials=['<unk>', '<pad>', '<end>', '<start>'])
        self.vocab.set_default_index(self.vocab['<unk>'])
        self.padding_idx = self.vocab['<pad>']
        self.end_idx = self.vocab['<end>']
        self.start_idx = self.vocab['<start>']
        return self.vocab

    def collate_batch(self, batch):
        #each sample is a list: [smilea, smileb, imagea, imageb, frac_similar]
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        label_list = []
        if self.modalities == 0 or self.modalities == 2: # text only
            text_batch = torch.zeros(len(batch), 2*self.max_tokens, dtype=torch.long)
        if self.modalities == 1 or self.modalities == 2: # image only
            image_batch = torch.zeros(len(batch), 6, self.image_dim, self.image_dim)
            image_to_tensor = transforms.ToTensor()

        cnt = 0
        for sample in batch:
            label_list.append(self.label_pipeline(float(sample[4])))
            if self.modalities == 0 or self.modalities == 2:
                text = [self.start_idx]
                text.extend(self.vocab(self.tokenize(sample[0])))
                if len(text) < self.max_tokens-1:
                    text.extend([self.padding_idx] * (self.max_tokens - len(text)-1))
                text.extend([self.end_idx, self.start_idx])
                text.extend(self.vocab(self.tokenize(sample[1])))
                if len(text) < 2*self.max_tokens-1:
                    text.extend([self.padding_idx] * (2*self.max_tokens - len(text)-1))
                text.append(self.end_idx)
                text_batch[cnt] = torch.tensor(text, dtype=torch.int64)
            if self.modalities == 1 or self.modalities == 2:
                imga = read_image(os.path.join(self.data_dir, f"images-2D-{self.image_dim}x{self.image_dim}", sample[2]))
                imgb = read_image(os.path.join(self.data_dir, f"images-2D-{self.image_dim}x{self.image_dim}", sample[3]))
                imga = image_to_tensor(imga.numpy())
                imga = imga.permute((1, 0, 2)).contiguous()
                imgb = image_to_tensor(imgb.numpy())
                imgb = imgb.permute((1,0,2)).contiguous()
                image_batch[cnt] = torch.cat((imga, imgb), 0)
            cnt += 1
        label_list = torch.tensor(label_list, dtype=torch.int64)

        if self.modalities == 0:
            return text_batch.to(device), label_list.to(device)
        elif self.modalities == 1:
            return image_batch.to(device), label_list.to(device)
        else:
            return text_batch.to(device), image_batch.to(device), label_list.to(device)

    def setup(self, stage=None):
        self.dataset = MolecularSimilarityDataSet(self.data_dir)
        if self.modalities == 0 or self.modalities == 2:
            # build the vocabulary
            self.build_vocab(self.dataset)
            # find max tokens
            self.max_tokens = self.max_tokens_in(self.dataset)
        n_data = len(self.dataset)
        n_train = int(self.train_fraction * n_data)
        n_val = int(self.val_fraction * n_train)
        n_train = n_train - n_val
        n_test = n_data - n_train - n_val
        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(self.dataset, [n_train, n_val, n_test])

        self._train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.collate_batch)
        self._val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_batch)
        self._test_dataloader = DataLoader(test_dataset, batch_size=self.batch_size, collate_fn=self.collate_batch)

    def train_dataloader(self):
        return self._train_dataloader

    def test_dataloader(self):
        return self._test_dataloader

    def val_dataloader(self):
        return self._val_dataloader



# Text Input Molecular Similarity Classifier

In the following section, you will build a classification model that predicts if a given pair of molecules are
1. Similar - defined as having a `frac_similar` score between 0.66 and 1.0
2. Uncertain - defined as having a `frac_similar` score between 0.33 and 0.66
3. Not Similar - defined as having a `frac_similar` score between 0 and 0.33

In this first model, the predictions will be made using only the SMILES text representation of the molecular formulas. An example SMILES formula is _Cc1ccsc1-c1cccnc1_. You can view more examples in the _assignments/source_data/molecular-similarity/data.csv_ file. Importantly, these formulas are not composed of words like we saw in Practicum 7. Hence, we require a different tokeninzing strategy for inputting the formula to the model. We will tokenize the SMILES formula at the character level, that is each character in the formula, e.g., _C_, will be represented as a token in our vocabulary. This tokenization is handled by the `MolecularSimilarityDataModule`. The model we build below, will learn an embedding for each character.

# Problem 1 (1 point)

In the code cell below, create an instance of `MolecularSimilarityDataModule` that uses the text only modality. The arguments to constructor should be:
- data_dir = molecular_data_dir
- modalities = 0
- batch_size = 10
- class_threshold_map = ctm
- class_name_map = cnm

In [None]:
seed_everything(rs)
ctm = {0:(0, 0.33), 1:(0.33, 0.66), 2:(0.66,1.01)}
cnm = {0:'Not Similar', 1:'Uncertain', 2:'Similar'}

##### Problem 1 start your code here #####
dm = None
##### Problem 1 end your code here #####

dm.setup()
print(len(dm.train_dataloader()))
print(len(dm.val_dataloader()))
print(len(dm.test_dataloader()))

Now that we have the data module ready, we can build the molecular similarity classification model using only the SMILES text inputs. We will design the model following the same procedure that was presented in Practicum 7. Specifically, the model will include an _encoder_ and a _classifier_. The input will pass through the _encoder_ which will create a numerical representation of input SMILES text formula which will be passed to the _classifier_ to predict the class (similar, uncertain, or not similar). As in the practicum, the _encoder_ will be based on a Transformer and the _classifier_ will be a feed forward linear layer.

# Problem 2 (2 points)
In the code cell, complete the implementation of the `TextEncoder` class. In the `__init__` method, you will need to create the `embedding` layer and the `transformer_encoder` layer. You should use the input arguments to the `__init__` function to set the arguments in the `embedding` and `transformer_encoder`. In the forward method, you will need to pass the input, `x` through both layers.

__HINT__: See Practicum 7.

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_heads,
                 num_transformer_layers=1, dim_feedforward=128, activation='relu', dropout=0.1):
        super().__init__()
        ##### Problem 2 start your code here #####
        self.embedding = None
        self.te_layer = None
        self.transformer_encoder = None
        ##### Problem 2 start your code here #####

    def forward(self, x):
        ##### Problem 2 start your code here #####

        ##### Problem 2 end your code here #####
        return x

# Test the ImageEncoder with random input data
encoder = TextEncoder(len(dm.vocab), dm.embedding_dim, num_heads=2)
input_text = torch.randint(0, len(dm.vocab), (dm.batch_size, 2*dm.max_sample_length()))
print("input shape",input_text.shape)
output_features = encoder(input_text)
print("Output shape:", output_features.shape)

# Problem 3 (2 points)
In the code cell, complete the implementation of the `TextClassifier` class. In the `__init__` method, you will need to create the `Flatten` layer and the `linear` feed forward layer. You should use the input arguments to the `__init__` function to set the arguments in the `linear` layer. In the forward method, you will need to pass the input, `x` through both layers.

__HINT__: See Practicum 7.

In [None]:
class TextClassifier(nn.Module):
    def __init__(self, embedding_dim, seq_length, num_classes):
        super().__init__()
        ##### Problem 3 start your code here #####
        self.flatten = None
        self.linear = None
        ##### Problem 3 start your code here #####

    def forward(self, x):
        ##### Problem 3 start your code here #####

        ##### Problem 3 end your code here #####
        return x

# Test the module with random input data
classifier = TextClassifier(dm.embedding_dim, 2*dm.max_sample_length(), len(dm.class_name_map))
input_tensor = torch.randn(dm.batch_size, 2*dm.max_sample_length(), dm.embedding_dim)  # Batch size of 5, input tensor shape [5, 256, 64]
print(input_tensor.shape)
output = classifier(input_tensor)
print("Output shape:", output.shape)  # Expected output shape: [4, 10]

We are now ready to create our classification model. The code cell below extends the `LightningModule` from PyTorch Lightning to create the overall molecular similarity classification model. Note that this model takes an _encoder_ module and a _classifier_ module as inputs. Importantly, it is agnostic to the type of input passed to these modules (i.e., it could be text or images). It assumes the _encoder_ and _classifier_ are constructed to handle the data appropriately. Hence, we will see in the next section that we can reuse this model for image inputs.

In [None]:
####################################################################################################
# DO NOT CHANGE THIS CELL
####################################################################################################
class SingleModalityClassifierModel(L.LightningModule):
    def __init__(self, encoder, classifier, num_classes):
        super().__init__()
        # model layers
        self.encoder = encoder
        self.classifier = classifier

        # validation metrics
        self.val_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes)]))
        self.validation_step_outputs = []
        self.validation_step_targets = []

        # test metrics
        self.test_roc = TM.ROC(task="multiclass", num_classes=num_classes, thresholds=list(np.linspace(0.0, 1.0, 20))) # roc and cm have methods we want to call so store them in a variable
        self.test_cm = TM.ConfusionMatrix(task='multiclass', num_classes=num_classes)
        self.test_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes),
                                                            self.test_roc, self.test_cm]))
        self.test_step_outputs = []
        self.test_step_targets = []

    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True)

        # store the outputs and targets for the epoch end step
        self.validation_step_outputs.append(logits)
        self.validation_step_targets.append(y)
        return loss

    def on_validation_epoch_end(self):
        # stack all the outputs and targets into a single tensor
        all_preds = torch.vstack(self.validation_step_outputs)
        all_targets = torch.hstack(self.validation_step_targets)

        # compute the metrics
        loss = nn.functional.cross_entropy(all_preds, all_targets)
        self.val_metrics_tracker.increment()
        self.val_metrics_tracker.update(all_preds, all_targets)
        self.log('val_loss_epoch_end', loss)

        # clear the validation step outputs
        self.validation_step_outputs.clear()
        self.validation_step_targets.clear()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.forward(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log('test_loss', loss, on_step=True, on_epoch=True)
        self.test_step_outputs.append(logits)
        self.test_step_targets.append(y)
        return loss

    def on_test_epoch_end(self):
        all_preds = torch.vstack(self.test_step_outputs)
        all_targets = torch.hstack(self.test_step_targets)

        self.test_metrics_tracker.increment()
        self.test_metrics_tracker.update(all_preds, all_targets)
        # clear the test step outputs
        self.test_step_outputs.clear()
        self.test_step_targets.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

Now that we've constructed the model, we can train the model using the PyTorch Lightning `Trainer` class. The code cell below implements the training following the same procedure used in Practicum 7.

In [None]:
seed_everything(rs)
encoder = TextEncoder(len(dm.vocab), dm.embedding_dim, num_heads=2)
classifier = TextClassifier(dm.embedding_dim, 2*dm.max_sample_length(), len(dm.class_name_map))
molecular_text_model = SingleModalityClassifierModel(encoder, classifier, num_classes=len(dm.class_name_map))

trainer = L.Trainer(default_root_dir=dir_lightning,
                    max_epochs=100,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=7)])
trainer.fit(model=molecular_text_model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())

## Validation Set Performance

Below, we can examine performance on the validation set across training epochs. You should see the that validation performance approaches a value near 0.9.

In [None]:
mca = molecular_text_model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

## Test Set Performance

Now let's examine the model performance on the test set. We follow the same procedure as in Practicum 7 to obtain the model predictions for the test set and view the confusion matrix, ROC curve and classificaiton report. You should see that the text based model performs reasonably well on the test data with an overall f-score near 0.8 and ROC curves that well above the random guessing line. You should also see in the confusion matrix, that the model does misclassify several of the "uncertain" examples.

In [None]:
trainer.test(model=molecular_text_model, dataloaders=dm.test_dataloader())

In [None]:
rslt = molecular_text_model.test_metrics_tracker.compute()

In [None]:
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted Label')
cmp.set_xticklabels(dm.class_name_map.values(), rotation=90)
cmp.set_yticklabels(dm.class_name_map.values(), rotation=0)
cmp.set_ylabel('Actual Label');

In [None]:
fpr, tpr, thresholds = rslt['MulticlassROC']
for i in range(len(dm.class_name_map)):
    plt.plot(fpr[i], tpr[i], label=dm.class_name_map[i])
plt.xlabel('False Positive Rate')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

In [None]:
# Print the classification report
device = torch.device("cpu")   #"cuda:0"
molecular_text_model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_data in dm.test_dataloader():
        test_samples, test_labels = test_data[0].to(device), test_data[1].to(device)
        pred = molecular_text_model(test_samples).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true,y_pred,target_names=list(dm.class_name_map.values()),digits=4))

# Image Input Molecular Similarity Classifier

Next, let's build a model where the predictions will be made using only the image of the 2D molecular structure. An example image is shown here:

![alt text](../source_data/molecular-similarity/images-2D-128x128/image_molecule_000a.png)

You can view more examples in the _assignments/source_data/molecular-similarity/images-2D-128x128_ directory. These images have already been resized to all be of the same dimensions (128 x 128 pixels). Additionally, the `MolecularSimilarityDataModule` handles normalizing the pixel values to between [0, 1].  

As in the preceeding section, in the following section, you will use the input images to build a classification model that predicts if a given pair of molecules are
1. Similar - defined as having a `frac_similar` score between 0.66 and 1.0
2. Uncertain - defined as having a `frac_similar` score between 0.33 and 0.66
3. Not Similar - defined as having a `frac_similar` score between 0 and 0.33

# Problem 4 (1 point)

In the code cell below, create an instance of `MolecularSimilarityDataModule` that uses the image only modality. The arguments to constructor should be:
- data_dir = molecular_data_dir
- modalities = 1
- batch_size = 10
- image_dim = 128
- class_threshold_map = ctm
- class_name_map = cnm

In [None]:
seed_everything(rs)
ctm = {0:(0, 0.33), 1:(0.33, 0.66), 2:(0.66,1.01)}
cnm = {0:'Not Similar', 1:'Uncertain', 2:'Similar'}

##### Problem 4 start your code here #####
dm = None
##### Problem 4 end your code here #####

dm.setup()
print(len(dm.train_dataloader()))
print(len(dm.val_dataloader()))
print(len(dm.test_dataloader()))

Here, we can view a batch of data. Notice that the batch tensor has shape [batch_size, 6, 128, 128]. This is because their are two images to be compared for the molecular simalarity problem, 1 image representing the structure for each model. Here, each image is RGB and thus has 3 channels. The two images are treated as separate channels in the overall input.

In [None]:
imgs, lables = next(iter(dm.train_dataloader()))
print(imgs.shape)
imgsa = imgs[:, :3, :, :]
print(imgsa.shape)
grid = make_grid(imgsa, nrow=5)
show(grid)
imgsb = imgs[:, 3:, :, :]
print(imgsb.shape)
grid = make_grid(imgsb, nrow=5)
show(grid)


Now that we have the data module ready, we can build the molecular similarity classification model using only the images as inputs. We will design the model following the same approach we used for the text model above and as presented for image data in Practicum 6. Again, the model will include an _encoder_ and a _classifier_. However, this time the _encoder_ will use convolutional layers to create representation of the two image input.

# Problem 5 (2 points)
In the code cell, complete the implementation of the `ImageEncoder` class. In the `__init__` method, you will need to create each of the two `Conv2d` layers that are to be used in the `Sequential` layer that represents the `encoder`.

You should use the input arguments to the `__init__` function to set the arguments in the `Conv2d` layers, specifically the first Conv2D layer, `conv2d_1` should use the `input_channels` argument for its `in_channels`, `output_channels[0]` for its `out_channels`, and `kernel_sizes[0]` for its `kernel_size`. Similarly, `conv2d_2` should use the ``output_channels[0]` argument for its `in_channels`, `output_channels[1]` for its `out_channels`, and `kernel_sizes[1]` for its `kernel_size`. For more details, refer to the [PyTorch Conv2D doc](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html).

In the forward method, you will need to pass the input, `x` through `encoder` layer.

__HINT__: See Practicum 6.

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, input_channels=3, output_channels=[16,32], kernel_sizes=[5, 3]):
        super().__init__()
        ##### Problem 5 start your code here #####
        conv2d_1 = None
        conv2d_2 = None
        ##### Problem 5 end your code here #####

        self.encoder = nn.Sequential(
            conv2d_1
            nn.ReLU(),                                             # ReLU activation function
            nn.MaxPool2d(kernel_size=2, stride=2),                 # max pooling layer with kernel size 2x2
            conv2d_2,
            nn.ReLU(),                                             # ReLU activation function
            nn.MaxPool2d(kernel_size=2, stride=2),                  # max pooling layer with kernel size 2x2
        )

    def forward(self, x):
        ##### Problem 5 start your code here #####

        ##### Problem 5 end your code here #####
        return x

# Test the ImageEncoder with random input data
encoder = ImageEncoder(input_channels=6, output_channels=[16,32])
input_image = torch.randn(5, 6, dm.image_dim, dm.image_dim)  # batch_size x channels x height x width
output_features = encoder(input_image)
print("Output shape:", output_features.shape)

# Problem 6 (2 points)
In the code cell, complete the implementation of the `ImageClassifier` class. In the `__init__` method, you will need to create the `Flatten` layer and the `linear` feed forward layer. You should use the input arguments to the `__init__` function to set the arguments in the `linear` layer. In the forward method, you will need to pass the input, `x` through both layers.

__HINT__: See Practicum 6.

In [None]:
class ImageClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        ##### Problem 6 start your code here #####
        self.flatten = None
        self.linear = None
        ##### Problem 6 start your code here #####

    def forward(self, x):
        ##### Problem 3 start your code here #####

        ##### Problem 3 end your code here #####
        return x

# Test the module with random input data
module = ImageClassifier(32*31*31, 3)
input_tensor = torch.randn(5, 32, 31, 31)  # Batch size of 5, input tensor shape [5, 32, 7, 7]
output = module(input_tensor)
print("Output shape:", output.shape)  # Expected output shape: [5, 10]

We can now build a model that uses only the molecular image data to calculate the similarity class. Importantly, because we can reuse the `MolecularSimilarityDataModule` to construct our model by passing it the new `ImageEncoder` and `ImageClassifier` that will handle the image inputs. Below, we fit the model using the PyTorch Lightning Trainer.

In [None]:
seed_everything(rs)
encoder = ImageEncoder(input_channels=6, output_channels=[16, 32])
classifier = ImageClassifier(32*31*31, len(dm.class_name_map))
molecular_image_model = SingleModalityClassifierModel(encoder, classifier, num_classes=len(dm.class_name_map))

trainer = L.Trainer(default_root_dir=dir_lightning,
                    max_epochs=100,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=7)])
trainer.fit(model=molecular_image_model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())

## Validation Set Performance

Below, we can examine performance on the validation set across training epochs. You should see the that validation performance is not as good as the text based model, approaching a value near 0.5. Why do you think this is?

In [None]:
mca = molecular_image_model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

Test Set Performance

In [None]:
trainer.test(model=molecular_image_model, dataloaders=dm.test_dataloader())

## Test Set Performance

Now let's examine the model performance on the test set. We follow the same procedure as for the text based model to obtain the model predictions for the test set and view the confusion matrix, ROC curve and classificaiton report. You should see that the text based model performs poorly compared to the text-based model on the test. This is not surprising given the validation set performance and is likely due to the small dataset size.

In [None]:
rslt = molecular_image_model.test_metrics_tracker.compute()

In [None]:
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted Label')
cmp.set_xticklabels(dm.class_name_map.values(), rotation=90)
cmp.set_yticklabels(dm.class_name_map.values(), rotation=0)
cmp.set_ylabel('Actual Label');

In [None]:
fpr, tpr, thresholds = rslt['MulticlassROC']
for i in range(len(dm.class_name_map)):
    plt.plot(fpr[i], tpr[i], label=dm.class_name_map[i])
plt.xlabel('False Positive Rate')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

In [None]:
# Print the classification report
device = torch.device("cpu")   #"cuda:0"
molecular_image_model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_data in dm.test_dataloader():
        test_samples, test_labels = test_data[0].to(device), test_data[1].to(device)
        pred = molecular_image_model(test_samples).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true,y_pred,target_names=list(dm.class_name_map.values()),digits=4))

# Multimodal model (Extra Credit)
Now that we've developed a text only and an image only input model, we can ask how we could incorporate both text and image data. There are multiple ways this could be done. Here, we will again use the same conceputal approach where we use an _encoder_ to encode both the image and text data. We will then combine these representations and pass them to a feed forward network that will classify them.

We can obtain the text and image data in our batch data by setting `modalities=2` in the `MolecularSimilarityDataModule` constructor. If you examine the code for the `MolecularSimilarityDataModule` class carefully, you will see that in this case the batch data is now a 3-element tuple containing the text, image, and label tensors for the batch. This is because the text and image tensors do not have compatible tensor shapes and thus cannot be easily combined. This will require us to modify our `LightningModule` class to handle this.

First, let's createa a data module that provides the multimodal data.

In [None]:
seed_everything(rs)
class_threshold_map = {0:(0, 0.33), 1:(0.33, 0.66), 2:(0.66,1.01)}
class_name_map = {0:'Not Similar', 1:'Uncertain', 2:'Similar'}
dm = MolecularSimilarityDataModule(data_dir=os.path.join(dir_dataroot, 'molecular-similarity'), image_dim=128, modalities=2, batch_size=10,
                                   class_threshold_map=class_threshold_map, class_name_map=class_name_map)
dm.setup()
print(len(dm.train_dataloader()))
print(len(dm.val_dataloader()))
print(len(dm.test_dataloader()))

In our model, we will reuse the `TextEncoder` and `ImageEncoder` classes we created previously. However, we need a new _classifier_ that handles input from these two encoder. This is implemented below. Notice in the `forward` method, that the input argument `x` is now a tuple containing the the output from the text encoder and and the output from the image encoder.

In [None]:
class ImageTextClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.flatten = nn.Flatten()  # Flatten the input tensor
        self.linear = nn.Linear(input_dim, num_classes)  # Linear layer with 10 output units

    def forward(self, x):
        text_enc = self.flatten(x[0])  # Flatten the input tensor
        img_enc = self.flatten(x[1])
        # concatenate the text and image encodings
        x = torch.hstack((text_enc, img_enc))
        x = self.linear(x)   # Pass through the linear layer
        return x

# Test the module with random input data
module = ImageTextClassifier(32*31*31+2*dm.max_sample_length()*dm.embedding_dim, 3)
input_image_tensor = torch.randn(5, 32, 31, 31)  # Batch size of 5, input tensor shape [5, 32, 7, 7]
input_text_tensor = torch.randn(5, 2*dm.max_sample_length(), dm.embedding_dim)  # Batch size of 5, input tensor shape [5, 256, 64]

output = module((input_text_tensor, input_image_tensor))
print("Output shape:", output.shape)  # Expected output shape: [5, 10]

# Problem 7 (2 points extra credit)

In the code cell below, complete the `forward` method of the `MultiModalityClassifierModel`. Hint: the input, `x` to the `forward` method is a tuple containing the text input and the image input.

In [None]:
class MultiModalityClassifierModel(L.LightningModule):
    def __init__(self, text_encoder, image_encoder, classifier, num_classes):
        super().__init__()
        # model layers
        self.text_encoder = text_encoder
        self.image_encoder = image_encoder
        self.classifier = classifier

        # validation metrics
        self.val_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes)]))
        self.validation_step_outputs = []
        self.validation_step_targets = []

        # test metrics
        self.test_roc = TM.ROC(task="multiclass", num_classes=num_classes, thresholds=list(np.linspace(0.0, 1.0, 20))) # roc and cm have methods we want to call so store them in a variable
        self.test_cm = TM.ConfusionMatrix(task='multiclass', num_classes=num_classes)
        self.test_metrics_tracker = TM.wrappers.MetricTracker(TM.MetricCollection([TM.classification.MulticlassAccuracy(num_classes=num_classes),
                                                            self.test_roc, self.test_cm]))
        self.test_step_outputs = []
        self.test_step_targets = []

    def forward(self, x):
        ##### Problem 7 extra credit start your code here #####

        ##### Problem 7 end your code here #####
        return x

    def training_step(self, batch, batch_idx):
        text_batch,image_batch, y = batch
        logits = self.forward((text_batch,image_batch))
        loss = nn.functional.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        text_batch, image_batch, y = batch
        logits = self.forward((text_batch,image_batch))
        loss = nn.functional.cross_entropy(logits, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True)

        # store the outputs and targets for the epoch end step
        self.validation_step_outputs.append(logits)
        self.validation_step_targets.append(y)
        return loss

    def on_validation_epoch_end(self):
        # stack all the outputs and targets into a single tensor
        all_preds = torch.vstack(self.validation_step_outputs)
        all_targets = torch.hstack(self.validation_step_targets)

        # compute the metrics
        loss = nn.functional.cross_entropy(all_preds, all_targets)
        self.val_metrics_tracker.increment()
        self.val_metrics_tracker.update(all_preds, all_targets)
        self.log('val_loss_epoch_end', loss)

        # clear the validation step outputs
        self.validation_step_outputs.clear()
        self.validation_step_targets.clear()

    def test_step(self, batch, batch_idx):
        text_batch, image_batch, y = batch
        logits = self.forward((text_batch,image_batch))
        loss = nn.functional.cross_entropy(logits, y)
        self.log('test_loss', loss, on_step=True, on_epoch=True)
        self.test_step_outputs.append(logits)
        self.test_step_targets.append(y)
        return loss

    def on_test_epoch_end(self):
        all_preds = torch.vstack(self.test_step_outputs)
        all_targets = torch.hstack(self.test_step_targets)

        self.test_metrics_tracker.increment()
        self.test_metrics_tracker.update(all_preds, all_targets)
        # clear the test step outputs
        self.test_step_outputs.clear()
        self.test_step_targets.clear()

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

The code below will now fit the multimodal model. Note, if you change the number of kernels in the `image_encoder` you will need to update the argument to the `ImageTextClassifier` accordingly.

In [None]:
seed_everything(rs)
image_encoder = ImageEncoder(input_channels=6, output_channels=[16,32])
text_encoder = TextEncoder(len(dm.vocab), dm.embedding_dim, num_heads=2)
classifier = ImageTextClassifier(32*31*31+2*dm.max_sample_length()*dm.embedding_dim, num_classes=len(dm.class_name_map))
molecular_image_text_model = MultiModalityClassifierModel(text_encoder, image_encoder, classifier, num_classes=len(dm.class_name_map))

trainer = L.Trainer(default_root_dir=dir_lightning,
                    max_epochs=100,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=7)])
trainer.fit(model=molecular_image_text_model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader())

## Validation Set Performance
Here, we again plot the validation set performance.

In [None]:
mca = molecular_image_text_model.val_metrics_tracker.compute_all()['MulticlassAccuracy']
plt.plot(range(1, len(mca)+1), mca, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Epoch Validation Accuracy')
plt.grid()

## Test set performance

Finally, we can evaluate the test set performance. Overall, we see that the multimodal does not perform as well as the text only model but does perform better than the image only model. Likely, if we had a much larger dataset, we would find that a multimodal would perform best. Of course, this could require substantially more training time.

In [None]:
trainer.test(model=molecular_image_text_model, dataloaders=dm.test_dataloader())

In [None]:
rslt = molecular_image_text_model.test_metrics_tracker.compute()

In [None]:
cmp = sns.heatmap(rslt['MulticlassConfusionMatrix'], annot=True, fmt='d', cmap='Blues')
cmp.set_xlabel('Predicted Label')
cmp.set_xticklabels(dm.class_name_map.values(), rotation=90)
cmp.set_yticklabels(dm.class_name_map.values(), rotation=0)
cmp.set_ylabel('Actual Label');

In [None]:
fpr, tpr, thresholds = rslt['MulticlassROC']
for i in range(len(dm.class_name_map)):
    plt.plot(fpr[i], tpr[i], label=dm.class_name_map[i])
plt.xlabel('False Positive Rate')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.ylabel('True Positive Rate')
plt.legend()
plt.grid()

In [None]:
# Print the classification report
device = torch.device("cpu")   #"cuda:0"
molecular_image_text_model.eval()
y_true=[]
y_pred=[]
with torch.no_grad():
    for test_data in dm.test_dataloader():
        test_text, test_img, test_labels = test_data[0].to(device), test_data[1].to(device), test_data[2].to(device)
        pred = molecular_image_text_model((test_text, test_img)).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true,y_pred,target_names=list(dm.class_name_map.values()),digits=4))