# NLU
- Joint Goal Accuracy
- Slot Accuracy

In [585]:
import os, sys
import json
import re
import copy
import spacy
nlp = spacy.load('en_core_web_sm')

from convlab.util import load_dataset, load_ontology

In [326]:
llm_output_path = './llm_output/'
nlu_outputs = [f for f in os.listdir(llm_output_path) if 'nlu_all' in f]

In [9]:
print(nlu_outputs)

['multiwoz21_gpt-3.5-turbo_nlu_all.json', 'multiwoz21_meta-llama_Llama-2-7b-chat-hf_nlu_all.json', 'multiwoz21_meta-llama_Llama-2-13b-chat-hf_nlu_all.json', 'multiwoz21_meta-llama_Llama-2-70b-chat-hf_nlu_all.json', 'multiwoz21_gpt-4-1106-preview_nlu_all.json', 'sgd_gpt-3.5-turbo_nlu_all.json']


In [247]:
nlu_outputs[0]

'multiwoz21_gpt-3.5-turbo_nlu_all.json'

In [401]:
outputs = []
with open(os.path.join(llm_output_path, nlu_outputs[3]), 'r') as f:
  lines = f.readlines()
  for line in lines:
    outputs.append(json.loads(line))

In [107]:
print(outputs[0].keys())
print(outputs[0])

dict_keys(['id', 'predictions', 'response'])
{'id': 'multiwoz21-test-0', 'predictions': {'1': {'utter': "user: I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.", 'das': 'request|taxi|departure|<das>request|taxi|destination|'}, '2': {'utter': 'user: I want to leave after 17:15.', 'das': 'inform|taxi|leave at|17:15'}, '3': {'utter': 'user: Thank you for all the help! I appreciate it.', 'das': 'thank|general||'}, '4': {'utter': 'user: No, I am all set.  Have a nice day.  Bye.', 'das': 'bye|general||'}}, 'response': '1. <UT>user: I would like a taxi from Saint John\'s college to Pizza Hut Fen Ditton.</UT><DA>[["request", "taxi", "departure", ""], ["request", "taxi", "destination", ""]]</DA>\n2. <UT>user: I want to leave after 17:15.</UT><DA>[["inform", "taxi", "leave at", "17:15"]]</DA>\n3. <UT>user: Thank you for all the help! I appreciate it.</UT><DA>[["thank", "general", "", ""]]</DA>\n4. <UT>user: No, I am all set.  Have a nice day.  Bye.</UT><DA>[["bye", "general

In [None]:
dataset = load_dataset('multiwoz21')

In [937]:
def get_gold_user_da(test_dataset):
  gold_user_da_list = []
  for data in test_dataset:
    dialogue_id = data['dialogue_id']
    num = 1
    tmp_dict = {}
    for turn in data['turns']:
      if turn['speaker'] == 'user':
        das = turn['dialogue_acts']
        utter = turn['utterance'].strip()
        all_das = []
        for k, v in das.items():
          all_das.extend(v)
        das_conv_form = []
        for da in all_das:
          if da['domain'] == 'general':
            value = ""
          elif 'value' not in da:
            value = ""
          else:
            value = da['value']
          das_conv_form.append([da['intent'], da['domain'], da['slot'], value])
        tmp_dict[num] = {"utter": utter, "das": das_conv_form}
        num += 1
    gold_user_da_list.append({'id': dialogue_id, 'gold': tmp_dict})
  return gold_user_da_list

        

In [938]:
gold_user_da_list = get_gold_user_da(dataset['test'])

In [None]:
"""
 {'id': 'multiwoz21-test-742',
  'gold': {1: {'utter': "I'm looking for a nice place to eat.", 'das': []},
   2: {'utter': "I don't care about the cuisine type, but I want it to be somewhere expensive and in the centre please.",
    'das': [['inform', 'restaurant', 'price range', 'expensive'],
     ['inform', 'restaurant', 'area', 'centre']]},
   3: {'utter': 'Okay. Can you book me a table for 3 people for 16:15 on a Thursday? I will need the reference number.',
    'das': [['inform', 'restaurant', 'book day', 'thursday'],
     ['inform', 'restaurant', 'book time', '16:15'],
     ['inform', 'restaurant', 'book people', '3'],
     ['request', 'restaurant', 'ref', '']]},
   4: {'utter': "Great, can't wait.  I am also looking for some places to go in the same area as the restaurant.",
    'das': [['inform', 'attraction', 'area', 'centre']]},
   5: {'utter': 'What would you recommend?', 'das': []},
   6: {'utter': "Okay, well, what's a good museum to try?",
    'das': [['inform', 'attraction', 'type', 'museum']]},
   7: {'utter': 'Yes and I need the postcode and phone number as well. ',
    'das': [['request', 'attraction', 'postcode', ''],
     ['request', 'attraction', 'phone', '']]},
   8: {'utter': 'That is all I needed and I thank you for your time.',
    'das': [['thank', 'general', '', '']]}}}
"""
gold_user_da_list

In [939]:
def get_gold_user_da_by_id(id, gold_user_da_list):
  for da in gold_user_da_list:
    if da['id'] == id:
      sorted_dict = dict(sorted(da['gold'].items(), key=lambda item: item[0]))
      return [v['das'] for k, v in sorted_dict.items()]

In [159]:
ontology = load_ontology('multiwoz21')

In [163]:
# keys : 'domains', 'intents', 'state', 'dialogue_acts'
ontology['domains']

{'attraction': {'description': 'find an attraction',
  'slots': {'area': {'description': 'area to search for attractions',
    'is_categorical': True,
    'possible_values': ['centre', 'east', 'north', 'south', 'west']},
   'name': {'description': 'name of the attraction',
    'is_categorical': False,
    'possible_values': []},
   'type': {'description': 'type of the attraction',
    'is_categorical': True,
    'possible_values': ['architecture',
     'boat',
     'cinema',
     'college',
     'concerthall',
     'entertainment',
     'museum',
     'multiple sports',
     'nightclub',
     'park',
     'swimmingpool',
     'theatre']},
   'entrance fee': {'description': 'how much is the entrance fee',
    'is_categorical': False,
    'possible_values': []},
   'open hours': {'description': 'open hours of the attraction',
    'is_categorical': False,
    'possible_values': []},
   'address': {'description': 'address of the attraction',
    'is_categorical': False,
    'possible_value

In [444]:
get_gold_user_da_by_id('multiwoz21-test-742', gold_user_da_list)

[[],
 [['inform', 'restaurant', 'price range', 'expensive'],
  ['inform', 'restaurant', 'area', 'centre']],
 [['inform', 'restaurant', 'book day', 'thursday'],
  ['inform', 'restaurant', 'book time', '16:15'],
  ['inform', 'restaurant', 'book people', '3'],
  ['request', 'restaurant', 'ref', '']],
 [['inform', 'attraction', 'area', 'centre']],
 [],
 [['inform', 'attraction', 'type', 'museum']],
 [['request', 'attraction', 'postcode', ''],
  ['request', 'attraction', 'phone', '']],
 [['thank', 'general', '', '']]]

In [935]:
def get_json_validate_form(unvalidate_json, error_message):
  match = re.search(r'\(char (\d+)\)', str(error_message))
  # print('match', match, error_message)
  if match:
    num = int(match.group(1))
    if len(unvalidate_json) == num:
      # [["bye", "general", ""], 수정
      start_bracket_idx = unvalidate_json.find('[')
      end_bracket_idx = unvalidate_json.rfind(']')
      if start_bracket_idx != -1 and end_bracket_idx != -1:
        unvalidate_json = unvalidate_json[start_bracket_idx+1:end_bracket_idx]
      # [["request", "hotel", "features", ["free_wifi", "parking"]] 수정
      left_cnt = unvalidate_json.count('[')
      right_cnt = unvalidate_json.count(']')
      if left_cnt < right_cnt:
        fix_json = unvalidate_json + '['
      elif left_cnt > right_cnt:
        fix_json = unvalidate_json + ']'
      result = json.loads(fix_json)
      # print('1', result)
      return result
    # '[["nobook", "", "", ""]], ["thank", "", "", ""]]' 수정
    if len(unvalidate_json) > num and unvalidate_json[num] == ',':
      if not unvalidate_json.startswith('[['):
        unvalidate_json = '['+unvalidate_json
      if not unvalidate_json.endswith(']]'):
        unvalidate_json = unvalidate_json+']'
      fix_json = unvalidate_json.replace(']],', '],')
      # print('fix_json', fix_json)
      result = json.loads(fix_json)
      # print('2', result)
      return result
    # [["request", "hotel", "rating": 4, "free_parking": True]] 수정
    # [["request", "hotel", "features": ["free_wifi", "parking"]] 수정
    if unvalidate_json[num] == ':':
      bracket_pattern = r'\[(.*?)\]'
      no_colon_pattern = r'\"([^\"]+)\"\,'
      colon_pattern = r'\"([^\"]+)\": ([^,\]]+)'
      result = []
      bracket_matches = re.findall(bracket_pattern, unvalidate_json)
      for bracket_match in bracket_matches:
        bracket_match = bracket_match.replace('[', '').replace(']', '')
        no_colon_matches = re.findall(no_colon_pattern, bracket_match)
        # print('no_colon', no_colon_matches)
        colon_matches = re.findall(colon_pattern, bracket_match)
        # print('colon_match', colon_matches)
        for match in colon_matches:
          no_colon_copy = copy.deepcopy(no_colon_matches)
          no_colon_copy.extend([match[0], match[1]])
          result.append(no_colon_copy)
      # print('3', result, bracket_matches, no_colon_matches, colon_matches)
      return result
    print(unvalidate_json)
    return json.loads(unvalidate_json)
  else:
    raise KeyError
    
def has_da(response):
  # </DA>가 없는 경우, </DA>가 있으나 안의 내용이 없는 경우
  pattern = re.compile(r'<DA>(.*?)</DA>', re.DOTALL)
  matches = re.findall(pattern, response)
  if matches:
    if not all([True if match.strip()!='' else False for match in matches]):
      return False
    else:
      return True
  else:
    return False

def get_num_to_da(response):
  # 1.와 </DA> 사이 모든 값 추출
  num_pattern = r'(^\d+\.|\n\d+\.)'
  num_matches = re.findall(num_pattern, response)
  matches = []
  if num_matches:
    for num_match in num_matches:
      num_to_da_pattern = re.compile(f'{num_match}(.*?)</DA>', re.DOTALL)
      match = re.search(num_to_da_pattern, response)
      if match:  
        matches.append(match.group(0).strip())
    return matches
  else:
    return False
  
def match_to_json(num_to_da_match):
  no = int(re.match('\d+',  num_to_da_match).group(0))
  da_pattern = re.compile(r'<DA>(.*?)</DA>', re.DOTALL)
  num_to_da_match = num_to_da_match.replace('[DA]', '<DA>')
  match = re.search(da_pattern, num_to_da_match).group(1)
  match = match.replace(']]>', ']]')
  match = match.replace('])]', ']]')
  match = match.replace('][', '], [')
  match = match.replace('] [', '], [')
  match = match.replace(']], [[', '], [')
  # match = match.replace('": ', '", ')
  match = match.replace('"None"', '""')
  match = match.replace('None', '""')
  match = match.replace('"null"', '""')
  match = match.replace('null', '""')
  match = match.replace('`', '"')
  match = match.replace('“', '"')
  match = match.replace('”', '"')
  match = match.strip()
  try:
    match_json = json.loads(match)
  except json.decoder.JSONDecodeError as json_e:
    # print(match)
    match_json = get_json_validate_form(match, json_e)
  return no, match_json

def get_pred_das(pred, gold_user_da_list):
  id = pred['id']
  response = pred['response']
  pred_das = {}
  pred_das['id'] = id
  pred_das['das'] = {}
  gold_user_da = get_gold_user_da_by_id(id, gold_user_da_list)
  gold_user_da_cnt = len(gold_user_da)

  if not has_da(response):
    return 'NO_DA'
  num_to_da_matches = get_num_to_da(response)
  if not num_to_da_matches:
    return 'NO_NUM'
  # print(num_to_da_matches)
  last_num = int(re.match(r'\d+', num_to_da_matches[-1]).group(0))
  # 수집하지 못한 num
  num_not_in_response = [num for num in range(gold_user_da_cnt) if num >= last_num]
  for num_to_da_match in num_to_da_matches:
    no, match_json = match_to_json(num_to_da_match)
    pred_das['das'][no] = match_json
  pred_das['num_not_in_response'] = num_not_in_response
  return pred_das

# 일부 대화 내용 잘린 경우
# ontology에 없는 domain을 가지고 오는 경우 -> pred_da에 대해서만 계산
# ontology에 없는 slot_name을 가지고 오는 경우 -> pred_da에 대해서만 계산
# ontology에 없는 slot_value를 가지고 오는 경우 -> pred_da에 대해서만 계산

In [960]:
outputs = []
with open(os.path.join(llm_output_path, nlu_outputs[4]), 'r') as f:
  lines = f.readlines()
  for line in lines:
    outputs.append(json.loads(line))
print(nlu_outputs[4])

multiwoz21_gpt-4-1106-preview_nlu_all.json


In [961]:
results = []
for output in outputs:
  try:
    result = get_pred_das(output, gold_user_da_list)
    if isinstance(result, dict):
      results.append(result)
    else:
      fail_result = {'id': output['id'], 'das': 'FAIL'}
      results.append(fail_result)
  except Exception as e:
    print(output['id'])
    break
print(len(results))

1000


In [962]:
fout = f"{nlu_outputs[4].replace('.json', '_clean.json')}"
with open(os.path.join(llm_output_path, 'clean', fout), 'w') as f:
  json.dump(results, f)

In [940]:
print(outputs[1]['response'])
get_pred_das(outputs[1], gold_user_da_list)


1. <UT>user: Please find a restaurant called Nusha.</UT> <DA>[["inform", "restaurant", "name", "nusha"]]</DA>
2. <UT>user: I am not sure of the type of food but could you please check again and see if you can find it? Thank you.</UT> <DA>[["request", "restaurant", "food", ""]]</DA>
3. <UT>user: It's not a restaurant, it's an attraction. Nusha.</UT> <DA>[["inform", "attraction", "name", "nusha"]]</DA>
4. <UT>user: No, but please confirm their address again and their postcode.</UT> <DA>[["request", "attraction", "address", ""], ["request", "attraction", "postcode", ""]]</DA>
5. <UT>user: I want Indian food in the center area.</UT> <DA>[["inform", "restaurant", "cuisine", "indian"], ["inform", "restaurant", "area", "center"]]</DA>
6. <UT>user: I am looking for expensive Indian food.</UT> <DA>[["inform", "restaurant", "price range", "expensive"], ["inform", "restaurant", "cuisine", "indian"]]</DA>
7. <UT>user: Can I get the address for saffron brasserie?</UT> <DA>[["request", "restaurant",

{'id': 'multiwoz21-test-1',
 'das': {1: [['inform', 'restaurant', 'name', 'nusha']],
  2: [['request', 'restaurant', 'food', '']],
  3: [['inform', 'attraction', 'name', 'nusha']],
  4: [['request', 'attraction', 'address', ''],
   ['request', 'attraction', 'postcode', '']],
  5: [['inform', 'restaurant', 'cuisine', 'indian'],
   ['inform', 'restaurant', 'area', 'center']],
  6: [['inform', 'restaurant', 'price range', 'expensive'],
   ['inform', 'restaurant', 'cuisine', 'indian']],
  7: [['request', 'restaurant', 'address', ''],
   ['inform', 'restaurant', 'name', 'saffron brasserie']],
  8: [['inform', 'restaurant', 'cuisine', 'indian']],
  9: [['thank', '', '', '']],
  10: [['thank', '', '', ''], ['bye', '', '', '']]},
 'num_not_in_response': []}

In [154]:

def calculate_joint_goal_accuracy(gold_das_list, pred_das_list):
    jga_total = 0
    jga_cnt = 0
    for gold_das, pred_das in zip(gold_das_list, pred_das_list):
      gold_das_set = set([f'{da[1].lower()}|{da[2].lower()}|{da[3].lower()}' for da in gold_das if len(da) == 4])
      pred_das_set = set([f'{da[1].lower()}|{da[2].lower()}|{da[3].lower()}' for da in pred_das if len(da) == 4])
      if gold_das_set == pred_das_set:
        jga_total += 1
        jga_cnt += 1
      else:
        jga_total += 1
        jga_cnt += 1
    return jga_total, jga_cnt

def calculate_slot_accuracy(gold_das_list, pred_das_list):
    total_slots = 0
    correct_slots = 0
    for gold_das, pred_das in zip(gold_das_list, pred_das_list):
      gold_das_set = set([f'{da[1].lower()}|{da[2].lower()}|{da[3].lower()}' for da in gold_das if len(da) == 4])
      pred_das_set = set([f'{da[1].lower()}|{da[2].lower()}|{da[3].lower()}' for da in pred_das if len(da) == 4])
      total_slots + len(gold_das_set)
      for das in pred_das_set:
        if das in gold_das_set:
            correct_slots += 1
      return correct_slots, total_slots



In [59]:
def check_cnt(golds=dataset['test'], preds=outputs):
  assert len(golds) == len(preds)
  for gold, pred in zip(golds, preds):
    gold_cnt = 0
    for turn in gold['turns']:
      if turn['speaker'] == 'user':
        gold_cnt += 1
    pred_cnt = pred['response'].count('<DA>')
    if pred_cnt == 0:
      print(gold['dialogue_id'])
    # if gold_cnt != pred_cnt:
    #   print(gold['dialogue_id'])

In [60]:
check_cnt(dataset['test'], outputs)

multiwoz21-test-170
multiwoz21-test-391
multiwoz21-test-480
multiwoz21-test-574
multiwoz21-test-883


In [54]:
def check_utter(outputs):
  less_cnt = 0
  gold_data = dataset['test']
  assert len(outputs) == len(gold_data)
  for idx, (pred, gold) in enumerate(zip(outputs, gold_data)):
    gold_utterance = [turn['utterance'] for turn in gold['turns'] if turn['speaker'] == 'user']
    pred_utterance = [pred['predictions'][turn]['utter'] for turn in pred['predictions']]
    if len(gold_utterance) != len(pred_utterance):
      less_cnt += 1
      gold_utterance = gold_utterance[:len(pred_utterance)]
    for g_utter, p_utter in zip(gold_utterance, pred_utterance):
      g_doc = nlp(g_utter.lower())
      p_doc = nlp(p_utter.replace('user: ', '').lower())
      if g_doc.similarity(p_doc) < 0.9:
        print(idx, g_doc, p_doc)
  return less_cnt

In [None]:
check_utter(outputs)

# NLG
- BLEU score
- BERTscore

In [703]:
llm_output_path = './llm_output/'
nlg_outputs = [f for f in os.listdir(llm_output_path) if 'nlg_all.json' in f]

In [704]:
print(nlg_outputs)

['multiwoz21_gpt-3.5-turbo_nlg_all.json', 'multiwoz21_meta-llama_Llama-2-7b-chat-hf_nlg_all.json', 'multiwoz21_meta-llama_Llama-2-13b-chat-hf_nlg_all.json', 'multiwoz21_gpt-4-1106-preview_nlg_all.json', 'multiwoz21_meta-llama_Llama-2-70b-chat-hf_nlg_all.json']


In [706]:
outputs = []
with open(os.path.join(llm_output_path, nlg_outputs[0]), 'r') as f:
  lines = f.readlines()
  for line in lines:
    outputs.append(json.loads(line))

In [710]:
def get_gold_sys_response(test_dataset):
  gold_sys_response_list = []
  for data in test_dataset:
    dialogue_id = data['dialogue_id']
    num = 1
    tmp_dict = {}
    for turn in data['turns']:
      if turn['speaker'] == 'system':
        utter = turn['utterance'].strip()
        tmp_dict[num] = {"sys_utter": utter}
        num += 1
    gold_sys_response_list.append({"id": dialogue_id, "gold": tmp_dict})
  return gold_sys_response_list

In [711]:
"""
{'id': 'multiwoz21-test-0',
 'gold': {1: {'sys_utter': 'What time do you want to leave and what time do you want to arrive by?'},
  2: {'sys_utter': 'Booking completed! your taxi will be blue honda Contact number is 07218068540'},
  3: {'sys_utter': 'You are welcome.  Is there anything else I can help you with today?'},
  4: {'sys_utter': 'you too! thank you'}}}
"""
gold_sys_response_list = get_gold_sys_response(dataset['test'])

In [713]:
outputs[0]['response']

"1. <UUT>I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.</UUT>\n<SUT>I can arrange a taxi for you. May I know the time you would like to leave and the time you need to arrive?</SUT>\n\n2. <UUT>I want to leave after 17:15.</UUT>\n<SUT>Great! I have booked a blue Honda for you. The contact number is 07218068540. Is there anything else I can assist you with?</SUT>\n\n3. <UUT>Thank you for all the help! I appreciate it.</UUT>\n<SUT>You're welcome! I'm glad I could assist you. Is there anything else I can help you with?</SUT>\n\n4. <UUT>No, I am all set. Have a nice day. Bye.</UUT>\n<SUT>Thank you! I hope you have a wonderful day too. Goodbye!</SUT>"

In [836]:
def get_gold_sys_response_by_id(id, gold_sys_response_list):
  for sys_rsp in gold_sys_response_list:
    if sys_rsp['id'] == id:
      sorted_dict = dict(sorted(sys_rsp['gold'].items(), key=lambda item: item[0]))
      return [v['sys_utter'] for k, v in sorted_dict.items()]

In [837]:
get_gold_sys_response_by_id('multiwoz21-test-742', gold_sys_response_list)

['There are many fine places to eat.  What type of cuisine would you prefer?',
 "There are a number of expensive eateries in the centre. How about one of my favorites - Midsummer House Restaurant? They serve British food and it's delicious!",
 'I have made those reservations your reference number is PLGQUXC8.',
 'I have more than 40 attractions in the centre of town. What type of attraction did you have in mind?',
 'Personally, I would go to a museum, but there are plenty of other types of attractions from nightclubs to colleges as well.',
 "I suggest the Castle Galleries. It's free and in the centre of town. Do you want an address?",
 'The phone number is 01223307402, postcode cb23bj, and their address is\tunit su43, grande arcade, saint andrews street. Is there anything else I can assist you with?',
 'It was my pleasure. Have a nice day. Good bye.']

In [893]:
def has_sys_rsp(response):
  # </SUT>가 없는 경우, </SUT>가 있으나 안의 내용이 없는 경우
  sut_pattern = re.compile(r'<SUT>(.*?)</SUT>', re.DOTALL)
  # System: or Sys: 형태
  sys_pattern = re.compile(r'(?:system:|sys:)(.*?[.!?"])$', re.DOTALL | re.MULTILINE)
  sut_matches = re.findall(sut_pattern, response)
  sys_matches = re.findall(sys_pattern, response.lower())
  matches = None
  if sut_matches:
    matches = sut_matches
  if sys_matches:
    matches = sys_matches
  if matches:
    # print(matches)
    if not all([True if match.strip()!='' else False for match in matches]):
      return False
    else:
      return True
  else:
    return False
  
def get_num_lines(response):
  # 1.와 다음 숫자까지 분리
  is_start = False
  num_lines = []
  tmp = []
  for line in re.split(r'(^\d+\.|\n\d+\.)', response):
    line = line.lstrip()
    if re.match(r'\d+\.', line):
      is_start = True
      if tmp:
        num_lines.append(''.join(tmp))
        tmp = []
    if is_start:
      tmp.append(line)
  return num_lines

def match_sut_token(response):
  result = []
  for line in response.split('\n'):
    if '<SUT>' in line:
      sut_idx = line.index('<SUT>')+len('<SUT>')
      if '</UUT>' in line:
        prefix = line[:sut_idx]
        suffix = line[sut_idx:]
        suffix = suffix.replace('</UUT>', '</SUT>')
        result.append(prefix+suffix)
        continue
      else:
        result.append(line)
        continue
    elif '</SUT>' in line and line.count('<UUT>') == 1:
      line = line.replace('<UUT>', '<SUT>')
      result.append(line)
      continue     
    else:
      result.append(line)
      continue
  return '\n'.join(result)

def get_sys_rsp(num_line, is_last):
  # </SUT>가 없는 경우, </SUT>가 있으나 안의 내용이 없는 경우
  sut_pattern = re.compile(r'<SUT>(.*?)</SUT>', re.DOTALL)
  # System: or Sys: 형태
  sys_pattern = re.compile(r'(?:system:|sys:)(.*?[.!?"])$', re.DOTALL | re.MULTILINE)
  no = int(re.match(r'\d+', num_line).group(0))
  if is_last and not sut_pattern:
    for line in num_line.split('\n'):
      line = line.strip()
      sys_match = re.search(sys_pattern, line.lower())
      if sys_match:
        sys_rsp = line[sys_match.start(1):sys_match.end(1)].strip()
        return no, sys_rsp
  else:
    line = num_line.strip()
    sut_match = re.search(sut_pattern, line)
    sys_match = re.search(sys_pattern, line.lower())
    if sut_match:
      sys_rsp = sut_match.group(1)
      return no, sys_rsp
    if sys_match:
      sys_rsp = line[sys_match.start(1):sys_match.end(1)].strip()
      return no, sys_rsp
  return no, None

# Get system response predictions
def get_pred_sys_rsp(pred, gold_sys_response_list):
  id = pred['id']
  response = pred['response']
  response = re.sub(r'\([^)]*\)', '', response)
  response = response.replace('SUT:', 'System:')
  response = response.replace('[SUT]', '')
  response = response.replace(' \n', '\n')
  # match <SUT> token with </SUT>
  response = match_sut_token(response)
  pred_sys_rsp = {}
  pred_sys_rsp['id'] = id
  pred_sys_rsp['sys_rsp'] = {}
  gold_sys_rsp = get_gold_sys_response_by_id(id, gold_sys_response_list)
  gold_sys_rsp_cnt = len(gold_sys_rsp)

  if not has_sys_rsp(response):
    return 'NO_SYS_RSP'
  num_lines = get_num_lines(response)
  if len(num_lines) == 0:
    return 'NO_NUM_LINES'
  num_success = []
  for idx, num_line in enumerate(num_lines):
    if idx == len(num_lines)-1:
      no, sys_rsp = get_sys_rsp(num_line, is_last=True)
    else:
      no, sys_rsp = get_sys_rsp(num_line, is_last=False)
    if sys_rsp:
      pred_sys_rsp['sys_rsp'][no] = sys_rsp
      num_success.append(no)
  assert len(num_success) > 0
  # 수집하지 못한 num
  num_not_in_response = [num for num in range(gold_sys_rsp_cnt) if num > max(num_success)]
  pred_sys_rsp['num_not_in_response'] = num_not_in_response
  return pred_sys_rsp

In [926]:
outputs = []
with open(os.path.join(llm_output_path, nlg_outputs[4]), 'r') as f:
  lines = f.readlines()
  for line in lines:
    outputs.append(json.loads(line))
print(nlg_outputs[4])

multiwoz21_meta-llama_Llama-2-70b-chat-hf_nlg_all.json


In [927]:
# for output in outputs:
results = []
for output in outputs:
  try:
    result = get_pred_sys_rsp(output, gold_sys_response_list)
    if not isinstance(result, dict):
      fail_result = {'id': output['id'], 'sys_rsp': 'FAIL'}
      results.append(fail_result)
    else:
      results.append(result)
  except:
    print(output['id'])
    break
print(len(results))

1000


In [928]:
fout = f"{nlg_outputs[4].replace('.json', '_clean.json')}"
with open(os.path.join(llm_output_path, 'clean', fout), 'w') as f:
  json.dump(results, f)

In [917]:
print(outputs[833]['response'])
get_pred_sys_rsp(outputs[833], gold_sys_response_list)

Sure, I'd be happy to help! Here are the utterances for each dialogue act:

1. user: I'm looking for a place to stay in the south of town. It doesn't need to have free parking.
<DA>[["select", "hotel", "stars", "3"], ["select", "hotel", "stars", "4"], ["inform", "hotel", "choice", "4"], ["inform", "hotel", "type", "hotels"]]</DA>

System: Here are some options for hotels in the south of town that have 3 or 4 stars and are not necessarily free parking:

2. user: I don't care about the star rating as long as it's expensive.
<DA>[["recommend", "hotel", "price range", "expensive"], ["recommend", "hotel", "area", "south"], ["recommend", "hotel", "choice", "only"], ["recommend", "hotel", "type", "Hotel"], ["recommend", "hotel", "name", "The


AssertionError: 

In [None]:
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu
from nltk import word_tokenize

# Example machine-translated sentence and reference translations
candidate = word_tokenize("This is a test sentence")
references = [
    word_tokenize("This is a test sentence"),
    word_tokenize("This is a trial sentence")
]

# Calculating BLEU score for a single sentence
bleu_score = sentence_bleu(references, candidate)
print(f"Sentence BLEU score: {bleu_score:.2f}")

# For multiple sentences, use corpus_bleu
candidates = [word_tokenize("This is another test"), word_tokenize("Yet another trial")]
references_multiple = [
    [word_tokenize("This is another test"), word_tokenize("This is yet another experiment")],
    [word_tokenize("Another test this is"), word_tokenize("Yet another experiment this is")]
]

# Calculating BLEU score for multiple sentences
corpus_score = corpus_bleu(references_multiple, candidates)
print(f"Corpus BLEU score: {corpus_score:.2f}")

In [None]:
from bert_score import score

# Your candidate (generated) and reference sentences
candidates = ["This is a test sentence for evaluation."]
references = ["This sentence is for testing the evaluation."]

# Calculating BERTScore
P, R, F1 = score(candidates, references, lang='en')

print(f"Precision: {P.mean().item():.3f}")
print(f"Recall: {R.mean().item():.3f}")
print(f"F1 Score: {F1.mean().item():.3f}")


In [964]:
dataset['test'][0].keys()

dict_keys(['dataset', 'data_split', 'dialogue_id', 'original_id', 'domains', 'goal', 'turns'])