In [None]:
import pandas as pd
from pyquantification.experiments import cached_experiments, DATASETS
from pyquantification.evaluation import corrected_resampled_t_test

In [None]:
quantification_methods = [
    'count',
    'pcc',
    'em', 
    'gsls',
    'true-weight-gsls',
]
gsls_df = cached_experiments(cache_key='gsls_results')
prior_df = cached_experiments(cache_key='prior_shift_results')
bins_df = cached_experiments(cache_key='bins_results')

In [None]:
dataset_labels = {
    'handwritten-letters-letter': 'HLL',
    'handwritten-letters-author': 'HLA',
    'arabic-digits': 'DIG',
    'insect-sex': 'ISX',
    'insect-species': 'ISP',
}
datasets = {
    dataset_name: DATASETS[dataset_name]()
    for dataset_name in dataset_labels.keys()
}
gain_weights = [0.0, 0.3, 0.7, 1.0]
loss_weights = [0.0, 0.3, 0.7, 1.0]

def print_table_latex(table_df):
    for _, row in table_df.iterrows():
        if row.isna().all():
            print('\hline')
        else:
            print(' & '.join(row.to_dict().values()) + r' \\')

In [None]:
def coverage_table():
    plot_methods = {
        'pcc': 'PCC',
        'em': 'EM',
        'gsls': 'GSLS',
        'true-weight-gsls': 'TGSLS',
    }

    def format_cell(mean):
        str_mean = f'{mean:.0%}'
        str_mean = str_mean.replace('%', '\%')
        if mean >= 0.8:
            str_mean = r'\textbf{' + str_mean + '}'
        return str_mean

    rows = []
    for dataset_name, dataset_label in dataset_labels.items():
        for method, method_label in plot_methods.items():
            row = {
                'dataset': dataset_label,
                'method': method_label,
            }
            for gain_weight in gain_weights:
                for loss_weight in loss_weights:
                    cell_gsls_df = gsls_df[
                        (gsls_df['dataset_name'] == dataset_name) &
                        (gsls_df['gain_weight'] == gain_weight) &
                        (gsls_df['loss_weight'] == loss_weight)
                    ]
                    row[f'gw{gain_weight}, lw{loss_weight}'] = format_cell(cell_gsls_df[f'{method}_coverage'].mean())
            if method == 'true-weight-gsls':
                row['prior_shift'] = 'N.A.'
            else:
                cell_prior_df = prior_df[prior_df['dataset_name'] == dataset_name]
                row['prior_shift'] = format_cell(cell_prior_df[f'{method}_coverage'].mean())
            rows.append(row)
        rows.append({})
    return pd.DataFrame(rows)
    
table_df = coverage_table()
display(table_df)
print_table_latex(table_df)

In [None]:
def width_table():
    t_test_alpha = 0.05
    
    def format_cell(mean, std, is_bold=False):
        # Convert to percentage.
        mean, std = mean * 100, std * 100

        # Number formatting
        mean_str = f'{mean:.0f}'
        if std == 0.0:
            std_str = '(0)'
        elif std < 1:
            std_str = r'(\textless1)'
        else:
            std_str = f'({std:.0f})'

        # Conditional boldface
        if is_bold:
            mean_str = r'\textbf{' + mean_str + '}'
            std_str = r'\textbf{' + std_str + '}'

        return f'{mean_str} & {std_str}'

    rows = []
    for dataset_name, dataset_label in dataset_labels.items():
        est_row = {'dataset': dataset_label, 'method': 'GSLS'}
        tru_row = {'dataset': dataset_label, 'method': 'TGSLS'}
        dataset = datasets[dataset_name]
        test_size = dataset.test_n / (dataset.test_n + dataset.train_n)
        for gain_weight in gain_weights:
            for loss_weight in loss_weights:
                cell_df = gsls_df[
                    (gsls_df['dataset_name'] == dataset_name) & 
                    (gsls_df['gain_weight'] == gain_weight) &
                    (gsls_df['loss_weight'] == loss_weight)
                ]
                est_widths = cell_df[f'gsls_width']
                tru_widths = cell_df[f'true-weight-gsls_width']
                significant = corrected_resampled_t_test(est_widths.to_numpy(),
                                                         tru_widths.to_numpy(),
                                                         test_size=test_size) < t_test_alpha

                cell_key = f'gw{gain_weight}, lw{loss_weight}'
                est_row[cell_key] = format_cell(est_widths.mean(), est_widths.std(),
                                                is_bold=(significant and (est_widths.mean() < tru_widths.mean())))
                tru_row[cell_key] = format_cell(tru_widths.mean(), tru_widths.std(),
                                                is_bold=(significant and (tru_widths.mean() < est_widths.mean())))
        rows += [est_row, tru_row, {}]
    return pd.DataFrame(rows)

table_df = width_table()
display(table_df)
print_table_latex(table_df)

In [None]:
def error_table():
    t_test_alpha = 0.05
    plot_methods = {
        'count': 'CC',
        'pcc': 'PCC',
        'em': 'EM',
        'gsls': 'GSLS',
    }

    def format_cell(mean, std, is_bold=False):
        # Convert to percent
        mean, std = mean * 100, std * 100
        
        # Number formatting
        mean_str = f'{mean:.0f}'
        if std == 0.0:
            std_str = '(0)'
        elif std < 1.0:
            std_str = r'(\textless1)'
        else:
            std_str = f'({std:.0f})'

        # Conditional boldface
        if is_bold:
            mean_str = r'\textbf{' + mean_str + '}'
            std_str = r'\textbf{' + std_str + '}'

        return f'{mean_str} & {std_str}'
    
    def error_cell(df, method, test_size):
        if method == 'count':
            significant = False
        else:
            significant = corrected_resampled_t_test(df['count_absolute_error'].to_numpy(),
                                                     df[f'{method}_absolute_error'].to_numpy(),
                                                     test_size=test_size) < t_test_alpha
        return format_cell(
            df[f'{method}_absolute_error'].mean(),
            df[f'{method}_absolute_error'].std(),
            (significant and (df[f'{method}_absolute_error'].mean() < df['count_absolute_error'].mean()))
        )

    rows = []
    for dataset_name, dataset_label in dataset_labels.items():
        dataset = datasets[dataset_name]
        test_size = dataset.test_n / (dataset.test_n + dataset.train_n)
        for method, method_label in plot_methods.items():
            row = {
                'dataset': dataset_label,
                'method': method_label,
            }
            
            no_shift_df = gsls_df[
                (gsls_df['dataset_name'] == dataset_name) &
                ((gsls_df['gain_weight'] + gsls_df['loss_weight']) == 0.0)
            ]
            gsls_shift_df = gsls_df[
                (gsls_df['dataset_name'] == dataset_name) &
                ((gsls_df['gain_weight'] + gsls_df['loss_weight']) > 0.0)
            ]
            prior_shift_df = prior_df[
                (prior_df['dataset_name'] == dataset_name)
            ]
            
            row['no_shift'] = error_cell(no_shift_df, method, test_size)
            row['gsls_shift'] = error_cell(gsls_shift_df, method, test_size)
            row['prior_shift'] = error_cell(prior_shift_df, method, test_size)
            rows.append(row)
        rows.append({})
    return pd.DataFrame(rows)
    
table_df = error_table()
display(table_df)
print_table_latex(table_df)

In [None]:
sensitivity_dataset_labels = {
    'insect-sex': 'ISX-500',
    'insect-sex_smaller': 'ISX-250',
    'insect-sex_smallest': 'ISX-50',
}
class_count = 2
sensitivity_bin_counts = [5, 'auto', 50]

def runtime_table():
    def format_cell(mean, std):
        return f'{mean:.1f} & ({std:.1f})'

    rows = []
    for dataset_name, dataset_label in sensitivity_dataset_labels.items():
        row = {'dataset': dataset_label}
        for bin_count in sensitivity_bin_counts:
            cell_df = bins_df[
                (bins_df['dataset_name'] == dataset_name) &
                (bins_df['bin_count'] == bin_count)
            ]
            # Convert to ms, and divide by number of classes computed in that time.
            time_ms_series = cell_df['gsls_all_class_time_ns'] / 1_000_000 / class_count
            row[f'{bin_count} bins'] = format_cell(
                time_ms_series.mean(),
                time_ms_series.std(),
            )
        rows.append(row)
    return pd.DataFrame(rows)
    
table_df = runtime_table()
display(table_df)
print_table_latex(table_df)