In [1]:
import numpy as np
import matplotlib.pyplot as plt

from collections import defaultdict
from sklearn.metrics import f1_score

import utils

H_ONLY='../nlunetwork/results/tuning/optimized/h_only/conf_4/huric/modern_right/'
FN_FT='../nlunetwork/results/tuning/optimized/fn_only_hyper_h/huric/modern_right/'
FN_LU='../nlunetwork/results/tuning/optimized/fn_lu_only_hyper_h/huric/modern_right/'
H_FN_FT='../nlunetwork/results/tuning/optimized/h_and_fn_ft/conf_4/huric/with_framenet_ft/'
H_FN_LU='../nlunetwork/results/tuning/optimized/h_and_fn_lu/conf_4/huric/with_framenet_lu/'

samples_h = utils.load_json(H_ONLY,16)
#samples_fn_ft = utils.load_json(FN_FT, 0)
samples_fn = utils.load_json(FN_LU, 0)
#samples_h_fn_ft = utils.load_json(H_FN_FT,48)
samples_h_fn = utils.load_json(H_FN_LU,25)

# also the XML stuff
HURIC_LOCATION = '../data/huric/modern/source'
gold_attn = utils.load_attention_gold('../data/huric/modern_right/lu_disc_final.tsv')

In [2]:
def eval_align_tops(pred_attn_vals, gold_one_hot, args):
    # how many targets are there (non-zero elements)
    n_targets = sum(gold_one_hot)
    # take n_targets best values from the predicted attention values
    if n_targets:
        pred_target_indexes = np.argpartition(pred_attn_vals, -n_targets)[-n_targets:]
    else:
        # just be careful when n_targets is zero, argpartition would return all the indexes
        pred_target_indexes = []
    # now from the indexes found, get the 0/1 representation
    pred_targets = [1 if el in pred_target_indexes else 0 for el in range(len(pred_attn_vals))]
    #print(n_targets, pred_target_indexes, gold_one_hot, pred_targets)
    # and get the f1 score only on the 1 class (average is binary)    
    f1 = f1_score(gold_one_hot, pred_targets, average='binary')
    return f1

def eval_align_treshold(pred_attn_vals, gold_one_hot, args):
    treshold = args['treshold']
    pred = [1 if el > treshold else 0 for el in pred_attn_vals]
    f1 = f1_score(gold_one_hot, pred, average='binary')
    return f1

def evaluate_attn(samples, golds, eval_fn, eval_args):
    total = defaultdict(lambda: 0)
    total_only_correct = defaultdict(lambda: 0)
    correct_frame_cnt = 0
    for s in samples:
        #print(s['id'])
        gold = golds[str(s['id'])]
        for what in ['lu', 'lu+disc', 'lu+disc2']:
            gold_one_hot = gold[what]
            measure = eval_fn(s['intent_attentions'], gold_one_hot, eval_args)
            total[what] += measure
        if gold['frame'] == s['intent_pred']:
            correct_frame_cnt += 1
            for what in ['lu', 'lu+disc', 'lu+disc2']:
                gold_one_hot = gold[what]
                measure = eval_fn(s['intent_attentions'], gold_one_hot, eval_args)
                total_only_correct['only_correct_frame_{}'.format(what)] += measure
    for what, val in total.items():
        print(what, val/len(samples))
    for what, val in total_only_correct.items():
        print(what, val/correct_frame_cnt)
        

In [3]:
evaluate_attn(samples_h, gold_attn, eval_align_tops, {})

lu 0.0636998254799302
lu+disc 0.1273996509598604
lu+disc2 0.14761489237929026
only_correct_frame_lu 0.0625
only_correct_frame_lu+disc 0.12409420289855072
only_correct_frame_lu+disc2 0.1444746376811594


In [4]:
evaluate_attn(samples_fn, gold_attn, eval_align_tops, {})

lu 0.9511343804537522
lu+disc 0.8685282140779523
lu+disc2 0.8454043048283887
only_correct_frame_lu 0.9923076923076923
only_correct_frame_lu+disc 0.8884615384615384
only_correct_frame_lu+disc2 0.8867521367521367


In [5]:
evaluate_attn(samples_h_fn, gold_attn, eval_align_tops, {})

lu 0.44851657940663175
lu+disc 0.45898778359511344
lu+disc2 0.4714950552646888
only_correct_frame_lu 0.4601593625498008
only_correct_frame_lu+disc 0.4677954847277556
only_correct_frame_lu+disc2 0.48240371845949526


In [6]:
# treshold-based measurements
for treshold in [0.01, 0.05, 0.08, 0.1]:
    print('treshold',treshold)
    evaluate_attn(samples_h, gold_attn, eval_align_treshold, {'treshold': treshold})

treshold 0.01
lu 0.11176071359317429
lu+disc 0.18781142262817688
lu+disc2 0.20534540115691982
only_correct_frame_lu 0.1114834943639291
only_correct_frame_lu+disc 0.1859056402534666
only_correct_frame_lu+disc2 0.20381022962544731
treshold 0.05
lu 0.08232361007230111
lu+disc 0.15999335161638842
lu+disc2 0.17728330424665523
only_correct_frame_lu 0.08153036576949615
only_correct_frame_lu+disc 0.15780710835058678
only_correct_frame_lu+disc2 0.1751509661835751
treshold 0.08
lu 0.07963932518906336
lu+disc 0.1557176099060917
lu+disc2 0.1740255962769054
only_correct_frame_lu 0.07874396135265696
only_correct_frame_lu+disc 0.15312715665976548
only_correct_frame_lu+disc2 0.17152777777777803
treshold 0.1
lu 0.08027923211169281
lu+disc 0.1520236017618218
lu+disc2 0.17116263608410223
only_correct_frame_lu 0.07940821256038644
only_correct_frame_lu+disc 0.14929261559696355
only_correct_frame_lu+disc2 0.16855590062111817


In [7]:
for treshold in [0.01, 0.05, 0.08, 0.1]:
    print('treshold',treshold)
    evaluate_attn(samples_fn, gold_attn, eval_align_treshold, {'treshold': treshold})

treshold 0.01
lu 0.9392088423502045
lu+disc 0.8661431064572455
lu+disc2 0.8464805119255415
only_correct_frame_lu 0.9717948717948723
only_correct_frame_lu+disc 0.8850427350427362
only_correct_frame_lu+disc2 0.8824786324786337
treshold 0.05


  'precision', 'predicted', average, warn_for)


lu 0.9421175101803378
lu+disc 0.8686445607911604
lu+disc2 0.8489819662594564
only_correct_frame_lu 0.9820512820512824
only_correct_frame_lu+disc 0.8948717948717961
only_correct_frame_lu+disc2 0.8923076923076936
treshold 0.08
lu 0.9450261780104716
lu+disc 0.871553228621294
lu+disc2 0.85189063408959
only_correct_frame_lu 0.9854700854700857
only_correct_frame_lu+disc 0.8982905982905994
only_correct_frame_lu+disc2 0.8957264957264969
treshold 0.1
lu 0.9456079115764984
lu+disc 0.8721349621873208
lu+disc2 0.8524723676556168
only_correct_frame_lu 0.9863247863247865
only_correct_frame_lu+disc 0.8991452991453004
only_correct_frame_lu+disc2 0.8965811965811978


In [8]:
for treshold in [0.01, 0.05, 0.08, 0.1]:
    print('treshold',treshold)
    evaluate_attn(samples_h_fn, gold_attn, eval_align_treshold, {'treshold': treshold})

treshold 0.01
lu 0.5131760521812879
lu+disc 0.5025441787221888
lu+disc2 0.5083462785033462
only_correct_frame_lu 0.5271577913012172
only_correct_frame_lu+disc 0.5157747520297311
only_correct_frame_lu+disc2 0.5243120314833453
treshold 0.05
lu 0.4916562785672731
lu+disc 0.4797390509432387
lu+disc2 0.486312640239341
only_correct_frame_lu 0.5047524188958447
only_correct_frame_lu+disc 0.49400493265034984
only_correct_frame_lu+disc2 0.5030354771390616
treshold 0.08
lu 0.4801628853984873
lu+disc 0.4688273913404795
lu+disc2 0.475400980636582
only_correct_frame_lu 0.4949535192563077
only_correct_frame_lu+disc 0.4836748245114769
only_correct_frame_lu+disc2 0.4927053690001887
treshold 0.1
lu 0.4724258289703315
lu+disc 0.4598603839441527
lu+disc2 0.46643397324025493
only_correct_frame_lu 0.4861221779548469
only_correct_frame_lu+disc 0.47410358565736965
only_correct_frame_lu+disc2 0.4831341301460814
