In [None]:
import os

import numpy as np
import pandas as pd
import torch
from pytorch_grad_cam import GradCAM, GradCAMPlusPlus, HiResCAM
from sklearn.model_selection import StratifiedGroupKFold
from torch.utils.data import Subset
from torchvision import transforms

from src.classes.cam import CAMGenerator
from src.classes.dataset import MRIDataset, MRISubset
from src.classes.models import ResNet50variant
from src.config import PATH_TO_MODELS, PATH_TO_DATASET_CSV, PATH_TO_DATASET, PATH_TO_OUTPUT

In [None]:
# Load model from checkpoint
model_checkpoint = os.path.join(PATH_TO_MODELS, "resnet50v.pth")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Instantiate and load the model
model = ResNet50variant().to(device)
model.load_state_dict(torch.load(model_checkpoint, map_location=device))
model.eval()  # Set to evaluation mode

# Load dataset from CSV
df = pd.read_csv(PATH_TO_DATASET_CSV, sep=';', header=0)

# Define class names and mappings
CLASS_NAMES = ['healthy', 'affected']
ID_TO_NAME = {idx: name for idx, name in enumerate(CLASS_NAMES)}

# Create a dictionary mapping image indices to (image path, label)
data_mapping = {
    idx: (os.path.join(PATH_TO_DATASET, ID_TO_NAME[row['label']], str(row['img_name'])), row['label'])
    for idx, row in df.iterrows()
}

# Convert labels and groups to numpy arrays
y = df['label'].to_numpy()
groups = df['group'].to_numpy()

# Set up Stratified Group K-Fold cross-validation
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=7)

# Convert dictionary keys to numpy array (image indices)
X = np.array(list(data_mapping.keys()))

# Generate train-test split
train_index, test_index = next(sgkf.split(X, y, groups))

# Create MRI dataset using the custom dataset class
dataset = MRIDataset(data_mapping)

# Create the test dataset subset
test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])
test_dataset = MRISubset(Subset(dataset, test_index), train_bool=False, transform=test_transforms)

# Define CAM methods
CAM_METHODS = {
    "GradCAM": GradCAM,
    "GradCAMPlusPlus": GradCAMPlusPlus,
    "HiResCAM": HiResCAM
}

# Define target layers (adjust based on model architecture)
target_layers = [model.conv]

# Iterate over all CAM methods
for cam_name, cam_method in CAM_METHODS.items():
    print(f"Generating {cam_name} visualizations...")

    # Create the CAM object
    cam = cam_method(model=model, target_layers=target_layers)

    # Define output path for each CAM type
    cam_output_path = os.path.join(PATH_TO_OUTPUT, cam_name)
    os.makedirs(cam_output_path, exist_ok=True)

    # Instantiate CAM Generator
    cam_generator = CAMGenerator(model, cam, output_path=cam_output_path, class_mapping=ID_TO_NAME)

    # Generate and save CAM images
    cam_generator.save_cam_images(test_dataset, data_mapping=data_mapping, as_numpy=False)

print("All CAM visualizations generated successfully.")