# Limiting Quantification Interval Widths

In [None]:
import numpy as np
import pandas as pd

from pyquantification.experiments import cached_experiments
from pyquantification.evaluation import t_test, corrected_resampled_t_test

In [None]:
synthetic_df = pd.concat([
    cached_experiments(cache_key='synthetic_1_rejection_results'),
    cached_experiments(cache_key='synthetic_2_rejection_results'),
    cached_experiments(cache_key='synthetic_3_rejection_results'),
]).reset_index()

sample_df = pd.concat([
    cached_experiments(cache_key='sample_plankton_1_rejection_results'),
    cached_experiments(cache_key='sample_plankton_2_rejection_results'),
    cached_experiments(cache_key='sample_plankton_3_rejection_results'),
    cached_experiments(cache_key='sample_fg-plankton_1_rejection_results'),
    cached_experiments(cache_key='sample_fg-plankton_2_rejection_results'),
    cached_experiments(cache_key='sample_fg-plankton_3_rejection_results'),
    cached_experiments(cache_key='sample_binary-plankton_1_rejection_results'),
    cached_experiments(cache_key='sample_binary-plankton_2_rejection_results'),
    cached_experiments(cache_key='sample_binary-plankton_3_rejection_results'),
]).reset_index()

# Separate runtime experiments that are performed serially.
runtime_sample_df = pd.concat([
    cached_experiments(cache_key='runtime_sample_plankton_1_rejection_results'),
    cached_experiments(cache_key='runtime_sample_plankton_2_rejection_results'),
    cached_experiments(cache_key='runtime_sample_plankton_3_rejection_results'),
    cached_experiments(cache_key='runtime_sample_fg-plankton_1_rejection_results'),
    cached_experiments(cache_key='runtime_sample_fg-plankton_2_rejection_results'),
    cached_experiments(cache_key='runtime_sample_fg-plankton_3_rejection_results'),
    cached_experiments(cache_key='runtime_sample_binary-plankton_1_rejection_results'),
    cached_experiments(cache_key='runtime_sample_binary-plankton_2_rejection_results'),
    cached_experiments(cache_key='runtime_sample_binary-plankton_3_rejection_results'),
]).reset_index()

In [None]:
def print_table_latex(table_df):
    for keys, row in table_df.iterrows():
        if row.isna().all():
            print('\hline')
        else:
            print(
                (' & '.join([
                    *keys,
                    *row.to_dict().values(),
                ]) + r' \\')
                .replace('%', r'\%')
                .replace('<strong>', r'\textbf{')
                .replace('</strong>', '}')
            )


def display_table(table_df):
    sdf = table_df.style
    sdf.set_table_styles([
        {'selector': 'td', 'props': [('text-align', 'left')]},
        {'selector': 'th', 'props': [('text-align', 'left')]},
    ], overwrite=False)
    sdf.set_table_styles({
        (dataset, table_df.index.get_level_values(1).unique()[0]): [
            {'selector': 'td', 'props': [('border-top', '1px solid black')]},
            {'selector': 'th', 'props': [('border-top', '1px solid black')]},
        ]
        for dataset in table_df.index.get_level_values(0).unique()
    }, overwrite=False, axis=1)
    sdf.set_table_styles({
        (config_label, table_df.columns.get_level_values(1).unique()[0]): [
            {'selector': 'td', 'props': [('border-left', '1px solid black')]},
            {'selector': 'th', 'props': [('border-left', '1px solid black')]},
        ]
        for config_label in table_df.columns.get_level_values(0).unique()
    }, overwrite=False, axis=0)
    display(sdf)


def rejection_table(plot_df, *, dataset_labels, configs, experiment_state_col, baseline_rejector_label,
                    t_test_alpha=0.05, corrected_t_test=True, stats=None):
    # Pre-compute statistic columns
    for config in configs.values():
        qua_prefix = config['quantifier']

        for rejector in config['rejectors'].values():
            rej_prefix = f'{rejector}_{config["rejection_limit"]}'
            
            plot_df[f'{rej_prefix}_runtime_seconds'] = plot_df[f'{rej_prefix}_all_class_time_ns'] / 1_000_000_000
            plot_df[f'{rej_prefix}_rejected_proportion'] = (
                plot_df[f'{rej_prefix}_rejected_count'] / plot_df[f'test_n']
            )
            plot_df[f'{rej_prefix}_coverage_difference'] = (
                plot_df[f'{rej_prefix}_coverage'] - plot_df[f'{qua_prefix}_coverage']
            ).astype(float)
            plot_df[f'{rej_prefix}_width_target_diff'] = (
                (plot_df[f'{rej_prefix}_target_width_limit'] - plot_df[f'{rej_prefix}_interval_width'])
                / plot_df['test_n']
            ).astype(float)
            plot_df[f'{rej_prefix}_distance_outside_interval'] = (
                np.maximum(
                    np.maximum(0, plot_df[f'{rej_prefix}_count_lower'] - plot_df['test_true_count']),
                    np.maximum(0, plot_df['test_true_count'] - plot_df[f'{rej_prefix}_count_upper']),
                ) / plot_df['test_n']
            )

    table_rows = []
    for dataset_name, dataset_label in dataset_labels.items():
        dataset_rows = {}
        for config_label, config in configs.items():
            subset_df = plot_df[
                (plot_df['dataset_name'] == dataset_name) &
                (plot_df['shift_type'] == config['shift_type']) &
                (~plot_df[f'{config["quantifier"]}_count'].isna())
            ]
            
            def get_stat(col, *, fmt='{:.2f}', std=False, t_test_col=None,
                         class_agg='mean', median=False):
                selected_cols = [experiment_state_col, col]
                if t_test_col is not None:
                    selected_cols.append(t_test_col)
                
                class_groupby = subset_df[set(selected_cols)].groupby(experiment_state_col)
                if class_agg == 'mean':
                    class_agg_df = class_groupby.mean()
                elif class_agg == 'max':
                    class_agg_df = class_groupby.max()
                elif class_agg == 'min':
                    class_agg_df = class_groupby.min()
                    
                else:
                    raise ValueError(f'Unrecognised class_agg: {class_agg}')
                # Ensure we are only grouping class-rows together in class_max_df
                assert (class_agg_df.shape[0] * subset_df['target_class'].nunique()) == subset_df.shape[0]
                
                if median:
                    stat_val = class_agg_df[col].median()
                else:
                    stat_val = class_agg_df[col].mean()
                stat = f'{fmt.format(stat_val)}'
                if std:
                    stat += f' ({fmt.format(class_agg_df[col].std())})'
                if t_test_col is not None and t_test_col != col:
                    if corrected_t_test:
                        train_n = subset_df['full_train_n'].mean()
                        assert np.all(subset_df['full_train_n'] == train_n)
                        test_n = subset_df['test_n'].mean()
                        assert np.all(subset_df['test_n'] == test_n)
                        test_size = test_n / (train_n + test_n)
                        p_value = corrected_resampled_t_test(
                            class_agg_df[col].to_numpy(),
                            class_agg_df[t_test_col].to_numpy(),
                            test_size=test_size,
                        )
                    else:
                        p_value = t_test(
                            class_agg_df[col].to_numpy(),
                            class_agg_df[t_test_col].to_numpy(),
                        )
                    significant = p_value <= t_test_alpha
                    if significant:
                        stat = fr'<strong>{stat}</strong>'
                return stat

            baseline_rej_prefix = f'{config["rejectors"][baseline_rejector_label]}_{config["rejection_limit"]}'
            for rejector_label, rejector in config['rejectors'].items():
                qua_prefix = config['quantifier']
                rej_prefix = f'{rejector}_{config["rejection_limit"]}'
                
                # Initialise dataset row for this rejector
                row = dataset_rows.get(rejector_label, {
                    'Dataset': dataset_label,
                    'Rejector': rejector_label,
                })
                # Populate row with stats for this config
                row[(config_label, 'Rejection')] = get_stat(
                    f'{rej_prefix}_rejected_proportion',
                    fmt='{:.1%}',
                    std=True,
                    t_test_col=f'{baseline_rej_prefix}_rejected_proportion',
                )
                row[(config_label, 'Interval Width: Limit - Actual')] = get_stat(
                    f'{rej_prefix}_width_target_diff',
                    fmt='{:.1%}',
                    std=True,
                    t_test_col=f'{baseline_rej_prefix}_width_target_diff',
                    class_agg='min',
                )
                row[(config_label, 'Coverage: Post - Pre')] = get_stat(
                    f'{rej_prefix}_coverage_difference',
                    fmt='{:.1%}',
                )
                row[(config_label, 'Distance From Post-Interval')] = get_stat(
                    f'{rej_prefix}_distance_outside_interval',
                    fmt='{:.1%}',
                    t_test_col=f'{baseline_rej_prefix}_distance_outside_interval',
                    std=True,
                )
                row[(config_label, 'Runtime Seconds')] = get_stat(
                    f'{rej_prefix}_runtime_seconds',
                    fmt='{:,.2f}',
                    std=True,
                    t_test_col=f'{baseline_rej_prefix}_runtime_seconds',
                    median=True,
                )
                dataset_rows[rejector_label] = row
        table_rows += list(dataset_rows.values())

    # Row index: dataset_label, rejector_label
    table_df = pd.DataFrame(table_rows).set_index(['Dataset', 'Rejector'])
    # Column index: config_label, stat
    table_df.columns = pd.MultiIndex.from_tuples(table_df.columns)
    
    # Select which stat columns to include
    if stats is not None:
        table_df = table_df.loc[:, pd.IndexSlice[table_df.columns.get_level_values(0).unique(), stats]]
    
    return table_df


def mip_rejection_check(plot_df, *, configs):
    results = {}
    for config_label, config in configs.items():
        mip_rejected_count = plot_df[f'{config["rejectors"]["MIP"]}_{config["rejection_limit"]}_rejected_count']
        apt_rejected_count = plot_df[f'{config["rejectors"]["APT"]}_{config["rejection_limit"]}_rejected_count']
        results[config_label] = (mip_rejected_count > apt_rejected_count).sum()
    return pd.DataFrame([results])

## Experiment Results

In [None]:
synthetic_dataset_labels = {
    'handwritten-letters-letter': 'HLL',
    'handwritten-letters-author': 'HLA',
    'arabic-digits': 'DIG',
    'insect-sex': 'ISX',
    'insect-species': 'ISP',
}
synthetic_configs = {
    'No shift + PCC': {
        'shift_type': 'no_shift',
        'quantifier': 'pcc',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'pcc-mip',
            'APT': 'pcc-apt',
            'PT': 'pcc-pt',
        },
    },
    'Prior shift + EM': {
        'shift_type': 'prior_shift',
        'quantifier': 'em',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'em-mip',
            'APT': 'em-apt',
            'PT': 'em-pt',
        },
    },
    'GSLS shift + GSLS': {
        'shift_type': 'gsls_shift',
        'quantifier': 'gsls',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'ugsls-mip',
            'APT': 'ugsls-apt',
            'PT': 'ugsls-pt',
        },
    },
}

synthetic_table = rejection_table(
    synthetic_df,
    dataset_labels=synthetic_dataset_labels,
    configs=synthetic_configs,
    experiment_state_col='random_state',
    baseline_rejector_label='MIP',
    stats=[
        'Rejection',
        'Interval Width: Limit - Actual',
        'Coverage: Post - Pre',
        'Distance From Post-Interval',
    ],
)
display_table(synthetic_table)
print_table_latex(synthetic_table)

In [None]:
sample_dataset_labels = {
    'plankton': 'OPL',
    'fg-plankton': 'FPL',
    'binary-plankton': 'BPL',
}
sample_configs = {
    'PCC': {
        'shift_type': 'no_shift',
        'quantifier': 'pcc',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'pcc-mip',
            'APT': 'pcc-apt',
            'PT': 'pcc-pt',
        },
    },
    'EM': {
        'shift_type': 'no_shift',
        'quantifier': 'em',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'em-mip',
            'APT': 'em-apt',
            'PT': 'em-pt',
        },
    },
    'GSLS': {
        'shift_type': 'no_shift',
        'quantifier': 'gsls',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'ugsls-mip',
            'APT': 'ugsls-apt',
            'PT': 'ugsls-pt',
        },
    },
}

sample_table = rejection_table(
    sample_df,
    dataset_labels=sample_dataset_labels,
    configs=sample_configs,
    experiment_state_col='sample_idx',
    baseline_rejector_label='MIP',
    corrected_t_test=False,
    stats=[
        'Rejection',
        'Interval Width: Limit - Actual',
        'Coverage: Post - Pre',
        'Distance From Post-Interval',
    ],
)
display_table(sample_table)
print_table_latex(sample_table)

### Runtime Results

In [None]:
runtime_sample_table = rejection_table(
    runtime_sample_df,
    dataset_labels=sample_dataset_labels,
    configs=sample_configs,
    experiment_state_col='sample_idx',
    baseline_rejector_label='MIP',
    corrected_t_test=False,
    stats=[
        'Runtime Seconds',
    ],
)
display_table(runtime_sample_table)
print_table_latex(runtime_sample_table)

In [None]:
runtime_test_counts = runtime_sample_df.groupby('sample_idx').mean()['test_n']
print(f'Runtime test sample sizes have mean: {runtime_test_counts.mean()} and std dev: {runtime_test_counts.std()}')

### MIP Correctness Check

Given `MIP` rejection uses the same formulation for interval widths as `APT` but has more freedom to select the best set of instances to reject, `MIP` should always reject less than `APT`. However, the underlying MIP solver has been observed to perform excessive rejection in some cases. We check that these cases remain rare, such that their impact on the results can be considered insignificant.

There are over 250 cases for sample shift GSLS because there are 5 samples for the plankton dataset that resort to full rejection because of solver errors. 

In [None]:
print('Synthetic shift')
display(mip_rejection_check(synthetic_df, configs=synthetic_configs))
print('Sample shift')
display(mip_rejection_check(sample_df, configs=sample_configs))