In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve
import glob
import re
from scipy.stats import sem
from google.cloud import bigquery
import seaborn as sns

# Standard Calibration Curves 

## Helper Plot Functions

In [None]:
def plot_calibration_curve(predictionfile, cohortfile="processed_data/cohort.csv", outcome_col="outcome", pred_col="Prediction", n_bins=10, title='Calibration Curve'):
    """
    Plots a single calibration curve for a given model's predicted probabilities.
    
    Parameters:
    - df: DataFrame containing actual outcomes and predicted probabilities.
    - outcome_col: Column with true binary outcomes (0 or 1).
    - pred_col: Column with predicted probabilities.
    - n_bins: Number of bins to divide predicted probabilities.
    
    Returns:
    - A plot of the calibration curve.
    """
    df_predictions=pd.read_csv(predictionfile)
    cohort = pd.read_csv(cohortfile)
    df=pd.merge(cohort, df_predictions, on='MRN', how='inner')
    
    # Extract true labels and predicted probabilities
    y_true = df[outcome_col]
    y_prob = df[pred_col]

    # Compute actual vs. predicted probabilities using sklearn's calibration_curve
    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')

    # Plot Calibration Curve
    plt.figure(figsize=(6, 6))
    plt.plot(prob_pred, prob_true, "s-", label="Model Calibration")
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect Calibration")  # Reference line

    # Formatting
    plt.xlabel("Mean Predicted Probability", fontsize=18)
    plt.ylabel("Fraction of Positives", fontsize=18)
    plt.title(title,fontsize=20)
    plt.legend(fontsize=16)
    plt.grid()

    # Show plot
    plt.show()
    
        # Create a DataFrame for the proportions for each decile
    decile_table = pd.DataFrame({
        'Decile': np.arange(1, n_bins + 1),
        'Mean Predicted Probability': prob_pred,
        'Proportion of Positives': prob_true
    })
    
    # Display the table of proportions for each decile
    # print(decile_table)
    return decile_table

In [None]:
def plot_multiple_calibration_curves(prediction_files, outcome_col="outcome", pred_col="Prediction", n_bins=10):
    """
    Plots calibration curves for multiple models using their prediction files.

    Parameters:
    - prediction_files: List of CSV files containing actual outcomes and predicted probabilities.
    - outcome_col: Column with true binary outcomes (0 or 1).
    - pred_col: Column with predicted probabilities.
    - n_bins: Number of bins to divide predicted probabilities.
    
    Returns:
    - A plot with calibration curves for multiple models.
    """
    plt.figure(figsize=(6, 6))

    # Loop through each prediction file
    for file in prediction_files:
        # Load data
        df_predictions=pd.read_csv(file)
        cohort = pd.read_csv("processed_data/cohort.csv")
        df=pd.merge(cohort, df_predictions, on='MRN', how='inner')

        # Extract true labels and predicted probabilities
        y_true = df[outcome_col]
        y_prob = df[pred_col]

        # Compute actual vs. predicted probabilities
        prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')

        # Clean model name (remove prefix if needed)
        model_name = file.removeprefix("final_predictions_").removesuffix(".csv")

        # Plot Calibration Curve
        plt.plot(prob_pred, prob_true, "s-", label=f"{model_name}")

    # Plot perfect calibration line
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect Calibration")

    # Formatting
    plt.xlabel("Mean Predicted Probability")
    plt.ylabel("Fraction of Positives")
    plt.title("Calibration Curves (Multiple Models)")
    plt.legend()
    plt.grid()

    # Show plot
    plt.show()

In [None]:
# def plot_calibration_curves_with_race(df, outcome_col="outcome", race_col='race', pred_col="Prediction", n_bins=10, title="Calibration Curves"):
def plot_calibration_curve_with_race(predictionfile, cohortfile="processed_data/cohort.csv", outcome_col="outcome", pred_col="Prediction", n_bins=10, title='Calibration Curve'):
    df_predictions = pd.read_csv(predictionfile)
    cohort = pd.read_csv(cohortfile)
    demo_df = pd.read_csv('processed_data/demo_not_1h_encoded.csv')
    
    # Merge data
    df = pd.merge(cohort, df_predictions, on='MRN', how='inner')
    df = df.merge(demo_df, left_on='MRN', right_on='MRN', how='left')
    
    # # Consolidate sparse race categories into 'other'
    df['race_consolidated'] = df['race'].replace({'aian': 'other', 'nhpi': 'asian'})
    df.loc[df['ethnicity'].str.lower() == 'hispanic', 'race_consolidated'] = 'hispanic'
    df['race_consolidated'] = df['race_consolidated'].replace({'hispanic': 'Hispanic', 'white': 'Non-Hispanic White', 'asian': 'Non-Hispanic Asian',
                                                              'black': 'Non-Hispanic Black', 'other': 'Other/Unknown'})
    plt.figure(figsize=(10, 10))
    races = df['race_consolidated'].unique()

    # Dictionary to store decile tables for each race
    decile_tables = {}

    for race in races:
        # Filter the DataFrame for the current race
        df_race = df[df['race_consolidated'] == race]

        # Extract true labels and predicted probabilities
        y_true = df_race[outcome_col]
        y_prob = df_race[pred_col]

        # Compute actual vs. predicted probabilities using sklearn's calibration_curve
        prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')

        # Plot Calibration Curve for the current race
        plt.plot(prob_pred, prob_true, "s-", label=f"{race}")

        # Create a DataFrame for the proportions for each decile
        decile_table = pd.DataFrame({
            'Decile': np.arange(1, n_bins + 1),
            'Mean Predicted Probability': prob_pred,
            'Proportion of Positives': prob_true
        })

        # Store the decile table in the dictionary
        decile_tables[race] = decile_table

    # Plot reference line for perfect calibration
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect Calibration")

    # Formatting
    plt.xlabel("Mean Predicted Probability", fontsize=18)
    plt.ylabel("Fraction of Positives",fontsize=18)
    plt.title(title,fontsize=20)
    plt.legend(fontsize=16)
    plt.grid()

    # Show plot
    plt.show()

    return decile_tables




In [None]:
best_test_predictions = 'out/final_test_predictions_AoUencoder_StanfordFinetune-15-100pct.csv'

In [None]:
decile_tables = plot_calibration_curve_with_race(best_test_predictions, title='Calibration Curves Stratified by Race/Ethnicity')


In [None]:
dec_table = plot_calibration_curve(best_test_predictions, title='Calibration Curve: Glaucoma Prediction')

In [None]:
dec_table.to_csv('standard_decile_table.csv', index=False)

In [None]:
prediction_files = glob.glob("out/final_test_predictions_AoUencoder_StanfordFinetune-15*.csv")

# Plot calibration curves for all models
plot_multiple_calibration_curves(prediction_files)

In [None]:
prediction_files = glob.glob("out/final_test_predictions_AoUencoder_StanfordFinetune-18*.csv")

# Plot calibration curves for all models
plot_multiple_calibration_curves(prediction_files)

# Special Calibration Curves 

## Helper Plot Functions

In [None]:
def plot_custom_calibration_curve(
    predictionfile,
    outcomefile,
    measure_col,
    title_string="",
    pred_col="Prediction",
    label_string="",
    n_bins=10,
    ylabel=None,
    ylim=None,
    auto_ylim=False,
    ylim_quantiles=(0.02, 0.98),
):
    df_predictions = pd.read_csv(predictionfile)
    cohort = pd.read_csv(outcomefile)
    df = pd.merge(cohort, df_predictions, on="MRN", how="inner")

    df = df[[pred_col, measure_col]].dropna()

    df["prob_bin"] = pd.qcut(df[pred_col], q=n_bins, labels=False, duplicates="drop")
    bin_stats = (
        df.groupby("prob_bin")
        .agg(mean_pred_prob=(pred_col, "mean"), mean_measure=(measure_col, "mean"))
        .reset_index()
    )

    if not label_string:
        label_string = measure_col
    if not title_string:
        title_string = measure_col
    if ylabel is None:
        ylabel = f"Mean {title_string}"

    plt.figure(figsize=(6, 6))
    plt.plot(
        bin_stats["mean_pred_prob"],
        bin_stats["mean_measure"],
        "s-",
        label=label_string,
    )
    plt.xlabel("Mean Predicted Probability (Decile)", fontsize=18)
    plt.ylabel(ylabel, fontsize=18)
    plt.title(f"Custom Calibration Curve: {title_string}", fontsize=20)
    plt.legend(fontsize=16)
    plt.grid()

    if ylim is not None:
        plt.ylim(*ylim)
    else:
        y = bin_stats["mean_measure"].to_numpy()
        if auto_ylim and len(y):
            lo, hi = np.nanquantile(y, ylim_quantiles)
            pad = 0.05 * (hi - lo) if hi > lo else 1.0
            plt.ylim(lo - pad, hi + pad)
        else:
            # auto-bound only when the measure clearly lives in [0,1]
            if len(y) and np.nanmin(y) >= 0 and np.nanmax(y) <= 1:
                plt.ylim(0, 1)

    plt.show()

    decile_table = pd.DataFrame({
        "Decile": bin_stats["prob_bin"] + 1,
        "Mean Predicted Probability": bin_stats["mean_pred_prob"],
        f"Mean {measure_col}": bin_stats["mean_measure"],
    })
    return decile_table

## IOP Calibration

### Get Max IOP

In [None]:
def run_query(query): 
    # Set up the BigQuery client
    project_id = 'som-nero-phi-sywang-starr'
    client = bigquery.Client(project=project_id)

    # Execute the query
    df = client.query(query, project=project_id).to_dataframe()

    return df

In [None]:
query = """
SELECT cohort.pat_mrn, iop.smrtdta_elem_value as iop
FROM `som-nero-phi-sywang-starr.gps_stanford_clinic.cohort_systemic` as cohort, `som-nero-phi-sywang-starr.SOURCE_02022024.sf_oph_enc_exam` as iop
where concept_id in ('EPIC#OPH153', 'EPIC#OPH154')
and cohort.pat_mrn = iop.pat_mrn
and smrtdta_elem_value is not null;

"""
df_iop = run_query(query)
print(f"# of rows: {len(df_iop)}")
print(f"# of unique pats: {len(df_iop['pat_mrn'].unique())}")

In [None]:
df_iop.head()

In [None]:
def extract_first_number(s):
    match = re.search(r'\d+', str(s))  # Find first number in string
    return int(match.group()) if match else None  # Convert to int, or return None if no number found

In [None]:
#clean up the IOP's so they are all numeric 
df_iop["iop_num"]=df_iop["iop"].apply(extract_first_number)

In [None]:
#check for outliers 
df_iop["iop_num"].hist()

In [None]:
sum(df_iop["iop_num"]>100)

In [None]:
df_iop["iop_num"] = df_iop["iop_num"].where(df_iop["iop_num"] <= 100, np.nan)

In [None]:
#get the max iop of each patient 
df_maxiop=df_iop.groupby("pat_mrn", as_index=False)["iop_num"].max()

In [None]:
df_maxiop['MRN']=df_maxiop['pat_mrn'].astype(int)

In [None]:
df_maxiop.head()

In [None]:
df_maxiop["iop_num"].hist()

In [None]:
df_maxiop.to_csv("processed_data/maxiop.csv", index=False)

In [None]:
best_test_predictions

### Calibration Curve

In [None]:
iop_table = plot_custom_calibration_curve(
    best_test_predictions,
    "processed_data/maxiop.csv",
    measure_col="iop_num",
    title_string="Max IOP",
    label_string="Model Calibration",
    ylabel="Mean Max IOP (mmHg)",
    ylim=(16, 24),
)



In [None]:
iop_table

In [None]:
iop_table.to_csv('iop_dec_table.csv', index=False)

## CDR Calibration 

### Get Max CDR

In [None]:
query = """SELECT cohort.pat_mrn, cdr.smrtdta_elem_value as cdr
FROM `som-nero-phi-sywang-starr.gps_stanford_clinic.cohort_systemic` as cohort, `som-nero-phi-sywang-starr.SOURCE_02022024.sf_oph_enc_exam` as cdr
where concept_id in ('EPIC#OPH1090', 'EPIC#OPH1091')
and cohort.pat_mrn = cdr.pat_mrn
and smrtdta_elem_value is not null"""
df_cdr = run_query(query)
print(f"# of rows: {len(df_cdr)}")
print(f"# of unique pats: {len(df_iop['pat_mrn'].unique())}")

In [None]:
df_cdr.head()

In [None]:
def extract_cdr(text):
    """
    Extracts the first valid decimal number or whole number from a string.
    If the text contains "NTC", "ntc", "near total cupping", or "totally cupped",
    it returns 1.0 instead.

    Parameters:
    - text: Input string

    Returns:
    - Extracted decimal/number as float, or 1.0 for special cases.
    """
    if pd.isna(text):  
        return None  # Handle NaNs
    
    text = str(text).lower().strip()  # Convert to lowercase and strip spaces
    
    # Special cases mapping to 1.0
    if any(phrase in text for phrase in ["ntc", "near total cupping", "totally cupped", "total cupping"]):
        return 1.0
    
    # Regular expression for decimal and whole numbers
    match = re.search(r"\d+\.\d+|\.\d+|\d+", text)  
    
    if match:
        return float(match.group())  # Convert matched number to float
    
    return None  # Return None if no valid number is found


In [None]:
df_cdr["cdr_num"]=df_cdr["cdr"].apply(extract_cdr)

In [None]:
sum(df_cdr["cdr_num"]>1)

In [None]:
df_cdr["cdr_num"] = df_cdr["cdr_num"].where(df_cdr["cdr_num"] <= 1, np.nan)

In [None]:
df_cdr["cdr_num"].hist()

In [None]:
#get the max iop of each patient 
df_maxcdr=df_cdr.groupby("pat_mrn", as_index=False)["cdr_num"].max()

In [None]:
df_maxcdr['MRN']=df_maxcdr['pat_mrn'].astype(int)

In [None]:
df_maxcdr["cdr_num"].hist()

In [None]:
df_maxcdr.to_csv("processed_data/maxcdr.csv", index=False)

### Calibration Curve

In [None]:
cdr_table = plot_custom_calibration_curve(
    best_test_predictions,
    "processed_data/maxcdr.csv",
    measure_col="cdr_num",
    title_string="Max CDR",
    label_string="Model Calibration",
    ylabel="Mean Max CDR",
    # auto_ylim=True
    ylim=(0.3, 0.65),
)

In [None]:
cdr_table

In [None]:
cdr_table.to_csv('cdr_decile_table.csv', index=False)

## Glaucoma Meds/Laser/Treatment Calibration 

In [None]:
cohort = pd.read_csv("processed_data/cohort.csv")
cohort.head()

### Get info on glaucoma meds

In [None]:
query = """
SELECT cohort.pat_mrn, med.medication_id, med.generic
FROM `som-nero-phi-sywang-starr.gps_stanford_clinic.cohort_systemic` as cohort, `som-nero-phi-sywang-starr.SOURCE_02022024.sf_oph_med` as med
where medication_id in (22952, 80897, 15016, 95690, 22991, 91540, 222820, 222947, 223383, 223440, 229853, 229855, 
 1048, 78246, 9268, 10393, 1026, 10394, 10584, 15111, 90818, 15112, 88037, 114915, 40463, 
 114914, 40464, 220775, 11561, 7970, 24575, 11562, 7971, 12023, 24576, 125362, 38874, 202894, 
 202896, 216320, 18829, 79326, 125837, 70393, 31161, 95023, 17881, 211278, 211284, 29901, 
 83155, 17442, 77371, 222856, 222884, 224089, 200270, 29884, 161715, 82555, 6278, 6279, 86065, 
 6280, 6282, 6288, 82414, 28834, 94057, 112, 113, 38199, 80001, 167878, 174506, 4961, 5500, 
 4962, 17152, 181593, 242643, 244901, 242000)
and cohort.pat_mrn = med.pat_mrn
"""
df_glaucmed = run_query(query)
print(f"# of rows: {len(df_glaucmed)}")
print(f"# of unique pats: {len(df_glaucmed['pat_mrn'].unique())}")

In [None]:
df_glaucmed['MRN']=df_glaucmed['pat_mrn'].astype(int)

In [None]:
df_glaucmed.head()

In [None]:
med_ids = [
    22952, 80897, 15016, 95690, 22991, 91540, 222820, 222947, 223383, 223440, 229853, 229855,
    1048, 78246, 9268, 10393, 1026, 10394, 10584, 15111, 90818, 15112, 88037, 114915, 40463,
    114914, 40464, 220775, 11561, 7970, 24575, 11562, 7971, 12023, 24576, 125362, 38874, 202894,
    202896, 216320, 18829, 79326, 125837, 70393, 31161, 95023, 17881, 211278, 211284, 29901,
    83155, 17442, 77371, 222856, 222884, 224089, 200270, 29884, 161715, 82555, 6278, 6279, 86065,
    6280, 6282, 6288, 82414, 28834, 94057, 112, 113, 38199, 80001, 167878, 174506, 4961, 5500,
    4962, 17152, 181593, 242643, 244901, 242000
]

generic_names = df_glaucmed[df_glaucmed['medication_id'].isin(med_ids)]['generic'].dropna().unique()

generic_names_str = ', '.join(generic_names)

print(generic_names_str)


In [None]:
cohort["glaucmeds"]=cohort["MRN"].isin(df_glaucmed["MRN"]).astype(int)

### Get info on glaucoma surgeries and slt

In [None]:
query = """
SELECT cohort.pat_mrn, surg.cpt1, surg.ALL_PROC_AS_ORDERED
FROM `som-nero-phi-sywang-starr.gps_stanford_clinic.cohort_systemic` as cohort, `som-nero-phi-sywang-starr.SOURCE_02022024.sf_oph_surgery_all` as surg
where cohort.pat_mrn = surg.pat_mrn
and (cpt1 in ("0191T", "66989", "66991", "0253T", "0474T", "0449T", "0376T", "66170", "66172", "66179", 
 "66180", "0192T", "66174", "66175", "66710", "66711", "66987", "66988", "66720", "66740", 
 "66155", "66160", "65820", "65850", "66183", "67250", "67255", "66184", "66185")
or cpt2 in ("0191T", "66989", "66991", "0253T", "0474T", "0449T", "0376T", "66170", "66172", "66179", 
 "66180", "0192T", "66174", "66175", "66710", "66711", "66987", "66988", "66720", "66740", 
 "66155", "66160", "65820", "65850", "66183", "67250", "67255", "66184", "66185")
 or cpt3 in ("0191T", "66989", "66991", "0253T", "0474T", "0449T", "0376T", "66170", "66172", "66179", 
 "66180", "0192T", "66174", "66175", "66710", "66711", "66987", "66988", "66720", "66740", 
 "66155", "66160", "65820", "65850", "66183", "67250", "67255", "66184", "66185")
 or cpt4 in ("0191T", "66989", "66991", "0253T", "0474T", "0449T", "0376T", "66170", "66172", "66179", 
 "66180", "0192T", "66174", "66175", "66710", "66711", "66987", "66988", "66720", "66740", 
 "66155", "66160", "65820", "65850", "66183", "67250", "67255", "66184", "66185")
 or cpt5 in ("0191T", "66989", "66991", "0253T", "0474T", "0449T", "0376T", "66170", "66172", "66179", 
 "66180", "0192T", "66174", "66175", "66710", "66711", "66987", "66988", "66720", "66740", 
 "66155", "66160", "65820", "65850", "66183", "67250", "67255", "66184", "66185")
) """
df_glaucsurg = run_query(query)
print(f"# of rows: {len(df_glaucsurg)}")
print(f"# of unique pats: {len(df_glaucsurg['pat_mrn'].unique())}")

In [None]:
df_glaucsurg['MRN']=df_glaucsurg['pat_mrn'].astype(int)

In [None]:
cohort["glaucsurg"]=cohort["MRN"].isin(df_glaucsurg["MRN"]).astype(int)

In [None]:
query = """
SELECT MRN, procedure_concept_id, procedure_source_value FROM `som-nero-phi-sywang-starr.gps_stanford_clinic.procedure_occurrence` as proc, `som-nero-phi-sywang-starr.gps_stanford_clinic.mrn_crosswalk` as crosswalk
where procedure_concept_id in (2110962)
and proc.person_id = crosswalk.person_id
"""
df_slt=run_query(query)
print(f"# of rows: {len(df_slt)}")
print(f"# of unique pats: {len(df_slt['MRN'].unique())}")

In [None]:
df_slt['MRN']=df_slt['MRN'].astype(int)
cohort["slt"]=cohort["MRN"].isin(df_slt["MRN"]).astype(int)
cohort["glauctx"]=(cohort[['slt', 'glaucmeds', 'glaucsurg']].sum(axis=1) > 0).astype(int)

In [None]:
cohort.to_csv("processed_data/glauctx.csv", index=False)

### Calibration Curve

In [None]:
dec_tab = plot_calibration_curve(best_test_predictions, 
                       'processed_data/glauctx.csv', outcome_col="glauctx",title='Calibration Curve: Any Glaucoma Treatment')



In [None]:
dec_tab

In [None]:
dec_tab.to_csv('glauctx_decile_table.csv', index=False)

In [None]:
def plot_calibration_curve_v2(predictionfile, cohortfile="processed_data/cohort.csv", outcome_col="outcome", pred_col="Prediction", n_bins=10, title='Calibration Curve'):
    df_predictions = pd.read_csv(predictionfile)
    cohort = pd.read_csv(cohortfile)
    df = pd.merge(cohort, df_predictions, on='MRN', how='inner')
    
    y_true = df[outcome_col]
    y_prob = df[pred_col]

    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins, strategy='uniform')

    plt.plot(prob_pred, prob_true, "s-", label="Model Calibration")
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Perfect Calibration")
    plt.xlabel("Mean Predicted Probability", fontsize=18)
    plt.ylabel("Fraction of Positives", fontsize=18)
    plt.title(title, fontsize=20)
    plt.legend(fontsize=16)
    plt.grid()

def plot_custom_calibration_curve_v2(predictionfile, outcomefile, measure_col, title_string='', pred_col='Prediction', label_string='', n_bins=10):
    df_predictions = pd.read_csv(predictionfile)
    cohort = pd.read_csv(outcomefile)
    df = pd.merge(cohort, df_predictions, on='MRN', how='inner')
    
    df = df[[pred_col, measure_col]].dropna()

    df["prob_bin"] = pd.qcut(df[pred_col], q=n_bins, labels=False, duplicates="drop")

    bin_stats = df.groupby("prob_bin").agg(
        mean_pred_prob=(pred_col, "mean"),
        mean_measure=(measure_col, "mean")
    ).reset_index()

    if label_string == '':
        label_string = measure_col
    plt.plot(bin_stats["mean_pred_prob"], bin_stats["mean_measure"], "s-", label=f"{label_string}")

    # plt.xlabel("Mean Predicted Probability (Decile)", fontsize=18)
    plt.xlabel("Mean Predicted Probability", fontsize=18)
    # plt.ylabel(f"{title_string} Decile Mean", fontsize=18)
    plt.ylabel(f"Mean {title_string} (by Decile)", fontsize=18)
    plt.ylabel(f"Decile Mean of {title_string}", fontsize=18)
    if title_string == '':
        title_string = measure_col
    plt.title(f"Custom Calibration Curve: {title_string}", fontsize=20)
    plt.legend(fontsize=16)
    plt.grid()

fig, axs = plt.subplots(2, 2, figsize=(14, 12))

# Plot 1: Calibration Curve: Glaucoma Prediction
plt.sca(axs[0, 0])
plot_calibration_curve_v2(best_test_predictions, title='Calibration Curve: Glaucoma Prediction')
axs[0, 0].text(-0.1, 1.1, 'A', transform=axs[0, 0].transAxes, fontsize=24, fontweight='bold', va='top', ha='right')

# Plot 2: Custom Calibration Curve: Max IOP
plt.sca(axs[0, 1])
plot_custom_calibration_curve_v2(best_test_predictions, 'processed_data/maxiop.csv', measure_col='iop_num', label_string='Model Calibration', title_string='Max IOP')
axs[0, 1].text(-0.1, 1.1, 'B', transform=axs[0, 1].transAxes, fontsize=24, fontweight='bold', va='top', ha='right')

# Plot 3: Custom Calibration Curve: Max CDR
plt.sca(axs[1, 0])
plot_custom_calibration_curve_v2(best_test_predictions, 'processed_data/maxcdr.csv', measure_col='cdr_num', label_string='Model Calibration', title_string='Max CDR')
axs[1, 0].text(-0.1, 1.1, 'C', transform=axs[1, 0].transAxes, fontsize=24, fontweight='bold', va='top', ha='right')

# Plot 4: Calibration Curve: Any Glaucoma Treatment
plt.sca(axs[1, 1])
plot_calibration_curve_v2(best_test_predictions, 'processed_data/glauctx.csv', outcome_col="glauctx", title='Calibration Curve: Any Glaucoma Treatment')
axs[1, 1].text(-0.1, 1.1, 'D', transform=axs[1, 1].transAxes, fontsize=24, fontweight='bold', va='top', ha='right')

plt.tight_layout()

plt.savefig('figures/calibrations.tiff', format='tiff')

plt.show()

