# Multi-class Experiments

In [None]:
import numpy as np
import pandas as pd

from witan_experiments import (is_cached,
                               save_to_cache,
                               load_from_cache,
                               run_experiments)
from witan_experiments.evaluation import (summarise_experiments,
                                          metric_line_grid,
                                          build_metric_df,
                                          display_metric_table,
                                          median_stds_df)
from witan_experiments.config import prepare_experiment_configs
from witan_experiments.rule_seeders import BlankRS, ClassSubsetAccRS
from witan_experiments.rule_generators import (TrueRG,
                                               WitanRG,
                                               SnubaRG,
                                               SemiSupervisedRG,
                                               ActiveLearningRG,
                                               CbiRG,
                                               HdcRG)
from witan_experiments.labellers import SnorkelLblr
from witan_experiments.models import AnnClf

In [None]:
ruleset_generators = {
    'Full supervision': TrueRG(),
    'Wɪᴛᴀɴ': WitanRG(),
    'Wɪᴛᴀɴ-Core': WitanRG(a=False, o=1),
    'Snuba': SnubaRG(),
    'HDC': HdcRG(),
    'CBI': CbiRG(clf=AnnClf()),
    'Semi-supervised': SemiSupervisedRG(),
    'Active learning': ActiveLearningRG(clf=AnnClf(), init_count=0),
}

base_config = dict(
    rule_seeder=[BlankRS()],
    rngseed=[1, 2, 3, 4, 5],
    interaction_count=[10, 25, 50, 100, 150, 200],
    labeller=[SnorkelLblr()],
    classifier=[AnnClf()],
)

parallel_configs = [
    # Unseeded ruleset_generators that are not affected by rngseed
    # (so we only need to execute them for the first rngseed)
    {
        **base_config,
        **dict(
            rngseed=base_config['rngseed'][:1],
            ruleset_generator=[
                ruleset_generators['Full supervision'],
                ruleset_generators['Wɪᴛᴀɴ'],
                ruleset_generators['Wɪᴛᴀɴ-Core'],
                ruleset_generators['HDC'],
            ],
        ),
    },
    # Unseeded ruleset_generators
    {
        **base_config,
        **dict(
            ruleset_generator=[
                ruleset_generators['Snuba'],
                ruleset_generators['Semi-supervised'],
                ruleset_generators['Active learning'],
            ],
        ),
    },
    # Seeded ruleset_generators
    {
        **base_config,
        **dict(
            rule_seeder=[ClassSubsetAccRS(c=2)],
            ruleset_generator=[
                ruleset_generators['CBI'],
            ],
        ),
    },
]

datasets = [
    'twentynews',
    'dbpedia',
    'agnews',
    'nyttopics',
]
dataset_configs = {
    dataset: {
        'parallel_configs': [
            experiment_config
            for parallel_config in parallel_configs
            for experiment_config in prepare_experiment_configs(
                **parallel_config,
                dataset_name=[dataset],
            )
        ],
    }
    for dataset in datasets
}

In [None]:
CACHE_KEY = 'multiclass-experiments'
CONTINUE_ON_FAILURE = False

if not is_cached(CACHE_KEY):
    dfs = []
    for dataset, configs in dataset_configs.items():
        print(f'\nRunning experiments for: {dataset}')
        dataset_parallel_results = run_experiments(
            configs['parallel_configs'],
            default_workers=3,
            rule_workers=4,
            continue_on_failure=CONTINUE_ON_FAILURE,
        )
        dfs.append(summarise_experiments(dataset_parallel_results, workers=8))
    df = pd.concat(dfs)
    save_to_cache(CACHE_KEY, df)

df = load_from_cache(CACHE_KEY)

## F1 Score Plots

In [None]:
legend_label_suffix = '  '
fig = metric_line_grid(
    df,
    metric='test_macro_f1',
    facet_col='dataset_name',
    ruleset_generators=ruleset_generators,
    legend_y=1.08,
    facet_row_spacing=0.15,
    facet_col_spacing=0.1,
    legend_label_suffix=legend_label_suffix,
    category_orders={
        'dataset_name': ['TWN', 'DBP', 'AGN', 'NYT'],
        'ruleset_generator': [rg + legend_label_suffix for rg in ruleset_generators.keys()],
    },
)
fig.write_image('plots/multi-f1-lines.svg')
fig.show()

### Standard Deviations

In [None]:
f1_std_df = build_metric_df(df,
                            method='ruleset_generator',
                            metric='test_macro_f1',
                            rngseed_agg='std',
                            labelled_methods=ruleset_generators)
display(display_metric_table(f1_std_df))

#### Median Standard Deviations

In [None]:
display(median_stds_df(f1_std_df, datasets=datasets, ics=base_config['interaction_count']))