# Quantification under GSLS Shift

In [None]:
from pyquantification.experiments import cached_experiments
from pyquantification.evaluation import (
    get_colormap,
    display_dataset_summary,
    display_stat_table,
    color_scale,
    plot_remain_weight,
    plot_error_bars_sample,
)

In [None]:
quantification_methods = [
    'count',
    'pcc',
    'em', 
    'gsls',
    'true-weight-gsls',
]
dataset_labels = {
    'handwritten-letters-letter': 'HLL',
    'handwritten-letters-author': 'HLA',
    'arabic-digits': 'DIG',
    'insect-sex': 'ISX',
    'insect-species': 'ISP',
}
results_df = cached_experiments(
    cache_key='gsls_results',
    dataset_names=list(dataset_labels.keys()),
    classifier_names=['logreg'],
    loss_weights=[0, 0.3, 0.7, 1],
    gain_weights=[0, 0.3, 0.7, 1],
    random_states=list(range(0, 1000)),
    shift_types=['gsls_shift'],
    bin_counts=['auto'],
    random_priors_options=[True],
    quantification_methods=quantification_methods,
    classification_workers=12,
    continue_on_failure=True,
)

In [None]:
# Allow for initial filtering of results
plot_df = results_df.reset_index()

# Set dataset labels
plot_df['dataset_label'] = plot_df['dataset_name'].map(dataset_labels)

# Colormap for consistent dataset label colours
colormap = get_colormap(dataset_labels.values())

## Dataset Summary

In [None]:
display_dataset_summary(dataset_labels)

## Quantification Method Comparison

In [None]:
display_stat_table(plot_df,
                   stat='coverage',
                   row_grouping=['gain_weight', 'loss_weight'],
                   methods={'PCC': 'pcc',
                            'EM': 'em',
                            'GSLS': 'gsls',
                            'True Weight GSLS': 'true-weight-gsls',},
                   color_func=color_scale(threshold=0.8),
                   format_string='{:.0%}')

In [None]:
display_stat_table(plot_df,
                   stat='absolute_error',
                   row_grouping='dataset_name',
                   methods={'CC': 'count',
                            'PCC': 'pcc',
                            'EM': 'em',
                            'GSLS': 'gsls'},
                   format_string='{:.1%}',
                   include_std=True)

## GSLS with True Weights vs Estimated Weights

In [None]:
fig = plot_remain_weight(plot_df, 'gsls', colormap=colormap)
fig.write_image("plots/remain-weights.svg")
fig.show()

In [None]:
display_stat_table(plot_df,
                   stat='width',
                   row_grouping=['gain_weight', 'loss_weight'],
                   methods={'True Weight GSLS': 'true-weight-gsls',
                            'GSLS': 'gsls'},
                   color_func=color_scale(inverted=True),
                   include_std=True,
                   format_string='{:.2%}')

## Visual Comparison of Quantification Intervals

In [None]:
fig = plot_error_bars_sample(plot_df, 'gsls', dataset_name='insect-sex', seed=7,
                             methods=['pcc', 'em', 'gsls'], include_fit_weights=False)
fig.write_image("plots/quant-lines.svg")
fig.show()