In [None]:
import os

import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import StratifiedGroupKFold
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

from src.classes.dataset import MRIDataset, MRISubset
from src.classes.models import ResNet50variant
from src.classes.occlusion import OcclusionMapGenerator, OcclusionHeatmap
from src.config import PATH_TO_MODELS, PATH_TO_DATASET_CSV, PATH_TO_DATASET

In [None]:
# Define the model path
MODEL_PATH = os.path.join(PATH_TO_MODELS, "resnet50v.pth")

# Load model and set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ResNet50variant().to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()  # Set model to evaluation mode

# ----------------------------- Load Test Dataset -----------------------------

# Load dataset CSV
df = pd.read_csv(PATH_TO_DATASET_CSV, sep=';', header=0)

# Define class mappings
CLASS_NAMES = ['healthy', 'affected']
ID_TO_NAME = {idx: name for idx, name in enumerate(CLASS_NAMES)}

# Create data mapping: image index -> (image path, label)
data = {
    idx: (os.path.join(PATH_TO_DATASET, ID_TO_NAME[row['label']], str(row['img_name'])), row['label'])
    for idx, row in df.iterrows()
}

# Convert labels and groups to numpy arrays
y = df['label'].to_numpy()
groups = df['group'].to_numpy()

# ----------------------------- Create Train-Test Split -----------------------------

# Initialize Stratified Group K-Fold
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=7)
X = np.array(list(data.keys()))  # Convert dictionary keys to numpy array (image indices)

# Generate train-test split
train_index, test_index = next(sgkf.split(X, y, groups))

# Create full dataset using custom Dataset class
dataset = MRIDataset(data)

# Define transformations for test dataset
test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])

# Create the test dataset subset with transformations
test_dataset = MRISubset(Subset(dataset, test_index), train_bool=False, transform=test_transforms)

# Create a DataLoader for the test set
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# ----------------------------- Apply Occlusion Analysis -----------------------------

# Instantiate occlusion-related classes
occlusion_map_generator = OcclusionMapGenerator(model, device)
occlusion_heatmap = OcclusionHeatmap(model, test_dataset, data)

# Apply hierarchical occlusion on a sample image (assuming `img` is a test sample)
img, _ = test_dataset[0]  # Get first test image
output_map = occlusion_map_generator.hierarchical_occlusion(
    image=img, target_class=1, stride=28, window_size=112, min_window_size=7,
    x_start=0, y_start=0, x_end=256, y_end=256
)

# Apply occlusion heatmap generation for the entire dataset
occlusion_heatmap.apply_occlusion_heatmap()
