# Experiments with Real-World Dataset Shift

In [None]:
import itertools
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import rgb2hex

from pyquantification.experiments import cached_experiments
from pyquantification.evaluation import (
    build_shift_classification_df,
    display_stat_table,
    color_scale,
    shift_test_table,
    baseline_comparison_table,
)

In [None]:
dataset_labels = {
    'binary-plankton': 'BPL',
    'fg-plankton': 'FPL',
    'plankton': 'OPL',
}
any_shift_test_labels = {
    'ks': 'KS',
    'lr': 'LR',
}
non_prior_shift_test_labels = {
    'aks': 'AKS',
}
shift_test_labels = {**any_shift_test_labels, **non_prior_shift_test_labels}
shift_classifier_labels = {
    'count': 'CC',
    'pcc': 'PCC',
    'em': 'EM',
    'gsls': 'GSLS',
    **{
        f'{any_shift_test}+{non_prior_shift_test}': f'{any_shift_test_label}+{non_prior_shift_test_label}'
        for any_shift_test, any_shift_test_label in any_shift_test_labels.items()
        for non_prior_shift_test, non_prior_shift_test_label in non_prior_shift_test_labels.items()
    }
}

def print_table_latex(table_df):
    for index, row in table_df.iterrows():
        index = index if isinstance(index, tuple) else (index,)
        if row.isna().all():
            print('\hline')
        else:
            print(' & '.join([
                str(value).replace('%', '\\%') for value in
                itertools.chain(index, row.to_dict().values())
            ]) + r' \\')

In [None]:
shift_dfs = []
for cache_key in ['binary_sample_results', 'fg_sample_results', 'sample_results']:
    raw_df = cached_experiments(cache_key=cache_key)
    _, shift_df = build_shift_classification_df(raw_df,
                                                any_shift_tests=any_shift_test_labels.keys(),
                                                non_prior_shift_tests=non_prior_shift_test_labels.keys())
    shift_dfs.append(shift_df)

shift_plot_df = pd.concat(shift_dfs)
shift_plot_df['shift_condition'] = 'Sample shift'
shift_plot_df['dataset_label'] = shift_plot_df['dataset_name'].map(dataset_labels)

alternative_dataset_label = (
    shift_plot_df['dataset_label']
    .mask((shift_plot_df['dataset_name'] == 'plankton') & ((shift_plot_df['test_true_count'] / shift_plot_df['test_n']) <= 0.01), 'OPL(q <= 1%)')
    .mask((shift_plot_df['dataset_name'] == 'plankton') & ((shift_plot_df['test_true_count'] / shift_plot_df['test_n']) > 0.01), 'OPL(q > 1%)')
)
display(alternative_dataset_label.value_counts())

## Shift Tests

In [None]:
shift_test_colormap = sns.light_palette('green', as_cmap=True)
    
table_df = shift_test_table(
    shift_plot_df,
    dataset_labels=dataset_labels,
    shift_conditions=['Sample shift'],
    shift_test_labels=shift_test_labels,
)
display(
    table_df.style
    .format('{:.2%}')
    .background_gradient(cmap=shift_test_colormap, vmin=0, vmax=1)
)

def format_shift_test_cell(value):
    color = shift_test_colormap(value)
    hexcolor = rgb2hex(color).strip('#').upper()
    return f'\cellcolor[HTML]{{{hexcolor}}} {value:.2%}'.replace('%', '\%')

#print_table_latex(table_df.applymap(format_shift_test_cell))

## Error Evaluation

In [None]:
error_table_df, _ = baseline_comparison_table(
    shift_plot_df.assign(dataset_label=alternative_dataset_label),
    methods={v: k for k, v in shift_classifier_labels.items()},
    baseline_method='count',
    dataset_labels={},
    metric='absolute_error',
    include_std=True,
    row_grouping=['dataset_label'],
)
styled_error_table_df = error_table_df.applymap(lambda v: f'{v[0]:.2%} ({v[1]:.2%})')
display(styled_error_table_df)
print_table_latex(styled_error_table_df)

## Coverage Evaluation

In [None]:
coverage_table_df, _ = baseline_comparison_table(
    shift_plot_df.assign(dataset_label=alternative_dataset_label),
    methods={v: k for k, v in shift_classifier_labels.items()},
    baseline_method='count',
    dataset_labels={},
    metric='coverage',
    row_grouping=['dataset_label'],
)
styled_coverage_table_df = coverage_table_df.applymap(lambda v: f'{v:.2%}')
display(styled_coverage_table_df)
print_table_latex(styled_coverage_table_df)