## Generate the baseline accuracy figures

### Figure 6:

In [None]:
import pandas as pd
dfs = []
target_imp = ['MICE']
for ds_name in ['BREAST_CANCER', 'MIMIC', 'PHARYNGITIS', 'FICO']:#, 'CKD', 'HEART_DISEASE']:
    for imputation_method in target_imp:
        for s_iter in range(120):
            try:
                #df = pd.read_csv(f'./parallelized_results/baselines_2024-02_01_iter_{s_iter}_10_imp_all.csv')
                df = pd.read_csv(f'./parallelized_results/baselines_iter_{s_iter}_{ds_name}_{imputation_method}.csv')
                dfs.append(df)
            except:
                print(f"WARNING: Skipping iteration {s_iter}")
                continue
combined_acc_df = pd.concat(dfs, axis=0)

target_metric = 'acc'

In [None]:
mask = (combined_acc_df['num_imputations'] == 10) & \
        (combined_acc_df['missingness_handling'].isin(target_imp))
cur_acc_df = combined_acc_df[mask]

In [None]:
def get_smim_tag(row):
    if row['use_smim']:
        return row['model_type'] + ' (SMIM)'
    else:
        return row['model_type'] + ' (No SMIM)'
cur_acc_df['model_type'] = cur_acc_df.apply(get_smim_tag, axis=1)

In [None]:
# Because we have a distinct entry for each val set for our gams,
# we need to take the average of each value along val sets
cur_acc_df = cur_acc_df.groupby(['model_type', 'dataset', 'holdout_set', 'metric', 'missingness_handling']).mean().reset_index()
cur_acc_df['model_type'].value_counts()

In [None]:
# Now lets filter down to grab just AUC for BRECA, ACC for FICO
#cur_acc_df = pd.concat([
    #cur_acc_df[(cur_acc_df['dataset'] == 'FICO') & (cur_acc_df['metric'] == 'acc')],
    #cur_acc_df[(cur_acc_df['dataset'] == 'BREAST_CANCER') & (cur_acc_df['metric'] == 'auc')]
#], axis=0)
cur_acc_df = cur_acc_df[cur_acc_df['metric'] == target_metric]


cur_acc_df = cur_acc_df[cur_acc_df['model_type'] != 'GAM_no_missing (SMIM)']

cur_acc_df['Model Type'] = cur_acc_df['model_type']
cur_acc_df.loc[cur_acc_df['Model Type'] == 'GAM_imputation (SMIM)', 'Model Type'] = 'M-GAM (Imputation)'
cur_acc_df.loc[cur_acc_df['Model Type'] == 'GAM_ind (SMIM)', 'Model Type'] = 'M-GAM (Indicators Only)'
cur_acc_df.loc[cur_acc_df['Model Type'] == 'GAM_aug (SMIM)', 'Model Type'] = 'M-GAM (w/ Interactions)'

cur_acc_df.loc[cur_acc_df['dataset'] == 'BREAST_CANCER', 'dataset'] = 'Breast Cancer'
cur_acc_df.loc[cur_acc_df['dataset'] == 'HEART_DISEASE', 'dataset'] = 'Heart Disease'
cur_acc_df.loc[cur_acc_df['dataset'] == 'PHARYNGITIS', 'dataset'] = 'Pharyngitis'
cur_acc_df.loc[cur_acc_df['dataset'] == 'ADULT', 'dataset'] = 'Adult'

cur_acc_df = cur_acc_df.sort_values('Model Type')

In [None]:
"""import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure

for d_ind, dataset in enumerate(cur_acc_df['dataset'].unique()):
    mask = cur_acc_df['dataset'] == dataset
    tmp_cur_acc_df = cur_acc_df[mask]
    sns.set(font_scale=2.5)
    figure(figsize=(6, 8), dpi=80)
    ax = sns.boxplot(
        tmp_cur_acc_df, hue='Model Type', y='metric_value_test'
    )
    sns.move_legend(ax, "upper left", bbox_to_anchor=(0.0, 2.0), ncol=2)
    plt.xlabel('')
    if dataset == "MIMIC":
        plt.ylim((0.86, 0.94))
    if dataset == "PHARYNGITIS":
        plt.title("Pharyngitis")
    else:
        plt.title(dataset)
    plt.ylabel(f'Test {target_metric.upper()}')
    plt.xticks(rotation=0, ha='center')
    plt.show()"""

import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
cur_acc_df = cur_acc_df.sort_values('dataset')
cur_acc_df['Test Accuracy'] = cur_acc_df['metric_value_test']
mgam_methods = ['M-GAM (Indicators Only)', 'M-GAM (w/ Interactions)']
ordered_list = [m for m in cur_acc_df['Model Type'].unique() if m not in mgam_methods]
ordered_list.sort()
ordered_list = mgam_methods + ordered_list
print(ordered_list)
cur_acc_df = pd.concat(
    [cur_acc_df[cur_acc_df['Model Type'] == m] for m in ordered_list], axis=0
)

g = sns.FacetGrid(
    cur_acc_df,
    #row='dataset',
    #col='missingness_handling',
    col='dataset',
    #col_wrap=4, 
    # aspect=0.8,
    height=7, 
    #sharey="row",
    sharey=False
    #margin_titles=True
)
#figure(figsize=(6, 8), dpi=80)
g.map_dataframe(
    sns.boxplot,
    hue='Model Type',
    palette=sns.color_palette('Spectral', n_colors=14),
    y='Test Accuracy'
)
g.set_titles(
    row_template='{row_name}',
    col_template='{col_name}',
)
'''ax = sns.boxplot(
    tmp_cur_acc_df, hue='Model Type', y='metric_value_test'
)'''
g.add_legend()
sns.move_legend(g, "upper left", bbox_to_anchor=(0.08, 1.4), ncol=3)

for text in g.legend.texts:
    if "GAM" in text.get_text():
        text.set_fontweight("bold")
        text.set_color("red")

#plt.xticks(rotation=0, ha='center')
sns.set(font_scale=2.3)
for ind, ax in enumerate(g.axes.flat):
    if ind == 2:
        ax.set_ylim(0.86, 0.94)

plt.show()

### Generate Appendix Fig 8

In [None]:
import pandas as pd
dfs = []
target_imp = ['MICE', 'Mean', 'MIWAE', 'MissForest']
for ds_name in ['BREAST_CANCER', 'MIMIC', 'PHARYNGITIS', 'FICO', 'CKD', 'HEART_DISEASE']:
    for imputation_method in target_imp:
        for s_iter in range(120):
            try:
                #df = pd.read_csv(f'./parallelized_results/baselines_2024-02_01_iter_{s_iter}_10_imp_all.csv')
                df = pd.read_csv(f'./parallelized_results/baselines_iter_{s_iter}_{ds_name}_{imputation_method}.csv')
                dfs.append(df)
            except:
                print(f"WARNING: Skipping iteration {s_iter}")
                continue
combined_acc_df = pd.concat(dfs, axis=0)

target_metric = 'acc'

In [None]:
mask = (combined_acc_df['num_imputations'] == 10) & \
        (combined_acc_df['missingness_handling'].isin(target_imp))
cur_acc_df = combined_acc_df[mask]

def get_smim_tag(row):
    if row['use_smim']:
        return row['model_type'] + ' (SMIM)'
    else:
        return row['model_type'] + ' (No SMIM)'
cur_acc_df['model_type'] = cur_acc_df.apply(get_smim_tag, axis=1)

# Because we have a distinct entry for each val set for our gams,
# we need to take the average of each value along val sets
cur_acc_df = cur_acc_df.groupby(['model_type', 'dataset', 'holdout_set', 'metric', 'missingness_handling']).mean().reset_index()
cur_acc_df['model_type'].value_counts()

In [None]:
# Now lets filter down to grab just AUC for BRECA, ACC for FICO
#cur_acc_df = pd.concat([
    #cur_acc_df[(cur_acc_df['dataset'] == 'FICO') & (cur_acc_df['metric'] == 'acc')],
    #cur_acc_df[(cur_acc_df['dataset'] == 'BREAST_CANCER') & (cur_acc_df['metric'] == 'auc')]
#], axis=0)
cur_acc_df = cur_acc_df[cur_acc_df['metric'] == target_metric]


cur_acc_df = cur_acc_df[cur_acc_df['model_type'] != 'GAM_no_missing (SMIM)']

cur_acc_df['Model Type'] = cur_acc_df['model_type']
cur_acc_df.loc[cur_acc_df['Model Type'] == 'GAM_imputation (SMIM)', 'Model Type'] = 'M-GAM (Imputation)'
cur_acc_df.loc[cur_acc_df['Model Type'] == 'GAM_ind (SMIM)', 'Model Type'] = 'M-GAM (Indicators Only)'
cur_acc_df.loc[cur_acc_df['Model Type'] == 'GAM_aug (SMIM)', 'Model Type'] = 'M-GAM (w/ Interactions)'

cur_acc_df.loc[cur_acc_df['dataset'] == 'BREAST_CANCER', 'dataset'] = 'Breast Cancer'
cur_acc_df.loc[cur_acc_df['dataset'] == 'HEART_DISEASE', 'dataset'] = 'Heart Disease'
cur_acc_df.loc[cur_acc_df['dataset'] == 'PHARYNGITIS', 'dataset'] = 'Pharyngitis'
cur_acc_df.loc[cur_acc_df['dataset'] == 'ADULT', 'dataset'] = 'Adult'

cur_acc_df = cur_acc_df.sort_values('Model Type')

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure

cur_acc_df['Test Accuracy'] = cur_acc_df['metric_value_test']
mgam_methods = ['M-GAM (Indicators Only)', 'M-GAM (w/ Interactions)']
ordered_list = [m for m in cur_acc_df['Model Type'].unique() if m not in mgam_methods]
ordered_list.sort()
ordered_list = mgam_methods + ordered_list
print(ordered_list)
cur_acc_df = pd.concat(
    [cur_acc_df[cur_acc_df['Model Type'] == m] for m in ordered_list], axis=0
)

g = sns.FacetGrid(
    cur_acc_df,
    row='dataset',
    row_order=['Breast Cancer', 'CKD', 'FICO', 'Heart Disease', 'Pharyngitis', 'MIMIC'],
    col='missingness_handling',
    #col_wrap=4, 
    height=7, 
    sharey="row",
    #sharey=False,
    margin_titles=True
)
#figure(figsize=(6, 8), dpi=80)
g.map_dataframe(
    sns.boxplot,
    hue='Model Type',
    palette=sns.color_palette('Spectral', n_colors=14),
    y='Test Accuracy'
)
g.set_titles(
    row_template='{row_name}',
    col_template='{col_name}',
)
'''ax = sns.boxplot(
    tmp_cur_acc_df, hue='Model Type', y='metric_value_test'
)'''

g.add_legend()
sns.move_legend(g, "upper left", bbox_to_anchor=(0.1, 1.07), ncol=3)

for text in g.legend.texts:
    if "GAM" in text.get_text():
        text.set_fontweight("bold")
        text.set_color("red")

#plt.xticks(rotation=0, ha='center')
for ind, ax in enumerate(g.axes.flat):
    if ind == 25:
        ax.set_ylim(0.86, 0.94)

sns.set(font_scale=2.3)
plt.show()

## Now, we grab the runtime comparison figures

In [None]:
# Load in model fitting data
import pandas as pd
dataset_of_interest = ['BREAST_CANCER','FICO','MIMIC', 'PHARYNGITIS', 'CKD', 'HEART_DISEASE']
datasets_with_subsamples = []
for d in dataset_of_interest:
    datasets_with_subsamples = datasets_with_subsamples + [d, f'{d}_0.25', f'{d}_0.5', f'{d}_0.75']
main_body_datasets = ['BREAST_CANCER','FICO','MIMIC', 'PHARYNGITIS']
imputations_of_interest = ['Mean', 'MICE', 'MIWAE', 'MissForest']
target_metric = 'acc'

dfs = []
for cur_ds in dataset_of_interest:
    for ds_name in [cur_ds, f'{cur_ds}_0.25', f'{cur_ds}_0.5', f'{cur_ds}_0.75']:
        for imputation_method in imputations_of_interest:
            for s_iter in range(120):
                try:
                    #df = pd.read_csv(f'./parallelized_results/baselines_2024-02_01_iter_{s_iter}_10_imp_all.csv')
                    df = pd.read_csv(f'./parallelized_results/baselines_iter_{s_iter}_{ds_name}_{imputation_method}.csv')
                    dfs.append(df)
                except:
                    continue
combined_acc_df = pd.concat(dfs, axis=0)

df_list = []
for ds_name in dataset_of_interest:
    for imputation_method in imputations_of_interest:
        for subsample in ['', '_0.25', '_0.5', '_0.75']:
            try:
                df_list.append(pd.read_csv(f'../../handling_missing_data/timing_stats_{ds_name}{subsample}_{imputation_method}_5_3.csv'))
            except:
                continue
base_timing_df = pd.concat(df_list, axis=0)

base_timing_df = base_timing_df[(base_timing_df['m'] < 10) & (base_timing_df['imputation'].isin(imputations_of_interest))]

mask = (combined_acc_df['num_imputations'] == 10) & (combined_acc_df['metric'] == target_metric)
acc_df = combined_acc_df[mask]

In [None]:
# NOTE: This aggregation should be mean for mean imputations
base_timing_df = base_timing_df.groupby(['holdout_set', 'dataset', 'validation_set', 'imputation']).sum().reset_index()
#base_timing_df = base_timing_df.groupby(['holdout_set', 'dataset']).mean().reset_index()
#base_timing_df
#base_timing_df['holdout_set'].value_counts()

In [None]:
acc_df = acc_df.groupby(['dataset', 'holdout_set', 'model_type']).mean().reset_index()#['holdout_set'].value_counts()

merged_df = base_timing_df.merge(acc_df, how='inner', on=['dataset','holdout_set'])

# Only MICE actually records a time_overall -- the rest of the imputations
# split the time across a few phases. Aggregate those phases for plotting
mask = merged_df['time_overall'] == 0
merged_df.loc[mask, 'time_overall'] = merged_df.loc[mask, "time_to_fit"]\
    + merged_df.loc[mask, "time_for_train"]\
    + merged_df.loc[mask, "time_for_val"]\
    + merged_df.loc[mask, "time_for_test"]

# And remove the imputation time from M-GAM rows
merged_df['impute_time'] = merged_df['time_overall']

mask = (merged_df['model_type'] != 'GAM_ind') & (merged_df['model_type'] != 'GAM_aug')
merged_df['overall_time'] = 0
merged_df.loc[mask, 'overall_time'] = merged_df.loc[mask, 'impute_time'] + merged_df.loc[mask, 'mean_fit_time']
merged_df.loc[~mask, 'overall_time'] = merged_df.loc[~mask,'mean_fit_time']
merged_df.loc[~mask, 'impute_time'] = 0

In [None]:
merged_df = merged_df[merged_df['model_type'] != 'GAM_no_missing']
mask = merged_df['model_type'] == 'GAM_aug'
merged_df.loc[mask, 'model_type'] = ' GAM (Interactions)'
mask = merged_df['model_type'] == 'GAM_ind'
merged_df.loc[mask, 'model_type'] = ' GAM (Indicators)'
mask = merged_df['model_type'] == 'GAM_imputation'
merged_df.loc[mask, 'model_type'] = ' GAM (Imputation)'


In [None]:

merged_df_gam_ind = merged_df[
    (merged_df['model_type'] == ' GAM (Indicators)')
]

merged_df_gam_int = merged_df[
    (merged_df['model_type'] == ' GAM (Interactions)')
]

merged_df_no_gam = merged_df[
    ~((merged_df['model_type'] == ' GAM (Imputation)') |
    (merged_df['model_type'] == ' GAM (Indicators)') |
    (merged_df['model_type'] == ' GAM (Interactions)'))
]
merged_df_gam_ind.loc[:, 'imputation'] = 'M-GAM (Ind)'
merged_df_gam_int.loc[:, 'imputation'] = 'M-GAM (Int)'
merged_df_mod = pd.concat((merged_df_no_gam, merged_df_gam_int, merged_df_gam_ind), axis=0)

In [None]:
imp_dfs = []
for imputation in merged_df_mod.imputation.unique():
    cur_imp_df = merged_df_mod[merged_df_mod['imputation'] == imputation]
    agg_imp_df = cur_imp_df.groupby("model_type").mean().reset_index()
    best_model_mean_acc = agg_imp_df['metric_value_test'].max()
    target_model = agg_imp_df[agg_imp_df['metric_value_test'] == best_model_mean_acc]['model_type'].values[0]

    imp_dfs.append(cur_imp_df[cur_imp_df['model_type'] == target_model])
    
only_best_model_df = pd.concat(imp_dfs, axis=0)

### Generate Figure 5

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties

df_for_cur_plot = only_best_model_df[only_best_model_df['dataset'].isin(main_body_datasets)]
def tweak_dataset_for_sorting(row):
    return row['dataset'] + 'z'
df_for_cur_plot['sortable_dataset'] = df_for_cur_plot.apply(tweak_dataset_for_sorting, axis=1)
df_for_cur_plot = df_for_cur_plot.sort_values(['sortable_dataset', 'imputation'])

method_order = ['M-GAM (Ind)', 'M-GAM (Int)', 'Mean', 'MICE', 'MissForest', 'MIWAE']

color_pal = sns.color_palette()
sns.set(font_scale=2.6)
g = sns.FacetGrid(
    df_for_cur_plot, 
    col="dataset", 
    col_wrap=4, 
    height=7, 
    sharey=True,
    aspect=1.0)
g.map_dataframe(sns.barplot,
    y='overall_time',
    x='imputation',
    order=method_order
    #label="Model Fit Time"
    #hue='model'
).set(yscale ='log')

g.set_xticklabels(method_order, rotation=45, ha='right')
    
g.set_xlabels('')
g.set_ylabels('Time (Seconds)')
for ax, title in zip(g.axes.flat, [
    t.title.get_text().split(" ")[-1].replace("_", " ") if "FICO" in t.title.get_text() or "MIMIC" in t.title.get_text() or "CKD" in t.title.get_text() \
    else t.title.get_text().split(" ")[-1].replace("_", " ").title() for t in g.axes.flat]):
    ax.set_title(title)


bold_font = FontProperties(weight='bold')
for ax in g.axes.flat:
    xticklabels = ax.get_xticklabels()
    new_labels = []
    for lbl in xticklabels:
        if "GAM" in lbl.get_text():
            lbl.set_fontproperties(bold_font)
            
plt.tight_layout()
plt.show()

### Generate Figure 9

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties

df_for_cur_plot = only_best_model_df[only_best_model_df['dataset'].isin(dataset_of_interest)]
def tweak_dataset_for_sorting(row):
    return row['dataset'] + 'z'
df_for_cur_plot['sortable_dataset'] = df_for_cur_plot.apply(tweak_dataset_for_sorting, axis=1)
df_for_cur_plot = df_for_cur_plot.sort_values(['sortable_dataset', 'imputation'])

method_order = ['M-GAM (Ind)', 'M-GAM (Int)', 'Mean', 'MICE', 'MissForest', 'MIWAE']

color_pal = sns.color_palette()
sns.set(font_scale=2.6)
g = sns.FacetGrid(
    df_for_cur_plot, 
    col="dataset", 
    col_wrap=3, 
    height=7, 
    sharey=True,
    aspect=1.0)
g.map_dataframe(sns.barplot,
    y='overall_time',
    x='imputation',
    order=method_order
    #label="Model Fit Time"
    #hue='model'
).set(yscale ='log')

g.set_xticklabels(method_order, rotation=45, ha='right')
    
g.set_xlabels('')
g.set_ylabels('Time (Seconds)')
for ax, title in zip(g.axes.flat, [
    t.title.get_text().split(" ")[-1].replace("_", " ") if "FICO" in t.title.get_text() or "MIMIC" in t.title.get_text() or "CKD" in t.title.get_text() \
    else t.title.get_text().split(" ")[-1].replace("_", " ").title() for t in g.axes.flat]):
    ax.set_title(title)


bold_font = FontProperties(weight='bold')
for ax in g.axes.flat:
    xticklabels = ax.get_xticklabels()
    new_labels = []
    for lbl in xticklabels:
        if "GAM" in lbl.get_text():
            lbl.set_fontproperties(bold_font)
            
plt.tight_layout()
plt.show()

### Generate Figure 10

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.font_manager import FontProperties

df_for_cur_plot = only_best_model_df
def tweak_dataset_for_sorting(row):
    return row['dataset'] + 'z'
df_for_cur_plot['sortable_dataset'] = df_for_cur_plot.apply(tweak_dataset_for_sorting, axis=1)
df_for_cur_plot = df_for_cur_plot.sort_values(['sortable_dataset', 'imputation'])

method_order = ['M-GAM (Ind)', 'M-GAM (Int)', 'Mean', 'MICE', 'MissForest', 'MIWAE']

color_pal = sns.color_palette()
sns.set(font_scale=2.6)
g = sns.FacetGrid(
    df_for_cur_plot, 
    col="dataset", 
    col_wrap=4, 
    height=7, 
    sharey=True,
    aspect=1.0)
g.map_dataframe(sns.barplot,
    y='overall_time',
    x='imputation',
    order=method_order
    #label="Model Fit Time"
    #hue='model'
).set(yscale ='log')

g.set_xticklabels(method_order, rotation=45, ha='right')
    
g.set_xlabels('')
g.set_ylabels('Time (Seconds)')
for ax, title in zip(g.axes.flat, [
    t.title.get_text().split(" ")[-1].replace("_", " ") if "FICO" in t.title.get_text() or "MIMIC" in t.title.get_text() or "CKD" in t.title.get_text() \
    else t.title.get_text().split(" ")[-1].replace("_", " ").title() for t in g.axes.flat]):
    ax.set_title(title)


bold_font = FontProperties(weight='bold')
for ax in g.axes.flat:
    xticklabels = ax.get_xticklabels()
    new_labels = []
    for lbl in xticklabels:
        if "GAM" in lbl.get_text():
            lbl.set_fontproperties(bold_font)
            
plt.tight_layout()
plt.show()