In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import pickle as pkl
# import notebook_helper
import imodelsx.process_results
from sklearn.tree import plot_tree
import sys
import numpy as np
import llm_tree.data
import dvu
import viz
import scipy.stats
import warnings
dvu.set_style()
plt.rcParams['font.size'] = '14'
sys.path.append('../experiments/')
results_dir = '../results/feb11/'

# load results as dataframe
r = imodelsx.process_results.get_results_df(results_dir, use_cached=True)

# fill missing args with default values from argparse
experiment_filename = '../experiments/01_train_model.py'
r = imodelsx.process_results.fill_missing_args_with_default(r, experiment_filename)
r = r[~(r.model_name == 'hstree')]
r = r[~(r.model_name == 'linear_finetune')]

In [None]:
# average over random seeds
ravg = imodelsx.process_results.average_over_seeds(r, experiment_filename)

default_params = {
    'max_features': 1,
    'ngrams': 2,
    'refinement_strategy': 'llm',
    'use_llm_prompt_context': 0,
    'use_stemming': 0,
    # 'n_estimators': 1,
    # 'subsample_frac': 1,
}

### Single-tree curves

In [None]:
rcurve = ravg[ravg.n_estimators == 1] # exclude ensembles
rcurve = rcurve[rcurve.subsample_frac == 1] # exclude subsampling


groupings = ['model_name', 'max_features', 'ngrams',
             'refinement_strategy', 'use_llm_prompt_context', 'use_stemming']
metric = 'roc_auc_test'
# metric = 'accuracy_test'

# viz.plot_train_and_test(rcurve, groupings, metric)

In [None]:
rp = rcurve
for k in default_params:
    rp = rp[rp[k] == default_params[k]]
viz.plot_curves(rp, fname_save='../results/figs/perf_curves_individual.pdf', metric=metric, figsize=(6, 4.75), legend=False,
                dset_names = ['financial_phrasebank', 'rotten_tomatoes', 'sst2', 'emotion'])
plt.show()

### Ensemble curves

In [None]:
rens = ravg
rens = rens[rens.max_depth  == 8]
rens = rens[rens.subsample_frac  == 1]
for k in default_params:
    rens = rens[rens[k] == default_params[k]]

# groupings to plo
groupings = 'model_name'
# groupings = ['model_name', 'max_features', 'refinement_strategy', 'use_llm_prompt_context', 'ngrams']
metric = 'roc_auc_test'
# metric = 'accuracy_test'

# viz.plot_train_and_test(rens, groupings, metric, x='n_estimators')
fig = viz.plot_curves(rens, x='n_estimators', fname_save='../results/figs/perf_curves_ensemble.pdf', metric=metric, figsize=(6, 4.75), legend=False,
                dset_names = ['financial_phrasebank', 'rotten_tomatoes', 'sst2', 'emotion'])

### Final figure

In [None]:
from viz import *
dvu.set_style()
plt.figure(figsize=(6, 4.75))
plt.rcParams['font.size'] = '14'
x = 'n_estimators'
y = 'roc_auc_test'
rp = rens
R, C = 2, 2
# dset_names = rp['dataset_name'].unique()
dset_names = ['financial_phrasebank', 'rotten_tomatoes', 'sst2', 'emotion']
for i in range(R * C):
    plt.subplot(R, C, i + 1)
    dset_name = dset_names[i]
    rd = rp[rp.dataset_name == dset_name] #.sort_values(by='model_name', ascending=False)
    groupings = 'model_name'
    rd = rd.sort_values(by=['model_name', 'n_estimators'], ascending=False)
    for (k, g) in rd.groupby(by=groupings, sort=False):
        # print(k)
        if 'llm_tree' in k:
            kwargs = {'lw': 2.5, 'alpha': 0.9, 'ls': '-', 'marker': '.', 'color': 'mediumseagreen', 'ms': 8}
        else:
            if 'decision_tree' in k:
                color = '#BBB'
            else:
                color = '#111'
            kwargs = {'alpha': 0.8, 'lw': 1.5, 'ls': '-', 'marker': '.', 'color': color, 'ms': 8}
        # if i == 0:
        kwargs['label'] = MODELS_RENAME_DICT.get(k, k)
        if metric + '_err' in g.columns:
            plt.errorbar(g[x], g[metric], yerr=g[metric + '_err'], **kwargs)
        else:
            plt.plot(g[x], g[metric], **kwargs)
    plt.title(DSETS_RENAME_DICT.get(dset_name, dset_name), fontsize='large')
    if i % 2 == 0:
        plt.ylabel(f'ROC AUC', fontsize='medium')
    if i >= 2:
        plt.xlabel('# estimators', fontsize='large')
    # plt.xscale('log')
plt.tight_layout()
plt.legend(fontsize='medium', bbox_to_anchor=(0.31, -0.45))

plt.savefig('acc_ens.pdf', bbox_inches='tight')
# plt.show()

### Subsampling curves

In [None]:
# Ensemble curves (note, only ran 1 seed for this)
rens = ravg
rens = rens[rens.max_depth  == 8]
rens = rens[rens.n_estimators == 1]
for k in default_params:
    rens = rens[rens[k] == default_params[k]]

# groupings to plo
groupings = 'model_name'
# groupings = ['model_name', 'max_features', 'refinement_strategy', 'use_llm_prompt_context', 'ngrams']
metric = 'roc_auc'
# metric = 'accuracy'

# viz.plot_train_and_test(rens, groupings, metric, x='n_estimators')
viz.plot_curves(rens, x='subsample_frac', fname_save='../results/figs/perf_curves_subsampling.pdf')

### Ablations table (with cross-validation)

In [None]:
# apply cross validation (selects best max_depth)
d = ravg
d.n_estimators == 1
d.subsample_frac == 1
groupings = ['model_name', 'max_features', 'ngrams',
             'refinement_strategy', 'use_llm_prompt_context', 'use_stemming'] 
# ravg_cv = (
#     d
#     .sort_values(by='accuracy_cv', ascending=False)
#     .groupby(by=groupings + ['dataset_name'])
#     .first()  # selects best max_depth
#     .reset_index()
# )
ravg_cv = ravg[ravg.max_depth == 12]

# plt.figure(figsize=(8, 3))
# sns.barplot(x='roc_auc_test', y=ravg_cv['dataset_name'].map(viz.DSETS_RENAME_DICT), hue='model_name', data=ravg_cv)
# plt.xlim(left=0.5)
# plt.ylabel('Dataset')
# plt.show()

In [None]:
# display all columns
with pd.option_context('display.max_columns', None):
    display(ravg_cv.head())

In [None]:
metric = 'roc_auc_test'
# metric = 'accuracy_test'
def round_3(x):
    return x.apply(lambda x: f'{x:.3f}')
ravg_cv['met_with_err'] = round_3(ravg_cv[metric]) + ' \\err{' + round_3(ravg_cv[metric+'_err']) + '}'
# print(ravg_cv['met_with_err'])
ablations = (
    ravg_cv
    .pivot_table(index=groupings, columns='dataset_name', values='met_with_err',
                 aggfunc=lambda x: ' '.join(x)) # needed to allow for string values
    .reset_index()
    .rename_axis(None, axis=1)
)
# display(ablations)

def rename_ablations(row):
    tup = tuple(
        row[groupings]
              .values.tolist())
    return {
            ('decision_tree', 1, 2,  'llm', 0, 0): 'CART',
            (          'id3', 1, 2,  'llm', 0, 0): 'ID3',
            (     'llm_tree', 1, 2, 'embs', 0, 0): 'Aug-Tree (Embeddings)',
            (     'llm_tree', 1, 2,  'llm', 0, 0): 'Aug-Tree',
            (     'llm_tree', 1, 2,  'llm', 0, 1): 'Aug-Tree (Stemming)',
            (     'llm_tree', 1, 2,  'llm', 1, 0): 'Aug-Tree (Contextual prompt)',
            (     'llm_tree', 1, 3,  'llm', 0, 0): 'Aug-Tree (Trigrams)',
            (     'llm_tree', 5, 2,  'llm', 0, 0): 'Aug-Tree (5 CART features)',
    }[tup]
    
ablations.index = ablations.apply(
    lambda x: rename_ablations(x), axis=1
)
ablations = ablations.drop(columns=groupings)
ablations = ablations.reindex(['Aug-Tree', 'Aug-Tree (Embeddings)', 'Aug-Tree (Contextual prompt)', 'Aug-Tree (5 CART features)', 'Aug-Tree (Stemming)', 'Aug-Tree (Trigrams)', 'CART', 'ID3'])
ablations = ablations.rename(index={'Aug-Tree': '\\textbf{Aug-Tree}'})
ablations.iloc[0] = ablations.iloc[0].apply(lambda x: '\\textbf{' + str(x) + '}')

print(
    ablations
    .rename(columns=viz.DSETS_RENAME_DICT)
    .style
    .format(precision=3)
    .to_latex(hrules=True, ).replace('_', ' ')
    .replace('\nCART', '\n\\midrule\nCART')
)

# Appendix accuracy table

In [39]:
# apply cross validation (selects best max_depth)
d_default = ravg
for k in default_params:
    d_default = d_default[d_default[k] == default_params[k]]
d_default = d_default[d_default.subsample_frac == 1]
d_default = d_default[d_default.model_name == 'llm_tree']

In [44]:
d_single = d_default[d_default.max_depth == 12]
metric = 'accuracy_test'
d_single['met_with_err'] = round_3(d_single[metric]) + ' \\err{' + round_3(d_single[metric+'_err']) + '}'
d_single[['dataset_name', 'met_with_err']].set_index('dataset_name').T

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  d_single['met_with_err'] = round_3(d_single[metric]) + ' \\err{' + round_3(d_single[metric+'_err']) + '}'


dataset_name,emotion,financial_phrasebank,rotten_tomatoes,sst2
met_with_err,0.637 \err{0.045},0.818 \err{0.014},0.613 \err{0.009},0.571 \err{0.018}


In [43]:
d_ens_acc = d_default[d_default.n_estimators == 40]
metric = 'accuracy_test'
d_ens_acc['met_with_err'] = round_3(d_ens_acc[metric]) + ' \\err{' + round_3(d_ens_acc[metric+'_err']) + '}'
d_ens_acc[['dataset_name', 'met_with_err']].set_index('dataset_name').T

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  d_ens_acc['met_with_err'] = round_3(d_ens_acc[metric]) + ' \\err{' + round_3(d_ens_acc[metric+'_err']) + '}'


dataset_name,emotion,financial_phrasebank,rotten_tomatoes,sst2
met_with_err,0.800 \err{0.008},0.848 \err{0.006},0.619 \err{0.004},0.614 \err{0.016}
