In [20]:
import numpy as np
import re
import collections

In [29]:
# modified from https://github.com/jasonwu0731/trade-dst/, https://rajpurkar.github.io/SQuAD-explorer/

# gets metrics for a all dialogue conversations
# all_diag_pred: dictionary of conversation_name -> diag_pred
# all_diag_gold: dictionary of conversation_name -> diag_gold
# diag_pred: list of predicted slots per dialogue turn (list of sets of tuples)
# diag_gold: list of gold slots per dialogue turn (list of sets of tuples)
# slot_temp: set of slots we are concerned about (template)
def evaluate_metrics(all_diag_pred, all_diag_gold, slot_temp):
    total, turn_acc, joint_acc, F1_pred, F1_count = 0, 0, 0, 0, 0
    for fname, diag_pred in all_diag_pred:
        diag_gold = all_diag_gold[fname]
        for t in range(len(diag_gold)):
            curr_pred = diag_pred[t]
            curr_gold = diag_gold[t] 
            if curr_pred == curr_gold:
                joint_acc += 1
            total += 1

            # Compute prediction slot accuracy
            temp_acc = self.compute_acc(curr_pred, curr_gold, slot_temp)
            turn_acc += temp_acc

            # Compute prediction joint F1 score
            temp_f1, temp_r, temp_p, count = self.compute_prf(curr_pred, curr_gold)
            F1_pred += temp_f1
            F1_count += count

    joint_acc_score = joint_acc / float(total) if total!=0 else 0
    turn_acc_score = turn_acc / float(total) if total!=0 else 0
    F1_score = F1_pred / float(F1_count) if F1_count!=0 else 0
    return joint_acc_score, F1_score, turn_acc_score

# gold, pred: dictionaries of slot -> list of values
# slot_temp: set of slots we are concerned about (template)
# returns: (# slots which have any correct prediction) / (# slots)
def compute_acc(gold, pred, slot_temp):
    pred_correct = 0
    # total up number of slots where any gold value was predicted
    for slot in slot_temp:
        for gold_val in gold[slot]:
            if gold_val in pred[slot]:
                pred_correct += 1
                break
    
    ACC_TOTAL = len(slot_temp)
    ACC = pred_correct
    ACC = ACC / float(ACC_TOTAL)
    return ACC

# gold, pred: dictionaries of slot -> list of values
def compute_prf(gold, pred, slot_temp):
    TP, FP, FN = 0, 0, 0
    if len(gold)!= 0:
        count = 1
        for slot in slot_temp:
            for gold_val in gold[slot]:
                if gold_val in pred[slot]:
                    TP += 1
                else:
                    FN += 1
            for pred_val in pred[slot]:
                if pred_val not in gold[slot]:
                    FP += 1
        precision = TP / float(TP+FP) if (TP+FP)!=0 else 0
        recall = TP / float(TP+FN) if (TP+FN)!=0 else 0
        F1 = 2 * precision * recall / float(precision + recall) if (precision+recall)!=0 else 0
#     else:
#         if len(pred)==0:
#             precision, recall, F1, count = 1, 1, 1, 1
#         else:
#             precision, recall, F1, count = 0, 0, 0, 1
    return F1, recall, precision, count

def compute_bleu(gold, pred):
    from sacrebleu.metrics import BLEU
    def get_bleu(gold, pred):
        # return BLEU, bleu-1, -2, -3, -4
        x = bleu.corpus_score([pred], [[gold]])
        b1, b2, b3, b4 = [float(s) for s in x._verbose.split()[0].split('/')]
        return {'bleu': [x.score], 'bleu-1': [b1], 'bleu-2': [b2], 'bleu-3': [b3], 'bleu-4': [b4]}
    
    bleu = BLEU()
    miss_gold = 0
#     miss_slot = []
    scorer = {'bleu': [], 'bleu-1': [], 'bleu-2': [], 'bleu-3': [], 'bleu-4': []}
    for g in gold:
        if g not in pred:
            miss_gold += 1
#             miss_slot.append(g[0])
#     wrong_pred = 0
    for s, v in pred:
        d_gold = dict(gold)
        if s in d_gold:
            bleus = get_bleu(v, d_gold[s])
            scorer = {key:scorer.get(key,[])+bleus.get(key,[]) 
                      for key in set(list(scorer.keys())+list(bleus.keys()))}
    for i in range(miss_gold):
        scorer = {k: scorer[k] + [0.] for k in scorer}
    return {k: np.mean(scorer[k]) for k in scorer} #, scorer

# Word-based F1, as used for span prediction
# a_fold, a_pred: strings (spans of text)
def compute_f1_span(a_gold, a_pred):
    gold_toks = a_gold.split()
    pred_toks = a_pred.split()
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

# word-based F1: average of slots' F1 scores, where a slot's score is the max F1 of all gold/pred pairs
def compute_f1(gold, pred, slot_temp):
    f1s = []
    for slot in slot_temp:
        # if gold answer was no answer, f1 is 1 if pred is same or 0 if not
        if len(gold[slot]) == 0:
            f1s.append(int(len(gold[slot]) == len(pred[slot])))
        else:
            f1_max = 0
            for gold_val in gold[slot]:
                for pred_val in pred[slot]:
                    f1 = compute_f1_span(gold_val, pred_val)
                    f1_max = max(f1_max, f1)
            f1s.append(f1_max)
    print(f1s)
    return np.mean(f1s)
                

In [30]:
slot_temp = {"Target_Object-ARG", "Place-Arg", "Start_Time-Arg", "End_Time-Arg"}

gold = {"Target_Object-ARG":["bike lights (both front and rear)"], "Place-Arg": ["in front of [ORG]"],
        "Start_Time-Arg": ["This evening ([DATE]), between #:##"], "End_Time-Arg": ["#:##"]}

pred = {"Place-Arg": ["in front of [ORG]"], "Target_Object-ARG":["bike lights"], 
        "End_Time-Arg": ["This evening"], "Start_Time-Arg":[]}

In [31]:
print(compute_acc(gold, pred, slot_temp))
print(compute_prf(gold, pred, slot_temp))
print(compute_f1(gold, pred, slot_temp))

0.25
(0.28571428571428575, 0.25, 0.3333333333333333, 1)
[0, 0.5, 0, 1.0]
0.375


In [73]:
print(compute_bleu(gold, pred))

{'bleu-3': 18.05, 'bleu-4': 17.5, 'bleu': 18.507465927846496, 'bleu-1': 20.833333333333332, 'bleu-2': 19.05}


In [18]:
# turn annotation file into list of dict

f = open("../data/val_gold/event_1054347.ann")
lines = f.readlines()
slot_2_tag = {"Target_Object-ARG":[], "Place-Arg":[], "Start_Time-Arg":[], "End_Time-Arg":[]}

#make dict of tags to location and string
annot_dict = {}
for line in lines:
    toks = re.split(r'\t+', line)
    if toks[0][0] != "E":
        annot_dict[toks[0]] = (toks[1], toks[2].strip())
        
print(annot_dict)

#make dict of slot needed to tag(s)
for line in lines:
    if line[:2] == "E1":
        #found the theft event
        toks = line[2:].split()
        for tok in toks:
            tag = tok.split(":")
            if tag[0] == "Start_Time-Arg":
                slot_2_tag["Start_Time-Arg"] += [tag[1]]
            elif tag[0] == "End_Time-Arg":
                slot_2_tag["End_Time-Arg"] += [tag[1]]
            elif tag[0] == "Target_Object-ARG":
                slot_2_tag["Target_Object-ARG"] += [tag[1]]
            elif tag[0] == "Place-Arg":
                slot_2_tag["Place-Arg"] += [tag[1]]
                
f.close()

print(slot_2_tag)

#make list of dicts (entry i is slot values at utterance i)
f = open("../data/val/event_1054347.txt")
lines = f.readlines()
slots_gold = []

len_total = 0
for line in lines:
    len_total += len(line)
    slots_gold_i = {"Target_Object-ARG":[], "Place-Arg":[], "Start_Time-Arg":[], "End_Time-Arg":[]}
    # for each slot, add any gold annotated strings which are in the dialogue history
    for slot in slot_temp:
        tags = slot_2_tag[slot]
        for tag in tags:
            loc = annot_dict[tag][0]
            toks = loc.split()
            if int(toks[1]) <= len_total:
                slots_gold_i[slot] = annot_dict[tag][1]
    slots_gold.append(slots_gold_i)
    
print(slots_gold)

{'T1': ('Intent_Inform 0 4', 'User'), 'T2': ('Time 6 41', 'This evening ([DATE]), between #:##'), 'T3': ('Time 56 60', '#:##'), 'T4': ('Stolen 113 119', 'stolen'), 'T5': ('Object_Stolen 74 107', 'bike lights (both front and rear)'), 'T6': ('Location_Region-General 132 149', 'in front of [ORG]'), 'T7': ('Location_Region-General 176 219', 'sign directly in front of a security camera'), 'T8': ('Intent_Thank 0 4', 'User'), 'T9': ('Intent_Thank 498 503', 'Admin'), 'T10': ('Intent_NotifyOthersInCharge 498 503', 'Admin'), 'T11': ('Intent_Other 651 656', 'Admin'), 'T12': ('Intent_Confirm 706 710', 'User'), 'T13': ('Intent_Inform 706 710', 'User'), 'T14': ('Intent_Thank 706 710', 'User'), 'T15': ('Intent_AskForDetail_Location 851 856', 'Admin'), 'T16': ('Person_Individual 71 73', 'my'), 'T17': ('Person_Individual 124 126', 'my'), 'T18': ('Intent_Inform 887 891', 'User'), 'T19': ('Location_Region-General 913 917', 'home'), 'T20': ('Intent_AskToVisit 921 926', 'Admin'), 'T21': ('Intent_AskForDeta

1