In [1]:
import json
import random
from pathlib import Path

In [2]:
RED = '\x1b[31m'
BLUE = '\x1b[34m'
NC = '\x1b[0m'

In [3]:
dataset_dir = Path('../data/multiwoz2_parsed')

raw_dials_path = dataset_dir / '..' / 'MULTIWOZ2 2' / 'data.json'
delex_dials_path = dataset_dir / 'multi-woz' / 'delex.json'
train_dials_path = dataset_dir / 'train_dials.json'
valid_dials_path = dataset_dir / 'val_dials.json'
test_dials_path = dataset_dir / 'test_dials.json'

gen_dir = Path('multiwoz/model/data')

valid_dials_gen_path = gen_dir / 'val_dials' / 'val_dials_gen.json'
test_dials_gen_path = gen_dir / 'test_dials' / 'test_dials_gen.json'

In [4]:
with open(raw_dials_path, 'r') as raw_dial_f:
    raw_dials = json.load(raw_dial_f)
with open(delex_dials_path, 'r') as delex_dial_f:
    delex_dials = json.load(delex_dial_f)
with open(valid_dials_path, 'r') as val_dial_f:
    valid_dials = json.load(val_dial_f)
with open(test_dials_path, 'r') as test_dial_f:
    test_dials = json.load(test_dial_f)

with open(valid_dials_gen_path, 'r') as val_dial_gen_f:
    valid_dials_gen = json.load(val_dial_gen_f)
with open(test_dials_gen_path, 'r') as test_dial_gen_f:
    test_dials_gen = json.load(test_dial_gen_f)

In [5]:
def show_turn(dial_id, turn_id, filt='11111'):
    return '\n'.join(filter(None, [
        '' if filt[0] == '0' else '{}User   (raw):\n{}\n{}'.
            format(RED, raw_dials[dial_id]['log'][turn_id*2]['text'], NC),
        '' if filt[1] == '0' else '{}User   (delex) (input):\n{}\n{}'.
            format(RED, valid_dials[dial_id]['usr'][turn_id].strip(), NC),
        '' if filt[2] == '0' else '{}System (raw):\n{}\n{}'.
            format(BLUE, raw_dials[dial_id]['log'][turn_id*2+1]['text'], NC),
        '' if filt[3] == '0' else '{}System (delex) (ground truth):\n{}\n{}'.
            format(BLUE, valid_dials[dial_id]['sys'][turn_id].strip(), NC),
        '' if filt[4] == '0' else '{}System (gen):\n{}\n{}'.
            format(BLUE, valid_dials_gen[dial_id][turn_id], NC),
    ]))

In [6]:
# Get all domains.
domains = set()
for dial_id in raw_dials.keys():
    for key in raw_dials[dial_id]['goal'].keys():
        if key != 'message' and key != 'topic':
            domains.add(key)
print(domains)

# Separate single and multi domain dials.
single_dial_ids = []
mul_dial_ids = []

for dial_id in raw_dials.keys():
    if 'SNG' in dial_id or 'WOZ' in dial_id:
        single_dial_ids.append(dial_id)
    elif 'MUL' in dial_id:
        mul_dial_ids.append(dial_id)
    else:
        assert False, dial_id

# Group dials by domain.
domain_dial_ids = {}
for dial_id in single_dial_ids:
    cnt = 0
    for domain in domains:
        if raw_dials[dial_id]['goal'].get(domain, {}):
            ids = domain_dial_ids.get(domain, [])
            ids.append(dial_id)
            domain_dial_ids[domain] = ids
            cnt += 1
    assert cnt == 1, raw_dials[dial_id]['goal']

{'attraction', 'hotel', 'police', 'taxi', 'hospital', 'restaurant', 'train'}


In [7]:
# Count number of dialogues in each domain.
for domain, domain_dials in domain_dial_ids.items():
    print(domain, len(domain_dials))

hotel 634
police 245
taxi 435
hospital 287
restaurant 1310
train 345
attraction 150


In [8]:
def mix_dialogues(lens, rng=None):
    turns = [iter(range(length)) for length in lens]
    index = list(range(len(lens)))
    if rng is None:
        rng = random
        seed = rng.randrange(100)
        rng.seed(seed)
        print('seed = {}'.format(seed))
#         rng.seed(0)
    
    while index:
        dial_id = rng.choice(index)
#         dial_id = index[0]
        try:
            turn_id = next(turns[dial_id])
            yield (dial_id, turn_id)
        except StopIteration:
            index.remove(dial_id)

In [9]:
dial_ids = [domain_dial_ids['restaurant'][12],
            domain_dial_ids['restaurant'][4]]
n_turns = [len(raw_dials[dial_id]['log']) // 2 for dial_id in dial_ids]
random.seed(65)
mix_turns = mix_dialogues(n_turns, rng=random)
# mix_turns = mix_dialogues(n_turns)

print(dial_ids)
for index, turn_id in mix_turns:
    dial_id = dial_ids[index]
    print('dial = {}, turn = {}'.format(index, turn_id))
    print(show_turn(dial_id, turn_id, filt='10100'))

['WOZ20299.json', 'SNG01608.json']
dial = 1, turn = 0
[31mUser   (raw):
Are there any Portuguese restaurants in Cambridge?
[0m
[34mSystem (raw):
Yes there is a Portuguese restaurant in Cambridge  with two different locations, would you like the addresses?
[0m
dial = 1, turn = 1
[31mUser   (raw):
If one of them has a moderate price range please give me that address. If not tell me about Turkish restaurants instead.
[0m
[34mSystem (raw):
I have two Turkish restuarants, both in the centre and both expensive. May I recommend anatolia?
[0m
dial = 1, turn = 2
[31mUser   (raw):
Actually I need a moderately priced restaurant. Are there any fitting that description?
[0m
[34mSystem (raw):
I am sorry. I mistook that price range. The Anatolia is in the moderate range. Would that work for you?
[0m
dial = 0, turn = 0
[31mUser   (raw):
I want to find a cheap restaurant in the south part of town. 
[0m
[34mSystem (raw):
The Lucky Star is an inexpensive chinese restaurant in the south par