# Quantification under Prior Shift

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pyquantification.experiments import cached_experiments
from pyquantification.evaluation import (
    display_stat_table,
    color_scale,
)

In [None]:
quantification_methods = [
    'count',
    'pcc',
    'em', 
    'gsls',
    'true-weight-gsls',
]
dataset_labels = {
    'insect-sex': 'ISX',
    'insect-species': 'ISP',
    'arabic-digits': 'DIG',
    'handwritten-letters-letter': 'HLL',
    'handwritten-letters-author': 'HLA',
}
results_df = cached_experiments(
    cache_key='prior_shift_results',
    dataset_names=list(dataset_labels.keys()),
    classifier_names=['logreg'],
    loss_weights=[0],
    gain_weights=[0],
    random_states=list(range(0, 1000)),
    shift_types=['prior_shift'],
    bin_counts=['auto'],
    random_priors_options=[True],
    quantification_methods=quantification_methods,
    classification_workers=12,
    continue_on_failure=True,
)

In [None]:
# Allow for initial filtering of results
plot_df = results_df.reset_index()

## Quantification Method Comparison

In [None]:
display_stat_table(plot_df,
                   stat='coverage',
                   row_grouping=['single_grouping'],
                   methods={'PCC': 'pcc',
                            'EM': 'em',
                            'GSLS': 'gsls'},
                   color_func=color_scale(threshold=0.8),
                   format_string='{:.0%}')

In [None]:
display_stat_table(plot_df,
                   stat='absolute_error',
                   row_grouping='dataset_name',
                   methods={'CC': 'count',
                            'PCC': 'pcc',
                            'EM': 'em',
                            'GSLS': 'gsls'},
                   format_string='{:.1%}',
                   include_std=True)