In [22]:
# Imports
import os
import random
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import confusion_matrix, classification_report

In [23]:
# Function to reset random seeds
def reset_random_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [24]:
# Custom Dataset Class
class MRIDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y.astype(int)  # Convert labels to integers
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32).permute(2, 0, 1), torch.tensor(self.y[idx], dtype=torch.long)

In [None]:
# Attention Layer
class AttentionLayer(nn.Module):
    def __init__(self, in_channels):
        super(AttentionLayer, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1)

    def forward(self, x):
        # Compute raw attention weights
        attn_scores = self.conv1(x)
        # Reshape to apply softmax over spatial dimensions
        B, C, H, W = attn_scores.shape  # B: Batch, C: Channel, H: Height, W: Width
        attn_scores_flat = attn_scores.view(B, C, -1)  # Flatten H and W
        attn_weights_flat = F.softmax(attn_scores_flat, dim=2)  # Softmax over spatial dimensions
        attn_weights = attn_weights_flat.view(B, C, H, W)  # Reshape back to original dimensions
        return x * attn_weights, attn_weights


# Modified MRI Model with Attention
class AttentionMRIModel(nn.Module):
    def __init__(self):
        super(AttentionMRIModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 100, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.dropout1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv2d(100, 50, kernel_size=3)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.dropout2 = nn.Dropout(0.3)
        self.attention = AttentionLayer(50)
        self.fc = nn.Linear(50 * 16 * 16, 3)  # Adjust size based on output from conv layers

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)
        x = self.dropout2(x)
        x, attn_weights = self.attention(x)  # Apply attention
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x, attn_weights

In [38]:
# Load and preprocess data
path = "../data/processed/mri/"

def load_data():
    with open(f"{path}/img_train.pkl", "rb") as fh:
        data = pickle.load(fh)
    X_train_ = pd.DataFrame(data)["img_array"]

    with open(f"{path}/img_test.pkl", "rb") as fh:
        data = pickle.load(fh)
    X_test_ = pd.DataFrame(data)["img_array"]

    with open(f"{path}/img_y_train.pkl", "rb") as fh:
        data = pickle.load(fh)
    y_train = np.array(pd.DataFrame(data)["label"].values.astype(np.float32)).flatten()

    with open(f"{path}/img_y_test.pkl", "rb") as fh:
        data = pickle.load(fh)
    y_test = np.array(pd.DataFrame(data)["label"].values.astype(np.float32)).flatten()

    y_train = np.where(y_train == 2, -1, y_train)
    y_train = np.where(y_train == 1, 2, y_train)
    y_train = np.where(y_train == -1, 1, y_train)

    y_test = np.where(y_test == 2, -1, y_test)
    y_test = np.where(y_test == 1, 2, y_test)
    y_test = np.where(y_test == -1, 1, y_test)

    X_train = np.array([X for X in X_train_.values])
    X_test = np.array([X for X in X_test_.values])

    return X_train, X_test, y_train, y_test

In [39]:
def plot_confusion_matrix(true_labels, pred_labels, class_names, filename):
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.savefig(filename)
    print(f"saved cm to {filename}")
    plt.close()

In [40]:
def visualize_attention(inputs, attn_weights, save_path):
    """
    Visualizes the input MRI slices with corresponding attention maps.
    """
    inputs = inputs.cpu().detach().numpy()  # Convert to numpy array
    attn_weights = F.interpolate(
        attn_weights, size=(72, 72), mode="bilinear", align_corners=False
    )  # Upsample attention weights to match input size
    attn_weights = attn_weights.cpu().detach().numpy()  # Convert to numpy array

    num_slices = inputs.shape[0]  # Number of slices in the batch
    fig, axes = plt.subplots(num_slices, 3, figsize=(12, 4 * num_slices))
    if num_slices == 1:  # Ensure axes are always iterable
        axes = [axes]

    for i in range(num_slices):
        input_slice = inputs[i]  # Shape: (C, H, W)
        if input_slice.shape[0] == 1:  # Grayscale input
            input_slice = input_slice[0]  # Shape: (H, W)

        # Original Input Slice
        if len(input_slice.shape) == 2:  # Grayscale slice
            axes[i][0].imshow(input_slice, cmap="gray")
        else:  # RGB slice
            axes[i][0].imshow(input_slice.transpose(1, 2, 0))  # Convert CHW to HWC
        axes[i][0].set_title(f"Input Slice {i+1}")

        # Attention Map
        axes[i][1].imshow(attn_weights[i, 0], cmap="hot")  # Display single-channel attention
        axes[i][1].set_title(f"Attention Map Slice {i+1}")

        # Weighted Input
        weighted_input = input_slice * attn_weights[i, 0]  # Apply attention weights
        if len(weighted_input.shape) == 2:  # Grayscale weighted input
            axes[i][2].imshow(weighted_input, cmap="hot")
        else:  # RGB weighted input
            axes[i][2].imshow(weighted_input.transpose(1, 2, 0))  # Convert CHW to HWC
        axes[i][2].set_title(f"Weighted Input Slice {i+1}")

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

In [41]:
# Training and Evaluation Functions
def train_model(model, train_loader, device, epochs=20):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model.to(device)
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs, _ = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()


        if epoch % 10 == 9:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

In [65]:
def calculate_slice_importance_per_sample(attn_weights):
    # Sum attention weights across spatial dimensions (H, W) for each slice
    # Shape of attn_weights: (batch_size, slices, H, W)
    importance = attn_weights.sum(dim=(2, 3))  # Shape: (batch_size, slices)
    
    # Normalize importance across slices
    normalized_importance = importance / importance.sum(dim=1, keepdim=True)  # Sum across slices
    return normalized_importance


def normalize_importance(importance_scores):
    """
    Normalize slice importance values for better interpretability.
    
    Args:
        importance_scores (torch.Tensor): Raw slice importance scores.
        
    Returns:
        torch.Tensor: Normalized importance scores (summing to 1).
    """
    return importance_scores / importance_scores.sum()


def plot_slice_importance(slice_importance, output_path=None):
    """
    Plot and save slice importance scores.
    
    Args:
        slice_importance (torch.Tensor): Normalized slice importance scores.
        output_path (str): Path to save the plot (optional).
    """
    num_slices = len(slice_importance)
    plt.bar(range(1, num_slices + 1), slice_importance.cpu().numpy())
    plt.xlabel('Slice Number')
    plt.ylabel('Importance')
    plt.title('Importance of Each Slice')
    if output_path:
        plt.savefig(output_path)
    plt.show()

In [None]:
def evaluate_model(model, test_loader, device, class_names):
    model.eval()
    true_labels = []
    pred_labels = []
    slice_importance_all = []  # Store slice importance values for all samples

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs, attn_weights = model(data)
            _, preds = torch.max(outputs, 1)
            
            # Calculate slice importance
            slice_importance = calculate_slice_importance_per_sample(attn_weights)
            slice_importance_all.append(slice_importance)
            
            # Print slice importance for debugging
            print(f"Slice Importance (Batch):\n{slice_importance.cpu().numpy()}")


            true_labels.extend(target.cpu().numpy())
            pred_labels.extend(preds.cpu().numpy())
            
            
            # Visualize attention maps for the first sample in this batch
            visualize_attention(data[0], attn_weights, '../outputs/mri/mri_attention_visualization.png')
            
            # Calculate slice importance for this batch
            batch_importance = calculate_slice_importance_per_sample(attn_weights)
            slice_importance_all.append(batch_importance)
    
    print(f"type slice importance all: {type(slice_importance_all[0])}")
    print(f"length of slice important all: {len(slice_importance_all)}")
    print(f"shape of slice importance all [0]: {slice_importance_all[0].shape}")
    print(f"shape of slice importance all [1]: {slice_importance_all[1].shape}")
    
    # Aggregate importance scores across the entire dataset
    slice_importance_all = torch.cat(slice_importance_all, dim=0)  # Combine all batches
    avg_slice_importance = slice_importance_all.mean(dim=0)  # Mean importance across all samples
    normalized_importance = normalize_importance(avg_slice_importance)

    # Print importance scores
    for i, score in enumerate(normalized_importance):
        print(f"Slice {i+1} Importance: {score.item():.4f}")

    # Plot slice importance
    plot_slice_importance(normalized_importance, '../outputs/mri/slice_importance.png')

    # Generate classification report and confusion matrix
    cr = classification_report(true_labels, pred_labels, target_names=class_names, output_dict=True)
    plot_confusion_matrix(true_labels, pred_labels, class_names, '../outputs/mri/mri_attn_confusion_matrix.png')
    
    print(f"Classification Report:\n{classification_report(true_labels, pred_labels, target_names=class_names)}")


In [70]:
# Main driver code
X_train, X_test, y_train, y_test = load_data()

train_dataset = MRIDataset(X_train, y_train)
test_dataset = MRIDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = ["CN", "MCI", "AD"]


In [71]:
model = AttentionMRIModel()
train_model(model, train_loader, device, epochs=250)

Epoch 10/250, Loss: 0.9304105639457703
Epoch 20/250, Loss: 0.6941651701927185
Epoch 30/250, Loss: 0.4358139932155609
Epoch 40/250, Loss: 0.05342188477516174
Epoch 50/250, Loss: 0.07146324962377548
Epoch 60/250, Loss: 0.04892287775874138
Epoch 70/250, Loss: 0.007601427845656872
Epoch 80/250, Loss: 0.03689282760024071
Epoch 90/250, Loss: 0.0018331061583012342
Epoch 100/250, Loss: 0.02023421972990036
Epoch 110/250, Loss: 0.1129530817270279
Epoch 120/250, Loss: 0.009742413647472858
Epoch 130/250, Loss: 0.004372932482510805
Epoch 140/250, Loss: 0.02005944587290287
Epoch 150/250, Loss: 0.08206738531589508
Epoch 160/250, Loss: 0.0292510986328125
Epoch 170/250, Loss: 0.0014368274714797735
Epoch 180/250, Loss: 0.011356942355632782
Epoch 190/250, Loss: 0.013981279917061329
Epoch 200/250, Loss: 0.0009675322799012065
Epoch 210/250, Loss: 0.014309910126030445
Epoch 220/250, Loss: 7.463169458787888e-05
Epoch 230/250, Loss: 0.0005635399138554931
Epoch 240/250, Loss: 0.0037442066241055727
Epoch 250/25

In [72]:
evaluate_model(model, test_loader, device, class_names)

Slice Importance (Batch):
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
Slice Importance (Batch):
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
Slice Importance (Batch):
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
Slice Importance (Batch):
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
Slice Importance (Batch):
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
type slice importance all: <class 'torch.Tensor'>
length of slice important all: 5
shape of slice importance all [0]: torch.Size([8, 1])
shape of slice importance all [1]: torch.Size([8, 1])
Slice 1 Average Importance: 1.0000


tensor([1.], device='cuda:0')