In [None]:
import os
import shutil

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchvision.transforms.v2 as transforms
from sklearn.model_selection import StratifiedGroupKFold
from torch.utils.data import Subset, DataLoader

from src.classes.Dataset import MRIDataset, MRISubset
from src.classes.Models import ResNet50variant
from src.config import PATH_TO_DATASET_CSV, PATH_TO_DATASET, PATH_TO_MODELS, PATH_TO_OUTPUT
from src.functions.train_eval import evaluate
from src.functions.utils_train import class_results

# Model Training and Evaluation

In [None]:
# Define the input path for images
input_path = PATH_TO_DATASET

# Load the dataset from a CSV file into a pandas DataFrame
df = pd.read_csv(PATH_TO_DATASET_CSV, sep=';', header=0)

# List of class names
class_names = ['healthy', 'affected']

# Create a dictionary that maps class ID to class name
id2name = {idx: c for idx, c in enumerate(class_names)}

# Create a dictionary where the key is the image index and the value is a tuple of the image path and the corresponding label
data = {
    idx: (os.path.join(input_path, id2name[row['label']], str(row['img_name'])), row['label'])
    for idx, row in df.iterrows()
}

# Convert labels to a numpy array
y = df['label'].to_numpy()

# Convert group information to a numpy array
groups = df['group'].to_numpy()

# Set the number of folds for cross-validation
k_folds = 5

# Create a StratifiedGroupKFold generator to split the data while preserving the class distribution and group structure
sgkf = StratifiedGroupKFold(n_splits=k_folds, shuffle=True, random_state=7)

# Convert the dictionary keys to a numpy array (these represent the indices of the images)
X = np.array(list(data.keys()))

# Generate the train-test split based on the StratifiedGroupKFold
train_index, test_index = next(sgkf.split(X, y, groups))

In [None]:
# Set device to GPU if available, otherwise fallback to CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define test data transformations
test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])

# Load the pre-trained model
path_to_model = os.path.join(PATH_TO_MODELS, "resnet50v.pth")
model = ResNet50variant()
model.load_state_dict(torch.load(path_to_model, map_location=device))

# Create the dataset
dataset = MRIDataset(data)

# Define class names and mapping dictionary
class_names = ['healthy', 'affected']
id2name = {idx: c for idx, c in enumerate(class_names)}

# Create the test dataset subset with transformations applied
test_dataset = MRISubset(Subset(dataset, test_index), train_bool=False, transform=test_transforms)

# Create a dataloader for the test set
dataloaders = {"test": DataLoader(test_dataset, batch_size=32)}

# Evaluate the model on the test set
res = evaluate(model, dataloaders["test"])

# Print the classification results
class_results(res)

## Analysis of CAM Methods

In [None]:
def cam_image(model, cam, img, target=None, transform=None, plot=False):
    """
    Generates a Class Activation Map (CAM) for a given image.

    :param model: Trained PyTorch model.
    :param cam: GradCAM object.
    :param img: Input image tensor.
    :param target: Target class label. If None, uses model prediction.
    :param transform: Image transformation function.
    :param plot: If True, displays the CAM visualization.
    :return: Processed CAM image.
    """
    test_img = img.unsqueeze(0)
    if target is None:
        target = evaluate_img(model, test_img).item()
    target = [ClassifierOutputTarget(target)]
    grayscale_cams = cam(input_tensor=test_img, targets=target)
    grayscale_cam = grayscale_cams[0]

    if transform:
        img = transform(img)

    rgb_img = img.repeat(3, 1, 1).numpy().transpose((1, 2, 0))
    cam_img = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

    if plot:
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.imshow(rgb_img)
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(cam_img)
        plt.axis('off')
        plt.show()

    return cam_img


def save_cam(test, data, output_path, model, cam):
    """
    Generates and saves CAM images for a dataset.

    :param test: Dataset subset.
    :param data: Dictionary mapping indices to image paths and labels.
    :param output_path: Directory to save CAM images.
    :param model: Trained PyTorch model.
    :param cam: GradCAM object.
    """
    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    os.mkdir(output_path)

    for i in range(len(test)):
        img, label = test[i]
        test_idx = test.subset.indices[i]
        path, _ = data[test_idx]
        name = os.path.basename(path)
        pred = evaluate_img(model, img.unsqueeze(0)).item()
        cam_img = cam_image(model, cam, img, target=pred)

        subdir = id2name[pred]
        final_output_path = os.path.join(output_path, subdir)
        os.makedirs(final_output_path, exist_ok=True)

        cv2.imwrite(os.path.join(final_output_path, name), cv2.cvtColor(cam_img, cv2.COLOR_RGB2BGR))
    print('Save completed')


def save_cam_array(test, data, output_path, model, cam):
    """
    Saves CAM images as numpy arrays for further processing.

    :param test: Dataset subset.
    :param data: Dictionary mapping indices to image paths and labels.
    :param output_path: Directory to save CAM numpy arrays.
    :param model: Trained PyTorch model.
    :param cam: GradCAM object.
    """
    if os.path.exists(output_path):
        shutil.rmtree(output_path)
    os.mkdir(output_path)

    for i in range(len(test)):
        img, label = test[i]
        test_idx = test.subset.indices[i]
        path, _ = data[test_idx]
        name = os.path.basename(path)

        pred = evaluate_img(model, img.unsqueeze(0)).item()
        test_img = img.unsqueeze(0)
        target = [ClassifierOutputTarget(pred)]
        grayscale_cams = cam(input_tensor=test_img, targets=target)
        grayscale_cam = grayscale_cams[0]

        np.save(os.path.join(output_path, f'{name}.npy'), grayscale_cam)
    print('Save completed')

## GradCAM

In [None]:
target_layers = [model.conv]
cam = GradCAM(model=model, target_layers=target_layers)

output_path = os.path.join(PATH_TO_OUTPUT, 'GradCAM')
preds, labels = save_cam(test, data, output_path, model, cam)

## GradCAM++

In [None]:
target_layers = [model.conv]
cam = GradCAMPlusPlus(model=model, target_layers=target_layers)

output_path = os.path.join(PATH_TO_OUTPUT, 'GradCAMPlusPlus')
save_cam(test, data, output_path, model, cam)

## HiresCAM

In [None]:
target_layers = [model.conv]
cam = HiResCAM(model=model, target_layers=target_layers)

output_path = os.path.join(PATH_TO_OUTPUT, 'HiResCAM')
save_cam(test, data, output_path, model, cam)

## Occlusion

In [None]:
def hierarchical_occlusion(model, image, target_class, stride, window_size, min_window_size, x_start, y_start, x_end,
                           y_end, zero=False):
    """
    Applies hierarchical occlusion on the image to evaluate the importance of each region for the target class.

    :param model: The model used for evaluation.
    :param image: The image tensor to analyze.
    :param target_class: The target class to focus on.
    :param stride: The stride of the sliding window.
    :param window_size: The initial window size for occlusion.
    :param min_window_size: The minimum window size to stop occlusion.
    :param x_start, y_start, x_end, y_end: The coordinates defining the area to apply occlusion on.
    :param zero: Whether to zero-out the occluded areas or set them to 1.
    :return: The aggregated occlusion map.
    """
    # Start with the initial occlusion map
    output, areas = class_occlusion(model, image, target_class, stride, window_size, x_start, y_start, x_end, y_end,
                                    None, zero)

    # Perform hierarchical occlusion by reducing window size iteratively
    while len(areas) > 0 and window_size > min_window_size:
        stride = max(stride // 2, 1)  # Reduce stride
        window_size = window_size // 2  # Halve the window size
        diff_map, areas = class_occlusion(model, image, target_class, stride, window_size, x_start, y_start, x_end,
                                          y_end, areas, zero)
        output += diff_map  # Accumulate the difference map

    return output


def class_occlusion(model, image, target_class, stride, window_size, x_start, y_start, x_end, y_end, old_areas, zero):
    """
    Performs class occlusion on an image and computes the difference map by occluding different regions.

    :param model: The model to evaluate.
    :param image: The input image tensor.
    :param target_class: The class to evaluate the importance of.
    :param stride: The stride for the sliding window.
    :param window_size: The size of the occlusion window.
    :param x_start, y_start, x_end, y_end: The coordinates defining the region of interest.
    :param old_areas: Previous occluded areas to consider.
    :param zero: Whether to zero out the occluded region or set it to 1.
    :return: The difference map and updated areas for further occlusion.
    """
    original_label = evaluate_img(model, image.unsqueeze(0))  # Get the original class prediction

    _, H, W = image.shape  # Get the height and width of the image

    # Initialize the difference map with zeros
    diff_map = np.zeros((H, W), dtype=float)
    new_areas = []  # To store newly occluded regions

    # Iterate over the image within the specified region, applying the occlusion window
    for i in range(y_start, y_end, max(stride, 1)):
        y_occlusion_end = min(y_end, i + window_size)  # Ensure occlusion does not exceed bounds
        for j in range(x_start, x_end, max(stride, 1)):
            x_occlusion_end = min(x_end, j + window_size)

            # Check if the current window is inside previously occluded areas
            if is_inside(i, j, old_areas, window_size * 2):
                # Create a copy of the image for occlusion
                occluded_image = image.clone()

                # Apply occlusion (zero out or set to 1 based on the 'zero' flag)
                if zero:
                    occluded_image[:, i:y_occlusion_end, j:x_occlusion_end] = 0
                else:
                    occluded_image[:, i:y_occlusion_end, j:x_occlusion_end] = 1 if target_class != 1 else 0

                # Predict the class after occlusion
                occluded_label = evaluate_img(model, occluded_image.unsqueeze(0))

                # Calculate the absolute difference in class probability
                diff = abs(original_label - occluded_label)
                diff_map[i:y_occlusion_end, j:x_occlusion_end] = np.maximum(diff, diff_map[i:y_occlusion_end,
                                                                                  j:x_occlusion_end])

                # If the difference is significant, store the area for further occlusion
                if diff == 1:
                    new_areas.append((i, j))

    return diff_map, new_areas  # Return the updated difference map and areas


def is_inside(y, x, areas, window_size):
    """
    Checks if the given position (y, x) is inside any of the previously occluded areas.

    :param y, x: The position to check.
    :param areas: List of previously occluded areas.
    :param window_size: The size of the occlusion window.
    :return: True if inside any area, False otherwise.
    """
    if areas is None:  # For the first iteration, all positions are valid
        return True
    for (ay, ax) in areas:
        if ay <= y < ay + window_size and ax <= x < ax + window_size:
            return True
    return False

In [None]:
def apply_occlusion_heatmap(model, test_data, data):
    """
    Applies the occlusion heatmap to the test dataset images using the specified model.
    """
    stats = {"p1": 0, "p0": 0, "count_affected": 0, "count_healthy": 0}

    for i in range(len(test_data)):
        img, label = test_data[i]
        test_idx = test_data.subset.indices[i]
        path, _ = data[test_idx]
        name = path.split('\\')[-1]

        _, H, W = img.shape
        pred = evaluate_img(model, img.unsqueeze(0)).item()

        rgb_img = img.repeat(3, 1, 1).numpy().transpose((1, 2, 0))

        occlusion_map = load_occlusion_map(name, "Resnet18v", suffix="" if pred == 1 else "-2")
        if np.any(occlusion_map >= 1):
            stats["count_affected" if pred == 1 else "count_healthy"] += 1
        else:
            occlusion_map = np.zeros((H, W), dtype=float)

        stats["p1" if pred == 1 else "p0"] += 1
        display_heatmap(rgb_img, occlusion_map, name)

    print(
        f"Pred 1: {stats['p1']}, Pred 0: {stats['p0']}, Count 1: {stats['count_affected']}, Count 0: {stats['count_healthy']}")


def hierarchical_occlusion_with_params(model, test_data, data):
    """
    Computes hierarchical occlusion with different parameters for images without occlusion output.
    """
    stats = {"p1": 0, "p0": 0, "count_affected": 0, "count_healthy": 0}

    for i in range(len(test_data)):
        img, label = test_data[i]
        test_idx = test_data.subset.indices[i]
        path, _ = data[test_idx]
        name = path.split('\\')[-1]

        _, H, W = img.shape
        pred = evaluate_img(model, img.unsqueeze(0)).item()

        rgb_img = img.repeat(3, 1, 1).numpy().transpose((1, 2, 0))
        occlusion_map = load_or_compute_occlusion(name, "Resnet18v", PATH_TO_MODELS, img, pred, H, W,
                                                  suffix="" if pred == 1 else "-2")
        if np.any(occlusion_map >= 1):
            stats["count_affected" if pred == 1 else "count_healthy"] += 1

        stats["p1" if pred == 1 else "p0"] += 1
        display_heatmap(rgb_img, occlusion_map, name)

    print(
        f"Pred 1: {stats['p1']}, Pred 0: {stats['p0']}, Count 1: {stats['count_affected']}, Count 0: {stats['count_healthy']}")


def load_occlusion_map(name, model_variant, suffix=""):
    map_path = os.path.join(PATH_TO_MODELS, model_variant, f"{name}{suffix}.npy")
    return normalize(np.load(map_path)) if os.path.exists(map_path) else np.zeros((256, 256), dtype=float)


def load_or_compute_occlusion(name, model_variant, img, pred, H, W, suffix=""):
    map_path = os.path.join(PATH_TO_MODELS, "Occlusion2" if pred == 1 else "Occlusion4", model_variant,
                            f"{name}{suffix}.npy")
    if os.path.exists(map_path):
        return normalize(np.load(map_path))
    return hierarchical_occlusion(model, img, pred, stride=28, window_size=112, min_window_size=7, x_start=0, y_start=0,
                                  x_end=W, y_end=H, zero=True)


def display_heatmap(rgb_img, occlusion_map, name):
    heatmap = cv2.applyColorMap(np.uint8(255 * occlusion_map), cv2.COLORMAP_PARULA)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    final_img = cv2.addWeighted(rgb_img, 1, heatmap, 0.5, 0)
    final_img = normalize(final_img)
    plt.imshow(final_img)
    plt.title(name)
    plt.colorbar()
    plt.show()


def normalize(image):
    return image / np.max(image) if np.max(image) > 0 else image

In [None]:
# Load models
m18 = ResNet18variant()
m18.load_state_dict(torch.load(os.path.join(PATH_TO_MODELS, "resnet18v.pth"), map_location=torch.device('cpu')))

m50 = ResNet50variant()
m50.load_state_dict(torch.load(os.path.join(PATH_TO_MODELS, "resnet50v.pth"), map_location=torch.device('cpu')))

# Choose the model (ResNet18 in this case)
model = m18

# Apply occlusion heatmap to test data
apply_occlusion_heatmap(model, test, data)

# Apply hierarchical occlusion with different parameters to test data
hierarchical_occlusion_with_params(model, test, data)