# ROC and AUPRC Curve Plots for Model Performance

This notebook generates ROC and AUPRC curve plots for each task-dataset combination, comparing the performance of all models on the same plot.

In [274]:
# Import Required Libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from scipy.special import expit  # Sigmoid function for logits

In [275]:
# Define file path and model types
predictions_path = '/Users/sophiaehlers/Documents/pulse/output/predictions'
model_types = {
    'RandomForest': 'probability',
    'LightGBM': 'probability',
    'XGBoost': 'probability',
    'CNN': 'logit',
    'InceptionTime': 'logit',
    'LSTM': 'logit',
    'GRU': 'logit'
}

In [276]:
# Load and process predictions
def load_predictions(file_path, model_type):
    df = pd.read_csv(file_path)
    if model_type == 'logit':
        df['predictions'] = expit(df['predictions'])  # Apply sigmoid to logits
    return df['predictions'], df['labels']

# Collect all task-dataset combinations
files = os.listdir(predictions_path)
combinations = {}
for file in files:
    parts = file.split('_')
    model, task, dataset = parts[0], parts[1], parts[2]
    key = (task, dataset)
    if key not in combinations:
        combinations[key] = []
    combinations[key].append((model, os.path.join(predictions_path, file)))

In [277]:
# Map dataset and task names
dataset_name_mapping = {"hirid": "HiRID", "miiv": "MIMIC-IV", "eicu": "eICU"}
task_name_mapping = {
    "mortality": "Mortality",
    "aki": "AKI",
    "sepsis": "Sepsis",
}

# Define color mapping for models
model_color_mapping = {
    'RandomForest': '#0073e6',
    'LightGBM': '#003d99',
    'XGBoost': '#00bfff',
    'CNN': '#a6d854',
    'InceptionTime': '#32CD32',
    'LSTM': '#228B22',
    'GRU': '#006400',
}

# Define the order of tasks and datasets
task_order = ["mortality", "aki", "sepsis"]
dataset_order = ["hirid", "miiv", "eicu"]

In [278]:
# Function to create plots (ROC or PRC)
def create_plot(ax, task, dataset, model_files, curve_type):
    for model, file_path in model_files:
        model_type = model_types[model]
        predictions, labels = load_predictions(file_path, model_type)
        if curve_type == 'roc':
            x, y, _ = roc_curve(labels, predictions)
            auc_value = auc(x, y)
            xlabel, ylabel = 'False Positive Rate', 'True Positive Rate'
        elif curve_type == 'prc':
            y, x, _ = precision_recall_curve(labels, predictions)
            auc_value = auc(x, y)
            xlabel, ylabel = 'Recall', 'Precision'
        color = model_color_mapping.get(model, '#000000')
        ax.plot(x, y, label=f'{model} (AUC = {auc_value:.2f})', color=color)
    
    # Add random guessing line
    if curve_type == 'roc':
        ax.plot([0, 1], [0, 1], 'r--', label='Random Guessing')
    elif curve_type == 'prc':
        ax.plot([0, 1], [0.5, 0.5], 'r--', label='Random Guessing')
    ax.set_xlabel(xlabel, fontsize=14)
    ax.set_ylabel(ylabel, fontsize=14)
    ax.set_title(f'{task_name_mapping.get(task, task)} - {dataset_name_mapping.get(dataset, dataset)}', fontsize=18, fontweight='bold', pad=10)
    ax.grid(True)
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    # Adjust legend ordering
    handles, labels = ax.get_legend_handles_labels()
    sorted_handles_labels = sorted(zip(handles, labels), key=lambda x: (['Random Guessing'] + list(model_color_mapping.keys())).index(x[1].split(' (')[0]))
    sorted_handles, sorted_labels = zip(*sorted_handles_labels)
    ax.legend(sorted_handles, sorted_labels, loc='lower right', fontsize=12)

In [279]:
# Function to generate and save plots
def generate_plots(curve_type, output_suffix):
    # Generate and save individual plots
    for (task, dataset), model_files in combinations.items():
        fig, ax = plt.subplots(figsize=(10, 8))
        create_plot(ax, task, dataset, model_files, curve_type)
        plt.savefig(os.path.join(output_dir, f'{task}_{dataset}_{output_suffix}.png'))
        plt.close()

    # Create a combined figure with all individual plots
    fig, axes = plt.subplots(len(task_order), len(dataset_order), figsize=(20, 20))
    fig.suptitle(f'Combined {curve_type.upper()} Curves for All Tasks and Datasets', fontsize=24, y=0.96)
    fig.subplots_adjust(hspace=0.4, wspace=0.4, top=0.95)
    for i, task in enumerate(task_order):
        for j, dataset in enumerate(dataset_order):
            ax = axes[i, j]
            key = (task, dataset)
            if key in combinations:
                model_files = combinations[key]
                create_plot(ax, task, dataset, model_files, curve_type)
    for i in range(len(task_order) * len(dataset_order), len(axes.flatten())):
        axes.flatten()[i].axis('off')
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(os.path.join(output_dir, f'combined_{output_suffix}.png'))
    plt.close()

# Generate both AUROC and AUPRC plots
output_dir = '/Users/sophiaehlers/Documents/pulse/visualizations/roc_prc_curves'
os.makedirs(output_dir, exist_ok=True)
generate_plots('roc', 'roc_curve')
generate_plots('prc', 'prc_curve')