# Synthetic "True Source Probabilities" Example

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from sklearn.calibration import calibration_curve

from pyquantification.experiments import (
    cached_experiments,
    classification_cache_key,
    load_from_cache,
    split_test,
    DATASETS,
)
from pyquantification.evaluation import BASE_LAYOUT, t_test, corrected_resampled_t_test

## Synthetic Dataset Plots

In [None]:
def plot_synthetic_dataset(dataset):
    # Plot an equal sized source/target set
    plot_df = dataset.df[:2000].copy()
    plot_df['colour'] = dataset.df['dist'] + '+' + dataset.df['class']

    fig = px.histogram(
        plot_df,
        x='x',
        color='colour',
    )
    fig.update_layout(barmode='overlay')
    fig.update_traces(opacity=0.75)
    display(fig)

### No Shift

In [None]:
plot_synthetic_dataset(DATASETS['synthetic-true-prob-no-shift']())

### Prior Shift

In [None]:
plot_synthetic_dataset(DATASETS['synthetic-true-prob-prior-shift']())

### GSLS Shift

In [None]:
plot_synthetic_dataset(DATASETS['synthetic-true-prob-gsls-shift']())

## Experiments

In [None]:
configs = [
    {
        'dataset_name': 'synthetic-true-prob-no-shift',
        'quantifier': 'pcc',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': [
            'pcc-pt',
            'pcc-apt',
            'pcc-mip',
        ],
    },
    {
        'dataset_name': 'synthetic-true-prob-prior-shift',
        'quantifier': 'em',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': [
            'em-pt',
            'em-apt',
            'em-mip',
        ],
    },
    {
        'dataset_name': 'synthetic-true-prob-gsls-shift',
        'quantifier': 'gsls',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': [
            'ugsls-pt',
            'ugsls-apt',
            'ugsls-mip',
        ],
    },
]
dataset_labels = {
    'synthetic-true-prob-no-shift': 'SNS',
    'synthetic-true-prob-prior-shift': 'SPS',
    'synthetic-true-prob-gsls-shift': 'SGS',
}

results_dfs = []
for config in configs:
    print(f'Running {config["dataset_name"]} experiments')
    results_dfs.append(cached_experiments(
        cache_key=f'sample_{config["dataset_name"]}_rejection_results',
        dataset_names=[config['dataset_name']],
        classifier_names=['logreg', 'source-prob'],
        calibration_methods=['uncalibrated'],
        loss_weights=[0],
        gain_weights=[0],
        random_states=[0],
        shift_types=['no_shift'],
        bin_counts=['auto'],
        random_priors_options=[False],
        quantification_methods=[config['quantifier']],
        rejectors=config['rejectors'],
        rejection_limits=[config['rejection_limit']],
        classification_workers=1,
        quantification_workers=10,
        continue_on_failure=True,
        # Run on all samples
        sample_idxs=None,
    ))
results_df = pd.concat(results_dfs)

In [None]:
plot_df = results_df.copy()
plot_df['dataset_label'] = plot_df['dataset_name'].map(dataset_labels)

## Coverage

In [None]:
def print_coverage_table_latex(table_df):
    for _, row in table_df.iterrows():
        if row.isna().all():
            print('\hline')
        else:
            print(' & '.join(row.to_dict().values()).replace('%', '\%') + r' \\')

def coverage_table():
    experiment_grouping = ['classifier_name', 'dataset_name', 'sample_idx']
    plot_methods = {
        'pcc': 'PCC',
        'em': 'EM',
        'gsls': 'GSLS',
    }

    def format_cell(mean):
        str_mean = f'{mean:.0%}'
        return str_mean

    rows = []
    for dataset_name, dataset_label in dataset_labels.items():
        for classifier_name in sorted(plot_df['classifier_name'].unique()):
            for method, method_label in plot_methods.items():
                row = {
                    'classifier': classifier_name,
                    'dataset': dataset_label,
                    'method': method_label,
                }
                cell_df = plot_df[
                    (plot_df['dataset_name'] == dataset_name)
                    & (plot_df['classifier_name'] == classifier_name)
                ].copy()
                # Ensure coverage is numeric so that it is kept during the groupby
                cell_df[f'{method}_coverage'] = cell_df[f'{method}_coverage'].astype(float)
                # Group by experiment first to group different target_classes together.
                cell_df = cell_df.groupby(experiment_grouping, dropna=False).mean().reset_index()
                if not pd.isna(cell_df[f'{method}_coverage'].mean()):
                    row['coverage'] = format_cell(cell_df[f'{method}_coverage'].mean())
                    rows.append(row)
        rows.append({})

    return pd.DataFrame(rows)

coverage_table_df = coverage_table()
display(coverage_table_df)
print_coverage_table_latex(coverage_table_df)

## Constrained Quantification

In [None]:
def display_rejection_table(table_df):
    sdf = table_df.style
    sdf.set_table_styles([
        {'selector': 'td', 'props': [('text-align', 'left')]},
        {'selector': 'th', 'props': [('text-align', 'left')]},
    ], overwrite=False)
    sdf.set_table_styles({
        (classifier, table_df.index.get_level_values(1).unique()[0]): [
            {'selector': 'td', 'props': [('border-top', '1px solid black')]},
            {'selector': 'th', 'props': [('border-top', '1px solid black')]},
        ]
        for classifier in table_df.index.get_level_values(0).unique()
    }, overwrite=False, axis=1)
    sdf.set_table_styles({
        (config_label, table_df.columns.get_level_values(1).unique()[0]): [
            {'selector': 'td', 'props': [('border-left', '1px solid black')]},
            {'selector': 'th', 'props': [('border-left', '1px solid black')]},
        ]
        for config_label in table_df.columns.get_level_values(0).unique()
    }, overwrite=False, axis=0)
    display(sdf)


def print_rejection_table_latex(table_df):
    for keys, row in table_df.iterrows():
        if row.isna().all():
            print('\hline')
        else:
            print(
                (' & '.join([
                    *keys,
                    *row.to_dict().values(),
                ]) + r' \\')
                .replace('%', r'\%')
                .replace('<strong>', r'\textbf{')
                .replace('</strong>', '}')
            )
    

def rejection_table(plot_df, *, dataset_labels, configs, experiment_state_col, baseline_rejector_label,
                    t_test_alpha=0.05, corrected_t_test=True, stats=None):
    # Pre-compute statistic columns
    for config in configs.values():
        qua_prefix = config['quantifier']

        for rejector in config['rejectors'].values():
            rej_prefix = f'{rejector}_{config["rejection_limit"]}'
            
            plot_df[f'{rej_prefix}_runtime_seconds'] = plot_df[f'{rej_prefix}_all_class_time_ns'] / 1_000_000_000
            plot_df[f'{rej_prefix}_rejected_proportion'] = (
                plot_df[f'{rej_prefix}_rejected_count'] / plot_df[f'test_n']
            )
            plot_df[f'{rej_prefix}_coverage_difference'] = (
                plot_df[f'{rej_prefix}_coverage'].astype(float) - plot_df[f'{qua_prefix}_coverage'].astype(float)
            ).astype(float)
            plot_df[f'{rej_prefix}_width_target_diff'] = (
                (plot_df[f'{rej_prefix}_target_width_limit'] - plot_df[f'{rej_prefix}_interval_width'])
                / plot_df['test_n']
            ).astype(float)
            plot_df[f'{rej_prefix}_distance_outside_interval'] = (
                np.maximum(
                    np.maximum(0, plot_df[f'{rej_prefix}_count_lower'] - plot_df['test_true_count']),
                    np.maximum(0, plot_df['test_true_count'] - plot_df[f'{rej_prefix}_count_upper']),
                ) / plot_df['test_n']
            )

    table_rows = []
    for classifier_name in sorted(plot_df['classifier_name'].unique()):
        classifier_rows = {}
        for config_label, config in configs.items():
            subset_df = plot_df[
                (plot_df['classifier_name'] == classifier_name) &
                (plot_df['dataset_name'] == config['dataset_name']) &
                (~plot_df[f'{config["quantifier"]}_count'].isna())
            ]
            
            def get_stat(col, *, fmt='{:.2f}', std=False, t_test_col=None,
                         class_agg='mean', median=False):
                selected_cols = [experiment_state_col, col]
                if t_test_col is not None:
                    selected_cols.append(t_test_col)
                
                class_groupby = subset_df[set(selected_cols)].groupby(experiment_state_col)
                if class_agg == 'mean':
                    class_agg_df = class_groupby.mean()
                elif class_agg == 'max':
                    class_agg_df = class_groupby.max()
                elif class_agg == 'min':
                    class_agg_df = class_groupby.min()
                    
                else:
                    raise ValueError(f'Unrecognised class_agg: {class_agg}')
                # Ensure we are only grouping class-rows together in class_max_df
                assert (class_agg_df.shape[0] * subset_df['target_class'].nunique()) == subset_df.shape[0]
                
                if median:
                    stat_val = class_agg_df[col].median()
                else:
                    stat_val = class_agg_df[col].mean()
                stat = f'{fmt.format(stat_val)}'
                if std:
                    stat += f' ({fmt.format(class_agg_df[col].std())})'
                if t_test_col is not None and t_test_col != col:
                    if corrected_t_test:
                        train_n = subset_df['full_train_n'].mean()
                        assert np.all(subset_df['full_train_n'] == train_n)
                        test_n = subset_df['test_n'].mean()
                        assert np.all(subset_df['test_n'] == test_n)
                        test_size = test_n / (train_n + test_n)
                        p_value = corrected_resampled_t_test(
                            class_agg_df[col].to_numpy(),
                            class_agg_df[t_test_col].to_numpy(),
                            test_size=test_size,
                        )
                    else:
                        p_value = t_test(
                            class_agg_df[col].to_numpy(),
                            class_agg_df[t_test_col].to_numpy(),
                        )
                    significant = p_value <= t_test_alpha
                    if significant:
                        stat = fr'<strong>{stat}</strong>'
                return stat

            baseline_rej_prefix = f'{config["rejectors"][baseline_rejector_label]}_{config["rejection_limit"]}'
            for rejector_label, rejector in config['rejectors'].items():
                qua_prefix = config['quantifier']
                rej_prefix = f'{rejector}_{config["rejection_limit"]}'
                
                # Initialise classifier row for this rejector
                row = classifier_rows.get(rejector_label, {
                    'Classifier': classifier_name,
                    'Rejector': rejector_label,
                })
                # Populate row with stats for this config
                row[(config_label, 'Rejection')] = get_stat(
                    f'{rej_prefix}_rejected_proportion',
                    fmt='{:.1%}',
                    std=True,
                    t_test_col=f'{baseline_rej_prefix}_rejected_proportion',
                )
                row[(config_label, 'Interval Width: Limit - Actual')] = get_stat(
                    f'{rej_prefix}_width_target_diff',
                    fmt='{:.1%}',
                    std=True,
                    t_test_col=f'{baseline_rej_prefix}_width_target_diff',
                    class_agg='min',
                )
                row[(config_label, 'Coverage: Post - Pre')] = get_stat(
                    f'{rej_prefix}_coverage_difference',
                    fmt='{:.1%}',
                )
                row[(config_label, 'Distance From Post-Interval')] = get_stat(
                    f'{rej_prefix}_distance_outside_interval',
                    fmt='{:.1%}',
                    t_test_col=f'{baseline_rej_prefix}_distance_outside_interval',
                    std=True,
                )
                row[(config_label, 'Runtime Seconds')] = get_stat(
                    f'{rej_prefix}_runtime_seconds',
                    fmt='{:,.2f}',
                    std=True,
                    t_test_col=f'{baseline_rej_prefix}_runtime_seconds',
                    median=True,
                )
                classifier_rows[rejector_label] = row
        table_rows += list(classifier_rows.values())

    # Row index: classifier, rejector_label
    table_df = pd.DataFrame(table_rows).set_index(['Classifier', 'Rejector'])
    # Column index: config_label, stat
    table_df.columns = pd.MultiIndex.from_tuples(table_df.columns)
    
    # Select which stat columns to include
    if stats is not None:
        table_df = table_df.loc[:, pd.IndexSlice[table_df.columns.get_level_values(0).unique(), stats]]
    
    return table_df

rejection_table_configs = {
    'PCC': {
        'dataset_name': 'synthetic-true-prob-no-shift',
        'quantifier': 'pcc',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'pcc-mip',
            'APT': 'pcc-apt',
            'PT': 'pcc-pt',
        },
    },
    'EM': {
        'dataset_name': 'synthetic-true-prob-prior-shift',
        'quantifier': 'em',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'em-mip',
            'APT': 'em-apt',
            'PT': 'em-pt',
        },
    },
    'GSLS': {
        'dataset_name': 'synthetic-true-prob-gsls-shift',
        'quantifier': 'gsls',
        'rejection_limit': 'fracmax:0.5',
        'rejectors': {
            'MIP': 'ugsls-mip',
            'APT': 'ugsls-apt',
            'PT': 'ugsls-pt',
        },
    },
}

rejection_table_df = rejection_table(
    plot_df,
    dataset_labels=dataset_labels,
    configs=rejection_table_configs,
    experiment_state_col='sample_idx',
    baseline_rejector_label='MIP',
    corrected_t_test=False,
    stats=[
        'Rejection',
        'Interval Width: Limit - Actual',
        'Coverage: Post - Pre',
        'Distance From Post-Interval',
    ],
)
display_rejection_table(rejection_table_df)
print_rejection_table_latex(rejection_table_df)