In [1]:
# Imports
import os
import random
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

In [2]:
# Function to reset random seeds
def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [3]:
# Custom Dataset Class
class MRIDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y.astype(int)  # Convert labels to integers
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32).permute(2, 0, 1), torch.tensor(self.y[idx], dtype=torch.long)

In [4]:
# Attention Layer
class AttentionLayer(nn.Module):
    def __init__(self, in_channels):
        super(AttentionLayer, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1)

    def forward(self, x):
        # Compute raw attention weights
        attn_scores = self.conv1(x)
        # Reshape to apply softmax over spatial dimensions
        B, C, H, W = attn_scores.shape  # B: Batch, C: Channel, H: Height, W: Width
        attn_scores_flat = attn_scores.view(B, C, -1)  # Flatten H and W
        attn_weights_flat = F.softmax(attn_scores_flat, dim=2)  # Softmax over spatial dimensions
        attn_weights = attn_weights_flat.view(B, C, H, W)  # Reshape back to original dimensions
        return x * attn_weights, attn_weights


# Modified MRI Model with Attention
class AttentionMRIModel(nn.Module):
    def __init__(self):
        super(AttentionMRIModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 100, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.dropout1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv2d(100, 50, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.dropout2 = nn.Dropout(0.3)
        self.attention = AttentionLayer(50)
        self.fc = nn.Linear(50 * 16 * 16, 3)  # Adjust size based on output from conv layers

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = self.dropout2(x)
        x, attn_weights = self.attention(x)  # Apply attention
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x, attn_weights

In [5]:
# Load and preprocess data
path = "../data/processed/mri/"

def load_data():
    with open(f"{path}/img_train.pkl", "rb") as fh:
        data = pickle.load(fh)
    X_train_ = pd.DataFrame(data)["img_array"]

    with open(f"{path}/img_test.pkl", "rb") as fh:
        data = pickle.load(fh)
    X_test_ = pd.DataFrame(data)["img_array"]

    with open(f"{path}/img_y_train.pkl", "rb") as fh:
        data = pickle.load(fh)
    y_train = np.array(pd.DataFrame(data)["label"].values.astype(np.float32)).flatten()

    with open(f"{path}/img_y_test.pkl", "rb") as fh:
        data = pickle.load(fh)
    y_test = np.array(pd.DataFrame(data)["label"].values.astype(np.float32)).flatten()

    y_train = np.where(y_train == 2, -1, y_train)
    y_train = np.where(y_train == 1, 2, y_train)
    y_train = np.where(y_train == -1, 1, y_train)

    y_test = np.where(y_test == 2, -1, y_test)
    y_test = np.where(y_test == 1, 2, y_test)
    y_test = np.where(y_test == -1, 1, y_test)

    X_train = np.array([X for X in X_train_.values])
    X_test = np.array([X for X in X_test_.values])

    return X_train, X_test, y_train, y_test

In [7]:
def visualize_attention(inputs, attn_weights, save_path):
    """
    Visualizes the input MRI slices with corresponding attention maps.
    """
    inputs = inputs.cpu().detach().numpy()  # Convert to numpy array
    attn_weights = F.interpolate(
        attn_weights, size=(72, 72), mode="bilinear", align_corners=False
    )  # Upsample attention weights to match input size
    attn_weights = attn_weights.cpu().detach().numpy()  # Convert to numpy array

    num_slices = inputs.shape[0]  # Number of slices in the batch
    fig, axes = plt.subplots(num_slices, 3, figsize=(12, 4 * num_slices))
    if num_slices == 1:  # Ensure axes are always iterable
        axes = [axes]

    for i in range(num_slices):
        input_slice = inputs[i]  # Shape: (C, H, W)
        if input_slice.shape[0] == 1:  # Grayscale input
            input_slice = input_slice[0]  # Shape: (H, W)

        # Original Input Slice
        if len(input_slice.shape) == 2:  # Grayscale slice
            axes[i][0].imshow(input_slice, cmap="gray")
        else:  # RGB slice
            axes[i][0].imshow(input_slice.transpose(1, 2, 0))  # Convert CHW to HWC
        axes[i][0].set_title(f"Input Slice {i+1}")

        # Attention Map
        axes[i][1].imshow(attn_weights[i, 0], cmap="hot")  # Display single-channel attention
        axes[i][1].set_title(f"Attention Map Slice {i+1}")

        # Weighted Input
        weighted_input = input_slice * attn_weights[i, 0]  # Apply attention weights
        if len(weighted_input.shape) == 2:  # Grayscale weighted input
            axes[i][2].imshow(weighted_input, cmap="hot")
        else:  # RGB weighted input
            axes[i][2].imshow(weighted_input.transpose(1, 2, 0))  # Convert CHW to HWC
        axes[i][2].set_title(f"Weighted Input Slice {i+1}")

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

In [8]:
# Training and Evaluation Functions
def train_model(model, train_loader, device, epochs=20):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model.to(device)
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs, _ = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()


        if epoch % 10 == 9:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

In [23]:
def evaluate_model(model, test_loader, device, class_names):
    path = "../outputs/mri"
    model.eval()
    model.to(device)

    all_preds = []
    all_targets = []
    all_attn_weights = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs, attn_weights = model(data)
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            all_attn_weights.append(attn_weights.cpu().numpy())

    # Convert collected lists to arrays
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_attn_weights = np.concatenate(all_attn_weights, axis=0)  # Shape: (N, C, H, W)

    # 1. Classification Report
    report = classification_report(all_targets, all_preds, target_names=class_names)
    print("Classification Report:\n", report)

    # 2. Confusion Matrix
    cm = confusion_matrix(all_targets, all_preds)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.savefig(f"{path}/mri_attn2_confusion_matrix.png")
    plt.close()

    # 3. Attention Weight Visualization
    # Visualize attention weights for a few test samples
    num_samples_to_visualize = min(5, len(test_loader.dataset))

    for i, (data, target) in enumerate(test_loader):
        if i >= num_samples_to_visualize:
            break

        data, target = data.to(device), target.to(device)
        _, attn_weights = model(data)
        save_path = f"../outputs/mri/attention_weights_sample_{i+1}.png"
        visualize_attention(data[0], attn_weights, save_path)

    # 4. Slice Importance
    # Average attention weights across height and width to determine slice importance
    slice_importances = all_attn_weights.mean(axis=(2, 3))  # Shape: (N, C)
    average_importance = slice_importances.mean(axis=0)  # Shape: (C,)

    print(slice_importances.shape)
    print(average_importance.shape)

    # Print slice importances
    for i, importance in enumerate(average_importance, start=1):
        print(f"Slice {i} Importance: {importance:.4f}")

    # Plot slice importance
    plt.figure(figsize=(8, 6))
    plt.bar(range(1, 4), average_importance, tick_label=["Slice 1", "Slice 2", "Slice 3"], color="skyblue")
    plt.title("Slice Importance")
    plt.xlabel("Slice")
    plt.ylabel("Average Attention Weight")
    plt.savefig(f"{path}/mri_attn2_slice_importance.png")
    plt.close()

    print("Evaluation complete. Results saved:")
    print("- Classification report printed in console.")
    print("- Confusion matrix saved as confusion_matrix.png.")
    print("- Attention weights visualization saved as attention_weights.png.")
    print("- Slice importance plot saved as slice_importance.png.")

In [11]:
# Main driver code
X_train, X_test, y_train, y_test = load_data()

train_dataset = MRIDataset(X_train, y_train)
test_dataset = MRIDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = ["CN", "MCI", "AD"]


In [12]:
model = AttentionMRIModel()
train_model(model, train_loader, device, epochs=250)

Epoch 10/250, Loss: 0.8890260457992554
Epoch 20/250, Loss: 0.6775628924369812
Epoch 30/250, Loss: 0.7630151510238647
Epoch 40/250, Loss: 0.6653193235397339
Epoch 50/250, Loss: 0.22531187534332275
Epoch 60/250, Loss: 0.4359924793243408
Epoch 70/250, Loss: 0.2643098831176758
Epoch 80/250, Loss: 0.47648367285728455
Epoch 90/250, Loss: 0.2895069122314453
Epoch 100/250, Loss: 0.2881888449192047
Epoch 110/250, Loss: 0.32141241431236267
Epoch 120/250, Loss: 0.3178722858428955
Epoch 130/250, Loss: 0.06007321923971176
Epoch 140/250, Loss: 0.06855204701423645
Epoch 150/250, Loss: 0.4204598069190979
Epoch 160/250, Loss: 0.19609060883522034
Epoch 170/250, Loss: 0.20010161399841309
Epoch 180/250, Loss: 0.10602924972772598
Epoch 190/250, Loss: 0.016724903136491776
Epoch 200/250, Loss: 0.07689861953258514
Epoch 210/250, Loss: 0.20947209000587463
Epoch 220/250, Loss: 0.06550080329179764
Epoch 230/250, Loss: 0.09307173639535904
Epoch 240/250, Loss: 0.02349722571671009
Epoch 250/250, Loss: 0.05928821489

In [24]:
evaluate_model(model, test_loader, device, class_names)

Classification Report:
               precision    recall  f1-score   support

          CN       0.96      1.00      0.98        26
         MCI       1.00      1.00      1.00         6
          AD       1.00      0.83      0.91         6

    accuracy                           0.97        38
   macro avg       0.99      0.94      0.96        38
weighted avg       0.97      0.97      0.97        38

(38, 1)
(1,)
Slice 1 Importance: 0.0039
Evaluation complete. Results saved:
- Classification report printed in console.
- Confusion matrix saved as confusion_matrix.png.
- Attention weights visualization saved as attention_weights.png.
- Slice importance plot saved as slice_importance.png.
