In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import numpy as np
import joblib
import sys
import imodels
import pmlb
import imodelsx.process_results
from collections import defaultdict

sys.path.append('../experiments/')

# results_dir = '../results/gam_shap_nov5'
results_dir = '../results/gam_shap_no_interactions_nov6'
experiment_filename = '../experiments/05_shap_gam.py'

r = imodelsx.process_results.get_results_df(results_dir)
d = imodelsx.process_results.fill_missing_args_with_default(r, experiment_filename)
d = imodelsx.process_results.average_over_seeds(
    d, experiment_filename, key_to_average_over='seed'
)
d = d[~d.dataset_name.str.contains('_fri_')]


# checking
cols_varied = imodelsx.process_results.get_experiment_keys(
    d, experiment_filename)
print('experiment varied these params:', cols_varied)
if not 'roc_auc_test' in d:
    d['roc_auc_test'] = np.nan

print('num_datasets run in different groups',)
display(d.groupby([x for x in cols_varied if not x == 'dataset_name']).size())

# imodelsx.process_results.delete_runs_in_dataframe(r[r.use_normalize_feature_targets], actually_delete=True)
# r.to_pickle('../results/agg.pkl')
# imodelsx.process_results.delete_runs_in_dataframe(r[(r.use_multitask == 0) * (r.linear_penalty != 'ridge')], actually_delete=True)
# imodelsx.process_results.get_experiment_keys(r, experiment_filename)

In [None]:
display(d.groupby([x for x in cols_varied if not x == 'dataset_name']).size())

### Compare different settings

In [None]:
cols_varied = imodelsx.process_results.get_experiment_keys(
    d, experiment_filename)
# d = d[(d.use_internal_classifiers == 0) * (d.use_onehot_prior == 0)]
cols_varied_d_ = [x for x in cols_varied if not x == 'dataset_name']
groups = d.groupby(cols_varied_d_)

dset_names = [set(d.loc[g]['dataset_name'].values)
              for g in groups.groups.values()]
dset_names_shared = list(set.intersection(*dset_names))
print('Num datasets run in each setting:')
display(groups.size())
dc = d[d.dataset_name.isin(dset_names_shared)]
print(len(dset_names_shared), 'completed shared datasets')


if len(dset_names_shared) > 0:
    # compute stats per group
    baseline_group_idx = 0
    groups = dc.groupby(cols_varied_d_)
    group_idxs = list(groups.groups.values())
    baseline_group = dc.loc[group_idxs[baseline_group_idx].values].sort_values(
        by='dataset_name')
    stat_cols = defaultdict(list)
    for group in groups.groups.values():
        g = dc.loc[group].sort_values(by='dataset_name')
        stat_cols['win_rate'].append(
            (g['roc_auc_test'].values >
             baseline_group['roc_auc_test'].values).mean())
        stat_cols['roc_auc_test__>=15_features'] = g[g['n_features']
                                                     >= 15]['roc_auc_test'].mean()
        for k in ['roc_auc_test', 'roc_auc_train']:
            stat_cols[k].append(g[k].mean())
            stat_cols[f'{k}_median'].append(g[k].median())
    stat_cols = pd.DataFrame(stat_cols)

    # save with index
    stats = groups['roc_auc_test'].mean().reset_index()
    for col in stat_cols.columns:
        stats[col] = stat_cols[col].values
    stats.loc[baseline_group_idx, 'win_rate'] = np.nan

    # color last 2 columns by value
    display(
        stats
        .style
        .background_gradient(
            # r2_test', 'r2_test_median', 'win_rate', 'r2_test__>=15_features', 'roc_auc_train']
            cmap='viridis', subset=['roc_auc_test'],
        )
        .format(precision=3)
    )

### Train frac plot

In [None]:
plt.figure(figsize=(12, 3))
for i, met in enumerate(['r2_test', 'r2_test_median', 'win_rate']):
    plt.subplot(1, 3, i + 1)
    tab = stats.pivot_table(index=['train_frac'], columns=[
        'use_multitask'], values=met)
    plt.plot(tab, 'o-', label=[{0: 'Single-task',
             1: 'Multi-task'}[x] for x in tab.columns])
    plt.ylabel(met)
    plt.xlabel('Fraction of data used for training')
plt.legend()
plt.show()