In [1]:
import _base_path
import pickle
import json
import numpy as np
import pandas as pd
from resources.data_io import load_mappings
from sklearn.metrics import f1_score

Setting base bath to "c:\Users\Korbi\Desktop\CICLe"


In [2]:
DATA      = 'incidents'
MODELS    = ['roberta-base', 'xlm-roberta-base']
LANGUAGES = ['en', 'de']
LABEL     = 'hazard-category'

# Load Class-Mappings:

In [3]:
class_map = load_mappings(f"../data/{DATA}/splits/", LABEL)
class_map

array(['allergens', 'biological', 'chemical',
       'food additives and flavourings', 'food contact materials',
       'foreign bodies', 'fraud', 'migration', 'organoleptic aspects',
       'other hazard', 'packaging defect'], dtype=object)

In [4]:
with open(f'../data/{DATA}/support_zones.json', 'r') as file:
    high_support, low_support = json.load(file)[LABEL]

In [5]:
high_support

['biological']

In [6]:
low_support

['foreign bodies',
 'chemical',
 'fraud',
 'other hazard',
 'packaging defect',
 'organoleptic aspects',
 'food additives and flavourings',
 'migration',
 'food contact materials']

In [7]:
counts = pd.read_csv(f'../data/{DATA}/{DATA}_final.csv')[LABEL].value_counts()

class_map = list(zip(
    class_map,
    range(len(class_map)),
    [counts[c] if c in counts else 0 for c in class_map]
))
class_map.sort(key=lambda row:row[2], reverse=True)
class_map

[('biological', 1, 2558),
 ('allergens', 0, 2527),
 ('foreign bodies', 5, 943),
 ('chemical', 2, 578),
 ('fraud', 6, 527),
 ('other hazard', 9, 187),
 ('packaging defect', 10, 100),
 ('organoleptic aspects', 8, 81),
 ('food additives and flavourings', 3, 32),
 ('migration', 7, 14),
 ('food contact materials', 4, 1)]

# Load Results:

In [8]:
results = {}

for m in MODELS:
    try:
        for i in range(5):
            with open(f'../results/{m}/{m}-{LABEL}-{i:d}.pickle', 'rb') as f:
                r = pickle.load(f)

            with open(f'../data/{DATA}/splits/split_{LABEL.split("-")[0]}_{i:d}.pickle', 'rb') as f:
                l = pickle.load(f)['test']['language'].values

            for ln in np.unique(l):
                if m not in results:     results[m] = {}
                if ln not in results[m]: results[m][ln] = []

                results[m][ln].append({
                    'labels':      r['labels'][l==ln],
                    'predictions': r['predictions'][l==ln]
                })

    except FileNotFoundError: continue

In [9]:
def calculate_metrics(classes=[c for c, _, _ in class_map]):
    classes = [i for c, i, _ in class_map if c in classes]
    metrics = {}

    for model in results:
        metrics[model] = {language: np.empty(5, dtype=float) for language in LANGUAGES}

        for language in LANGUAGES:

            for i, r in enumerate(results[model][language]):
                mask = np.vectorize(lambda c: c in classes)(r['labels'])
                y_true = r['labels'][mask]
                y_pred = r['predictions'][mask]

                metrics[model][language][i] = f1_score(y_true, y_pred, average='macro', zero_division=0.0)

    return metrics

In [10]:
metrics_all = calculate_metrics()
metrics_high_support = calculate_metrics(high_support)
metrics_low_support = calculate_metrics(low_support)

In [11]:
def metric2latex(metrics_dict): 
    metrics = np.array([[metrics_dict[model][language] for language in LANGUAGES] for model in metrics_dict], dtype=float)
    avg     = metrics.mean(axis=-1)
    err     = np.abs(metrics - avg.reshape(avg.shape + (1,))).mean(axis=-1)

    best    = np.round(avg, 2) == np.round(np.max(avg, axis=0), 2)

    return np.vectorize(
        lambda a, e, b: f'\\cellcolor\u007Bblue!15\u007D\\footnotesize $\\bf {a:.2f}$ \\tiny $\\bf\\pm {e:.2f}$' if b else  f'\\footnotesize ${a:.2f}$ \\tiny $\\pm {e:.2f}$'
    )(avg, err, best)

In [12]:
ltx_all = metric2latex(metrics_all)
ltx_hs  = metric2latex(metrics_high_support)
ltx_ls  = metric2latex(metrics_low_support)

for i, model in enumerate(MODELS):
    row =  f'{model.upper()} &\n'

    if model in metrics_all:            row += ' & '.join(ltx_all[i])
    else:                               row += ' &'*(len(LANGUAGES)-1)
    row += ' &\n'

    if model in metrics_high_support:   row += ' & '.join(ltx_hs[i])
    else:                               row += ' &'*(len(LANGUAGES)-1)
    row += ' &\n'

    if model in metrics_low_support:    row += ' & '.join(ltx_ls[i])
    else:                               row += ' &'*(len(LANGUAGES)-1)
    row += ' \\\\\n'

    print(row)

ROBERTA-BASE &
\cellcolor{blue!15}\footnotesize $\bf 0.64$ \tiny $\bf\pm 0.06$ & \cellcolor{blue!15}\footnotesize $\bf 0.22$ \tiny $\bf\pm 0.03$ &
\cellcolor{blue!15}\footnotesize $\bf 0.14$ \tiny $\bf\pm 0.03$ & \footnotesize $0.14$ \tiny $\pm 0.03$ &
\cellcolor{blue!15}\footnotesize $\bf 0.48$ \tiny $\bf\pm 0.05$ & \cellcolor{blue!15}\footnotesize $\bf 0.16$ \tiny $\bf\pm 0.03$ \\

XLM-ROBERTA-BASE &
\cellcolor{blue!15}\footnotesize $\bf 0.64$ \tiny $\bf\pm 0.05$ & \footnotesize $0.21$ \tiny $\pm 0.02$ &
\cellcolor{blue!15}\footnotesize $\bf 0.14$ \tiny $\bf\pm 0.01$ & \cellcolor{blue!15}\footnotesize $\bf 0.17$ \tiny $\bf\pm 0.02$ &
\cellcolor{blue!15}\footnotesize $\bf 0.48$ \tiny $\bf\pm 0.05$ & \footnotesize $0.15$ \tiny $\pm 0.02$ \\

