In [None]:
import json
import os
import shutil

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.v2 as transforms
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import StratifiedGroupKFold
from torch.utils.data import Subset, DataLoader

from src.classes.Dataset import MRIDataset, MRISubset
from src.classes.Models import ResNet50variant, ResNet18variant
from src.config import PATH_TO_DATASET, PATH_TO_DATASET_CSV, PATH_TO_OUTPUT
from src.functions.train_eval import train_model, train_final_model, evaluate, evaluate_img
from src.functions.utils_train import oversampling, plot_results, class_results

# Model Training and Evaluation

In [None]:
# Define input paths
input_path = PATH_TO_DATASET
df = pd.read_csv(PATH_TO_DATASET_CSV, sep=';', header=0)

# Define class names and mapping
class_names = ['healthy', 'affected']
id2name = {idx: c for idx, c in enumerate(class_names)}

# Generate dataset mapping
data = {idx: (os.path.join(input_path, id2name[row['label']], str(row['img_name'])), row['label']) for idx, row in
        df.iterrows()}

# Convert labels and groups to numpy arrays
y = df['label'].to_numpy()
groups = df['group'].to_numpy()

# Set up cross-validation with stratified group splitting
k_folds = 5
sgkf = StratifiedGroupKFold(n_splits=k_folds, shuffle=True, random_state=7)
X = np.array(list(data.keys()))

# Perform train-test split
train_index, test_index = next(sgkf.split(X, y, groups))

# Print dataset distribution information
print(f"Train groups: {set(groups[train_index])}, Test groups: {set(groups[test_index])}")
print(f"Dataset size: {len(data)}")
print(
    f"Train size: {len(train_index)}, Affected: {np.count_nonzero(y[train_index] == 1)}, Healthy: {np.count_nonzero(y[train_index] == 0)}")
print(
    f"Test size: {len(test_index)}, Affected: {np.count_nonzero(y[test_index] == 1)}, Healthy: {np.count_nonzero(y[test_index] == 0)}")


In [None]:
# Define training parameters
k_folds = 10
learning_rate = 0.001
momentum = 0.9
epochs = 100
patience = 5
delta = 0.006
title = 'train-resnet50v'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define transformations
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.GaussianBlur(5, sigma=(0.1, 0.5)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])

# Load dataset
dataset = MRIDataset(data)

# Define class mapping
class_names = ['healthy', 'affected']
id2name = {idx: c for idx, c in enumerate(class_names)}
num_classes = 2

# Define loss function
criterion = torch.nn.CrossEntropyLoss()

# Initialize metrics
accuracy, precision, recall, f1, training_epochs = [], [], [], [], []

# Set up Stratified Group K-Fold cross-validation
sgkf = StratifiedGroupKFold(n_splits=k_folds, shuffle=True, random_state=1)
figure, axis = plt.subplots(k_folds, 2, figsize=(20, 30))
figure.tight_layout(pad=5.0)

# Perform cross-validation
for fold, (new_train_index, valid_index) in enumerate(sgkf.split(X[train_index], y[train_index], groups[train_index])):
    # Perform oversampling
    new_train_index_resampled, new_train_y_resampled = oversampling(
        X[train_index][new_train_index], y[train_index][new_train_index], seed=fold, sampling_strategy=1
    )

    print(f'\n--------------------------------\nFOLD {fold}\n--------------------------------')
    print(
        f'Training set: {len(new_train_index_resampled)}, Validation set: {len(valid_index)}, Test set: {len(test_index)}')
    print(
        f'Training + Validation - Affected: {np.count_nonzero(y[train_index][valid_index] == 1) + np.count_nonzero(new_train_y_resampled == 1)}, Healthy: {np.count_nonzero(y[train_index][valid_index] == 0) + np.count_nonzero(new_train_y_resampled == 0)}')
    print(
        f'Training set - Affected: {np.count_nonzero(new_train_y_resampled == 1)}, Healthy: {np.count_nonzero(new_train_y_resampled == 0)}')
    print(
        f'Validation set - Affected: {np.count_nonzero(y[train_index][valid_index] == 1)}, Healthy: {np.count_nonzero(y[train_index][valid_index] == 0)}')
    print(
        f'Test set - Affected: {np.count_nonzero(y[test_index] == 1)}, Healthy: {np.count_nonzero(y[test_index] == 0)}')

    # Prepare datasets and dataloaders
    train = MRISubset(Subset(dataset, new_train_index_resampled), train_bool=True, transform=train_transforms)
    valid = MRISubset(Subset(dataset, X[train_index][valid_index]), train_bool=False, transform=test_transforms)
    datasets = {'train': train, 'valid': valid}
    dataloaders = {x: DataLoader(datasets[x], batch_size=32, shuffle=True) for x in ['train', 'valid']}

    # Initialize and train the model
    model = ResNet50variant().to(device)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
    model, res, history = train_model(fold, model, criterion, optimizer, epochs, patience, delta, title)
    plot_results(history, axis, fold)

    # Store results
    accuracy.append(accuracy_score(res['labels'], res['preds']))
    precision.append(precision_score(res['labels'], res['preds']))
    recall.append(recall_score(res['labels'], res['preds']))
    f1.append(f1_score(res['labels'], res['preds']))
    training_epochs.append(history['epochs'])

# Print cross-validation results
print(
    f'\n--------------------------------\nK-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS\n--------------------------------')
for fold in range(k_folds):
    print(
        f'Fold {fold}: Accuracy: {accuracy[fold]}, Precision: {precision[fold]}, Recall: {recall[fold]}, F1: {f1[fold]}')

print(
    f'Average accuracy: {np.mean(accuracy)}, Average precision: {np.mean(precision)}, Average recall: {np.mean(recall)}, Average F1: {np.mean(f1)}')
print(f'Average training epochs: {np.mean(training_epochs)}')
print(
    f'Variance of accuracy: {np.var(accuracy)}, Variance of precision: {np.var(precision)}, Variance of recall: {np.var(recall)}, Variance of F1: {np.var(f1)}')

# Save results
figure.savefig(f'{input_path}/models/{title}/plot.png')

results = {
    'title': title,
    'learning rate': learning_rate,
    'momentum': momentum,
    'patience': patience,
    'delta': delta,
    'final': {'accuracy': np.mean(accuracy), 'precision': np.mean(precision), 'recall': np.mean(recall),
              'f1': np.mean(f1), 'avg epochs': np.mean(training_epochs)},
    'variance': {'accuracy': np.var(accuracy), 'precision': np.var(precision), 'recall': np.var(recall),
                 'f1': np.var(f1)}
}
for fold in range(k_folds):
    results[str(fold)] = {'accuracy': accuracy[fold], 'precision': precision[fold], 'recall': recall[fold],
                          'f1': f1[fold]}

with open(f"{input_path}/models/{title}/results.json", "w") as outfile:
    json.dump(results, outfile)

avg_epochs = np.mean(training_epochs)


In [None]:
# Final model
learning_rate = 0.001
momentum = 0.9
epochs = round(avg_epochs)  # Hyperparameter

title = 'final-resnet50v'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Data Augmentation for Training and Testing
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.GaussianBlur(5, sigma=(0.1, 0.5)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])

dataset = MRIDataset(data)

# Mapping class labels to names
class_names = ['healthy', 'affected']
id2name = {idx: c for idx, c in enumerate(class_names)}

num_classes = 2
criterion = nn.CrossEntropyLoss()

accuracy = []
precision = []
recall = []
f1 = []

# Oversampling
train_index_resampled, train_y_resampled = oversampling(X[train_index], y[train_index], seed=0, sampling_strategy=1)

# Dataset Information
print('Dataset length:', len(data))
print('Training set length:', len(train_index))
print('Affected cases in training set:', np.count_nonzero(y[train_index] == 1))
print('Healthy cases in training set:', np.count_nonzero(y[train_index] == 0), '\n')

print('Resampled training set length:', len(train_index_resampled))
print('Affected cases in resampled training set:', np.count_nonzero(y[train_index_resampled] == 1))
print('Healthy cases in resampled training set:', np.count_nonzero(y[train_index_resampled] == 0), '\n')

print('Test set length:', len(test_index))
print('Affected cases in test set:', np.count_nonzero(y[test_index] == 1))
print('Healthy cases in test set:', np.count_nonzero(y[test_index] == 0), '\n')

# DataLoader Setup
train = MRISubset(Subset(dataset, train_index_resampled), train_bool=True, transform=train_transforms)
test = MRISubset(Subset(dataset, test_index), train_bool=False, transform=test_transforms)

datasets = {'train': train, 'test': test}
dataset_sizes = {x: len(datasets[x]) for x in ['train', 'test']}

dataloaders = {
    x: DataLoader(datasets[x], batch_size=32, shuffle=True) if x == 'train' else DataLoader(datasets[x], batch_size=32,
                                                                                            shuffle=False) for x in
    ['train', 'test']
}

# Model Setup
resnet50v = ResNet50variant()
resnet50v = resnet50v.to(device)

optimizer = optim.SGD(resnet50v.parameters(), lr=learning_rate, momentum=momentum)

# Training the Model
resnet50v = train_final_model(resnet50v, criterion, optimizer, epochs, title)

# Model Evaluation
res = evaluate(resnet50v, dataloaders['test'])
class_results(res)

# Print Evaluation Metrics
print()
print(f'Accuracy: {accuracy_score(res["labels"], res["preds"])}')
print(f'Precision: {precision_score(res["labels"], res["preds"])}')
print(f'Recall: {recall_score(res["labels"], res["preds"])}')
print(f'F1: {f1_score(res["labels"], res["preds"])}')

# Save Results to File
results = {}
results['title'] = title
results['learning rate'] = learning_rate
results['momentum'] = momentum
results['epochs'] = epochs

# Final Stats
stats = {}
stats['accuracy'] = accuracy_score(res["labels"], res["preds"])
stats['precision'] = precision_score(res["labels"], res["preds"])
stats['recall'] = recall_score(res["labels"], res["preds"])
stats['f1'] = f1_score(res["labels"], res["preds"])
results['final'] = stats

# Saving results to JSON
with open(f"{input_path}/models/{title}/results.json", "w") as outfile:
    json.dump(results, outfile)

&nbsp;
&nbsp;
&nbsp;


## SHAP

In [None]:
# Prediction function
def my_predict(img: np.ndarray) -> torch.Tensor:
    model.eval()
    img = nhwc_to_nchw(torch.Tensor(img))  # Convert NHWC to NCHW format
    img = img.to(device)
    output = model(img)
    return output

# Conversion from NCHW to NHWC
def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 4:
        return x if x.shape[3] == 3 else x.permute(0, 2, 3, 1)
    elif x.dim() == 3:
        return x if x.shape[2] == 3 else x.permute(1, 2, 0)
    return x

# Conversion from NHWC to NCHW
def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 4:
        return x if x.shape[1] == 3 else x.permute(0, 3, 1, 2)
    elif x.dim() == 3:
        return x if x.shape[0] == 3 else x.permute(2, 0, 1)
    return x

In [None]:
# Hyperparameters
topk = 1
batch_size = 50
n_evals = 5000

input_path = PATH_TO_DATASET
output_path = PATH_TO_OUTPUT

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define test transformations
test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToDtype(torch.float32),
    transforms.ToTensor()
])

# Dataset and Subset creation
dataset = MRIDataset(data)
test = MRISubset(Subset(dataset, test_index), train_bool=False, transform=test_transforms)

# Class names mapping
class_names = ['healthy', 'affected']

# Prepare shap values for all test samples
shap_test = []
for i in range(len(test)):
    img, _ = test[i]  # Get image (numpy array)
    shap_test.append(img)

shap_test = torch.stack(shap_test)
shap_test = nchw_to_nhwc(shap_test)

# Recreate output folder if it exists
if os.path.exists(output_path):
    shutil.rmtree(output_path)
os.mkdir(output_path)

# Load pre-trained model
model = ResNet50variant()
model.load_state_dict(
    torch.load(f"{input_path}/models/final-resnet50v/saved_model-final.pth", map_location=torch.device('cpu'))
)
model = model.to(device)

# Masker for image
masker_blur = shap.maskers.Image("blur(64,64)", shap_test[0].shape)

# Initialize the SHAP explainer
explainer = shap.Explainer(mypredict, masker_blur, output_names=class_names, seed=11)

# Get image names from the dataset
img_names = [data[k][0].split('/')[-1] for k in test_index]

# Store shap values and generate plots
shap_data = []

for i in range(len(test)):
    model.eval()
    input, _ = test[i]
    input = input.to(device)

    # Predict class
    predicted_class_index = torch.argmax(model(input.unsqueeze(0)).detach(), dim=1).item()

    # Compute SHAP values
    shap_values = explainer(
        shap_test[i].unsqueeze(0),
        max_evals=n_evals,
        batch_size=batch_size,
        outputs=shap.Explanation.argsort.flip[:topk]
    )

    shap_values.data = shap_values.data.cpu().numpy()
    shap_data.append(shap_values)

    # Create subdirectory based on predicted class
    subdir = id2name[predicted_class_index]
    final_output_path = os.path.join(output_path, subdir)
    os.makedirs(final_output_path, exist_ok=True)

    # Plot SHAP image
    shap.image_plot(shap_values, show=False)
    plt.savefig(f'{final_output_path}/{img_names[i]}')

    # Save SHAP values as numpy arrays
    numpy_output_path = os.path.join(output_path, 'shap_values')
    os.makedirs(numpy_output_path, exist_ok=True)
    np.save(os.path.join(numpy_output_path, f'{img_names[i]}.npy'), shap_values.values)


## Ensemble Method

In [None]:
# Function to discretize an array with a sliding window
def discretize(arr: np.ndarray, window_size: int) -> np.ndarray:
    """
    Discretizes an image (2D array) by averaging values in a sliding window.

    :param arr: Input 2D numpy array.
    :param window_size: Size of the window used to average the values.
    :return: 2D numpy array with averaged values.
    """
    H, W = arr.shape
    output = np.zeros((H, W), dtype=np.float32)

    # Iterate over the array in steps of window_size
    for i in range(0, H, window_size):
        y_end = min(i + window_size, H)
        for j in range(0, W, window_size):
            x_end = min(j + window_size, W)
            output[i:y_end, j:x_end] = np.mean(arr[i:y_end, j:x_end])

    return output

# Function to add text to an image
def add_text_to_image(img: np.ndarray, text: str, space: int = 10,
                      color: tuple = (255, 255, 255), thickness: int = 2,
                      scale: float = 0.5) -> np.ndarray:
    """
    Adds centered text to an image.

    :param img: Input image (numpy array).
    :param text: The text to be added.
    :param space: Space from the top of the image for the text.
    :param color: Color of the text in BGR format (default is white).
    :param thickness: Thickness of the text (default is 2).
    :param scale: Scale of the text (default is 0.5).
    :return: The image with added text.
    """
    font = cv2.FONT_HERSHEY_SIMPLEX
    text_size, _ = cv2.getTextSize(text, font, scale, thickness)
    text_width, text_height = text_size

    # Calculate position to center text horizontally and place it vertically
    height, width, _ = img.shape
    x = (width - text_width) // 2
    y = text_height + space if space else text_height + 10  # Top margin

    # Add the text to the image
    cv2.putText(img, text, (x, y), font, scale, color, thickness)

    return img

# Function to normalize an array to a range between 0 and 1
def normalize(arr: np.ndarray) -> np.ndarray:
    """
    Normalizes the array values to the range [0, 1].

    :param arr: Input 2D or 1D numpy array.
    :return: Normalized numpy array.
    """
    if np.any(arr):
        arr = np.maximum(0, arr)  # Ensure non-negative values
        arr = arr - arr.min()  # Shift values to start from 0
        return arr / (arr.max() - arr.min())  # Scale to [0, 1]
    return arr  # Return the input array if it is empty or all zeros


In [None]:
# Utility function to apply threshold and normalize values
def threshold(img: np.ndarray, res: list, thres: float, binary: bool = False, size: int = 14) -> np.ndarray:
    """
    Applies a threshold to the ensemble of results and returns a combined output.

    :param img: The input image (not used directly in the function).
    :param res: List of result arrays to combine.
    :param thres: Threshold value for inclusion.
    :param binary: Whether to return a binary mask (default: False).
    :param size: Size parameter (not used directly).
    :return: Combined output array, either binary or normalized.
    """
    H, W = res[0].shape
    output = np.zeros((H, W), dtype=float)
    mask = np.zeros((H, W), dtype=float)

    # Apply thresholding across each result array
    for arr in res:
        normalized_arr = normalize(arr)
        mask += np.where(normalized_arr >= thres, 1, 0)
        output += np.where(normalized_arr >= thres, normalized_arr, 0)

    output = output / len(res)
    majority = np.where(mask >= len(res) - 1, 1, 0)

    return majority if binary else normalize(output) * majority


# Utility function to average the results of multiple arrays
def average(res: list, binary: bool = False, size: int = 14) -> np.ndarray:
    """
    Averages the results across multiple arrays.

    :param res: List of result arrays to average.
    :param binary: Whether to return a binary mask (default: False).
    :param size: Size parameter (not used directly).
    :return: Averaged result array, either binary or normalized.
    """
    H, W = res[0].shape
    output = np.zeros((H, W), dtype=float)
    mask = np.zeros((H, W), dtype=float)

    # Average the values across the result arrays
    for arr in res:
        normalized_arr = normalize(arr)
        output += np.where(normalized_arr > 0, normalized_arr, 0)
        mask += np.where(normalized_arr > 0, 1, 0)

    output = output / len(res)
    majority = np.where(mask >= len(res) / 2, 1, 0)

    if binary:
        return majority
    else:
        output = output * majority
        return np.where(output > 0.5, output, 0)


# Utility function to find intersection of multiple result arrays
def intersection(res: list, binary: bool = False, size: int = 14) -> np.ndarray:
    """
    Returns the intersection of results from multiple arrays.

    :param res: List of result arrays to combine.
    :param binary: Whether to return a binary mask (default: False).
    :param size: Size parameter (not used directly).
    :return: Intersection result array, either binary or normalized.
    """
    H, W = res[0].shape
    output = np.zeros((H, W), dtype=float)
    mask = np.zeros((H, W), dtype=float)

    # Calculate intersection by considering values above a certain threshold
    for arr in res:
        normalized_arr = normalize(arr)
        output += np.where(normalized_arr > 0.2, normalized_arr, 0)
        mask += np.where(normalized_arr > 0.2, 1, 0)

    output = output / len(res)
    majority = np.where(mask == len(res), 1, 0)

    if binary:
        return majority
    else:
        output = output * majority
        return np.where(output, output, 0)


In [None]:
# Clear and create output directory
if os.path.exists(output_path):
    shutil.rmtree(output_path)
os.mkdir(output_path)

# Image and label initialization
img_names = []
labels = []
for k in test_index:
    name, label = data[k]
    labels.append(label)
    subname = name.split('/')
    img_names.append(subname[-1])

# Model initialization
m18 = ResNet18variant()
m18.load_state_dict(torch.load('drive/MyDrive/Tesi/data_tif_no_background/models/final-resnet18v/saved_model-final.pth', map_location=torch.device(device)))
m18 = m18.to(device)

m50 = ResNet50variant()
m50.load_state_dict(torch.load('drive/MyDrive/Tesi/data_tif_no_background/models/final-resnet50v/saved_model-final.pth', map_location=torch.device(device)))
m50 = m50.to(device)

models = [('Resnet18v', m18), ('Resnet50v', m50)]

# Set size for discretization
size = 7

# Process each image and generate ensemble results
for n, name in enumerate(img_names):

    # Process each model's predictions and explainability
    for m, model in models:

        img_orig, _ = test.__getitem__(n)
        rgb_img = img_orig.repeat(3, 1, 1).numpy().transpose((1, 2, 0))

        pred = evaluate_img(model, img_orig.unsqueeze(0)).item()

        # Load GradCAM, GradCAM++, HiResCAM, Shap5000, and Occlusion maps
        gradcam = np.load(os.path.join(input_path, "cam_array", m, "GradCAM", img_names[n] + '.npy'))
        gradcampp = np.load(os.path.join(input_path, "cam_array", m, "GradCAMPlusPlus", img_names[n] + '.npy'))
        hirescam = np.load(os.path.join(input_path, "cam_array", m, "HiResCAM", img_names[n] + '.npy'))

        shap5000 = np.squeeze(np.load(os.path.join(input_path, "Shap", f"{m}-5000", "shap_values", img_names[n] + '.npy'))[:, :, :, :, 0][0])

        # Determine Occlusion map based on prediction
        occlusion_file = os.path.join(input_path, "Occlusion_res", m, f"{img_names[n]}.npy" if pred == 1 else f"{img_names[n]}-2.npy")
        occlusion = np.load(occlusion_file)

        # Determine explainability components
        explainability = [discretize(gradcam, size), discretize(gradcampp, size), discretize(hirescam, size), discretize(shap5000, size)]
        if np.any(occlusion >= 1):
            explainability.append(discretize(occlusion, size))

        # Call ensemble algorithm to combine the explainability maps
        res = ensemble_algorithm(explainability)  # Replace with your ensemble function

        # Generate heatmap from ensemble results
        heatmap = cv2.applyColorMap(np.uint8(255 * res), cv2.COLORMAP_PARULA)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        heatmap = np.float32(heatmap) / 255

        # Blend heatmap with original image
        final_img = cv2.addWeighted(rgb_img, 0.5, heatmap, 0.5, 0)

        # Prepare output directory structure
        subdir = id2name[labels[n]]
        final_output_path = os.path.join(output_path, m, subdir)

        # Create necessary directories if they don't exist
        os.makedirs(final_output_path, exist_ok=True)

        # Save the final image
        final_img = cv2.cvtColor(final_img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(f'{final_output_path}/{img_names[n]}', final_img)
