In [82]:
import json
# from sklearn.metrics import f1_score, accuracy_score
# import numpy as np
from utils.Constants import SLOT_VALS
from utils.dst import ignore_none, default_cleaning, IGNORE_TURNS_TYPE2
import argparse
import sys
sys.argv = ['test_args.py']


parser = argparse.ArgumentParser()
parser.add_argument('--eval_file', default=str,
                    help='evaluate file name (json)')
parser.add_argument('--default_cleaning', action='store_true',
                    help='use default cleaning from multiwoz')
parser.add_argument('--type2_cleaning', action='store_true',
                    help='use type 2 cleaning, refer to [https://arxiv.org/abs/2005.00796]')
args = parser.parse_args()
args.eval_file = "test_eval_test_nocarry.json"
data = json.load(open(args.eval_file, 'r'))

In [83]:
data["SNG0073.json"]["generated_turn_belief"][3][1] = data["PMUL3688.json"]["generated_turn_belief"][3][1] + "1"

In [84]:
data["SNG0073.json"]["generated_turn_belief"][2][1] = "some thing wrong"

In [85]:
data = {
    "a": data["SNG0073.json"]
}

In [86]:

slot_template = {slot:"" for slot in SLOT_VALS}
def get_slot_map(slot_triplet_str_list):
    slot_map = slot_template.copy()
    for slot_triplet_str in slot_triplet_str_list:
        slot_triplets = slot_triplet_str.split()
        key = slot_triplets[0] + " " + slot_triplets[1]
        val = slot_triplets[2]
        if key not in SLOT_VALS:
            continue
        slot_map[key] = val
    return slot_map

def get_unique_slot_map(preds, targets):
    unique_slots = set()
    pred_map = {}
    target_map = {}
    
    for pred_str in preds:
        triplet = pred_str.split()
        key = triplet[0] + " " + triplet[1]
        val = triplet[2]
        pred_map[key] = val
        unique_slots.add(key)
    
    for target_str in targets:
        triplet = target_str.split()
        key = triplet[0] + " " + triplet[1]
        val = triplet[2]
        target_map[key] = val
        unique_slots.add(key)
    
    return unique_slots.copy(), pred_map.copy(), target_map.copy()
        

In [87]:
from tqdm.auto import tqdm
num_turns = 0
joint_acc = 0
slot_acc = 0
r_slot_acc = 0

num_slots = len(SLOT_VALS)
num_r_slots = 0

clean_tokens = ['<|endoftext|>']

for dial in tqdm(data):
    dialogue_pred = data[dial]['generated_turn_belief']
    dialogue_target = data[dial]['target_turn_belief']

    for turn_id, (turn_target, turn_pred) in enumerate(zip(dialogue_target, dialogue_pred)):
        
        # clean
        for bs in turn_pred:
            if bs in clean_tokens + ['', ' '] or bs.split()[-1] == 'none':
                turn_pred.remove(bs)
                
        new_turn_pred = []
        for bs in turn_pred:
            for tok in clean_tokens:
                bs = bs.replace(tok, '').strip()
                new_turn_pred.append(bs)
        turn_pred = new_turn_pred

        turn_pred, turn_target = ignore_none(turn_pred, turn_target)
        
        # MultiWOZ default cleaning
        if args.default_cleaning:
            turn_pred, turn_target = default_cleaning(turn_pred, turn_target)

        join_flag = False
        
        # calculate joint accuracy
        if set(turn_target) == set(turn_pred):
            joint_acc += 1
            join_flag = True
        
        pred_slot_map = get_slot_map(turn_pred)
        target_slot_map = get_slot_map(turn_target)
        
        # calculate slot accuracy
        for slot_key in SLOT_VALS:
            if target_slot_map[slot_key] == pred_slot_map[slot_key]:
                slot_acc += 1
                
        # calculate relative slot accuracy
        unique_slots, unique_pred_map, unique_target_map = get_unique_slot_map(turn_pred, turn_target)
        for slot_key in unique_slots:
            if slot_key not in unique_target_map.keys(): continue
            if slot_key not in unique_pred_map.keys(): continue
            if unique_target_map[slot_key] == unique_pred_map[slot_key]:
                r_slot_acc += 1
        num_r_slots += len(unique_slots)
        
            
#         elif args.type2_cleaning: # check for possible Type 2 noisy annotations
#             flag = True
#             for bs in turn_target:
#                 if bs not in turn_pred:
#                     flag = False
#                     break
#             if flag:
#                 for bs in turn_pred:
#                     if bs not in dialogue_target_final:
#                         flag = False
#                         break

#             if flag: # model prediction might be correct if found in Type 2 list of noisy annotations
#                 dial_name = dial.split('.')[0]
#                 if dial_name in IGNORE_TURNS_TYPE2 and turn_id in IGNORE_TURNS_TYPE2[dial_name]: # ignore these turns
#                     pass
#                 else:
#                     joint_acc += 1
#                     join_flag = True

        num_turns += 1

  0%|          | 0/1 [00:00<?, ?it/s]

In [88]:
print('joint accuracy: {}'.format(joint_acc / num_turns))

joint accuracy: 0.5


In [89]:
total_slot_num = num_slots * num_turns
slot_acc /= total_slot_num

print('slot accuracy: {}'.format(slot_acc))

slot accuracy: 0.9758064516129032


In [90]:
r_slot_acc /= num_r_slots

print('relative slot accuracy: {}'.format(r_slot_acc))

relative slot accuracy: 0.6923076923076923
