# Mean absolute error as IAA measure

In [1]:
import pandas as pd
from sklearn.metrics import mean_absolute_error
from itertools import permutations, combinations

import sys
sys.path.insert(0, '..')
from utils.config import PATHS
from utils.data_process import drop_disregard, fix_week_14, pad_sen_id

## Load and pre-process data

In [2]:
# load annotations

path = PATHS.getpath('data') / 'iaa'

df = pd.concat([pd.read_pickle(fp) for fp in path.glob('*_parsed.pkl')], ignore_index=True)

In [3]:
# load batch info

path = PATHS.getpath('data_to_inception_conll')

info = pd.DataFrame()

for batch in df.batch.unique():
    batch_info = pd.read_pickle(path / f'{batch}.pkl').assign(
        batch = batch,
    )
    info = info.append(batch_info, ignore_index=True)

In [4]:
# select IAA files

iaa = info.query("samp_meth == 'kwd_iaa'").NotitieID.unique()
df = df.loc[df.NotitieID.isin(iaa)]

In [5]:
# process annotations

domains = ['ADM', 'ATT', 'BER', 'ENR', 'ETN', 'FAC', 'INS', 'MBW', 'STM']
levels = [f"{domain}_lvl" for domain in domains]
other = ['target', 'background', 'plus']

df = df.assign(
    background_sent = lambda df: df.groupby('sen_id').background.transform('any'),
    target_sent = lambda df: df.groupby('sen_id').target.transform('any'),
    disregard_note = lambda df: df.groupby('NotitieID').disregard.transform('any'),
    pad_sen_id = df.sen_id.apply(pad_sen_id),
).pipe(fix_week_14)

df[domains + other] = df[domains + other].fillna(False)

In [6]:
# replace annotaotr names with alias (first letter)

df.annotator = df.annotator.str[0]

## Create sentence-level labels

In [7]:
# select notes that were annotated by all 6 annotators
# (needed since some annotators skipped some of the batches);
# create a sentence-level label for each domain

sent_labels = df.groupby(['annotator', 'pad_sen_id'])[levels].mean().unstack(0).stack(0)

## Pairwise metrics

In [8]:
pairs = list(permutations(sent_labels.columns, r=2))
combis = list(combinations(sent_labels.columns, r=2))

In [9]:
cols = [
    'annotator1',
    'annotator2',
    'level',
    'support',
    'mae',
]
classreport = pd.DataFrame(columns=cols)
for annotator1, annotator2 in combis:
    for level in levels:
        ys = sent_labels.xs(level, level=1)[[annotator1, annotator2]].dropna()
        if ys.empty:
            continue
        data = pd.DataFrame(index=[0]).assign(
            annotator1 = annotator1,
            annotator2 = annotator2,
            level = level,
            support = len(ys),
            mae = mean_absolute_error(ys[annotator1], ys[annotator2]),
        )
        classreport = classreport.append(data, ignore_index=True)

## Average F1-score per domain

In [10]:
classreport.groupby('level').agg({
    'mae': lambda s: round(s.mean(), 2),
    'support': lambda s: [s.min(), s.max()],
}).T

level,ADM_lvl,ATT_lvl,BER_lvl,ENR_lvl,ETN_lvl,FAC_lvl,INS_lvl,MBW_lvl,STM_lvl
mae,0.25,0.32,0.38,0.39,0.28,0.17,0.3,0.32,0.31
support,"[15, 26]","[1, 3]","[1, 3]","[7, 11]","[4, 16]","[7, 14]","[1, 10]","[1, 4]","[6, 19]"


## Average F1-score per domain

In [11]:
table = classreport.set_index(['annotator1', 'annotator2', 'level']).unstack(-1).xs('mae', axis=1, level=0).reset_index().rename_axis('', axis=1)
table.loc[15,'ADM_lvl':] = table.loc[:,'ADM_lvl':].mean()
table.loc[15,'annotator1'] = 'mean'
table.loc[15,'annotator2'] = ''

table.style.background_gradient(cmap='Greens', axis=None).format({level:"{:.2}" for level in levels}, na_rep='N/A').apply(lambda s: len(s) * ["background: black; color: white"],subset=15)

Unnamed: 0,annotator1,annotator2,ADM_lvl,ATT_lvl,BER_lvl,ENR_lvl,ETN_lvl,FAC_lvl,INS_lvl,MBW_lvl,STM_lvl
0,a,k,0.23,,0.0,0.2,0.27,0.0,0.0,0.5,0.22
1,a,m,0.19,0.0,0.5,0.44,0.56,0.27,0.3,0.0,0.44
2,a,o,0.18,0.0,1.0,0.43,0.33,0.33,0.0,0.0,0.39
3,a,s,0.2,0.33,0.5,0.44,0.5,0.18,0.4,0.5,0.49
4,a,v,0.32,0.5,0.0,0.18,0.25,0.17,0.6,0.33,0.29
5,k,m,0.34,1.0,1.0,0.44,0.2,0.0,0.0,0.5,0.26
6,k,o,0.37,,1.0,0.29,0.25,0.0,0.0,0.5,0.048
7,k,s,0.33,1.0,1.0,0.25,0.33,0.1,0.0,0.67,0.35
8,k,v,0.26,1.0,,0.18,0.2,0.22,0.6,0.75,0.33
9,m,o,0.091,0.0,0.0,0.6,0.33,0.1,0.0,0.0,0.12
