# Model Selection over the grid search
This notebook quickly inspects our training models and sequentially ranks models by an a priori chosen set of 4 metrics. This rank-sorted list of models helps us choose a top model for test-set eval. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
encoder_name = "minilm"

In [3]:
encoder = 'minilm'


grid_dir = "/dfs/scratch1/gmachi/k2/K2/src/outputs/wikisection/" + encoder + "_gridsearch/"

results_cache_dir = grid_dir + encoder + "-eval_results"
model_cache_dir = grid_dir + encoder + "-fitted_k2_models"
processor_cache_dir = grid_dir + encoder + "-fitted_k2_processors"
linearized_cache_dir = grid_dir + encoder + "-linearized_data"

In [4]:
key_conf_metrics = ["precision", "correlation", "dice"]
key_cont_metrics = ["auprc"]

In [5]:
from model_selection import top_model_confusion, top_model_continuous_avg
import pandas as pd

conf_res = []
for metric in key_conf_metrics:
    print(metric)
    res = top_model_confusion(metric,results_cache_dir, model_cache_dir, eval_class=1, return_all=True)
    res["metric"] = [metric]*len(res)
    conf_res.append(res)
conf_res = pd.concat(conf_res)

cont_res = []
for metric in key_cont_metrics:
    print(metric)
    res = top_model_continuous_avg(metric, results_cache_dir, model_cache_dir, return_all=True)
    res["metric"] = [metric]*len(res)
    cont_res.append(res)
cont_res = pd.concat(cont_res)

Matplotlib created a temporary cache directory at /tmp/matplotlib-61h_4780 because the default path (/afs/cs.stanford.edu/u/gmachi/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


precision
correlation
dice
auprc


In [8]:
conf_pvt = conf_res.pivot(index=['model_name', 'threshold'], columns='metric', values='score')
for met in key_conf_metrics:
    rank = conf_pvt[met].rank(method='min', ascending=False)
    conf_pvt[f'rank_{met}'] = rank
# conf_pvt['rank'] = conf_pvt[key_conf_metrics].apply(tuple,axis=1).rank(method='dense',ascending=False)

cont_pvt = cont_res.pivot(index='model_name', columns='metric', values='score')
cont_pvt['rank_auprc'] = cont_pvt['auprc'].rank(method='min', ascending=False)

merged = pd.merge(conf_pvt.reset_index(level=['threshold']), cont_pvt, on='model_name', how='left')
merged['avg_rank'] = merged[[f'rank_{i}' for i in key_conf_metrics] + ['rank_auprc']].mean(axis=1)
merged.sort_values('avg_rank').head(20)

metric,threshold,correlation,dice,precision,rank_precision,rank_correlation,rank_dice,auprc,rank_auprc,avg_rank
model_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
k25_r1_alpha0.050_tau1.00_lamnan.model,0.7,0.520579,0.590902,0.671789,510.0,65.0,97.0,0.691907,61.0,183.25
k25_r1_alpha10000000000.000_tau1.00_lamnan.model,0.7,0.520579,0.590902,0.671789,510.0,65.0,97.0,0.691907,61.0,183.25
k25_r1_alpha0.025_tau1.00_lamnan.model,0.7,0.520579,0.590902,0.671789,510.0,65.0,97.0,0.691907,61.0,183.25
k25_r1_alpha0.010_tau1.00_lamnan.model,0.7,0.520579,0.590902,0.671789,510.0,65.0,97.0,0.691907,61.0,183.25
k20_r1_alpha0.050_tau1.00_lamnan.model,0.6,0.538465,0.606729,0.644304,687.0,5.0,17.0,0.699523,29.0,184.5
k20_r1_alpha0.010_tau1.00_lamnan.model,0.6,0.538465,0.606729,0.644304,687.0,5.0,17.0,0.699523,29.0,184.5
k20_r1_alpha0.025_tau1.00_lamnan.model,0.6,0.538465,0.606729,0.644304,687.0,5.0,17.0,0.699523,29.0,184.5
k20_r1_alpha10000000000.000_tau1.00_lamnan.model,0.6,0.538465,0.606729,0.644304,687.0,5.0,17.0,0.699523,29.0,184.5
k25_r1_alpha10000000000.000_tau0.00_lamnan.model,0.7,0.524838,0.596116,0.663277,561.0,47.0,66.0,0.689489,73.0,186.75
k25_r1_alpha0.050_tau0.00_lamnan.model,0.7,0.524838,0.596116,0.663277,561.0,47.0,66.0,0.689489,73.0,186.75


In [10]:
import os
save_path = "/dfs/scratch1/gmachi/k2/K2/src/outputs/wikisection/gridsearch-all"
save_file = os.path.join(save_path, "complete_"+encoder+".csv")
cont_pvt.to_csv(save_file)