# Ablation Experiments

The experiments below compare variants of our proposed Wɪᴛᴀɴ method to justify decisions in the design of the algorithm.

In [None]:
import numpy as np
import pandas as pd

from witan_experiments import (is_cached,
                               save_to_cache,
                               load_from_cache,
                               run_experiments)
from witan_experiments.evaluation import (summarise_experiments,
                                          build_metric_df,
                                          display_metric_table)
from witan_experiments.config import prepare_experiment_configs
from witan_experiments.rule_seeders import BlankRS
from witan_experiments.rule_generators import TrueRG, WitanRG
from witan_experiments.labellers import SnorkelLblr
from witan_experiments.models import AnnClf

## Experiments

In [None]:
ruleset_generators = {
    'Wɪᴛᴀɴ': WitanRG(),
    'Core': WitanRG(a=False, o=1),
    'Without ANDs': WitanRG(a=False),
    'Without ORs': WitanRG(o=1),
    'Without GE': WitanRG(ge=1),
    'With feedback': WitanRG(f=True),
    'Full supervision': TrueRG(),
}

base_config = dict(
    rule_seeder=[BlankRS()],
    rngseed=[1],
    ruleset_generator=list(ruleset_generators.values()),
    interaction_count=[25, 100],
    labeller=[SnorkelLblr()],
    classifier=[AnnClf()],
)

datasets = [
    'imdb',
    'bias_pa',
    'bias_pt',
    'bias_jp',
    'bias_pp',
    'amazon',
    'yelp',
    'plots',
    'fakenews',
    'binary_dbpedia',
    'binary_agnews',
    'airline_tweets',
    'damage',
    'spam',
    'twentynews',
    'dbpedia',
    'agnews',
    'nyttopics',
]
dataset_configs = {
    dataset: [
        *prepare_experiment_configs(**base_config, dataset_name=[dataset]),
    ]
    for dataset in datasets
}

In [None]:
CACHE_KEY = 'ablation-experiments'

if not is_cached(CACHE_KEY):
    dfs = []
    for dataset, configs in dataset_configs.items():
        print(f'\nRunning experiments for: {dataset}')
        dataset_results = run_experiments(
            configs,
            default_workers=2,
            rule_workers=4,
            continue_on_failure=False,
        )
        dfs.append(summarise_experiments(dataset_results, workers=8))
    df = pd.concat(dfs)
    save_to_cache(CACHE_KEY, df)

df = load_from_cache(CACHE_KEY)

In [None]:
table_df = df[df['interaction_count'].isin([25, 100])]

## F1 Score Results

In [None]:
f1_df = build_metric_df(table_df, method='ruleset_generator', metric='test_macro_f1',
                        labelled_methods=ruleset_generators)
table = display_metric_table(f1_df, baseline_label='Wɪᴛᴀɴ')
display(table)
print(table.to_latex(multirow_align='t', convert_css=True))