In [None]:
# System and Utility Libraries
import os
import sys
import random
import glob
import json
import copy
import warnings
from collections import Counter

# Scientific and Data Processing Libraries
import numpy as np
import pandas as pd
import h5py
from scipy import stats

# Visualization Libraries
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import to_rgba, LinearSegmentedColormap
import seaborn as sns
%matplotlib inline

# Machine Learning and Deep Learning Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms, models
from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights
import timm

# Metrics and Evaluation Libraries
from sklearn.metrics import (
    roc_auc_score, roc_curve, recall_score, f1_score, 
    precision_recall_curve, average_precision_score, 
    precision_score, cohen_kappa_score, confusion_matrix, 
    accuracy_score, classification_report, auc
)
from sklearn.preprocessing import label_binarize, LabelBinarizer, LabelEncoder

# Configuration Libraries
from omegaconf import OmegaConf, DictConfig

# Progress Bar
from tqdm import tqdm

# Filter Warnings
warnings.filterwarnings("ignore")

### Config and data loading

In [None]:
preproc_conf = OmegaConf.load("../conf/preproc.yaml")
preproc_conf = preproc_conf['classic_mil_on_embeddings_bag']['bracs_224_224_patches']

In [None]:
if torch.cuda.is_available():
    DEVICE = 'cuda:0' 
else:
    DEVICE = 'cpu'
print("Device:", DEVICE)

In [None]:
parent_folder = preproc_conf.img_dir_lvl4

In [None]:
%%time
X_val = np.load( parent_folder+'bracs_level4_regions_224_test_data_macenkonorm_bracs.npy')
y_val = np.load( parent_folder+'bracs_level4_regions_224_test_label.npy')

#Binary encode
lb = LabelEncoder()
lb.fit(y_val)
y_val_oh = lb.transform(y_val)
X_val.shape, y_val.shape, y_val_oh.shape

In [None]:
lb.classes_

In [None]:
# Original array
lb_classes = lb.classes_

# Desired abbreviations
desired_order = ['N', 'PB', 'UDH', 'ADH', 'FEA', 'DCIS', 'IC']

# Dictionary mapping full class names to abbreviations
class_mapping = {
    'N': 'NORMAL',
    'PB': 'PATHOLOGICAL-BENIGN',
    'UDH': 'UDH',
    'ADH': 'ADH',
    'FEA': 'FEA',
    'DCIS': 'DCIS',
    'IC': 'INVASIVE-CARCINOMA'
}

# Reverse mapping to go from full name to abbreviation
reverse_mapping = {v: k for k, v in class_mapping.items()}

# Convert lb_classes to abbreviations using the reverse mapping
lb_classes_abbr = np.array([reverse_mapping[cls] for cls in lb_classes])

# Get the indices that would sort the array based on desired order
sorted_indices = np.argsort([desired_order.index(abbr) for abbr in lb_classes_abbr])

# Display the sorted indices
sorted_indices, lb_classes_abbr[sorted_indices]

### Model

In [None]:
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
device = 'cuda:0'

In [None]:
model.fc

In [None]:
model.fc = nn.Linear(in_features=2048, out_features=7)
model = model.to(DEVICE)

In [None]:
def default_transforms(mean = (0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    t = transforms.Compose(
                        [transforms.ToTensor(),
                         transforms.Normalize(mean = mean, std = std)])
    return t

In [None]:
weight_to_use = preproc_conf.weights_dir+'weights_train_resnet50_smallLRbackbone_simplehead_level4_macenko_bracs_50epochs/checkpoint_epoch_33_1.4022CE_49_acc.pth'

loaded_model = torch.load( weight_to_use, map_location=torch.device('cuda'))
model.load_state_dict(loaded_model['model'])
model.eval();

In [None]:
# Create a torch tensor for the images
imgs_tensor = torch.zeros((X_val.shape[0], X_val.shape[3], X_val.shape[2], X_val.shape[1]), dtype=torch.float32)

# Apply transforms to images
for i in range(X_val.shape[0]):
    imgs_tensor[i] = default_transforms()(X_val[i])

In [None]:
bracs_class_pred_val_all = []

# Create a data loader for the images
data_loader = DataLoader( imgs_tensor, batch_size=128, shuffle=False, num_workers=1)

# Process the images using the feature extractor
for img in data_loader:
    img = img.to(DEVICE)

    with torch.no_grad():
        bracs_class_pred_val = torch.nn.functional.softmax( model(img), dim=1 ) # softmax applied here !

    bracs_class_pred_val_all.append(bracs_class_pred_val.cpu().numpy())
bracs_class_pred_val_all =  np.concatenate( bracs_class_pred_val_all )

In [None]:
bracs_class_pred_val_all.shape

In [None]:
print(classification_report( y_val_oh, np.argmax( bracs_class_pred_val_all, axis=1) ))

In [None]:
np.save('Predictions/Predictions_weights_train_resnet50_simplehead_level4_macenkonorm_bracs_50epochs_checkpoint_epoch_33_1.4022CE_49_acc.npy', bracs_class_pred_val_all)

### Rename variables

In [None]:
biopsy_df_local_test = y_val_oh
final_pred_ensemble_local_test = bracs_class_pred_val_all

y_true = biopsy_df_local_test
predicted_labels = np.argmax( final_pred_ensemble_local_test, axis=1 )

In [None]:
biopsy_df_local_test.shape, y_true.shape, predicted_labels.shape

In [None]:
roc_auc = roc_auc_score(y_true, final_pred_ensemble_local_test, multi_class="ovr", average="macro")
print("ROC macro:", roc_auc)

roc_auc = roc_auc_score(y_true, final_pred_ensemble_local_test, multi_class="ovr", average="weighted")
print("ROC weighted:", roc_auc)

f1 = f1_score(y_true, predicted_labels, average='macro')
print("F1 macro:", f1)

f1 = f1_score(y_true, predicted_labels, average='weighted')
print("F1 weighted:", f1)

accuracy = accuracy_score(y_true, predicted_labels)
print("Accuracy: ", accuracy)

### metrics on test - bootstrap

In [None]:
def generate_bootstrap_samples(y_true, y_pred, num_bootstrap_samples=10000, random_seed=42):
    if random_seed is not None:
        np.random.seed(random_seed)
    
    n = len(y_true)
    bootstrap_samples = []

    for _ in tqdm(range(num_bootstrap_samples)):
        while True:
            # Create bootstrap sample
            indices = np.random.randint(0, n, n)
            sample_y_true = y_true[indices]
            sample_y_pred = y_pred[indices]
            
            # Check if all labels are present for macro/weighted metrics
            if len(np.unique(sample_y_true)) == len(np.unique(y_true)):
                # Check if valid one-vs-rest samples exist for each class
                valid_sample = all(
                    np.sum(sample_y_true == class_index) > 0 and np.sum(sample_y_true != class_index) > 0
                    for class_index in range(7)
                )
                if valid_sample:
                    bootstrap_samples.append((sample_y_true, sample_y_pred))
                    break  # Exit the loop if all conditions are met

    return bootstrap_samples

In [None]:
bootstrap_samples = generate_bootstrap_samples(
    y_true, 
    final_pred_ensemble_local_test, 
    num_bootstrap_samples=10000, 
    random_seed=42
)

In [None]:
len(bootstrap_samples), bootstrap_samples[0][0].shape, bootstrap_samples[0][1].shape

In [None]:
def bootstrap_metric_ci(bootstrap_samples, y_true, y_pred, confidence_level=0.95, **kwargs):
    n = len(bootstrap_samples)
    
    bootstrap_metrics = {
        "accuracy": [],
        "f1_macro": [],
        "f1_weighted": [],
        "roc_auc_macro": [],
        "roc_auc_weighted": [],
        "classification_report_calc": [],
        "pr_auc_macro": [],
        "pr_auc_weighted": [],
        "classwise_aucs": {i: [] for i in range(7)},
        "classwise_pr_aucs": {i: [] for i in range(7)}
    }
    
    for sample_y_true, sample_y_pred in tqdm(bootstrap_samples):
        # Calculate the macro and weighted ROC AUC
        y_true_binarized = label_binarize(sample_y_true, classes=[0, 1, 2, 3, 4, 5, 6])
        roc_auc_macro = roc_auc_score(y_true_binarized, sample_y_pred, multi_class='ovr', average='macro')
        roc_auc_weighted = roc_auc_score(y_true_binarized, sample_y_pred, multi_class='ovr', average='weighted')
        classification_report_calc = classification_report( sample_y_true, np.argmax( sample_y_pred, axis=1), output_dict=True )
        
        bootstrap_metrics["roc_auc_macro"].append(roc_auc_macro)
        bootstrap_metrics["roc_auc_weighted"].append(roc_auc_weighted)
        bootstrap_metrics['classification_report_calc'].append(classification_report_calc)

        # Calculate the macro and weighted Precision-Recall AUC
        pr_auc_macro = average_precision_score(y_true_binarized, sample_y_pred, average='macro')
        pr_auc_weighted = average_precision_score(y_true_binarized, sample_y_pred, average='weighted')
        bootstrap_metrics["pr_auc_macro"].append(pr_auc_macro)
        bootstrap_metrics["pr_auc_weighted"].append(pr_auc_weighted)

        # Convert probabilities to discrete class labels for F1-score and accuracy calculation
        sample_y_pred_labels = np.argmax(sample_y_pred, axis=1)
        
        f1_macro = f1_score(sample_y_true, sample_y_pred_labels, average='macro')
        f1_weighted = f1_score(sample_y_true, sample_y_pred_labels, average='weighted')
        accuracy = accuracy_score(sample_y_true, sample_y_pred_labels)
        bootstrap_metrics["f1_macro"].append(f1_macro)
        bootstrap_metrics["f1_weighted"].append(f1_weighted)
        bootstrap_metrics["accuracy"].append(accuracy)
        
        # Calculate one-vs-rest AUC for each class (ROC AUC)
        for class_index in range(7):
            binarized_y_true = (sample_y_true == class_index).astype(int)
            fpr, tpr, _ = roc_curve(binarized_y_true, sample_y_pred[:, class_index])
            roc_auc = auc(fpr, tpr)
            bootstrap_metrics["classwise_aucs"][class_index].append(roc_auc)

            # Calculate one-vs-rest AUC for each class (Precision-Recall AUC)
            precision, recall, _ = precision_recall_curve(binarized_y_true, sample_y_pred[:, class_index])
            pr_auc = auc(recall, precision)
            bootstrap_metrics["classwise_pr_aucs"][class_index].append(pr_auc)

    # Compute single metric values on the full dataset
    y_true_binarized = label_binarize(y_true, classes=[0, 1, 2, 3, 4, 5, 6])
    roc_auc_macro = roc_auc_score(y_true_binarized, y_pred, multi_class='ovr', average='macro')
    roc_auc_weighted = roc_auc_score(y_true_binarized, y_pred, multi_class='ovr', average='weighted')
    pr_auc_macro = average_precision_score(y_true_binarized, y_pred, average='macro')
    pr_auc_weighted = average_precision_score(y_true_binarized, y_pred, average='weighted')
    y_pred_labels = np.argmax(y_pred, axis=1)    
    
    f1_macro = f1_score(y_true, y_pred_labels, average='macro')
    f1_weighted = f1_score(y_true, y_pred_labels, average='weighted')
    accuracy = accuracy_score(y_true, y_pred_labels)

    single_metric_values = {
        "accuracy": accuracy,
        "f1_macro": f1_macro,
        "f1_weighted": f1_weighted,
        "roc_auc_macro": roc_auc_macro,
        "roc_auc_weighted": roc_auc_weighted,
        "pr_auc_macro": pr_auc_macro,
        "pr_auc_weighted": pr_auc_weighted,
        "classwise_aucs": {},
        "classwise_pr_aucs": {}
    }

    for class_index in range(7):
        binarized_y_true = (y_true == class_index).astype(int)
        fpr, tpr, _ = roc_curve(binarized_y_true, y_pred[:, class_index])
        roc_auc = auc(fpr, tpr)
        single_metric_values["classwise_aucs"][class_index] = roc_auc

        precision, recall, _ = precision_recall_curve(binarized_y_true, y_pred[:, class_index])
        pr_auc = auc(recall, precision)
        single_metric_values["classwise_pr_aucs"][class_index] = pr_auc

    # Compute confidence intervals
    metric_cis = {}
    for metric_name in ["accuracy", "f1_macro", "f1_weighted", "roc_auc_macro", "roc_auc_weighted", "pr_auc_macro", "pr_auc_weighted"]:
        lower_bound = np.percentile(bootstrap_metrics[metric_name], (1 - confidence_level) / 2 * 100)
        upper_bound = np.percentile(bootstrap_metrics[metric_name], (1 + confidence_level) / 2 * 100)
        metric_cis[metric_name] = (single_metric_values[metric_name], lower_bound, upper_bound)

    classwise_roc_cis = {}
    classwise_pr_cis = {}
    for class_index in range(7):
        lower_bound_roc = np.percentile(bootstrap_metrics["classwise_aucs"][class_index], (1 - confidence_level) / 2 * 100)
        upper_bound_roc = np.percentile(bootstrap_metrics["classwise_aucs"][class_index], (1 + confidence_level) / 2 * 100)
        classwise_roc_cis[f"Class {class_index} AUC"] = (
            single_metric_values["classwise_aucs"][class_index], lower_bound_roc, upper_bound_roc
        )

        lower_bound_pr = np.percentile(bootstrap_metrics["classwise_pr_aucs"][class_index], (1 - confidence_level) / 2 * 100)
        upper_bound_pr = np.percentile(bootstrap_metrics["classwise_pr_aucs"][class_index], (1 + confidence_level) / 2 * 100)
        classwise_pr_cis[f"Class {class_index} PR AUC"] = (
            single_metric_values["classwise_pr_aucs"][class_index], lower_bound_pr, upper_bound_pr
        )

    return metric_cis, classwise_roc_cis, classwise_pr_cis, bootstrap_metrics

In [None]:
metric_cis, classwise_roc_cis, classwise_pr_cis, bootstrap_metrics = bootstrap_metric_ci(
    bootstrap_samples, 
    y_true, 
    final_pred_ensemble_local_test
)

# Rename class enumeration to labels
classwise_roc_cis = {f'{lb.classes_[i]} AUC': classwise_roc_cis[f'Class {i} AUC'] for i in range(len(lb.classes_))}
classwise_pr_cis = {f'{lb.classes_[i]} AUC': classwise_pr_cis[f'Class {i} PR AUC'] for i in range(len(lb.classes_))}

# Display single metric values and bootstrap results for the internal test set
print("Internal test set (local test) - Single Metric Values and Bootstrap CIs:")
for metric_name, (value, lower, upper) in metric_cis.items():
    print(f"{metric_name:<40}: {round(value, 3)}, 95% CI = [{round(lower, 3)}, {round(upper, 3)}]")

print("\nOne-vs-Rest Class-wise AUCs and Bootstrap CIs (ROC):")
for class_name, (value, lower, upper) in classwise_roc_cis.items():
    print(f"{class_name:<40}: {round(value, 3)}, 95% CI = [{round(lower, 3)}, {round(upper, 3)}]")

print("\nOne-vs-Rest Class-wise AUCs and Bootstrap CIs (PR):")
for class_name, (value, lower, upper) in classwise_pr_cis.items():
    print(f"{class_name:<40}: {round(value, 3)}, 95% CI = [{round(lower, 3)}, {round(upper, 3)}]")

### Parse sklearn classification report

In [None]:
def compute_ci_for_classification_report(bootstrap_metrics, lb_classes, confidence_level=0.95):
    """
    Compute confidence intervals and mean values for precision, recall, and f1-scores across all classes.
    
    Parameters:
    bootstrap_metrics: list of dicts
        List of classification report dictionaries generated from multiple bootstrap iterations.
    lb_classes: list or array
        The list or array containing the class names corresponding to the class indices.
    confidence_level: float
        Confidence level for the intervals (default is 0.95 for 95% CI).
        
    Returns:
    ci_results: dict
        Dictionary containing mean values and confidence intervals for each class and overall metrics.
    """
    # Initialize dictionary to store metrics for all classes and overall metrics
    num_classes = len(lb_classes)
    all_metrics = {
        "precision": {lb_classes[i]: [] for i in range(num_classes)},  # Use class names instead of indices
        "recall": {lb_classes[i]: [] for i in range(num_classes)},
        "f1_score": {lb_classes[i]: [] for i in range(num_classes)},
        "precision_macro": [],
        "recall_macro": [],
        "f1_macro": [],
        "precision_weighted": [],
        "recall_weighted": [],
        "f1_weighted": [],
        "accuracy": []
    }

    # Parse each classification report from the bootstrap iterations
    for report in bootstrap_metrics['classification_report_calc']:
        # For each class, map by lb_classes
        for class_id in range(num_classes):
            class_str = lb_classes[class_id]  # Get class name
            all_metrics["precision"][class_str].append(report[str(class_id)]["precision"])
            all_metrics["recall"][class_str].append(report[str(class_id)]["recall"])
            all_metrics["f1_score"][class_str].append(report[str(class_id)]["f1-score"])
        
        # Add macro and weighted average values
        all_metrics["precision_macro"].append(report["macro avg"]["precision"])
        all_metrics["recall_macro"].append(report["macro avg"]["recall"])
        all_metrics["f1_macro"].append(report["macro avg"]["f1-score"])
        all_metrics["precision_weighted"].append(report["weighted avg"]["precision"])
        all_metrics["recall_weighted"].append(report["weighted avg"]["recall"])
        all_metrics["f1_weighted"].append(report["weighted avg"]["f1-score"])
        all_metrics["accuracy"].append(report["accuracy"])

    # Function to calculate mean and confidence intervals for a list of values
    def calculate_mean_and_ci(data, confidence_level):
        mean_value = np.mean(data)
        lower_bound = np.percentile(data, (1 - confidence_level) / 2 * 100)
        upper_bound = np.percentile(data, (1 + confidence_level) / 2 * 100)
        return (mean_value, lower_bound, upper_bound)

    # Initialize a dictionary to store the CI results
    ci_results = {}

    # Compute CI and mean for each class (precision, recall, f1-score)
    for class_id in range(num_classes):
        class_str = lb_classes[class_id]  # Get class name
        ci_results[class_str] = {
            "precision": calculate_mean_and_ci(all_metrics["precision"][class_str], confidence_level),
            "recall": calculate_mean_and_ci(all_metrics["recall"][class_str], confidence_level),
            "f1_score": calculate_mean_and_ci(all_metrics["f1_score"][class_str], confidence_level)
        }

    # Compute CI and mean for overall metrics (macro avg, weighted avg, and accuracy)
    ci_results["macro_avg"] = {
        "precision": calculate_mean_and_ci(all_metrics["precision_macro"], confidence_level),
        "recall": calculate_mean_and_ci(all_metrics["recall_macro"], confidence_level),
        "f1_score": calculate_mean_and_ci(all_metrics["f1_macro"], confidence_level)
    }
    ci_results["weighted_avg"] = {
        "precision": calculate_mean_and_ci(all_metrics["precision_weighted"], confidence_level),
        "recall": calculate_mean_and_ci(all_metrics["recall_weighted"], confidence_level),
        "f1_score": calculate_mean_and_ci(all_metrics["f1_weighted"], confidence_level)
    }
    ci_results["accuracy"] = calculate_mean_and_ci(all_metrics["accuracy"], confidence_level)

    return ci_results

def print_ci_results_with_original(ci_results, original_report, lb_classes):
    """
    Print the CI results for precision, recall, and f1-score, replacing the mean values with
    the ones from the original classification report. It aligns class names between the original report and ci_results.
    
    Parameters:
    ci_results: dict
        Dictionary containing confidence intervals for each class and overall metrics.
    original_report: dict
        Original classification report (non-bootstrapped), used to replace the mean values in the output.
    lb_classes: list or array
        List of class names to map the original report classes to the same as ci_results.
    """
    # Align class names in the original report to match ci_results' format
    mapped_original_report = {}
    for idx, class_name in enumerate(lb_classes):
        mapped_original_report[class_name] = original_report[str(idx)]

    # Add macro_avg and weighted_avg back if they are present in original report
    mapped_original_report["macro_avg"] = original_report["macro avg"]
    mapped_original_report["weighted_avg"] = original_report["weighted avg"]
    mapped_original_report["accuracy"] = original_report["accuracy"]

    print("One-vs-Rest Class-wise Metrics and Bootstrap CIs:")

    # Loop through the classes in ci_results and print the metrics
    for class_name, metrics in ci_results.items():
        if class_name in ["macro_avg", "weighted_avg", "accuracy"]:
            continue  # Skipping overall averages for this printout

        print(f"\n{class_name.upper()} Metrics:")
        for metric_name, (mean_value, lower_bound, upper_bound) in metrics.items():
            # Use 'f1-score' instead of 'f1_score' for the original report
            metric_name_in_original = metric_name if metric_name != 'f1_score' else 'f1-score'
            # Replace the mean_value with the value from the mapped original classification report
            original_mean_value = mapped_original_report[class_name][metric_name_in_original]
            print(f"  {metric_name.capitalize().replace('_', '-')} : {original_mean_value:.3f}, 95% CI = [{lower_bound:.3f}, {upper_bound:.3f}]")

    # Print the macro and weighted averages separately
    print("\nOverall Metrics:")
    for avg_type in ["macro_avg", "weighted_avg"]:
        print(f"\n{avg_type.replace('_', ' ').capitalize()}:")
        for metric_name, (mean_value, lower_bound, upper_bound) in ci_results[avg_type].items():
            metric_name_in_original = metric_name if metric_name != 'f1_score' else 'f1-score'
            # Replace the mean_value with the value from the original classification report
            original_mean_value = mapped_original_report[avg_type][metric_name_in_original]
            print(f"  {metric_name.capitalize().replace('_', '-')} : {original_mean_value:.3f}, 95% CI = [{lower_bound:.3f}, {upper_bound:.3f}]")

In [None]:
print(classification_report( y_val_oh, np.argmax( bracs_class_pred_val_all, axis=1), digits=3 ))

In [None]:
class_reports_ci_results = compute_ci_for_classification_report(bootstrap_metrics, lb.classes_, confidence_level=0.95)
print_ci_results_with_original(class_reports_ci_results, classification_report( 
    y_val_oh, 
    np.argmax( bracs_class_pred_val_all, axis=1), digits=3, output_dict=True ), lb.classes_ )

### ROC curve with CI - simple interp

In [None]:
def calculate_bootstrap_roc_curve(bootstrap_samples, class_index):
    base_fpr = np.linspace(0, 1, 101)
    tprs = []
    aucs = []

    for sample_y_true, sample_y_pred in tqdm(bootstrap_samples):
        # Binarize the true labels for one-vs-rest classification
        binarized_y_true = (sample_y_true == class_index).astype(int)
        
        # Calculate ROC curve
        fpr, tpr, _ = roc_curve(binarized_y_true, sample_y_pred[:, class_index])
        roc_auc = auc(fpr, tpr)
        aucs.append(roc_auc)

        # Interpolate the TPR values to get consistent x-axis (FPR) values
        tpr = np.interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

    tprs = np.array(tprs)
    mean_tprs = tprs.mean(axis=0)
    std_tprs = tprs.std(axis=0)
    
    tpr_lower = np.percentile(tprs, 2.5, axis=0)
    tpr_upper = np.percentile(tprs, 97.5, axis=0)

    # Calculate 95% CI for the AUC values
    auc_lower = np.percentile(aucs, 2.5)
    auc_upper = np.percentile(aucs, 97.5)
    
    return base_fpr, tpr_lower, tpr_upper, mean_tprs, auc_lower, auc_upper

In [None]:
def plot_roc_with_saved_bootstrap(bootstrap_samples, y_true, y_pred):
    if y_pred.shape[1] != 7:
        raise ValueError("The number of classes should be 7")

    if y_pred.shape[0] != y_true.shape[0]:
        raise ValueError("Mismatched shape between y_true and y_pred")

    # One-hot encode y_true if necessary
    if len(y_true.shape) == 1 or y_true.shape[1] != 7:
        y_true = F.one_hot(torch.from_numpy(y_true).to(torch.int64), 7).numpy()

    fig, axs = plt.subplots(1, 7, figsize=(42, 6), dpi=150)

    # Reorder indices according to desired abbreviations
    lb_classes_abbr = [reverse_mapping[cls] for cls in lb_classes]
    sorted_indices = np.argsort([desired_order.index(abbr) for abbr in lb_classes_abbr])

    for i, class_ind in enumerate(sorted_indices):
        # Calculate the original ROC curve and AUC for the given class
        fpr, tpr, _ = roc_curve(y_true[:, class_ind], y_pred[:, class_ind])
        roc_auc = auc(fpr, tpr)

        # Calculate bootstrap-based ROC curves and CIs using saved samples
        base_fpr, tpr_lower, tpr_upper, mean_tprs, auc_lower, auc_upper = calculate_bootstrap_roc_curve(bootstrap_samples, class_ind)

        axs[i].plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.3f})', color='blue', lw=2)
        if len(tpr_lower) > 0 and len(tpr_upper) > 0:  # Check if CI is computed
            axs[i].fill_between(base_fpr, tpr_lower, tpr_upper, color='lightblue', alpha=0.2, label=f'95% CI [{auc_lower:.3f}, {auc_upper:.3f}]')
        else:
            print("CI NOT COMPUTED")
        
        # Plot no skill line
        axs[i].plot([0, 1], [0, 1], linestyle='dashed', lw=2, color='black', label="No skill classifier")        
        
        axs[i].axis("square")
        axs[i].set_xlim([-0.05, 1.05])
        axs[i].set_ylim([-0.05, 1.05])
        
        axs[i].set_xlabel('False Positive Rate', fontsize=22)
        axs[i].set_ylabel('True Positive Rate', fontsize=22)

        # Set the title according to the reordered stage names
        stage_names = lb_classes_abbr  # use the original class names
        axs[i].set_title(f'{stage_names[class_ind]}', fontsize=22, pad=20)
        axs[i].legend(loc='lower right', fontsize=16)
        
        # Set ticks for both axes
        axs[i].set_xticks([0.0, 0.25, 0.5, 0.75, 1.0])
        axs[i].set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
        axs[i].tick_params(axis='both', which='major', labelsize=22)
        axs[i].xaxis.label.set_size(22)
        axs[i].yaxis.label.set_size(22)
        
    plt.tight_layout()
    plt.savefig('paper_figures/roc_curve_with_ci_bracs_test_resnet50_simplehead.png')
    plt.savefig('paper_figures/roc_curve_with_ci_bracs_test_resnet50_simplehead.svg')
    plt.show()

In [None]:
plot_roc_with_saved_bootstrap(bootstrap_samples, biopsy_df_local_test, final_pred_ensemble_local_test)

### PR curve with CI - simple interp

In [None]:
def calculate_bootstrap_pr_curve_simple_interp(bootstrap_samples, class_index):
    base_recalls = np.linspace(0, 1, 101)
    precisions = []
    pr_aucs = []

    for sample_y_true, sample_y_pred in tqdm(bootstrap_samples):
        # Binarize the true labels for one-vs-rest classification
        binarized_y_true = (sample_y_true == class_index).astype(int)
        
        # Calculate precision-recall curve
        precision, recall, _ = precision_recall_curve(binarized_y_true, sample_y_pred[:, class_index])
        pr_auc = auc(recall, precision)
        pr_aucs.append(pr_auc)

        precision_interp = np.interp(base_recalls, recall[::-1], precision[::-1])
        precisions.append(precision_interp)

    precisions = np.array(precisions)
    mean_precisions = precisions.mean(axis=0)
    std_precisions = precisions.std(axis=0)
    
    precision_lower = np.percentile(precisions, 2.5, axis=0)
    precision_upper = np.percentile(precisions, 97.5, axis=0)

    # Calculate 95% CI for the PR AUC values
    pr_auc_lower = np.percentile(pr_aucs, 2.5)
    pr_auc_upper = np.percentile(pr_aucs, 97.5)
    
    return base_recalls, precision_lower, precision_upper, mean_precisions, pr_auc_lower, pr_auc_upper

In [None]:
def plot_pr_with_saved_bootstrap_simple_interp(bootstrap_samples, y_true, y_pred):
    if y_pred.shape[1] != 7:
        raise ValueError("The number of classes should be 7")

    if y_pred.shape[0] != y_true.shape[0]:
        raise ValueError("Mismatched shape between y_true and y_pred")

    # One-hot encode y_true if necessary
    if len(y_true.shape) == 1 or y_true.shape[1] != 7:
        y_true = np.eye(7)[y_true]

    fig, axs = plt.subplots(1, 7, figsize=(42, 6), dpi=150)

    # Reorder indices according to desired abbreviations
    lb_classes_abbr = [reverse_mapping[cls] for cls in lb_classes]
    sorted_indices = np.argsort([desired_order.index(abbr) for abbr in lb_classes_abbr])

    for i, class_ind in enumerate(sorted_indices):
        # Calculate the original precision-recall curve and AUC for the given class
        precision, recall, _ = precision_recall_curve(y_true[:, class_ind], y_pred[:, class_ind])
        pr_auc = auc(recall, precision)

        # Calculate bootstrap-based precision-recall curves and CIs using saved samples
        base_recalls, precision_lower, precision_upper, mean_precisions, pr_auc_lower, pr_auc_upper = calculate_bootstrap_pr_curve_simple_interp(bootstrap_samples, class_ind)

        axs[i].plot(recall, precision, label=f'PR curve (AP = {pr_auc:.3f})', color='green', lw=2)
        if len(precision_lower) > 0 and len(precision_upper) > 0:  # Check if CI is computed
            axs[i].fill_between(base_recalls, precision_lower, precision_upper, color='lightgreen', alpha=0.2, label=f'95% CI [{pr_auc_lower:.3f}, {pr_auc_upper:.3f}]')
        else:
            print("CI NOT COMPUTED")
            
        # Plot no skill line
        axs[i].hlines(
            y_true[:, class_ind].sum() / y_true[:, class_ind].shape[0], 
            0,
            1, 
            color="black", 
            label="No skill classifier",
            lw=2,
            linestyles='dashed',
            alpha=1.)

        axs[i].axis("square")
        axs[i].set_xlim([-0.05, 1.05])
        axs[i].set_ylim([-0.05, 1.05])
        
        axs[i].set_xlabel('Recall', fontsize=22)
        axs[i].set_ylabel('Precision', fontsize=22)

        # Set the title according to the reordered stage names
        stage_names = lb_classes_abbr  # use the original class names
        axs[i].set_title(f'{stage_names[class_ind]}', fontsize=22, pad=20)
        axs[i].legend(fontsize=16)
        
        # Set ticks for both axes
        axs[i].set_xticks([0.0, 0.25, 0.5, 0.75, 1.0])
        axs[i].set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])
        axs[i].tick_params(axis='both', which='major', labelsize=22)
        axs[i].xaxis.label.set_size(22)
        axs[i].yaxis.label.set_size(22)

    plt.tight_layout()
    plt.savefig('paper_figures/pr_curve_with_ci_bracs_test_resnet50_simplehead.png')
    plt.savefig('paper_figures/pr_curve_with_ci_bracs_test_resnet50_simplehead.svg')
    plt.show()

In [None]:
plot_pr_with_saved_bootstrap_simple_interp(bootstrap_samples, biopsy_df_local_test, final_pred_ensemble_local_test)

### Confusion matrix - default argmax

In [None]:
# calculating the confusion matrix
confusion_matrix_initial = pd.crosstab( biopsy_df_local_test, np.argmax( final_pred_ensemble_local_test, axis=1),
                                rownames=['True stages'], colnames=['Predicted stage'] )

confusion_matrix_padded = confusion_matrix_initial.reindex(index=range(0, 7), columns=range(0, 7), fill_value=0)

custom_heatmap = np.round(confusion_matrix_initial,0).astype(int).astype(str) #+ ' Â± ' + std_confusion_matrix_tuned.round(0).astype(int).astype(str)


# visualizng it on a heatmap
plt.figure(figsize=(6,6), dpi=100)
plt.title( 'Confusion matrix of MIL')
sns.heatmap(confusion_matrix_padded, annot=True, fmt='.5g', 
            xticklabels=np.array(lb.classes_), 
            yticklabels=np.array(lb.classes_), annot_kws={"size": 14}
           )
plt.tick_params(labelsize=12, rotation=25)
plt.xlabel('Predicted stage')
plt.show()

In [None]:
# Calculate the confusion matrix (example code)
confusion_matrix_initial = pd.crosstab(
    biopsy_df_local_test, 
    np.argmax(final_pred_ensemble_local_test, axis=1),
    rownames=['True stages'], colnames=['Predicted stage']
)

# Reindex the confusion matrix based on sorted indices
confusion_matrix_padded = confusion_matrix_initial.reindex(index=sorted_indices, columns=sorted_indices, fill_value=0)

# Creating a custom heatmap
custom_heatmap = np.round(confusion_matrix_padded, 0).astype(int).astype(str)

# Visualizing the heatmap with reordered x and y labels
plt.figure(figsize=(8, 8), dpi=100)
plt.title('Confusion matrix of MIL')

sns.heatmap(
    confusion_matrix_padded, 
    annot=True, 
    fmt='.5g', 
    xticklabels=desired_order, 
    yticklabels=desired_order, 
    annot_kws={"size": 14}
)

plt.tick_params(labelsize=12, rotation=25)
plt.xlabel('Predicted stage')
plt.ylabel('True stage')
plt.show()

In [None]:
# Sorting indices to reorder rows and columns of the confusion matrix
sorted_indices = np.argsort([desired_order.index(abbr) for abbr in lb_classes_abbr])

# Normalize confusion matrix (convert to numpy arrays first)
confusion_matrix_normalized = confusion_matrix_initial.values.astype('float') / confusion_matrix_initial.values.sum(axis=1, keepdims=True)

# Reindex both the normalized confusion matrix and the original one for actual values
confusion_matrix_normalized = confusion_matrix_initial.reindex(index=sorted_indices, columns=sorted_indices).values.astype('float') / confusion_matrix_initial.reindex(index=sorted_indices, columns=sorted_indices).values.sum(axis=1, keepdims=True)

# Replace NaN values with zeros (in case some rows have all zero values)
confusion_matrix_normalized = np.nan_to_num(confusion_matrix_normalized)

# Create custom annotations with percentages as main and actual values in parentheses
custom_annotations = np.empty_like(confusion_matrix_normalized, dtype=object)
for i in range(confusion_matrix_normalized.shape[0]):
    for j in range(confusion_matrix_normalized.shape[1]):
        actual_value = int(confusion_matrix_initial.reindex(index=sorted_indices, columns=sorted_indices).values[i, j])
        percentage_value = confusion_matrix_normalized[i, j]
        custom_annotations[i, j] = f"{percentage_value:.2f}\n({actual_value})"

# Plot
fig, ax = plt.subplots(figsize=(8, 8), dpi=150)

# Heatmap with both percentages and actual values
sns.heatmap(confusion_matrix_normalized, annot=custom_annotations, fmt="", cmap='viridis', ax=ax, vmin=0, vmax=1, cbar_kws={"shrink": 0.65})
ax.set_ylabel('True stages', fontsize=14)
ax.set_xlabel('Predicted stages', fontsize=14)
ax.tick_params(labelsize=14)
ax.set_box_aspect(1)

# Set custom tick labels for x and y axes to use abbreviations
ax.set_xticklabels(desired_order, rotation=0)
ax.set_yticklabels(desired_order, rotation=0)

# Add colorbar labels and adjust size
cbar = ax.collections[0].colorbar  # Access the colorbar
cbar.set_label('Normalized scale', fontsize=14)
cbar.set_ticks([0, 0.25, 0.5, 0.75, 1])
cbar.ax.tick_params(labelsize=10)  # Adjust tick size

plt.tight_layout()

# Save the figure (optional)
plt.savefig('paper_figures/confmat_bracs_test_combined_percentages_resnet50_simplehead.png')
plt.savefig('paper_figures/confmat_bracs_test_combined_percentages_resnet50_simplehead.svg', format='svg')

# Show the plot
plt.show()