# Bootstrap Quantification under Sample Shift

In [None]:
import itertools
import pandas as pd

from pyquantification.experiments import cached_experiments
from pyquantification.evaluation import baseline_comparison_table

In [None]:
quantification_methods = [
    'count',
    'em-bs',
]
conf = dict(
    classifier_names=['logreg'],
    calibration_methods=['uncalibrated'],
    loss_weights=[0],
    gain_weights=[0],
    random_states=[0],
    shift_types=['no_shift'],
    bin_counts=['auto'],
    random_priors_options=[False],
    quantification_methods=quantification_methods,
    classification_workers=1,
    quantification_workers=10,
    continue_on_failure=True,
    # Run on all samples
    sample_idxs=None,
)

In [None]:
binary_results_df = cached_experiments(**conf, cache_key='bootstrap_binary_sample_results', dataset_names=['binary-plankton'])

In [None]:
fg_results_df = cached_experiments(**conf, cache_key='bootstrap_fg_sample_results', dataset_names=['fg-plankton'])

In [None]:
orig_results_df = cached_experiments(**conf, cache_key='bootstrap_sample_results', dataset_names=['plankton'])

In [None]:
dataset_labels = {
    'binary-plankton': 'BPL',
    'fg-plankton': 'FPL',
    'plankton': 'OPL',
}
shift_classifier_labels = {
    'count': 'CC',
    'em-bs': 'EM-BS',
}

plot_df = pd.concat([
    binary_results_df,
    fg_results_df,
    orig_results_df,
])
plot_df['shift_condition'] = 'Sample shift'
plot_df['dataset_label'] = plot_df['dataset_name'].map(dataset_labels)

alternative_dataset_label = (
    plot_df['dataset_label']
    .mask((plot_df['dataset_name'] == 'plankton') & ((plot_df['test_true_count'] / plot_df['test_n']) <= 0.01), 'OPL(q <= 1%)')
    .mask((plot_df['dataset_name'] == 'plankton') & ((plot_df['test_true_count'] / plot_df['test_n']) > 0.01), 'OPL(q > 1%)')
)
display(alternative_dataset_label.value_counts())

In [None]:
def print_table_latex(table_df):
    for index, row in table_df.iterrows():
        index = index if isinstance(index, tuple) else (index,)
        if row.isna().all():
            print('\hline')
        else:
            print(' & '.join([
                str(value).replace('%', '\\%') for value in
                itertools.chain(index, row.to_dict().values())
            ]) + r' \\')
            
coverage_table_df, _ = baseline_comparison_table(
    plot_df.assign(dataset_label=alternative_dataset_label),
    methods={v: k for k, v in shift_classifier_labels.items()},
    baseline_method='count',
    dataset_labels={},
    metric='coverage',
    row_grouping=['dataset_label'],
)
styled_coverage_table_df = coverage_table_df.applymap(lambda v: f'{v:.2%}')
display(styled_coverage_table_df)
print_table_latex(styled_coverage_table_df)