# Baselines

In [25]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
from baselines import get_baseline
from baselines import multi_class_roc_auc, multi_class_spearman
from baselines import empirical_dist
from pprint import pprint
import pandas as pd
import numpy as np

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
d = pd.read_csv('../../data/annotations/baseline/annotations.tsv', sep='\t')
d.index = d.rev_id

In [20]:
d_b = d.query("ns=='user' and sample=='blocked'")
d_r = d.query("ns=='user' and sample=='random'")

In [27]:
iters = 1
K = 20
tasks = ['attack', 'recipient', 'aggression']
metrics = {'ROC': multi_class_roc_auc, 'spearman':multi_class_spearman}
ds = {'blocked': d_b, 'random':d_r}

In [28]:
def get_all_baselines(K, tasks, metrics, ds, iters, pairs):
    results = {}
    
    for task in tasks:
        results[task] = {}
        for d_name, d in ds.items():
            results[task][d_name] = {}
            for metric_name, metric in metrics.items():
                results[task][d_name][metric_name] = {pair:[] for pair in pairs}
                l = d[task]
                for i in range(iters):
                    pms = get_baseline(l, K, empirical_dist, metric, pairs)
                    for pair, m in pms.items():
                        results[task][d_name][metric_name][pair].append(m)
    for task in tasks:
        print('Task: ', task)
        for d_name, d in ds.items():
            print('\tData: ', d_name)
            for metric_name, metric in metrics.items():
                print('\t\tMetric: ', metric_name)
                for pair in pairs:
                    a = np.array(results[task][d_name][metric_name][pair])
                    print('\t\t\t(%d, %d): %0.3f (%0.3f)' % (pair[0], pair[1], a.mean(), a.std()))        

In [29]:
F = int(K/2)
pairs = list(zip(range(1, F+1), range(1, F+1))) + list(zip(range(1, F+1), [F]*F))

In [30]:
get_all_baselines(K, tasks, metrics, ds, 1, pairs)

Task:  attack
	Data:  random
		Metric:  spearman
			(1, 1): 0.149 (0.000)
			(2, 2): 0.183 (0.000)
			(3, 3): 0.228 (0.000)
			(4, 4): 0.250 (0.000)
			(5, 5): 0.270 (0.000)
			(6, 6): 0.289 (0.000)
			(7, 7): 0.284 (0.000)
			(8, 8): 0.287 (0.000)
			(9, 9): 0.300 (0.000)
			(10, 10): 0.304 (0.000)
			(1, 10): 0.197 (0.000)
			(2, 10): 0.226 (0.000)
			(3, 10): 0.239 (0.000)
			(4, 10): 0.252 (0.000)
			(5, 10): 0.282 (0.000)
			(6, 10): 0.295 (0.000)
			(7, 10): 0.288 (0.000)
			(8, 10): 0.295 (0.000)
			(9, 10): 0.299 (0.000)
			(10, 10): 0.304 (0.000)
		Metric:  ROC
			(1, 1): 0.573 (0.000)
			(2, 2): 0.788 (0.000)
			(3, 3): 0.762 (0.000)
			(4, 4): 0.958 (0.000)
			(5, 5): 0.891 (0.000)
			(6, 6): 0.953 (0.000)
			(7, 7): 0.950 (0.000)
			(8, 8): 0.985 (0.000)
			(9, 9): 0.966 (0.000)
			(10, 10): 0.994 (0.000)
			(1, 10): 0.690 (0.000)
			(2, 10): 0.936 (0.000)
			(3, 10): 0.857 (0.000)
			(4, 10): 0.951 (0.000)
			(5, 10): 0.937 (0.000)
			(6, 10): 0.973 (0.000)
			(7, 10): 0.9