# Model comparison (dataset)
# Drop pi questions with selection metric on train set

Reduced models were obtained by the following procedure:
* drop all questions from sum score PI
* drop next question according to selection metric on **train set**

Selection metric criteria:
* `ca`: min of mean conditional accuracy
* `ca_class`: max of min conditional accuracy
* `ca_prod`: max of product conditional accuracy
* `mse`: min of mean square error
* `mse_class`: min of max conditional mean square error
* `xent` min of cross-entropy
* `xent_class` min of max cross-entropy

http://jmlr.csail.mit.edu/papers/volume3/guyon03a/guyon03a.pdf

## Environment initialization

In [1]:
%autosave 0
%matplotlib notebook
%load_ext autoreload
%autoreload 2

import ipywidgets as widgets
import plotly.graph_objects as go
import plotly.express as px

import sys
sys.path.append("../")

import mod_evaluation
import mod_viewer
import mod_helper

Autosave disabled


## Execution params

In [2]:
results_path = 'data/results'

model_ref_id = 'linear'

n_splits = 25

metrics = mod_evaluation.sort_params_train

train_val_random = None

In [3]:
cache_pre = 'model_'+model_ref_id
cache_post = str(n_splits)

if train_val_random is not None:
    cache_post += '_r'+str(train_val_random)

## Load results

In [4]:
from copy import deepcopy

def get_model_id(model_id):
    id_split = model_id.split(' + ')
    if len(id_split)>1:
        my_model_id = int(id_split[0])
    else:
        my_model_id = 0
        
    return my_model_id
                        
available = {0: [], 1: []}

for item in mod_evaluation.list_cache(results_path):
    if '_0_' in item:
        available[0] += [item]
    if '_1_' in item:
        available[1] += [item]    

my_data = {}

for run_type in available:
    
    info, stats, stats_val = {}, {}, {}

    for item in available[run_type]:

        for metric_id in metrics:

            cache_sig = mod_evaluation.cache_sig_gen(
                metric_id, 
                cache_pre=cache_pre, 
                cache_post=cache_post
            )
    
            if cache_sig in item:
                
                run_id = item.split('_')[-1]
                
                if run_id not in info:
                    info[run_id], stats[run_id], stats_val[run_id] = {}, {}, {}
                    
                if metric_id not in info[run_id]:
                    info[run_id][metric_id], stats[run_id][metric_id], stats_val[run_id][metric_id] = {}, {}, {}
                    
                cache = mod_evaluation.from_cache(
                    item, 
                    results_path
                )
                
                for model_id in cache['info']:
                    
                    my_model_id = get_model_id(model_id)
                    
                    info[run_id][metric_id][my_model_id] = cache['info'][model_id]
                    stats[run_id][metric_id][my_model_id] = cache['stats'][model_id]
                    stats_val[run_id][metric_id][my_model_id] = cache['stats_val'][model_id]

    my_data[run_type] = [deepcopy(info), deepcopy(stats), deepcopy(stats_val)]
    
    print('Loaded', run_type)

Loaded 0
Loaded 1


## Global stats

In [5]:
empty = {0:{}, 1:{}}

df_multi, df_multi_val = deepcopy(empty), deepcopy(empty)
df_multi_ca, df_multi_val_ca = deepcopy(empty), deepcopy(empty)

df_flat, df_flat_val = deepcopy(empty), deepcopy(empty)
df_flat_ca, df_flat_val_ca = deepcopy(empty), deepcopy(empty)

for run_type in my_data:
    
    info, stats, stats_val = my_data[run_type]
    info_flat, stats_flat, stats_val_flat = {}, {}, {}
    
    for run_id in info:
        
        for metric_id in info[run_id]:
            
            if metric_id not in info_flat:
                info_flat[metric_id], stats_flat[metric_id], stats_val_flat[metric_id] = {}, {}, {}

            for model_id in info[run_id][metric_id]:
                      
                if my_model_id not in info_flat:
                    info_flat[metric_id][model_id] = info[run_id][metric_id][model_id]
                    stats_flat[metric_id][model_id] = stats[run_id][metric_id][model_id]
                    stats_val_flat[metric_id][model_id] = stats_val[run_id][metric_id][model_id]
                    
                else:
                    stats_flat[metric_id][model_id] += stats[run_id][metric_id][model_id]
                    stats_val_flat[metric_id][model_id] += stats_val[run_id][metric_id][model_id]
                    
        df_multi[run_type][run_id], df_multi_val[run_type][run_id] = mod_evaluation.get_df_questions(
            info[run_id], stats[run_id], stats_val[run_id],
            ci=False
        )

        df_multi_ca[run_type][run_id], df_multi_val_ca[run_type][run_id] = mod_evaluation.get_df_questions_ca(
            info[run_id], stats[run_id], stats_val[run_id]
        )

    df_flat[run_type], df_flat_val[run_type] = mod_evaluation.get_df_questions(
        info_flat, stats_flat, stats_val_flat,
        ci=True
    )

    df_flat_ca[run_type], df_flat_val_ca[run_type] = mod_evaluation.get_df_questions_ca(
        info_flat, stats_flat, stats_val_flat,
    )

## Holdout set variability

In [None]:
display(mod_helper.tab_plot_accuracy_multi(
    df_multi[1],
    df_multi_val[1]
))

display(mod_helper.tab_plot_accuracy_multi(
    df_multi[0],
    df_multi_val[0]
))

# Model comparison

## Mean accuracy on validation set (cross-validation)

* Mean accuracy on validation set (cross-validation) according to selection metric
* Confidence interval estimated by bootstrap method over cross-validation repetitions

(clicking on labels adds/removes traces, double-clicking selects single trace)

* `ca_class`: max of min conditional accuracy
* `mse_class`: min of max conditional mean square error

Figures: **original dataset** (top), **new dataset** (bottom)

In [6]:
display(mod_viewer.plot_accuracy_mse(df_flat[1]))

display(mod_viewer.plot_accuracy_mse(df_flat[0]))

HBox(children=(FigureWidget({
    'data': [{'fill': 'toself',
              'fillcolor': 'rgba(31, 119, 180, 0…

HBox(children=(FigureWidget({
    'data': [{'fill': 'toself',
              'fillcolor': 'rgba(31, 119, 180, 0…

## Mean conditional accuracy on validation set  (cross-validation)

* Mean conditional accuracy on validation set (cross-validation) according to selection metric
* Confidence interval estimated by bootstrap method over cross-validation repetitions

(clicking on labels adds/removes traces, double-clicking selects single trace)

Figures: **original dataset** (top), **new dataset** (bottom)

In [7]:
display(mod_viewer.tab_plot_conditional_accuracy(
    df_flat[1],
    df_flat_ca[1],
    info_flat
))

display(mod_viewer.tab_plot_conditional_accuracy(
    df_flat[0],
    df_flat_ca[0],
    info_flat
))

Tab(children=(HBox(children=(FigureWidget({
    'data': [{'line': {'color': 'rgba(31, 119, 180, 0.6)'},
      …

Tab(children=(HBox(children=(FigureWidget({
    'data': [{'line': {'color': 'rgba(31, 119, 180, 0.6)'},
      …

# Model validation

## Mean accuracy on holdout set

* Accuracy on holdout set according to selection metric
* Holdout accuracy outside confidence interval bounds may indicate (1) model overfitting or (2) data domain shift

(clicking on labels adds/removes traces, double-clicking selects single trace)

Figures: **original dataset** (top), **new dataset** (bottom)

In [8]:
display(mod_viewer.tab_plot_accuracy(
    df_flat[1],
    info_flat,
    df_questions_holdout=df_flat_val[1]
))

display(mod_viewer.tab_plot_accuracy(
    df_flat[0],
    info_flat,
    df_questions_holdout=df_flat_val[0]
))

Tab(children=(HBox(children=(FigureWidget({
    'data': [{'fill': 'toself',
              'fillcolor': 'rgba(3…

Tab(children=(HBox(children=(FigureWidget({
    'data': [{'fill': 'toself',
              'fillcolor': 'rgba(3…

## Mean conditional accuracy on validation set

* Accuracy on holdout set according to selection metric
* Holdout accuracy outside confidence interval bounds may indicate (1) model overfitting or (2) data domain shift

(clicking on labels adds/removes traces, double-clicking selects single trace)

Figures: **original dataset** (top), **new dataset** (bottom)

In [9]:
display(mod_viewer.tab_plot_conditional_accuracy(
    df_flat_val[1],
    df_flat_val_ca[1],
    info_flat,
    holdout=True
))

display(mod_viewer.tab_plot_conditional_accuracy(
    df_flat_val[0],
    df_flat_val_ca[0],
    info_flat,
    holdout=True
))

Tab(children=(HBox(children=(FigureWidget({
    'data': [{'line': {'color': 'rgba(31, 119, 180, 0.6)'},
      …

Tab(children=(HBox(children=(FigureWidget({
    'data': [{'line': {'color': 'rgba(31, 119, 180, 0.6)'},
      …