In [1]:
import os
import json
from copy import deepcopy

from convlab.util.multiwoz.state import default_state
from convlab.dst.rule.multiwoz.dst_util import normalize_value
from convlab.dst.dst import DST
from convlab.util import load_dataset, load_ontology, load_database
from typing import List, Dict, Union, Optional
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# from datasets import load_dataset, load_metric

In [2]:
model_checkpoint = 't5-small'

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.all_special_tokens

['</s>',
 '<unk>',
 '<pad>',
 '<extra_id_0>',
 '<extra_id_1>',
 '<extra_id_2>',
 '<extra_id_3>',
 '<extra_id_4>',
 '<extra_id_5>',
 '<extra_id_6>',
 '<extra_id_7>',
 '<extra_id_8>',
 '<extra_id_9>',
 '<extra_id_10>',
 '<extra_id_11>',
 '<extra_id_12>',
 '<extra_id_13>',
 '<extra_id_14>',
 '<extra_id_15>',
 '<extra_id_16>',
 '<extra_id_17>',
 '<extra_id_18>',
 '<extra_id_19>',
 '<extra_id_20>',
 '<extra_id_21>',
 '<extra_id_22>',
 '<extra_id_23>',
 '<extra_id_24>',
 '<extra_id_25>',
 '<extra_id_26>',
 '<extra_id_27>',
 '<extra_id_28>',
 '<extra_id_29>',
 '<extra_id_30>',
 '<extra_id_31>',
 '<extra_id_32>',
 '<extra_id_33>',
 '<extra_id_34>',
 '<extra_id_35>',
 '<extra_id_36>',
 '<extra_id_37>',
 '<extra_id_38>',
 '<extra_id_39>',
 '<extra_id_40>',
 '<extra_id_41>',
 '<extra_id_42>',
 '<extra_id_43>',
 '<extra_id_44>',
 '<extra_id_45>',
 '<extra_id_46>',
 '<extra_id_47>',
 '<extra_id_48>',
 '<extra_id_49>',
 '<extra_id_50>',
 '<extra_id_51>',
 '<extra_id_52>',
 '<extra_id_53>',
 '<extra_

In [4]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [5]:
# special_tokens_dict = {'additional_special_tokens': ['<ds>', '<db>']}
# num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
# model.resize_token_embeddings(len(tokenizer))

Embedding(32102, 512)

In [2]:
dataset = load_dataset('multiwoz21')
train_dataset = dataset['train']
valid_dataset = dataset['validation']
test_dataset = dataset['test']

In [6]:
def get_dialog_contexts(dialogue: Dict) -> List[List]:
  contexts = []
  for idx, turn in enumerate(dialogue['turns']):
    if turn['speaker'] == 'system':
      contexts.append([turn['utterance'] for turn in dialogue['turns']][:idx])
  return contexts
  
def serialize_dialogue_acts(dialogue_acts: Dict) -> str:
  da_dict = {}
  for da_type in dialogue_acts:
    for da in dialogue_acts[da_type]:
      intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da.get('value', '')
      intent_domain = f'[{intent}][{domain}]'
      da_dict.setdefault(intent_domain, [])
      da_dict[intent_domain].append(f'[{slot}][{value}]')
  return ';'.join([intent_domain+'('+','.join(slot_values)+')' for intent_domain, slot_values in da_dict.items()])

In [7]:
get_dialog_contexts(train_dataset[0])

[['am looking for a place to to stay that has cheap price range it should be in a type of hotel'],
 ['am looking for a place to to stay that has cheap price range it should be in a type of hotel',
  'Okay, do you have a specific area you want to stay in?',
  "no, i just need to make sure it's cheap. oh, and i need parking"],
 ['am looking for a place to to stay that has cheap price range it should be in a type of hotel',
  'Okay, do you have a specific area you want to stay in?',
  "no, i just need to make sure it's cheap. oh, and i need parking",
  'I found 1 cheap hotel for you that includes parking. Do you like me to book it?',
  'Yes, please. 6 people 3 nights starting on tuesday.'],
 ['am looking for a place to to stay that has cheap price range it should be in a type of hotel',
  'Okay, do you have a specific area you want to stay in?',
  "no, i just need to make sure it's cheap. oh, and i need parking",
  'I found 1 cheap hotel for you that includes parking. Do you like me to bo

In [8]:
serialize_dialogue_acts(train_dataset[0]['turns'][0]['dialogue_acts'])

'[inform][hotel]([price range][cheap],[type][hotel])'

In [9]:
class RuleDST(DST):
  """Rule based DST which trivially updates new values from NLU result to states.

  Attributes:
    state(dict):
      Dialog state. Function ``convlab.util.multiwoz.state.default_state`` returns a default state.
    value_dict(dict):
      It helps check whether ``user_act`` has correct content.
  """

  def __init__(self, dataset_name='multiwoz21'):
    DST.__init__(self)
    self.ontology = load_ontology(dataset_name)
    self.state = default_state()
    self.default_belief_state = deepcopy(self.ontology['state'])
    self.state['belief_state'] = deepcopy(self.default_belief_state)
    # path = os.path.dirname(
    #   os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
    # path = os.path.join(path, 'data/multiwoz/value_dict.json')
    # path = os.path.join('/content/ConvLab-3', 'data/multiwoz/value_dict.json')
    path = os.path.join('/home3/hgsun/ConvLab-3', 'data/multiwoz/value_dict.json')
    self.value_dict = json.load(open(path))

  def update(self, user_act=None):
    """
    update belief_state, request_state
    :param user_act:
    :return:
    """
    for intent, domain, slot, value in user_act:
      if domain not in self.state['belief_state']:
        continue
      if intent == 'inform':
        if slot == 'none' or slot == '':
          continue
        domain_dic = self.state['belief_state'][domain]
        if slot in domain_dic:
          nvalue = normalize_value(
            self.value_dict, domain, slot, value)
          self.state['belief_state'][domain][slot] = nvalue
        elif slot != 'none' or slot != '':
          # raise Exception('unknown slot name <{}> of domain <{}>'.format(k, domain))
          with open('unknown_slot.log', 'a+') as f:
            f.write(
              'unknown slot name <{}> of domain <{}>\n'.format(slot, domain))
      elif intent == 'request':
        if domain not in self.state['request_state']:
          self.state['request_state'][domain] = {}
        if slot not in self.state['request_state'][domain]:
          self.state['request_state'][domain][slot] = 0
    # self.state['user_action'] = user_act  # should be added outside DST module
    return self.state

  def init_session(self):
    """Initialize ``self.state`` with a default state, which ``convlab.util.multiwoz.state.default_state`` returns."""
    self.state = default_state()
    self.state['belief_state'] = deepcopy(self.default_belief_state)


In [3]:
default_state()

{'user_action': [],
 'system_action': [],
 'belief_state': {'attraction': {'type': '', 'name': '', 'area': ''},
  'hotel': {'name': '',
   'area': '',
   'parking': '',
   'price range': '',
   'stars': '4',
   'internet': 'yes',
   'type': 'hotel',
   'book stay': '',
   'book day': '',
   'book people': ''},
  'restaurant': {'food': '',
   'price range': '',
   'name': '',
   'area': '',
   'book time': '',
   'book day': '',
   'book people': ''},
  'taxi': {'leave at': '',
   'destination': '',
   'departure': '',
   'arrive by': ''},
  'train': {'leave at': '',
   'destination': '',
   'day': '',
   'arrive by': '',
   'departure': '',
   'book people': ''},
  'hospital': {'department': ''}},
 'booked': {},
 'request_state': {},
 'terminated': False,
 'history': []}

In [4]:
ontology = load_ontology('multiwoz21')

In [5]:
ontology['state']

{'attraction': {'type': '', 'name': '', 'area': ''},
 'hotel': {'name': '',
  'area': '',
  'parking': '',
  'price range': '',
  'stars': '',
  'internet': '',
  'type': '',
  'book stay': '',
  'book day': '',
  'book people': ''},
 'restaurant': {'food': '',
  'price range': '',
  'name': '',
  'area': '',
  'book time': '',
  'book day': '',
  'book people': ''},
 'taxi': {'leave at': '', 'destination': '', 'departure': '', 'arrive by': ''},
 'train': {'leave at': '',
  'destination': '',
  'day': '',
  'arrive by': '',
  'departure': '',
  'book people': ''},
 'hospital': {'department': ''}}

In [10]:
dst = RuleDST()
dst.value_dict # 모든 DB 값

{'train': {'arriveby': ['05:51',
   '07:51',
   '09:51',
   '11:51',
   '13:51',
   '15:51',
   '17:51',
   '19:51',
   '21:51',
   '23:51',
   '06:08',
   '08:08',
   '10:08',
   '12:08',
   '14:08',
   '16:08',
   '18:08',
   '20:08',
   '22:08',
   '24:08',
   '07:27',
   '09:27',
   '11:27',
   '13:27',
   '15:27',
   '17:27',
   '19:27',
   '21:27',
   '23:27',
   '01:27',
   '07:07',
   '09:07',
   '11:07',
   '13:07',
   '15:07',
   '17:07',
   '19:07',
   '21:07',
   '23:07',
   '01:07',
   '05:58',
   '06:58',
   '07:58',
   '08:58',
   '09:58',
   '10:58',
   '11:58',
   '12:58',
   '13:58',
   '14:58',
   '15:58',
   '16:58',
   '17:58',
   '18:58',
   '19:58',
   '20:58',
   '21:58',
   '22:58',
   '23:58',
   '06:55',
   '07:55',
   '08:55',
   '09:55',
   '10:55',
   '11:55',
   '12:55',
   '13:55',
   '14:55',
   '15:55',
   '16:55',
   '17:55',
   '18:55',
   '19:55',
   '20:55',
   '21:55',
   '22:55',
   '23:55',
   '24:55',
   '06:35',
   '07:35',
   '08:35',
   '09:

In [11]:
dst.init_session()

In [12]:
da = train_dataset[0]['turns'][0]['dialogue_acts']
user_acts = []
for k, v in da.items():
  for v_each in v:
    user_acts.append((v_each['intent'], v_each['domain'], v_each['slot'], v_each['value']))
state = dst.update(user_acts)

In [13]:
state

{'user_action': [],
 'system_action': [],
 'belief_state': {'attraction': {'type': '', 'name': '', 'area': ''},
  'hotel': {'name': '',
   'area': '',
   'parking': '',
   'price range': 'cheap',
   'stars': '',
   'internet': '',
   'type': 'hotel',
   'book stay': '',
   'book day': '',
   'book people': ''},
  'restaurant': {'food': '',
   'price range': '',
   'name': '',
   'area': '',
   'book time': '',
   'book day': '',
   'book people': ''},
  'taxi': {'leave at': '',
   'destination': '',
   'departure': '',
   'arrive by': ''},
  'train': {'leave at': '',
   'destination': '',
   'day': '',
   'arrive by': '',
   'departure': '',
   'book people': ''},
  'hospital': {'department': ''}},
 'booked': {},
 'request_state': {},
 'terminated': False,
 'history': []}

In [14]:
database = load_database('multiwoz21')
# state = {"hotel": {"area": "east", "price range": "moderate"}}
res = database.query("hotel", state['belief_state'], topk=3)

In [15]:
res

[{'address': 'back lane, cambourne',
  'area': 'west',
  'internet': 'yes',
  'parking': 'yes',
  'id': '28',
  'location': [52.2213805555556, -0.0680333333333333],
  'name': 'the cambridge belfry',
  'phone': '01954714600',
  'postcode': 'cb236bw',
  'price': {'double': '60', 'single': '60'},
  'pricerange': 'cheap',
  'stars': '4',
  'takesbookings': 'yes',
  'type': 'hotel',
  'Ref': '00000028'}]

In [18]:
for turn in test_dataset[1]['turns']:
  print(turn['utterance'])
  print(turn['dialogue_acts'])

Please find a restaurant called Nusha.
{'categorical': [], 'non-categorical': [{'intent': 'inform', 'domain': 'attraction', 'slot': 'name', 'value': 'Nusha', 'start': 32, 'end': 37}], 'binary': [{'intent': 'inform', 'domain': 'restaurant', 'slot': ''}]}
I don't seem to be finding anything called Nusha.  What type of food does the restaurant serve?
{'categorical': [], 'non-categorical': [{'intent': 'nooffer', 'domain': 'restaurant', 'slot': 'name', 'value': 'Nusha', 'start': 43, 'end': 48}], 'binary': [{'intent': 'request', 'domain': 'restaurant', 'slot': 'food'}]}
I am not sure of the type of food but could you please check again and see if you can find it? Thank you.
{'categorical': [], 'non-categorical': [], 'binary': [{'intent': 'thank', 'domain': 'general', 'slot': ''}]}
Could you double check that you've spelled the name correctly? The closest I can find is Nandos.
{'categorical': [], 'non-categorical': [{'intent': 'inform', 'domain': 'restaurant', 'slot': 'name', 'value': 'Nandos