# Shift Tests for Dynamically Selecting a Quantification Method

In [None]:
import itertools
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import rgb2hex

from pyquantification.experiments import cached_experiments
from pyquantification.evaluation import (
    display_stat_table,
    color_scale,
    build_shift_classification_df,
    shift_classifier_plot,
    baseline_comparison_table,
    shift_test_runtime_table,
    shift_test_table,
)

In [None]:
gsls_df = cached_experiments(cache_key='shift_test_gsls_results')
prior_df = cached_experiments(cache_key='shift_test_prior_shift_results')
shift_test_runtime_df = cached_experiments(cache_key='shift_test_runtime_results')

In [None]:
dataset_labels = {
    'handwritten-letters-letter': 'HLL',
    'handwritten-letters-author': 'HLA',
    'arabic-digits': 'DIG',
    'insect-sex': 'ISX',
    'insect-species': 'ISP',
}
any_shift_test_labels = {
    'ks': 'KS',
    'lr': 'LR',
}
non_prior_shift_test_labels = {
    'wpa-xks': 'WPA-KS',
    'wpa-xdhd': 'WPA-HD',
    'cdt-xks': 'CDT-KS',
    'cdt-xdhd': 'CDT-HD',
    'aks': 'AKS',
}
shift_test_labels = {**any_shift_test_labels, **non_prior_shift_test_labels}
shift_conditions = [
    'No shift',
    'GSLS shift',
    'Prior shift',
]
shift_classifier_labels = {
    'pcc': 'PCC',
    'em': 'EM',
    'gsls': 'GSLS',
    **{
        f'{any_shift_test}+{non_prior_shift_test}': f'{any_shift_test_label}+{non_prior_shift_test_label}'
        for any_shift_test, any_shift_test_label in any_shift_test_labels.items()
        for non_prior_shift_test, non_prior_shift_test_label in non_prior_shift_test_labels.items()
    }
}

def print_table_latex(table_df):
    print(' & '.join(map(str, table_df.index.names)) + ' & ' + ' & '.join(map(str, table_df.columns)))
    for index, row in table_df.iterrows():
        index = index if isinstance(index, tuple) else (index,)
        if row.isna().all():
            print('\hline')
        else:
            print(' & '.join([
                str(value) for value in
                itertools.chain(index, row.to_dict().values())
            ]) + r' \\')

def get_shift_condition(row):
    if row['shift_type'] == 'gsls_shift':
        if (row['gain_weight'] + row['loss_weight']) == 0:
            return '0_no_shift'
        else:
            return '1_gsls_shift'
    elif row['shift_type'] == 'prior_shift':
        return '2_prior_shift'
    else:
        raise ValueError()

In [None]:
gsls_shifted = (gsls_df['gain_weight'] > 0) | (gsls_df['loss_weight'] > 0)
gsls_df['shift_condition'] = pd.Series('No shift', index=gsls_df.index).mask(gsls_shifted, 'GSLS shift')
prior_df['shift_condition'] = pd.Series('Prior shift', index=prior_df.index)

plot_df = pd.concat([gsls_df, prior_df]).reset_index()
_, shift_plot_df = build_shift_classification_df(
    plot_df,
    any_shift_tests=any_shift_test_labels.keys(),
    non_prior_shift_tests=non_prior_shift_test_labels.keys(),
)

## Shift Test Experiments

In [None]:
shift_test_colormap = sns.light_palette('green', as_cmap=True)
    
table_df = shift_test_table(
    shift_plot_df,
    dataset_labels=dataset_labels,
    shift_conditions=shift_conditions,
    shift_test_labels=shift_test_labels,
)
display(
    table_df.style
    .format('{:.0%}')
    .background_gradient(cmap=shift_test_colormap, vmin=0, vmax=1)
)

def format_shift_test_cell(value):
    color = shift_test_colormap(value)
    hexcolor = rgb2hex(color).strip('#').upper()
    return f'\cellcolor[HTML]{{{hexcolor}}} {value:.0%}'.replace('%', '\%')

print_table_latex(table_df.applymap(format_shift_test_cell))

## Dynamic Quantifier Selection Experiments

In [None]:
from pyquantification.evaluation import shift_classifier_plot

fig = shift_classifier_plot(
    shift_plot_df,
    shift_classifiers=shift_classifier_labels.keys(),
    shift_classifier_labels=shift_classifier_labels,
)
fig.write_image("plots/shift-classifiers.svg")
fig.show()

## Dynamic vs Static Quantifiers

In [None]:
best_shift_classifier = 'ks+aks'
base_methods = ['pcc', 'em', 'gsls']
shift_classifier_methods = {
    shift_classifier_labels[method]: method
    for method in [best_shift_classifier] + base_methods
}

BASE_GRADIENT_OPTIONS = dict(
    cmap=sns.diverging_palette(20, 220, as_cmap=True),
    vmin=-1,
    vmax=1,
)

def display_comparison_table(*, table_df, significance_mask, gradient_options=None, format='{:.0%}'):
    gradient_options = {} if gradient_options is None else gradient_options
    gradient_options = {**BASE_GRADIENT_OPTIONS, **gradient_options}
    
    background_row_mask = (table_df.index.get_level_values(1) != shift_classifier_labels[best_shift_classifier])
    
    def significance_formatter(df):
        assert df.shape == significance_mask.shape
        return pd.DataFrame('', index=df.index, columns=df.columns).mask(significance_mask, 'font-weight: bold;')
    
    display(
        table_df.style
        .format(format)
        .background_gradient(**gradient_options,
                             subset=(background_row_mask, table_df.columns))
        .apply(significance_formatter, axis=None)
    )

def print_comparison_table_latex(*, table_df, significance_mask, gradient_options=None, format='{:.0%}'):
    gradient_options = {} if gradient_options is None else gradient_options
    gradient_options = {**BASE_GRADIENT_OPTIONS, **gradient_options}
    
    def cellcolor(value):
        color = gradient_options['cmap'](
            (value - gradient_options['vmin']) / (gradient_options['vmax'] - gradient_options['vmin'])
        )
        hexcolor = rgb2hex(color).strip('#').upper()
        return f'\cellcolor[HTML]{{{hexcolor}}} '
    
    # Format percentages
    output_df = table_df.applymap(lambda v: format.format(v).replace('%', '\%'))
    # Bold significant values
    output_df = output_df.mask(significance_mask, output_df.applymap(lambda v: fr'\textbf{{{v}}}'))
    # Add cellcolor commands
    output_df = table_df.applymap(cellcolor) + output_df
    print_table_latex(output_df)

### Coverage Comparison

In [None]:
coverage_table_df, coverage_significance_mask = baseline_comparison_table(
    shift_plot_df,
    methods=shift_classifier_methods,
    baseline_method=best_shift_classifier,
    dataset_labels=dataset_labels,
    metric='coverage',
    relative_values=True,
)
coverage_table_args = dict(
    table_df=coverage_table_df,
    significance_mask=coverage_significance_mask,
)
display_comparison_table(**coverage_table_args)
#print_comparison_table_latex(**coverage_table_args)

condition_coverage_table_df, _ = baseline_comparison_table(
    shift_plot_df.assign(shift_condition=shift_plot_df.apply(get_shift_condition, axis=1)),
    methods=shift_classifier_methods,
    baseline_method=best_shift_classifier,
    dataset_labels=dataset_labels,
    metric='coverage',
    relative_values=True,
    row_grouping=['shift_condition'],
)

def format_coverage_cell(value):
    return f'{value:.0%}'.replace('%', '\%')

print_table_latex(condition_coverage_table_df.applymap(format_coverage_cell))

### Error Comparison

In [None]:
error_table_df, error_significance_mask = baseline_comparison_table(
    shift_plot_df,
    methods=shift_classifier_methods,
    baseline_method=best_shift_classifier,
    dataset_labels=dataset_labels,
    metric='absolute_error',
    relative_values=True,
    t_test_alpha=0.05,
)
error_table_args = dict(
    table_df=error_table_df,
    significance_mask=error_significance_mask,
    format='{:.1%}',
    gradient_options=dict(
        cmap=sns.diverging_palette(220, 20, as_cmap=True),
        vmin=-0.05,
        vmax=0.05,
    ),
)
display_comparison_table(**error_table_args)
#print_comparison_table_latex(**error_table_args)

condition_error_table_df, _ = baseline_comparison_table(
    shift_plot_df.assign(shift_condition=shift_plot_df.apply(get_shift_condition, axis=1)),
    methods=shift_classifier_methods,
    baseline_method=best_shift_classifier,
    dataset_labels=dataset_labels,
    metric='absolute_error',
    relative_values=True,
    t_test_alpha=0.05,
    include_std=True,
    row_grouping=['shift_condition'],
)
def format_error_cell(value):
    mean, std = value
    return f'{mean:.2%} ({std:.2%})'

print_table_latex(condition_error_table_df.applymap(format_error_cell))

## Shift Test Runtimes

In [None]:
runtime_table_df = shift_test_runtime_table(
    shift_test_runtime_df,
    shift_test_labels=shift_test_labels,
    dataset_labels=dataset_labels,
)
display(runtime_table_df.style.format('{:.2f}'))
#print_table_latex(runtime_table_df.applymap(lambda v: f'{v:.2f}'))