In [None]:
import os

import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import StratifiedGroupKFold

from src.classes.ensemble import ExplainabilityEnsemble
from src.classes.models import ResNet18variant, ResNet50variant
from src.config import PATH_TO_DATASET_CSV, PATH_TO_DATASET, ID_TO_NAME, PATH_TO_MODELS, PATH_TO_OUTPUT

In [None]:
# Load dataset from CSV
df = pd.read_csv(PATH_TO_DATASET_CSV, sep=';', header=0)

# Create a dictionary mapping image indices to (image path, label)
data = {
    idx: (os.path.join(PATH_TO_DATASET, ID_TO_NAME[row['label']], str(row['img_name'])), row['label'])
    for idx, row in df.iterrows()
}

# Convert labels and groups to numpy arrays
y = df['label'].to_numpy()
groups = df['group'].to_numpy()

# Set up Stratified Group K-Fold cross-validation
sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=7)

# Convert dictionary keys to numpy array (image indices)
X = np.array(list(data.keys()))

# Generate train-test split
train_index, test_index = next(sgkf.split(X, y, groups))

# Log dataset statistics
print(f"Total dataset size: {len(data)}")
print(f"Training samples: {len(train_index)} | Test samples: {len(test_index)}")
print(f"Training - Affected: {np.count_nonzero(y[train_index] == 1)}, Healthy: {np.count_nonzero(y[train_index] == 0)}")
print(f"Test - Affected: {np.count_nonzero(y[test_index] == 1)}, Healthy: {np.count_nonzero(y[test_index] == 0)}")

# Set device (GPU if available, else CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load pre-trained models
m18 = ResNet18variant().to(device)
m18.load_state_dict(torch.load(os.path.join(PATH_TO_MODELS, "final-resnet18v.pth"), map_location=device))

m50 = ResNet50variant().to(device)
m50.load_state_dict(torch.load(os.path.join(PATH_TO_MODELS, "final-resnet50v.pth"), map_location=device))

# Define model list
models = [('Resnet18v', m18), ('Resnet50v', m50)]

# Initialize and run the ensemble process
ensemble_processor = ExplainabilityEnsemble(models, PATH_TO_DATASET, PATH_TO_OUTPUT, test_index, data, device, size=7)
ensemble_processor.run()
