In [4]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import seaborn as sns
from sklearn.metrics import roc_auc_score, roc_curve, auc, precision_recall_curve, average_precision_score, mean_squared_error, log_loss
from sklearn.calibration import calibration_curve

# Colorblind-friendly palette (Okabe-Ito)
CB_PALETTE = [
    "#000000",  # black
    "#E69F00",  # orange
    "#499BCA",  # sky blue
    "#009E73",  # bluish green
    "#F0E442",  # yellow
    "#0072B2",  # blue
    "#D55E00",  # vermillion
    "#CC79A7",  # reddish purple
]

rename = {
    "4700-0.0": 'Age Cataract Diagnosed',
    "5901-0.0": 'Age Diabetic Retinopathy Diagnosed',
    "30780-0.0": 'LDL',
    "head_injury": 'Head Injury',
    "22038-0.0": 'Min/Week Moderate Activity',
    "20161-0.0": 'Years of Smoking',
    "alcohol_consumption": 'Alcohol consumption',
    "hypertension": 'Hypertension',
    "obesity": 'Obesity',
    "diabetes": 'Diabetes',
    "hearing_loss": 'Hearing Loss',
    "depression": 'Depression',
    "freq_friends_family_visit": 'Frequency of Friends/Family Visits',
    "24012-0.0": 'Distance to Major Road',
    "24018-0.0": 'NO2 Air Pollution',
    "24019-0.0": 'PM10 Air Pollution',
    "24006-0.0": 'PM2.5 Air Pollution',
    "24015-0.0": 'Amount of Major Roads',
    "24011-0.0": 'Traffic Intensity',
    '6138-0.0': 'Education Level',
    '845-0.0': 'Years Education',
    'curr_age': 'Age',  # changed from 'Current Age' to 'Age'
    "NIFK9/BIN1 (hg38)": "BIN1 (hg38)",  # clarify this mapping as needed
}

def feature_importances_plot(path_to_experiment, ax=None, color=CB_PALETTE[1]):
    df = pd.read_csv(os.path.join(path_to_experiment, 'summary_stats/features.txt'))
    df['fnames'] = df['fnames'].replace(rename)
    df = df.sort_values(by='avg_fi', ascending=False)
    df = df.head(20)
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, len(df) * 0.3))
    ax.barh(df['fnames'], df['avg_fi'], color=color, capsize=4, xerr=df['std_fi'])
    ax.set_xlabel('Average Feature Importance', fontsize=22, fontweight='bold')
    ax.set_ylabel('Features', fontsize=22, fontweight='bold')
    ax.tick_params(axis='x', labelsize=16, length=8)
    ax.tick_params(axis='y', labelsize=16)
    ax.set_title('', fontsize=0)
    ax.invert_yaxis()
    return df[['fnames', 'avg_fi', 'std_fi']] #'std_fi

def mean_roc_curve(true_labels_list, predicted_probs_list):
    mean_fpr = np.linspace(0, 1, 100)
    tprs = []
    auc_l = []
    auc_maxfpr025_l = []
    for true_labels, predicted_probs in zip(true_labels_list, predicted_probs_list):
        rocauc = roc_auc_score(true_labels, predicted_probs)
        auc_l.append(rocauc)
        rocauc_maxfpr025 = roc_auc_score(true_labels, predicted_probs, max_fpr=0.25)
        auc_maxfpr025_l.append(rocauc_maxfpr025)
        fpr, tpr, _ = roc_curve(true_labels, predicted_probs)
        interp_tpr = np.interp(mean_fpr, fpr, tpr)
        interp_tpr[0] = 0.0
        tprs.append(interp_tpr)
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    std_tpr = np.std(tprs, axis=0)
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(auc_l)
    mean_auc_maxfpr025 = np.mean(auc_maxfpr025_l)
    std_auc_maxfpr025 = np.std(auc_maxfpr025_l)
    return mean_fpr, mean_tpr, std_tpr, mean_auc, std_auc, mean_auc_maxfpr025, std_auc_maxfpr025

def extract_true_pred(exp):
    pred = "test_labels_predictions.parquet"
    preds = []
    trues = []
    for folder in os.listdir(exp):
        folder_path = os.path.join(exp, folder)
        if os.path.isdir(folder_path):
            file_path = os.path.join(folder_path, pred)
            if os.path.isfile(file_path):
                df = pd.read_parquet(file_path, engine='fastparquet')
                predictions_df = df.y_pred
                trues_df = df.y_test
                preds.append(predictions_df)
                trues.append(trues_df)
    return preds, trues

experiments = [
    './results/LDE_only',
    './results_all/age_alone/none/allages/AD/lgbm',
    './results_all/all_demographics/apoe/allages/AD/lgbm',
    './results_all/all_demographics/LDE/allages/AD/lgbm',
    './results_all/demographics_lancet2024/none/allages/AD/lgbm',
    './results_all/demographics_lancet2024/apoe/allages/AD/lgbm',
    './results_all/demographics_lancet2024/LDE/allages/AD/lgbm'
]

name_map = {
    'LDE_only': 'Genetics alone',
    'age_alone': 'Age alone',
    'all_demographics': 'All demographics',
    'demographics_lancet2024': 'Demographics + lancet factors',
}

genes_map = {
    'apoe': 'APOE',
    'LDE': 'All SNPs',
}

def plot_roc_curve(experiments, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    for i, exp_path in enumerate(experiments):
        preds, trues = extract_true_pred(exp_path)
        (mean_fpr, mean_tpr, std_tpr, mean_auc, std_auc, _, _) = mean_roc_curve(trues, preds)
        color = CB_PALETTE[i % len(CB_PALETTE)]
        parts = exp_path.split('/')
        main_key = parts[2] if len(parts) > 2 else exp_path
        title = name_map.get(main_key, main_key)
        if len(parts) > 3 and parts[3] != 'none':
            gene_key = parts[3]
            gene_label = genes_map.get(gene_key, gene_key)
            title += f" ({gene_label})"
        label = f"{title}\nAUC: {mean_auc:.3f} ± {std_auc:.3f}"
        ax.plot(mean_fpr, mean_tpr, color=color, label=label, lw=3, alpha=0.9)
        tpr_upper = np.minimum(mean_tpr + std_tpr, 1)
        tpr_lower = np.maximum(mean_tpr - std_tpr, 0)
        ax.fill_between(mean_fpr, tpr_lower, tpr_upper, color=color, alpha=0.15)
    ax.legend(loc='lower right', fontsize=18)
    ax.set_xlabel("False Positive Rate", fontsize=22, fontweight='bold')
    ax.set_ylabel("True Positive Rate", fontsize=22, fontweight='bold')
    ax.tick_params(axis='both', labelsize=18, length=8)
    ax.set_title('', fontsize=0)
    return ax

def mean_pr_curve(true_labels_list, predicted_probs_list):
    mean_recall = np.linspace(0, 1, 100)
    precisions = []
    ap_scores = []
    for true_labels, predicted_probs in zip(true_labels_list, predicted_probs_list):
        ap = average_precision_score(true_labels, predicted_probs)
        ap_scores.append(ap)
        precision, recall, _ = precision_recall_curve(true_labels, predicted_probs)
        interp_precision = np.interp(mean_recall[::-1], recall[::-1], precision[::-1])[::-1]
        precisions.append(interp_precision)
    mean_precision = np.mean(precisions, axis=0)
    std_precision = np.std(precisions, axis=0)
    mean_ap = np.mean(ap_scores)
    std_ap = np.std(ap_scores)
    return mean_precision, std_precision, mean_recall, mean_ap, std_ap

def plot_multiple_pr_curves(experiments, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    for i, exp_path in enumerate(experiments):
        preds, trues = extract_true_pred(exp_path)
        if len(preds) == 0 or len(trues) == 0:
            continue
        mean_precision, std_precision, mean_recall, mean_ap, std_ap = mean_pr_curve(trues, preds)
        parts = exp_path.split('/')
        main_key = parts[2] if len(parts) > 2 else exp_path
        title = name_map.get(main_key, main_key)
        if len(parts) > 3 and parts[3] != 'none':
            gene_key = parts[3]
            gene_label = genes_map.get(gene_key, gene_key)
            title += f" ({gene_label})"
        color = CB_PALETTE[i % len(CB_PALETTE)]
        label = f"{title}\nAP: {mean_ap:.3f} ± {std_ap:.3f}"
        ax.plot(mean_recall, mean_precision, color=color, label=label, lw=3, alpha=0.9)
        ax.fill_between(mean_recall,
                        np.maximum(mean_precision - std_precision, 0),
                        np.minimum(mean_precision + std_precision, 1),
                        color=color, alpha=0.15)
    ax.set_xlabel('Recall', fontsize=22, fontweight='bold')
    ax.set_ylabel('Precision', fontsize=22, fontweight='bold')
    ax.tick_params(axis='both', labelsize=18, length=8)
    ax.set_xlim([-0.05, 1.0])
    ax.set_ylim([-0.05, 1.05])
    ax.legend(loc='upper right', fontsize=8)
    ax.set_title('', fontsize=0)
    ax.grid(False)
    return ax

def plot_multiple_calibration_curves(experiments, ax=None):
    """
    Plot multiple calibration curves for different experiments with log scale and enhanced metrics.
    If ax is provided, plot on that axis (for panel use); otherwise, create a new figure.
    """
    import matplotlib.pyplot as plt
    from sklearn.metrics import mean_squared_error, log_loss
    from sklearn.calibration import calibration_curve

    colors = CB_PALETTE
    created_fig = False
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
        created_fig = True

    ax.set_xlabel("Mean Predicted Probability", fontsize=22, fontweight='bold')
    ax.set_ylabel("Fraction of Positives", fontsize=22, fontweight='bold')
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.tick_params(axis='both', labelsize=18, length=8)

    pred_l = []
    true_l = []

    for i, exp_path in enumerate(experiments):
        try:
            preds, trues = extract_true_pred(exp_path)
            if len(preds) == 0 or len(trues) == 0:
                continue
            all_preds = np.concatenate(preds)
            all_trues = np.concatenate(trues)

            # Calculate calibration curve with quantile strategy
            prob_true, prob_pred = calibration_curve(all_trues, all_preds, n_bins=10, strategy="quantile")

            # Remove zero values to avoid log(0) issues
            non_zero_indices = (prob_true > 0) & (prob_pred > 0)
            prob_true = prob_true[non_zero_indices]
            prob_pred = prob_pred[non_zero_indices]

            pred_l.extend(prob_pred)
            true_l.extend(prob_true)

            # Extract experiment title
            parts = exp_path.split('/')
            main_key = parts[2] if len(parts) > 2 else exp_path
            title = name_map.get(main_key, main_key)
            if len(parts) > 3 and parts[3] != 'none':
                gene_key = parts[3]
                gene_label = genes_map.get(gene_key, gene_key)
                title += f" ({gene_label})"

            # Calculate additional metrics
            mse = mean_squared_error(all_trues, all_preds)
            ll = log_loss(all_trues, all_preds)

            color = colors[i % len(colors)]
            label = f"{title}\nMSE: {mse:.3f}, LL: {ll:.3f}"

            ax.plot(prob_pred, prob_true,
                    color=color,
                    marker='s',
                    label=label,
                    lw=2,
                    markersize=6)
        except Exception as e:
            print(f"Error processing {exp_path}: {e}")
            continue

    # Add diagonal reference line
    if pred_l and true_l:
        min_val = min(min(pred_l), min(true_l)) * 0.9
        max_val = max(max(pred_l), max(true_l)) * 1.1
        ax.plot([min_val, max_val], [min_val, max_val],
                linestyle='--', color='gray', alpha=0.8, label='Perfect Calibration')
        ax.set_xlim(min_val, max_val)
        ax.set_ylim(min_val, max_val)

    ax.legend(loc='upper left', fontsize=8, frameon=True)
    ax.grid(True, alpha=0.3)
    ax.set_title('', fontsize=0)

    if created_fig:
        plt.tight_layout()
        plt.show()
    return ax


def generate_conf_mtx(path, ax=None):
    metrics = pd.read_csv(f'{path}/summary_stats/metrics.txt', header=None).T
    metrics.columns = metrics.iloc[0]
    metrics = metrics.drop(0).reset_index(drop=True)
    for val in ['TN', 'FP', 'FN', 'TP']:
        metrics[val] = pd.to_numeric(metrics[val], errors='coerce')
    array = {val: metrics[val].mean() for val in ['TN', 'FP', 'FN', 'TP']}
    conf_matrix = np.array([[array['TP'], array['FP']],
                            [array['FN'], array['TN']]])
    total = conf_matrix.sum()
    percent_matrix = conf_matrix / total * 100
    labels = np.array([
        [f"TP\n{conf_matrix[0,0]:.0f}\n({percent_matrix[0,0]:.1f}%)", 
         f"FP\n{conf_matrix[0,1]:.0f}\n({percent_matrix[0,1]:.1f}%)"],
        [f"FN\n{conf_matrix[1,0]:.0f}\n({percent_matrix[1,0]:.1f}%)", 
         f"TN\n{conf_matrix[1,1]:.0f}\n({percent_matrix[1,1]:.1f}%)"]
    ])
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 6))
    sns.heatmap(conf_matrix, annot=labels, fmt='', cmap="Blues", cbar=False,
            xticklabels=['', ''], yticklabels=['', ''], square=True, linewidths=0.5, ax=ax)
    ax.set_xlabel('Predicted Values', fontsize=22, fontweight='bold')
    ax.set_ylabel('Actual Values', fontsize=22, fontweight='bold')
    ax.set_title('', fontsize=0)
    ax.tick_params(axis='both', labelsize=22, length=8)
    for text in ax.texts:
        text.set_fontsize(24)
        text.set_fontweight('bold')
    return ax


# --- AGGREGATE PANELS INTO ONE FIGURE ---

def publication_figure(experiments, conf_matrix_path, feature_imp_path, output_png='publication_figure.png'):
    import matplotlib.gridspec as gridspec
    import matplotlib.font_manager as fm

    font = fm.FontProperties(weight='bold')

    fig = plt.figure(constrained_layout=True, figsize=(24, 14))
    gs = gridspec.GridSpec(2, 3, width_ratios=[1.2, 1, 1], height_ratios=[1, 1.1], figure=fig)

    # ROC panel
    ax_roc = fig.add_subplot(gs[0, 0])
    plot_roc_curve(experiments, ax=ax_roc)
    ax_roc.set_title('A', loc='left', fontsize=28, fontweight='bold', pad=20)
    ax_roc.legend(loc='lower right', fontsize=9, frameon=True, prop=font)
    # PRC panel
    ax_prc = fig.add_subplot(gs[0, 1])
    plot_multiple_pr_curves(experiments, ax=ax_prc)
    ax_prc.set_title('B', loc='left', fontsize=28, fontweight='bold', pad=20)
    ax_prc.legend(loc='upper right', fontsize=9, frameon=True, prop=font)

    # Confusion matrix panel
    ax_conf = fig.add_subplot(gs[0, 2])
    generate_conf_mtx(conf_matrix_path, ax=ax_conf)
    ax_conf.set_title('C', loc='left', fontsize=28, fontweight='bold', pad=20)

    # Calibration panel
    ax_cal = fig.add_subplot(gs[1, 0])
    plot_multiple_calibration_curves(experiments, ax=ax_cal)
    ax_cal.set_title('D', loc='left', fontsize=28, fontweight='bold', pad=20)
    ax_cal.legend(loc='upper left', fontsize=9, frameon=True, prop=font)

    # Bold tick marks for A, B, and C panels
    for axis in [ax_roc, ax_prc, ax_cal]:
        axis.tick_params(axis='both', labelsize=18, length=8, width=2)
        for tick in axis.xaxis.get_major_ticks():
            tick.label1.set_fontweight('bold')
        for tick in axis.yaxis.get_major_ticks():
            tick.label1.set_fontweight('bold')


    # Feature importance panel (span both columns 1 and 2 in row 1)
    ax_fi = fig.add_subplot(gs[1, 1:])
    feature_importances_plot(feature_imp_path, ax=ax_fi, color="#009E73")
    ax_fi.set_title('E', loc='left', fontsize=28, fontweight='bold', pad=20)

    plt.savefig(output_png, dpi=300, bbox_inches='tight')
    plt.close(fig)

# Example usage:
publication_figure(
    experiments,
    conf_matrix_path='./results_all/demographics_lancet2024/LDE/allages/AD/lgbm',  # or any experiment folder with metrics.txt
    feature_imp_path='./results_all/demographics_lancet2024/LDE/allages/AD/lgbm',  # or any experiment folder with features.txt
    output_png='publication_figure.png'
)