<h3><center>Intelligent Analysis of Medical Images</center></h3>
<h4><center>HW 2: Practical Part</center></h4>
<table width='100%' style="border: none;">
    <tr style="border: none; text-align: center;">
        <td style="border: none;"><h5>Javad Razi</h5></td>
        <td style="border: none;"><h5>401204354</h5></td>
        <td style="border: none;"><h5>j.razi@outlook.com</h5></td>
</table>
<hr/>
<br/>

# **Part A**

# Install and Import Libraries
In this cell, we will install, and import all the necessary libraries required for the implementation. This includes `torch` for model building and training, `torchvision` for datasets and data transforms, and additional libraries for metrics and visualization.

In [None]:
try: 
    import gdown
except:
    %pip install gdown
    
try:
    import wandb
except:
    %pip install wandb
try:
    import torch
except:
    %pip install torch

try:
    import torchvision
except:
    %pip install torchvision
    
try:
    import pickle
except:
    %pip install pickle

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import os
import pandas as pd
import pickle
import wandb

  from .autonotebook import tqdm as notebook_tqdm


# Downloading the Datasets

In [None]:
import os
import gdown

# Create the directory if it doesn't exist
if not os.path.exists('./datasets'):
    os.makedirs('./datasets')

# List of file IDs
file_ids = ['1i9Ei3QSmPBnYzqknvZPg_6TNsPomxQYG', '1-2cN6EuFQnIM53q1N3MQCVcoid1NPBWk', '1-0bl_TuSM-JQ4nCwAW7ySn3uYMAGlaz4']

# Download the files
for file_id in file_ids:
    url = f'https://drive.google.com/uc?id={file_id}'
    output = f'./datasets/{file_id}.pickle'
    if not os.path.exists(output):
        gdown.download(url, output, quiet=False)


## Load The Pickle Files

In [None]:
import pickle

with open('./datasets/train.pickle', 'rb') as f:
    train_data = pickle.load(f)
with open('./datasets/test.pickle', 'rb') as f:
    test_data = pickle.load(f)
with open('./datasets/validation.pickle', 'rb') as f:
    validation_data = pickle.load(f)


## Initialize W&B (WandB)

In [None]:
# Offline and Online switches for wandb. Offline for now since we don't want to log anything yet.
import os

def wandb_off():
    os.environ['WANDB_MODE'] = 'online'
    os.environ['WANDB_SILENT'] = 'true'

def wandb_on():
    os.environ['WANDB_MODE'] = 'disabled'
    os.environ['WANDB_SILENT'] = 'false'

wandb_off()

In [None]:
# The API Key is required to login. If I should've provided my own API key, please let me know and I will provide it ASAP. 
wandb.login()

In [None]:
wandb.init(project='mri-alexnet', entity='jrazi', name='jrazi-alexnet-from-scratch')

# AlexNet Architecture
Here we define the AlexNet architecture in PyTorch. We create a class `AlexNet` that inherits from `nn.Module` and define all the layers in the `__init__` method. The `forward` method dictates the data flow through the network.


In [1]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=3):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

NameError: name 'nn' is not defined

# Dataset and Dataloaders
In this cell, we define our custom dataset class which will handle the MRI images. We'll also create the dataloaders for training, validation, and testing datasets.

In [None]:
class MRIDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        image, label = self.data[index]
        if self.transform:
            image = self.transform(image)
        return image, label


# Define transforms for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((227, 227)), # AlexNet uses 227x227 inputs
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create dataset instances
train_dataset = MRIDataset(train_data, transform=transform)
valid_dataset = MRIDataset(validation_data, transform=transform)
test_dataset = MRIDataset(test_data, transform=transform)

# Create dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


# Loss Function and Optimizer
In this cell, we define the loss function and the optimizer for our AlexNet model. We use Cross-Entropy Loss for our multi-class classification problem and the Adam optimizer with a learning rate of 1e-4.

In [None]:
model = AlexNet(num_classes=3)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training Loop
Here we define the training loop where we train our AlexNet model. We keep track of the loss and accuracy for both the training and validation datasets.

In [3]:
# Training parameters
num_epochs = 10
train_steps = len(train_loader)
valid_steps = len(valid_loader)

# To track the training loss as the model trains
train_losses = []
# To track the validation loss as the model trains
valid_losses = []
# To track the average training loss per epoch as the model trains
avg_train_losses = []
# To track the average validation loss per epoch as the model trains
avg_valid_losses = [] 

for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    total_train_loss = 0
    total_valid_loss = 0

    # Training
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        wandb.log({'train_loss': loss.item()})

    # Validation
    model.eval()  # Set model to evaluate mode
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_function(outputs, labels)
            total_valid_loss += loss.item()
            wandb.log({'valid_loss': loss.item()})

    # Calculate average losses
    avg_train_loss = total_train_loss / train_steps
    avg_valid_loss = total_valid_loss / valid_steps
    train_losses.append(avg_train_loss)
    valid_losses.append(avg_valid_loss)
    
    # Log average losses to wandb
    wandb.log({'epoch': epoch, 'avg_train_loss': avg_train_loss, 'avg_valid_loss': avg_valid_loss})

    # Print training and validation progress
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_valid_loss:.4f}')

NameError: name 'train_loader' is not defined

# Training and Validation Loss Plots
In this cell, we visualize the training and validation loss over the epochs to understand the learning trend.

In [None]:
plt.figure(figsize=(10,5))
plt.title("Training and Validation Loss")
plt.plot(avg_train_losses,label="train")
plt.plot(avg_valid_losses,label="validation")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Evaluate the Model
In this cell, we evaluate the trained AlexNet model on the test dataset and calculate the classification metrics like accuracy, precision, recall, and F1 score.

In [None]:
# Evaluate on test set
model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

# Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')

# Log metrics to wandb
wandb.log({'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1})

print(f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}')

In [None]:
# Create wandb metrics

from_scratch_metrics = {
    'accuracy': accuracy,
    'precision': precision,
    'recall': recall,
    'f1_score': f1
}

# Log metrics to wandb for the from-scratch model
wandb.log({'from_scratch_accuracy': from_scratch_metrics['accuracy'],
           'from_scratch_precision': from_scratch_metrics['precision'],
           'from_scratch_recall': from_scratch_metrics['recall'],
           'from_scratch_f1_score': from_scratch_metrics['f1_score']})

In [None]:
wandb.finish()

# **Part B**

## Initialize WandB

In [None]:
wandb.init(project='mri-alexnet', entity='jrazi', name='jrazi-alexnet-pretrained')

# Load Pre-Trained AlexNet
In this cell, we load a pre-trained AlexNet model from torchvision's models. We then replace the final classification layer to match the number of classes in our dataset. We'll also define the loss function and optimizer for this pre-trained model.

In [None]:
from torchvision import models

# Load pre-trained AlexNet
pretrained_alexnet = models.alexnet(pretrained=True)
# Modify the classifier to match the number of classes
pretrained_alexnet.classifier[6] = nn.Linear(pretrained_alexnet.classifier[6].in_features, 3)
pretrained_alexnet = pretrained_alexnet.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pretrained_alexnet.parameters(), lr=1e-4)

# Train and Validate Pre-Trained AlexNet
Here we train and validate the pre-trained AlexNet on our MRI dataset. We track the loss and accuracy for both the training and validation sets.

In [4]:
# Training parameters
num_epochs = 10
train_steps = len(train_loader)
valid_steps = len(valid_loader)

# To track the training loss as the model trains
pretrained_train_losses = []
# To track the validation loss as the model trains
pretrained_valid_losses = []
# To track the average training loss per epoch as the model trains
pretrained_avg_train_losses = []
# To track the average validation loss per epoch as the model trains
pretrained_avg_valid_losses = [] 

# Training loop
for epoch in range(num_epochs):
    pretrained_alexnet.train()  # Set model to training mode
    total_train_loss = 0
    total_valid_loss = 0

    # Training
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = pretrained_alexnet(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    # Validation
    pretrained_alexnet.eval()  # Set model to evaluate mode
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = pretrained_alexnet(images)
            loss = criterion(outputs, labels)
            total_valid_loss += loss.item()

    # Calculate average losses
    avg_train_loss = total_train_loss / train_steps
    avg_valid_loss = total_valid_loss / valid_steps
    pretrained_train_losses.append(avg_train_loss)
    pretrained_valid_losses.append(avg_valid_loss)
    
    # Print training and validation progress
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_valid_loss:.4f}')

# Visualize Loss and Accuracy for Pre-Trained AlexNet
We visualize the training and validation loss and accuracy for the pre-trained AlexNet using matplotlib.

In [5]:
plt.figure(figsize=(10,5))
plt.title("Training and Validation Loss")
plt.plot(pretrained_avg_train_losses,label="train")
plt.plot(pretrained_avg_valid_losses,label="validation")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Evaluate Pre-Trained AlexNet
In this cell, we evaluate the pre-trained AlexNet on the test dataset and calculate the classification metrics.

In [6]:
# Evaluate on test set
pretrained_alexnet.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = pretrained_alexnet(images)
        _, predicted = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

# Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')

print(f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}')

In [None]:
pretrained_metrics = {
    'accuracy': accuracy,
    'precision': precision,
    'recall': recall,
    'f1_score': f1
}

# Log metrics to wandb for the pre-trained model
wandb.log({'pretrained_accuracy': pretrained_metrics['accuracy'],
           'pretrained_precision': pretrained_metrics['precision'],
           'pretrained_recall': pretrained_metrics['recall'],
           'pretrained_f1_score': pretrained_metrics['f1_score']})

# Comparison of From-Scratch and Pre-Trained AlexNet
In this cell, we compare the performance of our from-scratch AlexNet with the pre-trained AlexNet on our dataset, based on the metrics we calculated.

In [7]:
wandb.log({
    'from_scratch_accuracy': from_scratch_metrics['accuracy'],
    'pretrained_accuracy': pretrained_metrics['accuracy'],
    'from_scratch_precision': from_scratch_metrics['precision'],
    'pretrained_precision': pretrained_metrics['precision'],
    'from_scratch_recall': from_scratch_metrics['recall'],
    'pretrained_recall': pretrained_metrics['recall'],
    'from_scratch_f1_score': from_scratch_metrics['f1_score'],
    'pretrained_f1_score': pretrained_metrics['f1_score']
})

# Print out for inspection
print("From-Scratch AlexNet Metrics:")
print(f"Accuracy: {from_scratch_metrics['accuracy']:.4f}")
print(f"Precision: {from_scratch_metrics['precision']:.4f}")
print(f"Recall: {from_scratch_metrics['recall']:.4f}")
print(f"F1-Score: {from_scratch_metrics['f1_score']:.4f}\n")

print("Pre-Trained AlexNet Metrics:")
print(f"Accuracy: {pretrained_metrics['accuracy']:.4f}")
print(f"Precision: {pretrained_metrics['precision']:.4f}")
print(f"Recall: {pretrained_metrics['recall']:.4f}")
print(f"F1-Score: {pretrained_metrics['f1_score']:.4f}")

In [None]:
wandb.finish()

# **Part C**

# Implementing Supervised Contrastive Loss
Here We'll define the new loss function and update our training loop accordingly.

In [None]:
class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        batch_size = features.shape[0]
        mask = torch.eye(batch_size).to(features.device)
        
        # Normalize the features
        features = F.normalize(features, dim=1)
        
        # Compute the similarity matrix
        sim_matrix = torch.matmul(features, features.T)
        
        # Create the positive and negative masks
        pos_mask = labels.expand(batch_size, batch_size).eq(labels.expand(batch_size, batch_size).T)
        neg_mask = ~pos_mask
        
        # Compute the loss
        pos_sim = sim_matrix[pos_mask & ~mask].view(batch_size, -1)
        neg_sim = sim_matrix[neg_mask].view(batch_size, -1)
        
        pos_loss = torch.sum(-torch.log(F.softmax(pos_sim / self.temperature, dim=1)), dim=1)
        neg_loss = torch.sum(torch.log(F.softmax(neg_sim / self.temperature, dim=1)), dim=1)
        
        loss = pos_loss + neg_loss
        return loss.mean()


## Initialize WandB

In [None]:
wandb.init(project='mri-alexnet', entity='jrazi', name='jrazi-alexnet-supervised-contrastive-loss')

# Retrain & Fine-Tune with Categorical Cross-Entropy
Here, we retrain the from-scratch AlexNet using Supervised Contrastive Loss for 10 epochs followed by fine-tuning with Cross-Entropy Loss for 5 epochs.

In [8]:
model = AlexNet(num_classes=3)
model.to(device)
contrastive_loss_function = SupervisedContrastiveLoss()
cross_entropy_loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training parameters
contrastive_epochs = 10
cross_entropy_epochs = 5

# Training loop for Supervised Contrastive Loss
for epoch in range(contrastive_epochs):
    model.train()  # Set model to training mode
    total_train_loss = 0

    # Training
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = contrastive_loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        # Log batch loss
        wandb.log({'contrastive_train_loss': loss.item()})

    # Calculate average losses
    avg_train_loss = total_train_loss / train_steps
    train_losses.append(avg_train_loss)
    
    # Log average loss to wandb
    wandb.log({'epoch': epoch, 'contrastive_avg_train_loss': avg_train_loss})

    # Print training progress
    print(f'Contrastive Epoch [{epoch+1}/{contrastive_epochs}], Train Loss: {avg_train_loss:.4f}')

# Transition to Cross-Entropy Loss training
wandb.log({'phase': 'cross_entropy_loss_training'})

# Training loop for Cross-Entropy Loss
for epoch in range(cross_entropy_epochs):
    model.train()  # Set model to training mode
    total_train_loss = 0
    total_valid_loss = 0

    # Training
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = cross_entropy_loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        # Log batch loss
        wandb.log({'cross_entropy_train_loss': loss.item()})

    # Validation
    model.eval()  # Set model to evaluate mode
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = cross_entropy_loss_function(outputs, labels)
            total_valid_loss += loss.item()
            # Log batch validation loss
            wandb.log({'cross_entropy_valid_loss': loss.item()})

    # Calculate average losses
    avg_train_loss = total_train_loss / train_steps
    avg_valid_loss = total_valid_loss / valid_steps
    train_losses.append(avg_train_loss)
    valid_losses.append(avg_valid_loss)
    
    # Log average losses to wandb
    wandb.log({'epoch': epoch + contrastive_epochs,  # offset by the number of contrastive epochs
               'cross_entropy_avg_train_loss': avg_train_loss,
               'cross_entropy_avg_valid_loss': avg_valid_loss})

    # Print training and validation progress
    print(f'Cross-Entropy Epoch [{epoch+1}/{cross_entropy_epochs}], Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_valid_loss:.4f}')

# Visualizing Loss and Accuracy Plot for Train & Fine-Tuning 

In [None]:
plt.figure(figsize=(10,5))
plt.title("Training and Validation Loss")
plt.plot(train_losses,label="train")
plt.plot(valid_losses,label="validation")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Evaluate Fine-Tuned Model
In this cell, we evaluate the fine-tuned model on the test dataset using our classification metrics.

In [9]:
model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

# Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')

# Log metrics to wandb
wandb.log({
    'test_accuracy': accuracy,
    'test_precision': precision,
    'test_recall': recall,
    'test_f1_score': f1
})

print(f'Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}')

In [None]:
# Finish the wandb run after evaluation
wandb.finish()