In [None]:
from _load_llm_results import *
import numpy as np
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

boxplot_kwargs = {
    'sharey':True,
    'notch': False,
    'showcaps':True,
    'flierprops':{"marker": "x"},
    # 'boxprops':{"facecolor": 'white'},
    'medianprops':{"color": "k", "linewidth": 1, 'alpha':0.5}
}

palette = sns.color_palette("hls", 6) # ['#374aa3', '#cc6666', '#6688d0', '#ffcccc', '#336699', '#99ccff']

phenotypes_order =  ['HTN Heuristic', 'Htn-Hypokalemia Heuristic', 'Resistant HTN Heuristic',
                     'HTN Diagnosis', 'HTN-Hypokalemia Diagnosis', 'Resistant HTN Diagnosis']

settings_order = ['Simple prompt,\nfew features', 'Simple prompt,\nall features', 
                  'Rich prompt,\nfew features',   'Rich prompt,\nall features']
paper_dir = '../../paper/floats/'
model_order = [
    'gpt-3.5-turbo',
    'gpt-3.5-turbo-iter',
    'gpt-4o-mini',
    'gpt-4o-mini-iter',
    'gpt-4o',
    'gpt-4o-iter',
    'gpt-4-turbo',
    'gpt-4-turbo-iter',
]

# Making it the format seaborn likes
results_df_melted = pd.melt(
    results_df, 
    id_vars=['model', 'target', 'fold', 'RunID', 'random_state', 'prompt_richness', 'few_feature']
)

print(results_df.columns)
print(results_df.shape)
# results_df_melted.sample(3)

In [None]:
# Number of results

# group by different experimental settings, count the occurence of experiments (any of
# 'fold', 'RunID', 'random_state' should do it), then pivot to fit everything in the screen

results_df \
    .groupby(['model', 'target', 'prompt_richness', 'few_feature']) \
    .count()[['fold', 'RunID', 'random_state']] \
    .pivot_table(index=['target'], columns=['model', 'prompt_richness', 'few_feature'],values='random_state') \
    .fillna(0).astype('int').style.background_gradient(axis=None, cmap='viridis')  

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('whitegrid')
# for target in results_df.target.unique():
for yaxis in [
    # 'accuracy_score_train',
    # 'accuracy_score_test',
    # 'average_precision_score_train',
    'average_precision_score_test',
    # 'roc_auc_score_train',
    'roc_auc_score_test',
    # 'size'
]: 
    data = results_df_melted[
        (results_df_melted['variable']==yaxis)
        # &
        # (results_df_melted['target']==target)
    ].copy()
    data['Strategy'] = ['SEDI' if 'iter' in v else 'Zero shot' for v in data['model'].values]
    
    
    # data = data[~data['model'].str.contains("iterative")]
    
    yaxis = yaxis.replace('_', ' ').capitalize()

    data = data.rename(columns={'value': yaxis}) #, 'model':'Model'})
    data['Model'] = data['model'].apply(lambda x: x.replace('-iter',''))

    data['Setting'] = data[['few_feature', 'prompt_richness']].apply(
        lambda row: 
        f"{'Rich prompt,\n'  if row.prompt_richness  else 'Simple prompt,\n'}"
        f"{'few features' if row.few_feature else 'all features'}", 
    axis=1)
    # col.name = 'Setting'
    data['Dx Description'] = data['prompt_richness'].apply(lambda x: 'Detailed' if x else 'Simple')
    data['Feature Set'] = data['few_feature'].apply(lambda x: 'Minimal' if x else 'Full')

    plt_data_args = dict(
        # y='Model', 
        y='Setting',
        order = [
            'Simple prompt,\nfew features',
            'Simple prompt,\nall features',
            'Rich prompt,\nfew features',
            'Rich prompt,\nall features',
        ],
        x=yaxis, 
        # order=[m for m in model_order if m in data['Model'].unique()],
        hue='Strategy',
        palette = palette,
    )
    
    g = sns.catplot(
        data=data,
        row='Model',
        row_order=[m for m in model_order if m in data['Model'].unique()],
        col='target',
        col_order=phenotypes_order,
        aspect=1, 
        height=3,
        margin_titles=True,
        estimator=np.median,
        linewidth=1.5,
        kind='bar',
        dodge=True,
        capsize=.4,
        err_kws={"color": ".5", "linewidth": 1.5},
        edgecolor=".5", 
        errorbar=('ci',95),
        # errorbar=('pi',95),
        # kind="box", 
        # join=False,
        # dodge=True,
        sharex='col',
        # **boxplot_kwargs,
        **plt_data_args
    )
    g.set_titles(col_template='{col_name}',row_template='{row_name}')
    g.set_ylabels('')
    # g.set_xlabels(yaxis.replace(' test',''))

    plt.subplots_adjust(
        left=0.1,
        right=1,
        bottom=0,
        top=0.92,
        hspace = 0.05
    )

    for ax in g.axes.flat:
        # Make the grid horizontal instead of vertical
        ax.xaxis.grid(True)
        ax.yaxis.grid(True)
        if 'roc' in yaxis.lower():
            ax.set_xlim(left=0.4)
      
    sns.move_legend(g, "upper left", bbox_to_anchor=(0.0, 1),ncols=2,frameon=True)

    plt.savefig(f"{paper_dir}/llm_comparison_{yaxis}_{target.replace(' ','-')}.pdf")
    plt.savefig(f"{paper_dir}/llm_comparison_{yaxis}_{target.replace(' ','-')}.png")
    plt.show()

In [None]:
data.Model.unique()

In [None]:
for yaxis in [
    # 'accuracy_score_train',
    'accuracy_score_test',
    # 'average_precision_score_train',
    'average_precision_score_test',
    # 'roc_auc_score_train',
    'roc_auc_score_test',
    'size'
]: 
    data = results_df_melted[results_df_melted['variable']==yaxis]
    
    # data = data[~data['model'].str.contains("iterative")]
    
    yaxis = yaxis.replace('_', ' ').capitalize()

    data = data.rename(columns={'value': yaxis, 'model':'Model'})

    hue= data[['few_feature', 'prompt_richness']].apply(
        lambda row: f"{'Rich prompt,\n' if row.prompt_richness else 'Simple prompt,\n'}{'few features' if row.few_feature else 'all features'}", axis=1)
    hue.name = 'Setting'

    g = sns.catplot(
        data=data,
        x="Model", y=yaxis, order=model_order,
        col="target", col_wrap=3, col_order = phenotypes_order,
        hue=hue, hue_order=settings_order,
        aspect=0.8, estimator=np.median,
        palette = palette,
        linewidth=1.5,
        kind="box", **boxplot_kwargs
    )

    [g.refline(x=x, color='gray', lw=0.5, ls=':', zorder=0)
        for x in [0.5, 1.5,2.5,3.5,4.5]]
    
    # hatches must equal the number of hues (3 in this case)
    hatches = ['///', '///', '', '']

    # iterate through each subplot / Facet
    for ax in g.axes.flat:
        # select the correct patches (works only if all results exist)
        patches = [patch for patch in ax.patches if type(patch) == matplotlib.patches.PathPatch]
        # the number of patches should be evenly divisible by the number of hatches
        h = hatches * (len(patches) // len(hatches))
        n_models = len(data['Model'].unique())

        # print(len(patches), n_models)

        # iterate through the patches for each subplot
        # assert len(patches)==4*n_models, "Inconsistent number of results"

        if len(patches)!=4*n_models:
            continue
        
        # for patch, hatch in zip(patches*n_models, h):
        #     patch.set_hatch(hatch)
            # fc = patch.get_facecolor()
            # patch.set_edgecolor(fc)
            # patch.set_facecolor('none')

    for (title, xaxis) in g._axes_dict.items():
        xaxis.set_title(title)
        # xaxis.set_ylabel(yaxis)
        # xaxis.grid(which='major', axis='y', linewidth=.8)
        xaxis.grid(which='both', axis='both', ls=":", linewidth=.8)

        for tick in xaxis.get_xticklabels():
            tick.set(rotation=30, ha='center', va='top', ma='right')

    g.map_dataframe(sns.swarmplot, y=yaxis, dodge=True, 
                    x="Model", order=model_order,
                    # col_order = phenotypes_order,
                    hue=hue, hue_order=settings_order,
                    palette = palette, size=3,
                    linewidth=0.5, alpha=0.5)

    # sns.move_legend(g, "upper center", ncol=4, title=None, frameon=False, bbox_to_anchor=(.5, 1.08), )
    sns.move_legend(g, "upper left", bbox_to_anchor=(0.86, 0.625))

    # for lp, hatch in zip(g.legend.get_patches(), hatches):
    #     lp.set_hatch(hatch)
        # fc = lp.get_facecolor()
        # lp.set_edgecolor(fc)
        # lp.set_facecolor('none')
        
    # plt.tight_layout()
    plt.savefig(f"../paper/llm_comparison_{yaxis}.pdf")
    plt.savefig(f"../paper/llm_comparison_{yaxis}.png")
    plt.show()

In [None]:
for yaxis in [
    # 'accuracy_score_train',
    'accuracy_score_test',
    # 'average_precision_score_train',
    'average_precision_score_test',
    # 'roc_auc_score_train',
    'roc_auc_score_test',
    'size'
]: 
    data = results_df_melted[results_df_melted['variable']==yaxis]
    
    data['Iterative'] = data['model'].str.contains("-iter")
    data['model'] = data['model'].str.replace('-iter', '')

    # display(data.sample(5))

    yaxis = yaxis.replace('_', ' ').capitalize()

    data = data.rename(columns={'value': yaxis, 'model':'Model'})

    hue= data[['few_feature', 'prompt_richness']].apply(
        lambda row: f"{'Rich prompt' if row.prompt_richness else 'Simple prompt'},\n{'few features' if row.few_feature else 'all features'}", axis=1)
    hue.name = 'Setting'

    g = sns.catplot(
        data=data,
        hue=hue, y=yaxis, # order=model_order,
        col="target", col_order = phenotypes_order[3:], #col_wrap=3, 
        row='Model',
        x='Iterative', #hue_order=settings_order,
        aspect=0.8, height=3, estimator=np.median,
        palette = palette,
        linewidth=1.5,
        kind="point", dodge=True, margin_titles=True
    )

    # g.set_titles("{col_name}").set_axis_labels("Iterative", "")

    # Customize y-axis labels
    # for i, row in enumerate(g.axes):
    #     # Get unique parents_file value for this row
    #     parents_file_value = data['Model'].unique()[i]
        
    #     # Set y-axis label only for the leftmost subplot in each row
    #     row[0].set_ylabel(f"{parents_file_value}")

    for (title, xaxis) in g._axes_dict.items():
        xaxis.set_title(title[1])
        # xaxis.set_ylabel(yaxis)
        xaxis.grid(which='major', axis='y', linewidth=.8)
        xaxis.grid(which='both', axis='x', linewidth=.8)
        
        # for tick in xaxis.get_xticklabels():
        #     tick.set(rotation=30, ha='center', va='top', ma='right')

    # sns.move_legend(g, "upper center", ncol=4, title=None, frameon=False, bbox_to_anchor=(.5, 1.08), )
    sns.move_legend(g, "upper left", bbox_to_anchor=(0.71, 0.33), frameon=True, fontsize=9, title=None)

    # for lp, hatch in zip(g.legend.get_patches(), hatches):
    #     lp.set_hatch(hatch)
        # fc = lp.get_facecolor()
        # lp.set_edgecolor(fc)
        # lp.set_facecolor('none')
        
    plt.tight_layout()
    plt.savefig(f"../paper/pointplot_llms_{yaxis}.pdf")
    plt.savefig(f"../paper/pointplot_llms_{yaxis}.png")
    plt.show()

In [None]:
for yaxis in [
    # 'accuracy_score_train',
    'accuracy_score_test',
    # 'average_precision_score_train',
    'average_precision_score_test',
    # 'roc_auc_score_train',
    'roc_auc_score_test',
    'size'
]: 
    data = results_df_melted[results_df_melted['variable']==yaxis]
    
    yaxis = yaxis.replace('_', ' ').capitalize()

    data = data.rename(columns={'value': yaxis, 'model':'Model'})

    data['Setting'] = data[['few_feature', 'prompt_richness']].apply(
        lambda row: f"{'Rich prompt,\n' if row.prompt_richness else 'Simple prompt,\n'}{'few features' if row.few_feature else 'all features'}", axis=1)

    g = sns.catplot(
        data=data, y=yaxis, 
        hue="Model", hue_order=model_order,
        col="target", col_wrap=3, col_order=phenotypes_order,
        x='Setting', order=settings_order,
        aspect=0.8, estimator=np.mean,
        palette = palette,
        linewidth=1,
        kind="box", **boxplot_kwargs
    )

    [g.refline(x=x, color='gray', lw=0.5, ls=':', zorder=0)
        for x in [0.5, 1.5,2.5,3.5,4.5]]

    # hatches must equal the number of hues (3 in this case)
    hatches = ['///', '///', '', '']

    # iterate through each subplot / Facet
    for ax in g.axes.flat:
        # select the correct patches (works only if all results exist)
        patches = [patch for patch in ax.patches if type(patch) == matplotlib.patches.PathPatch]
        # the number of patches should be evenly divisible by the number of hatches
        h = hatches * (len(patches) // len(hatches))
        n_models = len(data['Model'].unique())

        # print(len(patches), n_models)

        # for i, patch in enumerate(patches):
        #     model_index = i // 6
        #     model_name = data['Model'].unique()[model_index]
        #     if '-iter' in model_name:
        #         patch.set_hatch(hatches[0])
                # fc = patch.get_facecolor()
                # patch.set_edgecolor(fc)
                # patch.set_facecolor('none')

    for (title, xaxis) in g._axes_dict.items():
        xaxis.set_title(title)
        # xaxis.set_ylabel(yaxis)
        xaxis.grid(which='major', axis='y', ls=":", linewidth=.8)
        xaxis.grid(which='both', axis='x', linewidth=.8)
        
        for tick in xaxis.get_xticklabels():
            # tick.set(rotation=90, ha='center', va='top', ma='right')
            tick.set(rotation=90, ha='right', va='top', ma='right')

    g.map_dataframe(sns.swarmplot, y=yaxis, dodge=True, 
        hue="Model", hue_order=model_order,
        x='Setting', order=settings_order,
        palette = palette, size=2,
        linewidth=0.5, alpha=0.5)

    # sns.move_legend(g, "upper left", bbox_to_anchor=(-0.08, 0.625))
    sns.move_legend(g, "upper left", bbox_to_anchor=(0.525, 0.325), frameon=True)

    # for lp, hatch in zip(g.legend.get_patches(), hatches):
    #     lp.set_hatch(hatch)
        # fc = lp.get_facecolor()
        # lp.set_edgecolor(fc)
        # lp.set_facecolor('none')
        
    plt.tight_layout()
    plt.savefig(f"../paper/prompt_comparison_{yaxis}.pdf", bbox_inches='tight')
    plt.savefig(f"../paper/prompt_comparison_{yaxis}.png", bbox_inches='tight')
    plt.show()

In [None]:
# Copying Fig.3 from paper "A flexible symbolic regression method for constructing
# interpretable clinical prediction models"
for yaxis in ['average_precision_score_test', 'roc_auc_score_test']:
    
    data = results_df.rename(columns={yaxis: yaxis.replace('_', ' ').capitalize(), 'model':'Model'})
    yaxis = yaxis.replace('_', ' ').capitalize()
    
    data['Setting'] = data[['few_feature', 'prompt_richness']].apply(
        lambda row: f"{'Rich prompt,\n' if row.prompt_richness else 'Simple prompt,\n'}{'few features' if row.few_feature else 'all features'}", axis=1)

    data = data.groupby(['target', 'Model', 'Setting'])[[yaxis, 'size']].agg(['mean', 'std']).reset_index()
    data.columns = list(map(''.join, data.columns.values))

    g = sns.relplot(
        data=data,
        x="sizemean", y=f"{yaxis}mean", aspect=1, height=4, 
        col="target", col_wrap=3, col_order=phenotypes_order,
        kind="scatter",
        hue='Model', hue_order=model_order,
        palette = palette,
        style ='Setting', style_order=settings_order,
        linewidth=1.0, s=125, alpha=0.75, 
    )

    for (ds, plot_ax) in g._axes_dict.items():
        plot_ax.set_title(ds)
        plot_ax.grid(which='major', axis='y', linewidth=.8, ls=':')
        plot_ax.grid(which='major', axis='x', linewidth=.5, ls=':')
        
        # for tick in plot_ax.get_xticklabels():
        #     tick.set(rotation=30, ha='center', va='top', ma='right')

    # Create error bars using the std column
    for (ds, plot_ax) in g._axes_dict.items():
        # Get data points for this subplot
        mask = data['target'] == ds
        x_data = data[mask]['sizemean']
        x_std = data[mask]['sizestd']

        y_data = data[mask][f"{yaxis}mean"]
        y_std = data[mask][f"{yaxis}std"]
        
        # Plot error bars for each point
        for i, (x, x_std, y, y_std) in enumerate(zip(x_data, x_std, y_data, y_std)):
            plot_ax.errorbar(
                x=x, 
                y=y, 
                xerr=x_std,
                yerr=y_std,
                fmt='none',  # No marker, just error bars
                color='black',
                capsize=2,
                elinewidth=1,
                capthick=1,
                alpha=0.5,
                zorder=-999
            )

    # g.set(xscale="log")

    sns.move_legend(g, "upper left", bbox_to_anchor=(0.87, 0.675))
    # plt.tight_layout()
    plt.savefig(f"../paper/llm_pareto_grouped_{yaxis}.pdf")
    plt.savefig(f"../paper/llm_pareto_grouped_{yaxis}.png")
    plt.show()

In [None]:
# Copying Fig.3 from paper "A flexible symbolic regression method for constructing
# interpretable clinical prediction models"
for yaxis in ['average_precision_score_test', 'roc_auc_score_test']:
    
    data = results_df.rename(columns={yaxis: yaxis.replace('_', ' ').capitalize(), 'model':'Model'})
    yaxis = yaxis.replace('_', ' ').capitalize()
    
    data['Setting'] = data[['few_feature', 'prompt_richness']].apply(
        lambda row: f"{'Rich prompt,\n' if row.prompt_richness else 'Simple prompt,\n'}{'few features' if row.few_feature else 'all features'}", axis=1)

    g = sns.relplot(
        data=data,
        x="size", y=f"{yaxis}", aspect=1, height=4, 
        col="target", col_wrap=3, col_order=phenotypes_order,
        kind="scatter",
        hue='Model', hue_order=model_order,
        palette = palette,
        style ='Setting', style_order=settings_order,
        linewidth=1.0, s=125, alpha=0.75, 
    )

    for (ds, plot_ax) in g._axes_dict.items():
        plot_ax.set_title(ds)
        plot_ax.grid(which='major', axis='y', linewidth=.8, ls=':')
        plot_ax.grid(which='major', axis='x', linewidth=.5, ls=':')
        
        # for tick in plot_ax.get_xticklabels():
        #     tick.set(rotation=30, ha='center', va='top', ma='right')

    # g.set(xscale="log")

    sns.move_legend(g, "upper left", bbox_to_anchor=(0.87, 0.675))
    # plt.tight_layout()
    plt.savefig(f"../paper/llm_pareto_{yaxis}.pdf")
    plt.savefig(f"../paper/llm_pareto_{yaxis}.png")
    plt.show()

In [None]:
def parse(pred):
    pred = pred.replace('\n','')
    pred = pred.replace('[','')
    pred = pred.replace(']','')
    pred = list(map(float,pred.split()))
    return pred

def prc_values(y,y_pred_proba):
    precision, recall, prcthresholds = precision_recall_curve(y, y_pred_proba, pos_label=1)
    precision[-1] = np.max(precision[:-1])
    s = np.argsort(recall)
    precision = precision[s]
    recall = recall[s]
    mean_recall = np.linspace(0.0, 1, 21)
    precision = interp(mean_recall, recall, precision)
    return mean_recall, precision

def roc_values(y,y_pred_proba):
    fpr,tpr, rocthresholds = roc_curve(y, y_pred_proba, pos_label=1)
    roc = pd.DataFrame(list(zip(fpr,tpr, rocthresholds)), columns =['fpr','tpr','thresholds']) 
    roc = roc.sort_values(by='fpr')
    tpr = roc['tpr']
    fpr = roc['fpr']
    mean_fpr = np.linspace(0, 1, 21)
    tpr = interp(mean_fpr, fpr, tpr)
    return mean_fpr, tpr

In [None]:
data = results_df
data = data[data['prompt_richness']]
# data = data[~data['few_feature']]

# Calculating rocauc and auprc
spacing, fontsize = 3, 18
for target, perf_t in data.groupby('target'):
    target_new = dnames_to_ugly[target]
    print(target, target_new, f"(shape {perf_t.shape})", sep=",")

    fig, axs = plt.subplots(1, 2, figsize=(10,5))

    i = 1
    for m, model_nice in enumerate([o for o in order if o in perf_t.model.unique()]):
        for few_feature in [True, False]:
            for prompt_richness in [True, False]:
                # model = nice_to_ugly[model_nice]
                perf_t_m = perf_t.loc[ (perf_t.model==model_nice)
                                     & (perf_t.few_feature==few_feature)
                                     & (perf_t.prompt_richness==prompt_richness)] 
                
                if len(perf_t_m) == 0:
                    continue

                print(f'- graphing {model_nice} - few_feature {few_feature} - prompt_richness {prompt_richness} - shape {perf_t_m.shape}')

                mean_run_precisions = []
                mean_run_tprs = []
                if i == 1 and target_new in heuristics.keys():
                    mean_run_precision_h = []
                    mean_run_recall_h = []
                    mean_run_fpr_h = []
                    mean_run_tpr_h = []

                for RunID, perf_t_m_id in perf_t_m.groupby('RunID'):
                    precisions = []
                    tprs = []
                    precisions_h = []
                    recalls_h = []
                    fprs_h = []
                    tprs_h = []

                    for fold, perf_t_m_id_f in perf_t_m_id.groupby('fold'):

                        #True labels
                        df = pd.DataFrame()
                        if fold=="ALL":
                            for f in ['A', 'B', 'C', 'D', 'E']:
                                df = pd.concat([df, pd.read_csv('../data/Dataset' + str(RunID) + '/' +
                                        target_new + '/' + target_new + f +
                                        'Test.csv')])
                        else:
                            df = pd.read_csv('../data/Dataset' + str(RunID) + '/' +
                                        target_new + '/' + target_new + fold +
                                        'Test.csv')
                            
                        y = df[targets_rev[target_new]].values

                        for random_state, perf_t_m_id_f_r in perf_t_m_id_f.groupby('random_state'):
                            print(" -", RunID, fold, random_state)

                            # handle the heuristic
                            if i == 1 and target_new in heuristics.keys():
                                y_heuristic = df[heuristics[target_new]].values

                                # print('y_heuristic:', y_heuristic)

                                precision_h = np.sum((y==1) & (y_heuristic==1))/np.sum(y_heuristic==1)
                                recall_h = np.sum((y==1) & (y_heuristic==1))/np.sum(y==1)

                                # print('precision_h:',precision_h)
                                # print('recall_h:',recall_h)
                                
                                precisions_h.append(precision_h)
                                recalls_h.append(recall_h)

                                fpr_h = np.sum((y==0) & (y_heuristic==1))/np.sum(y==0) 
                                tpr_h = recall_h

                                # print('fpr_h:',fpr_h)
                                # print('tpr_h:',tpr_h)
                            
                                fprs_h.append(fpr_h)
                                tprs_h.append(tpr_h)

                                heuristic=False

                            # print('y:',len(y))

                            #Predicted probabilities
                            assert(len(perf_t_m_id_f_r)==1)

                            # print(perf_t_m_id_f_r['pred_proba'].values[0])
                            # print(type(perf_t_m_id_f_r['pred_proba'].values[0]))
                            # y_pred_proba = eval(perf_t_m_id_f_r['pred_proba'].values[0])
                            y_pred_proba = np.array(perf_t_m_id_f_r['pred_proba'].values[0])
                        
                            mask = np.array([v is None for v in y_pred_proba])
                            if np.sum(mask)> 0:
                                print(f"    There are {np.sum(mask)} non-numeric values (out of {len(mask)}). Set to zero")
                                y_pred_proba[mask] = 0.0

                            # Precision / Recall
                            ####################
                            mean_recall, precision = prc_values(y,y_pred_proba)
                            precisions.append(precision)
                        
                            # ROC
                            #####
                            mean_fpr, tpr = roc_values(y,y_pred_proba)
                            tprs.append(tpr)
                        
                        #mean_run_precisions: The mean of five fold precisions
                        mean_run_precisions.append(np.mean(precisions, axis=0))
                        #mean_run_tprs: The mean of five fold tprs
                        mean_run_tprs.append(np.mean(tprs, axis=0))
                        if i == 1 and target_new in heuristics.keys():
                            mean_run_precision_h.append(np.mean(precisions_h, axis=0))
                            mean_run_recall_h.append(np.mean(recalls_h, axis=0))
                            mean_run_fpr_h.append(np.mean(fprs_h, axis=0))
                            mean_run_tpr_h.append(np.mean(tprs_h, axis=0))

                #mean_precisions: The mean of mean_run_precisions over 50 iterations
                mean_precisions = np.mean(mean_run_precisions, axis=0)

                #mean_tprs: The mean of mean_run_tprs over 50 iterations
                mean_tprs = np.mean(mean_run_tprs, axis=0)
                
        #         plt.figure(target_new, figsize=(10, 6))
                # Precision/Recall plot 
                axs[0].plot(mean_recall, mean_precisions, 
                        alpha=1,
                        c=palette[m],
                        ls='--' if few_feature else '-',
                        label= model_nice + ('\nRich prompt,' if prompt_richness else "\nSimple prompt,") + \
                               ('\few features' if not few_feature else "all features"),
                        marker = marker_choice[model_nice], 
                        markevery=spacing)
            
                # Confidence intervals
                print(model_nice, mean_run_precisions, len(mean_run_precisions))
                std_err = sem(mean_run_precisions, axis=0)
                print(std_err)
                h = std_err * t.ppf(1.95/2, len(mean_run_precisions) - 1)
                precisions_upper = np.minimum(mean_precisions + h, 1)
                precisions_lower = np.maximum(mean_precisions - h, 0)
                axs[0].fill_between(mean_recall, precisions_lower, precisions_upper, 
                                    color=palette[m], alpha=.1, label=r'95% Confidence Interval')
                print(h)

                # ROC plot
                axs[1].plot(mean_fpr, mean_tprs, 
                        alpha=1,
                        c=palette[m],
                        ls='--' if few_feature else '-',
                        label= model_nice + ('\nRich prompt,' if prompt_richness else "\nSimple prompt,") + \
                               ('few features' if not few_feature else "all features"),
                        marker = marker_choice[model_nice], 
                        markevery=spacing)
                axs[1].plot([0,1],[0,1],':k',label=None)

                print(model_nice, mean_run_tprs)
                std_err = sem(mean_run_tprs, axis=0)
                h = std_err * t.ppf(1.95/2, len(mean_run_tprs) - 1)
                tprs_upper = np.minimum(mean_tprs + h, 1)
                tprs_lower = np.maximum(mean_tprs - h, 0)
                axs[1].fill_between(mean_fpr, tprs_lower, tprs_upper, 
                                    color=palette[m], alpha=.1)
                i+=1
                    
    # heuristic performance
#     print('mean_run_precision_h:',mean_run_precision_h)
#     print('mean_run_recall_h:',mean_run_recall_h)
#     print('mean_run_fpr_h:',mean_run_fpr_h)
#     print('mean_run_tpr_h:',mean_run_tpr_h)
    mean_recall_h = np.mean(mean_run_recall_h, axis=0)
    mean_precision_h = np.mean(mean_run_precision_h, axis=0)
    mean_fpr_h = np.mean(mean_run_fpr_h, axis=0)
    mean_tpr_h = np.mean(mean_run_tpr_h, axis=0)

    print(mean_recall_h, mean_precision_h, mean_fpr_h, mean_tpr_h)

    # plot heuristics
    axs[0].plot(mean_recall_h, mean_precision_h, 'Xk', label='Heuristic') 
    h, = axs[1].plot(mean_fpr_h, mean_tpr_h, 'Xk', label='Heuristic') 
    
    plt.suptitle(dnames_to_nice[target_new], fontsize=fontsize)

    axs[0].set_xlabel("Recall (Sensitivity)", fontsize=fontsize)
    axs[0].set_ylabel("Precision", fontsize=fontsize)
    axs[0].grid()
    axs[1].set_xlabel("1 - Specificity", fontsize=fontsize)
    axs[1].set_ylabel("Sensitivity", fontsize=fontsize)
    axs[1].grid()

    # plt.legend(loc='best')

    dummy_all, = plt.plot(0, 0, color='black', linestyle='-')
    dummy_few, = plt.plot(0, 0, color='black', linestyle='--')
    dummy_models = [model_nice for model_nice in [o for o in order if o in data.model.unique()] ]
    dummy_model_colors = [ plt.plot(0,0, alpha=1, c=palette[m], marker = marker_choice[mn])[0]
                           for m, mn in enumerate(dummy_models)]
    plt.legend([h, dummy_all, dummy_few]+dummy_model_colors,['Heuristic', 'All features', 'Few features']+dummy_models)

    plt.tight_layout()
    sns.despine()
    
    # for filetype in ['.svg','.png','.pdf']:
    #     plt.savefig('../paper/' + target_new + '_PRC_ROC'+ filetype, dpi=400)

    plt.savefig(f"../paper/PRC_ROC_{target_new}.pdf")
    plt.savefig(f"../paper/PRC_ROC_{target_new}.png")
    plt.show() 