In [1]:
import gc
import json
import os, h5py
import math
import multiprocessing
import numpy as np
import pandas as pd
import torch
import importlib
import logging
from itertools import combinations
from pathlib import Path
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from sklearn.utils import resample
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import multiprocessing as mp
mp.set_start_method('spawn', force=True)

# Custom modules for data handling, balancing, training, evaluation, and model architectures
import modeleval
import avg_ensemble

importlib.reload(modeleval)
importlib.reload(avg_ensemble)

# Import specific functions from custom modules to keep code clean and readable
from modeleval import (
    dh_test_model, nam_dagostino_chi2, get_baseline_hazard_at_timepoints, combined_test_model, evaluate_cif_predictions
)
from avg_ensemble import (
    load_bootstrap_predictions, generate_combinations, get_cif_for_combination, ensemble_cif_arrays, save_metrics_to_individual_json,
    get_processed_combinations, process_combination, process_combination_worker
)

import psutil
torch.cuda.empty_cache()
gc.collect()

17

In [2]:
# Define Constants and Load Datasets
RANDOM_SEED = 12345
N_SPLIT = 2
FEATURE_COLS = ['gender', 'dm', 'ht', 'sprint', 'a1c', 'po4', 'UACR_mg_g', 'Cr', 'age', 'alb', 'ca', 'hb', 'hco3']
DURATION_COL = 'date_from_sub_60'
EVENT_COL = 'endpoint'
CLUSTER_COL = 'key'
TIME_GRID = np.array([i * 365 for i in range(6)])

# Define Feature Groups
CAT_FEATURES = ['gender', 'dm', 'ht', 'sprint']
LOG_FEATURES = ['a1c', 'po4', 'UACR_mg_g', 'Cr']
STANDARD_FEATURES = ['age', 'alb', 'ca', 'hb', 'hco3']
PASSTHROUGH_FEATURES = ['key', 'date_from_sub_60', 'endpoint']

# List of CIF labels
cif_array_labels = [
    "deepsurv_ann_clustering",
    "deepsurv_ann_enn",
    "deepsurv_ann_tomek",
    "deepsurv_lstm_clustering",
    "deepsurv_lstm_NearMiss",
    "deephit_ann_clustering",
    "deephit_ann_NearMiss",
    "deephit_lstm_clustering",
    "deephit_lstm_NearMiss",
]

In [3]:
# Directory containing the bootstrap HDF5 files
directory_path = '/mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/'
output_dir = "/mnt/d/PYDataScience/g3_regress/data/results/full_ensemble"

# Load all bootstrap predictions
bootstrap_data = load_bootstrap_predictions(directory_path)
all_combinations = generate_combinations(cif_array_labels)

predictions = bootstrap_data["predictions"]
durations = bootstrap_data["durations"]
events = bootstrap_data["events"]
combo = [i for i in all_combinations if len(i) == 9]


Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstrap_iteration_1.h5...
Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstrap_iteration_10.h5...
Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstrap_iteration_100.h5...
Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstrap_iteration_101.h5...
Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstrap_iteration_102.h5...
Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstrap_iteration_103.h5...
Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstrap_iteration_104.h5...
Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstrap_iteration_105.h5...
Loading /mnt/d/PYDataScience/g3_regress/data/results/bootstrap_predictions/20241130/bootstr

In [None]:
def calculate_null_brier(durations, events, time_grid):
    """
    Calculate the null Brier scores for all competing risks in the presence of censoring.

    Args:
        durations (array-like): Array of durations (time to event or censoring).
        events (array-like): Array of event indicators (0 for censored, 1 for Event 1, 2 for Event 2, etc.).
        time_grid (array-like): Array of time points at which to calculate Brier scores.

    Returns:
        dict: A dictionary where keys are event codes (1, 2, ...) and values are tuples containing:
              - Integrated Brier score for the event.
              - Brier scores for each time point in the time grid.
    """

    # Dictionary to store results for each event
    null_brier_scores = {}

    for event_of_interest in np.unique(events):
        if event_of_interest == 0:
            # Skip the censored group
            continue

        print(f"Calculating null Brier scores for Event_{event_of_interest}...")

        # Fit Aalen-Johansen estimator
        ajf = AalenJohansenFitter()
        ajf.fit(durations, events, event_of_interest=event_of_interest)

        # Extract the CIF at specified time points
        cif = ajf.cumulative_density_
        if len(cif) == len(time_grid):
            surv_probs = 1 - cif.values
            surv_df = pd.DataFrame(np.tile(surv_probs, len(durations)), index=time_grid)
        else:
            cif_values = cif.values.squeeze()
            cif_times = cif.index.values
            interpolation_function = interp1d(
                    cif_times, cif_values, kind="linear", bounds_error=False, fill_value=(cif_values[0], cif_values[-1])
                )
            interpolated_cif = interpolation_function(time_grid)
            surv_probs = 1 - interpolated_cif
            surv_df = pd.DataFrame(np.tile(surv_probs, (len(durations), 1)).T, index=time_grid)
            
        # Evaluate using EvalSurv
        ev = EvalSurv(surv_df, durations, events == event_of_interest, censor_surv="km")

        # Integrated Brier score
        integrated_brier = ev.integrated_brier_score(time_grid)

        # Store the results
        null_brier_scores[event_of_interest] = integrated_brier

    return null_brier_scores

In [58]:
time_grid = np.array([0, 1, 2, 3, 4, 5])
null_brier_scores = []

predictions = bootstrap_data["predictions"]
durations = bootstrap_data["durations"]
events = bootstrap_data["events"]

for i, itr_predictions in enumerate(predictions.values()):
    itr_durations = durations[i]
    itr_events = events[i]
    
    # Calculate null Brier score for this bootstrap iteration
    null_brier_score = calculate_null_brier(itr_durations, itr_events, time_grid)
    null_brier_scores.append(null_brier_score)
    
    print(f"Bootstrap {i+1}: Integrated Null Brier Score = {null_brier_score}")


Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 1: Integrated Null Brier Score = {1: 0.061694604425891364, 2: 0.10291029302777473}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 2: Integrated Null Brier Score = {1: 0.07210566705267989, 2: 0.0906691929506529}
Calculating null Brier scores for Event_1...


                To resolve ties, data is randomly jittered.


Calculating null Brier scores for Event_2...


                To resolve ties, data is randomly jittered.


Bootstrap 3: Integrated Null Brier Score = {1: 0.07893109559757475, 2: 0.09340716801582283}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 4: Integrated Null Brier Score = {1: 0.05985476890747403, 2: 0.08604777724318272}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 5: Integrated Null Brier Score = {1: 0.05953871538254415, 2: 0.08898374804288758}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 6: Integrated Null Brier Score = {1: 0.0662235607115377, 2: 0.07880531947207751}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 7: Integrated Null Brier Score = {1: 0.06925440752583926, 2: 0.09344632502149115}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 8: Integrated Null Brier Score = {1: 0.058303067926930095, 2: 0.10317787634300475

                To resolve ties, data is randomly jittered.


Calculating null Brier scores for Event_2...


                To resolve ties, data is randomly jittered.


Bootstrap 105: Integrated Null Brier Score = {1: 0.08445420556096903, 2: 0.08835524153163987}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 106: Integrated Null Brier Score = {1: 0.06778723162634208, 2: 0.08745863817273374}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 107: Integrated Null Brier Score = {1: 0.06359017364349637, 2: 0.08573206442027229}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 108: Integrated Null Brier Score = {1: 0.06601062860423751, 2: 0.08964876893660115}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 109: Integrated Null Brier Score = {1: 0.060067066891537256, 2: 0.0847283153504141}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 110: Integrated Null Brier Score = {1: 0.06647429118566725, 2: 0.07948

                To resolve ties, data is randomly jittered.


Calculating null Brier scores for Event_2...


                To resolve ties, data is randomly jittered.


Bootstrap 121: Integrated Null Brier Score = {1: 0.07706042845221608, 2: 0.0999302727151227}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 122: Integrated Null Brier Score = {1: 0.07822699235788835, 2: 0.09013472462302834}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 123: Integrated Null Brier Score = {1: 0.06221369744241547, 2: 0.09406782840816205}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 124: Integrated Null Brier Score = {1: 0.058284786086219585, 2: 0.07920478411221003}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 125: Integrated Null Brier Score = {1: 0.07539813469140054, 2: 0.08111614703414118}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 126: Integrated Null Brier Score = {1: 0.060786181748507306, 2: 0.0842

                To resolve ties, data is randomly jittered.


Calculating null Brier scores for Event_2...


                To resolve ties, data is randomly jittered.


Bootstrap 246: Integrated Null Brier Score = {1: 0.0805294558776731, 2: 0.08962738824408509}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 247: Integrated Null Brier Score = {1: 0.07575143983958713, 2: 0.09226609839282082}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 248: Integrated Null Brier Score = {1: 0.07009836628334166, 2: 0.08588168423349751}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 249: Integrated Null Brier Score = {1: 0.06447197218308862, 2: 0.07773426859419047}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 250: Integrated Null Brier Score = {1: 0.06967561700599466, 2: 0.0830420445493337}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 251: Integrated Null Brier Score = {1: 0.06357018865741326, 2: 0.0856756

                To resolve ties, data is randomly jittered.


Calculating null Brier scores for Event_2...


                To resolve ties, data is randomly jittered.


Bootstrap 257: Integrated Null Brier Score = {1: 0.07895871342383268, 2: 0.08305716354398847}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 258: Integrated Null Brier Score = {1: 0.08274513260132664, 2: 0.0874767031457569}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 259: Integrated Null Brier Score = {1: 0.06287294868684275, 2: 0.09010654646769185}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 260: Integrated Null Brier Score = {1: 0.06064386090175462, 2: 0.09206337174128597}
Calculating null Brier scores for Event_1...


                To resolve ties, data is randomly jittered.


Calculating null Brier scores for Event_2...


                To resolve ties, data is randomly jittered.


Bootstrap 261: Integrated Null Brier Score = {1: 0.07726041144660391, 2: 0.08542172435577444}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 262: Integrated Null Brier Score = {1: 0.05767979784463293, 2: 0.08306705061666934}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 263: Integrated Null Brier Score = {1: 0.060354690829223845, 2: 0.0902281608517779}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 264: Integrated Null Brier Score = {1: 0.05967383665898949, 2: 0.09509685052347516}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 265: Integrated Null Brier Score = {1: 0.06239564437634243, 2: 0.09193401093435759}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 266: Integrated Null Brier Score = {1: 0.06307044627444286, 2: 0.08873

                To resolve ties, data is randomly jittered.


Calculating null Brier scores for Event_2...


                To resolve ties, data is randomly jittered.


Bootstrap 456: Integrated Null Brier Score = {1: 0.0807864409677053, 2: 0.07997279079554326}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 457: Integrated Null Brier Score = {1: 0.06479502451138211, 2: 0.08219478243406333}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 458: Integrated Null Brier Score = {1: 0.060889287538215986, 2: 0.09211112747461467}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 459: Integrated Null Brier Score = {1: 0.05870993923116006, 2: 0.09131757573597363}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 460: Integrated Null Brier Score = {1: 0.06258178103666723, 2: 0.07726042770775143}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 461: Integrated Null Brier Score = {1: 0.06855281160884322, 2: 0.08550

                To resolve ties, data is randomly jittered.


Calculating null Brier scores for Event_2...


                To resolve ties, data is randomly jittered.


Bootstrap 497: Integrated Null Brier Score = {1: 0.07039862783962784, 2: 0.0882040080568875}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 498: Integrated Null Brier Score = {1: 0.06192550916422801, 2: 0.07564096203065898}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 499: Integrated Null Brier Score = {1: 0.05665736533987973, 2: 0.10004783008623482}
Calculating null Brier scores for Event_1...
Calculating null Brier scores for Event_2...
Bootstrap 500: Integrated Null Brier Score = {1: 0.06201496183793596, 2: 0.09382599532172556}


In [59]:
null_brier_scores

[{1: 0.061694604425891364, 2: 0.10291029302777473},
 {1: 0.07210566705267989, 2: 0.0906691929506529},
 {1: 0.07893109559757475, 2: 0.09340716801582283},
 {1: 0.05985476890747403, 2: 0.08604777724318272},
 {1: 0.05953871538254415, 2: 0.08898374804288758},
 {1: 0.0662235607115377, 2: 0.07880531947207751},
 {1: 0.06925440752583926, 2: 0.09344632502149115},
 {1: 0.058303067926930095, 2: 0.10317787634300475},
 {1: 0.07753072055776787, 2: 0.0899862422467708},
 {1: 0.06444143041009415, 2: 0.08800857872645236},
 {1: 0.06025304805359072, 2: 0.09391560770601028},
 {1: 0.05388229659990389, 2: 0.0938341536049427},
 {1: 0.056584908950354264, 2: 0.08517260786758754},
 {1: 0.06810491845430788, 2: 0.08843183755727183},
 {1: 0.06254903807462206, 2: 0.08475460933239316},
 {1: 0.06605179627314803, 2: 0.0789107716631382},
 {1: 0.06648111280130942, 2: 0.0908625736255064},
 {1: 0.06631804934016197, 2: 0.09180484059470054},
 {1: 0.07049810422128762, 2: 0.09012735310122381},
 {1: 0.06175909173500417, 2: 0.102

In [9]:
[i for i in all_combinations if len(i) == 9]

[('deepsurv_ann_clustering',
  'deepsurv_ann_enn',
  'deepsurv_ann_tomek',
  'deepsurv_lstm_clustering',
  'deepsurv_lstm_NearMiss',
  'deephit_ann_clustering',
  'deephit_ann_NearMiss',
  'deephit_lstm_clustering',
  'deephit_lstm_NearMiss')]

In [2]:
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import sem, t, ttest_1samp

def load_metrics_to_dataframe(output_dir):
    """
    Load all JSON metrics from the specified directory and organize them into a Pandas DataFrame.

    Args:
        output_dir (str): Directory containing JSON files.

    Returns:
        pd.DataFrame: DataFrame containing all metrics, sorted by combinations.
    """
    metrics_list = []

    # Iterate over all JSON files in the directory
    for file_name in os.listdir(output_dir):
        if file_name.endswith(".json"):  # Only process JSON files
            file_path = os.path.join(output_dir, file_name)
            with open(file_path, "r") as infile:
                try:
                    # Load the JSON data
                    data = json.load(infile)
                    for combo, metrics in data.items():
                        metrics_list.append({
                            "Bootstrap": file_name,
                            "Combination": combo,
                            **metrics
                        })
                except json.JSONDecodeError as e:
                    print(f"Error decoding {file_path}: {e}")

    # Convert the list of metrics to a DataFrame
    metrics_df = pd.DataFrame(metrics_list)

    # Sort the DataFrame by combination
    metrics_df = metrics_df.sort_values(by="Combination").reset_index(drop=True)

    print(f"Loaded metrics from {len(metrics_df)} combinations.")
    return metrics_df

def separate_event_metrics(df):
    """
    Separate and expand metrics for Event_1 and Event_2 into distinct columns.

    Args:
        df (pd.DataFrame): Original DataFrame containing nested metrics for events.

    Returns:
        pd.DataFrame: Updated DataFrame with separate columns for Event_1 and Event_2 metrics.
    """
    # Initialize lists to store separated data
    event_1_data = []
    event_2_data = []

    for idx, row in df.iterrows():
        # Extract Nam-Dagostino results for Event_1 and Event_2
        nam_dagostino_event_1 = [item for item in row["nam_dagostino_results"] if item["Event"] == 1]
        nam_dagostino_event_2 = [item for item in row["nam_dagostino_results"] if item["Event"] == 2]

        # Extract nested metrics
        event_1_row = {
            "Bootstrap": row["Bootstrap"],
            "Combination": row["Combination"],
            "Concordance_Index": row["concordance_indices"]["Event_1"],
            "Integrated_Brier_Score": row["integrated_brier_scores"]["Event_1"],
            "Neg_Log_Likelihood": row["neg_log_likelihoods"]["Event_1"],
            "Nam_Dagostino": nam_dagostino_event_1
        }
        event_2_row = {
            "Bootstrap": row["Bootstrap"],
            "Combination": row["Combination"],
            "Concordance_Index": row["concordance_indices"]["Event_2"],
            "Integrated_Brier_Score": row["integrated_brier_scores"]["Event_2"],
            "Neg_Log_Likelihood": row["neg_log_likelihoods"]["Event_2"],
            "Nam_Dagostino": nam_dagostino_event_2
        }

        # Append separated rows
        event_1_data.append(event_1_row)
        event_2_data.append(event_2_row)

    # Create separate DataFrames for Event_1 and Event_2
    event_1_df = pd.DataFrame(event_1_data)
    event_1_df["Event"] = "Event_1"

    event_2_df = pd.DataFrame(event_2_data)
    event_2_df["Event"] = "Event_2"

    # Combine into a single DataFrame
    combined_df = pd.concat([event_1_df, event_2_df], axis=0).reset_index(drop=True)

    return combined_df

def sort_metrics_by_combination_length(df):
    """
    Sort metrics for Event_1 and Event_2 by the length of the combination (as tuples).

    Args:
        df (pd.DataFrame): Input DataFrame containing metrics with Combination as tuples.

    Returns:
        pd.DataFrame: Sorted DataFrame with an additional "Combination_Length" column.
    """
    # Ensure the Combination column contains tuples
    assert all(isinstance(combo, tuple) for combo in df["Combination"]), "Combination column must contain tuples."

    # Add a column for the length of combinations
    df["Combination_Length"] = df["Combination"].apply(len)

    # Sort by Combination_Length and Combination (optional for consistency)
    sorted_df = df.sort_values(by=["Combination_Length", "Combination"], ascending=True).reset_index(drop=True)

    return sorted_df

def calculate_statistics_per_combination_length(df):
    """
    Calculate the mean, 95% CI, and p-value for concordance index, integrated brier score, and negative log likelihood
    for Event_1 and Event_2, grouped by combination length.

    Args:
        df (pd.DataFrame): Input DataFrame containing metrics and combination lengths.

    Returns:
        pd.DataFrame: Summary statistics for each combination length and event.
    """
    results = []

    # Group by Combination_Length and Event
    grouped = df.groupby(["Combination_Length", "Event"])

    for (comb_length, event), group in grouped:
        stats = {
            "Combination_Length": comb_length,
            "Event": event,
        }

        # Calculate statistics for each metric
        for metric in ["Concordance_Index", "Integrated_Brier_Score", "Neg_Log_Likelihood"]:
            values = group[metric].values
            mean = np.mean(values)
            std_err = sem(values)  # Standard error
            ci_low, ci_high = t.interval(0.95, len(values) - 1, loc=mean, scale=std_err)  # 95% CI
            
            # Perform a one-sample t-test against 0.5 as an example (adjust as needed)
            t_stat, p_value = ttest_1samp(values, 0.5)
            
            stats[f"{metric}_Mean"] = mean
            stats[f"{metric}_CI_Low"] = ci_low
            stats[f"{metric}_CI_High"] = ci_high
            stats[f"{metric}_P_Value"] = p_value

        results.append(stats)

    return pd.DataFrame(results)

def plot_metric_vs_combination_length(df, metric_name, event_col="Event"):
    """
    Generalized function to plot a specified metric against combination length.
    
    Args:
        df (pd.DataFrame): DataFrame containing metrics for combinations of models.
        metric_name (str): The base name of the metric to plot (e.g., "Concordance_Index").
        event_col (str): The column representing the event type (default is "Event").
    
    Returns:
        None: Displays the plot.
    """
    # Ensure the metric columns exist
    metric_mean_col = f"{metric_name}_Mean"
    metric_ci_low_col = f"{metric_name}_CI_Low"
    metric_ci_high_col = f"{metric_name}_CI_High"
    
    if not all(col in df.columns for col in [metric_mean_col, metric_ci_low_col, metric_ci_high_col, event_col]):
        raise ValueError(f"One or more required columns for '{metric_name}' are missing in the DataFrame.")
    
    # Plot for each event
    unique_events = df[event_col].unique()
    plt.figure(figsize=(10, 6))
    
    for event in unique_events:
        event_data = df[df[event_col] == event]
        combination_lengths = event_data["Combination_Length"]
        metric_means = event_data[metric_mean_col]
        ci_low = event_data[metric_ci_low_col]
        ci_high = event_data[metric_ci_high_col]
        
        # Plot mean and confidence interval
        plt.plot(combination_lengths, metric_means, label=f"{metric_name} (Event {event})")
        plt.fill_between(combination_lengths, ci_low, ci_high, alpha=0.2, label=f"95% CI (Event {event})")
    
    # Customize the plot
    plt.title(f"{metric_name} vs. Number of Model Predictions Ensembled")
    plt.xlabel("Number of Model Predictions Ensembled")
    plt.ylabel(metric_name)
    plt.legend()
    plt.grid(False)
    plt.show()

def reshape_brier_scores(metrics_df):
    """
    Reshape the Brier scores into a flat structure with columns like 'event_1_year_0', 'event_2_year_1', etc.

    Args:
        metrics_df (pd.DataFrame): The input DataFrame containing metrics, including 'brier_series'.

    Returns:
        pd.DataFrame: A DataFrame where Brier scores are flattened into specific columns for each event and year.
    """
    # Ensure the Combination column contains tuples
    assert all(isinstance(combo, tuple) for combo in metrics_df["Combination"]), "Combination column must contain tuples."

    # Add a column for the length of combinations
    metrics_df["Combination_Length"] = metrics_df["Combination"].apply(len)

    reshaped_data = []

    for idx, row in metrics_df.iterrows():
        # Start with base row data
        reshaped_row = {
            "Bootstrap": row["Bootstrap"],
            "Combination": row["Combination"],
            "Combination_Length": row["Combination_Length"]
        }

        # Flatten Brier scores
        brier_series = row["brier_series"]
        for event, scores in brier_series.items():  # Iterate over Event_1 and Event_2
            for year, score in enumerate(scores):  # Iterate over years (0, 1, 2, 3, 4, 5)
                column_name = f"{event.lower()}_year_{year}"
                reshaped_row[column_name] = score

        reshaped_data.append(reshaped_row)

    return pd.DataFrame(reshaped_data)

def calculate_brier_statistics(reshaped_brier_df):
    """
    Calculate the mean, 95% CI, and p-value for Brier scores for each combination length and time point.

    Args:
        reshaped_brier_df (pd.DataFrame): DataFrame with reshaped Brier scores.

    Returns:
        pd.DataFrame: DataFrame with aggregated statistics for each combination length and time point.
    """
    aggregated_data = []

    for combination_length in reshaped_brier_df["Combination_Length"].unique():
        subset = reshaped_brier_df[reshaped_brier_df["Combination_Length"] == combination_length]

        for event in ["event_1", "event_2"]:
            for year in range(6):  # Assuming 6 years
                column_name = f"{event}_year_{year}"
                values = subset[column_name].dropna()

                if len(values) > 1:
                    mean = values.mean()
                    std_err = values.std(ddof=1) / np.sqrt(len(values))
                    ci_low, ci_high = t.interval(0.95, len(values) - 1, loc=mean, scale=std_err)
                    t_stat, p_value = ttest_1samp(values, 0.0)
                else:
                    mean, ci_low, ci_high, p_value = [np.nan] * 4

                aggregated_data.append({
                    "Combination_Length": combination_length,
                    "Event": event,
                    "Time_Point": year,
                    "Brier_Mean": mean,
                    "Brier_CI_Low": ci_low,
                    "Brier_CI_High": ci_high,
                    "P_Value": p_value
                })

    return pd.DataFrame(aggregated_data)

def plot_brier_scores(brier_stats_df):
    """
    Plot the Brier scores for different combination lengths and events.

    Args:
        brier_stats_df (pd.DataFrame): DataFrame with Brier score statistics.
    """
    for event in ["event_1", "event_2"]:
        event_data = brier_stats_df[brier_stats_df["Event"] == event]
        plt.figure(figsize=(10, 6))
        for combination_length in event_data["Combination_Length"].unique():
            subset = event_data[event_data["Combination_Length"] == combination_length]
            plt.plot(
                subset["Time_Point"],
                subset["Brier_Mean"],
                label=f"Combination Length {combination_length}"
            )
            plt.fill_between(
                subset["Time_Point"],
                subset["Brier_CI_Low"],
                subset["Brier_CI_High"],
                alpha=0.2
            )

        plt.title(f"Brier Scores for {event.replace('_', ' ').title()}")
        plt.xlabel("Time Point")
        plt.ylabel("Brier Score")
        plt.legend()
        plt.grid(True)
        plt.show()


In [None]:
output_dir = "/mnt/d/PYDataScience/g3_regress/data/results/avg_ensemble"
metrics_df = load_metrics_to_dataframe(output_dir)
metrics_df['Combination'] = metrics_df['Combination'].apply(eval)