# Multi-class Experiments

In [None]:
import numpy as np
import pandas as pd

from witan_experiments import (is_cached,
                               save_to_cache,
                               load_from_cache,
                               run_experiments)
from witan_experiments.evaluation import (summarise_experiments,
                                          metric_line_grid)
from witan_experiments.config import prepare_experiment_configs
from witan_experiments.rule_seeders import BlankRS
from witan_experiments.rule_generators import (TrueRG,
                                               WitanRG,
                                               SnubaRG,
                                               SemiSupervisedRG,
                                               ActiveLearningRG)
from witan_experiments.labellers import SnorkelLblr
from witan_experiments.models import AnnClf

In [None]:
ruleset_generators = {
    'Full supervision': TrueRG(),
    'Wɪᴛᴀɴ': WitanRG(),
    'Wɪᴛᴀɴ-Core': WitanRG(a=False, o=1),
    'Snuba': SnubaRG(),
    'Semi-supervised': SemiSupervisedRG(),
    'Active learning': ActiveLearningRG(clf=AnnClf(), init_count=0),
}

base_config = dict(
    rule_seeder=[BlankRS()],
    rngseed=[1],
    ruleset_generator=list(ruleset_generators.values()),
    interaction_count=[10, 25, 50, 100, 150, 200],
    labeller=[SnorkelLblr()],
    classifier=[AnnClf()],
)

datasets = [
    'twentynews',
    'dbpedia',
    'agnews',
    'nyttopics',
]
dataset_configs = {
    dataset: [
        *prepare_experiment_configs(**base_config, dataset_name=[dataset]),
    ]
    for dataset in datasets
}

In [None]:
CACHE_KEY = 'multiclass-experiments'

if not is_cached(CACHE_KEY):
    results = {}
    for dataset, configs in dataset_configs.items():
        print(f'\nExperiments for: {dataset}')
        results.update(
            run_experiments(
                configs,
                default_workers=3,
                rule_workers=1,
                continue_on_failure=False,
            )
        )
    df = summarise_experiments(results)
    save_to_cache(CACHE_KEY, df)

df = load_from_cache(CACHE_KEY)

In [None]:
table_df = df[
    df['interaction_count'].isin([25, 100]) &
    ~df['ruleset_generator'].isin([ruleset_generators['Full supervision']])
]

## F1 Score Plots

In [None]:
fig = metric_line_grid(
    df,
    metric='test_macro_f1',
    facet_col='dataset_name',
    ruleset_generators=ruleset_generators,
    legend_y=1.08,
    facet_row_spacing=0.15,
    facet_col_spacing=0.1,
    legend_label_suffix='  ',
)
fig.write_image('plots/multi-f1-lines.svg')
fig.show()