In [8]:
import numpy as np

In [70]:
# modified from https://github.com/jasonwu0731/trade-dst/

# gets metrics for a all dialogue conversations
# all_diag_pred: dictionary of conversation_name -> diag_pred
# all_diag_gold: dictionary of conversation_name -> diag_gold
# diag_pred: list of predicted slots per dialogue turn (list of sets of tuples)
# diag_gold: list of gold slots per dialogue turn (list of sets of tuples)
# slot_temp: set of slots we are concerned about (template)
def evaluate_metrics(all_diag_pred, all_diag_gold, slot_temp):
    total, turn_acc, joint_acc, F1_pred, F1_count = 0, 0, 0, 0, 0
    for fname, diag_pred in all_diag_pred:
        diag_gold = all_diag_gold[fname]
        for t in range(len(diag_gold)):
            curr_pred = diag_pred[t]
            curr_gold = diag_gold[t] 
            if curr_pred == curr_gold:
                joint_acc += 1
            total += 1

            # Compute prediction slot accuracy
            temp_acc = self.compute_acc(curr_pred, curr_gold, slot_temp)
            turn_acc += temp_acc

            # Compute prediction joint F1 score
            temp_f1, temp_r, temp_p, count = self.compute_prf(curr_pred, curr_gold)
            F1_pred += temp_f1
            F1_count += count

    joint_acc_score = joint_acc / float(total) if total!=0 else 0
    turn_acc_score = turn_acc / float(total) if total!=0 else 0
    F1_score = F1_pred / float(F1_count) if F1_count!=0 else 0
    return joint_acc_score, F1_score, turn_acc_score

# gold, pred: sets of tuples of slot, value
# slot_temp: set of slots we are concerned about (template)
def compute_acc(gold, pred, slot_temp):
    miss_gold = 0
    miss_slot = []
    for g in gold:
        if g not in pred:
            miss_gold += 1
            miss_slot.append(g[0])
    wrong_pred = 0
    for p in pred:
        if p not in gold and p[0] not in miss_slot:
            wrong_pred += 1
    ACC_TOTAL = len(slot_temp)
    ACC = len(slot_temp) - miss_gold - wrong_pred
    ACC = ACC / float(ACC_TOTAL)
    return ACC

# gold, pred: sets of tuples of slot, value
def compute_prf(gold, pred):
    TP, FP, FN = 0, 0, 0
    if len(gold)!= 0:
        count = 1
        for g in gold:
            if g in pred:
                TP += 1
            else:
                FN += 1
        for p in pred:
            if p not in gold:
                FP += 1
        precision = TP / float(TP+FP) if (TP+FP)!=0 else 0
        recall = TP / float(TP+FN) if (TP+FN)!=0 else 0
        F1 = 2 * precision * recall / float(precision + recall) if (precision+recall)!=0 else 0
    else:
        if len(pred)==0:
            precision, recall, F1, count = 1, 1, 1, 1
        else:
            precision, recall, F1, count = 0, 0, 0, 1
    return F1, recall, precision, count

def compute_bleu(gold, pred):
    from sacrebleu.metrics import BLEU
    def get_bleu(gold, pred):
        # return BLEU, bleu-1, -2, -3, -4
        x = bleu.corpus_score([pred], [[gold]])
        b1, b2, b3, b4 = [float(s) for s in x._verbose.split()[0].split('/')]
        return {'bleu': [x.score], 'bleu-1': [b1], 'bleu-2': [b2], 'bleu-3': [b3], 'bleu-4': [b4]}
    
    bleu = BLEU()
    miss_gold = 0
#     miss_slot = []
    scorer = {'bleu': [], 'bleu-1': [], 'bleu-2': [], 'bleu-3': [], 'bleu-4': []}
    for g in gold:
        if g not in pred:
            miss_gold += 1
#             miss_slot.append(g[0])
#     wrong_pred = 0
    for s, v in pred:
        d_gold = dict(gold)
        if s in d_gold:
            bleus = get_bleu(v, d_gold[s])
            scorer = {key:scorer.get(key,[])+bleus.get(key,[]) 
                      for key in set(list(scorer.keys())+list(bleus.keys()))}
    for i in range(miss_gold):
        scorer = {k: scorer[k] + [0.] for k in scorer}
    return {k: np.mean(scorer[k]) for k in scorer} #, scorer



In [71]:
slot_temp = {"object_stolen", "location", "time_start", "time_end"}

gold = {("object_stolen", "bike lights (both front and rear)"), ("location", "in front of [ORG]"),
        ("time_start", "This evening ([DATE]), between #:##"), ("time_end", "#:##")}

pred = {("location", "in front of [ORG]"), ("object_stolen", "bike lights"), ("time_end", "This evening")}

In [72]:
print(compute_acc(gold, pred, slot_temp))
print(compute_prf(gold, pred))

0.25
(0.28571428571428575, 0.25, 0.3333333333333333, 1)


In [73]:
print(compute_bleu(gold, pred))

{'bleu-3': 18.05, 'bleu-4': 17.5, 'bleu': 18.507465927846496, 'bleu-1': 20.833333333333332, 'bleu-2': 19.05}
