In [165]:
import json
import glob
from functools import lru_cache
from collections import OrderedDict, defaultdict
from copy import deepcopy
import numpy as np
from tqdm import tqdm_notebook as tqdm
from ipdb import set_trace
import re
import os

In [215]:
np.random.seed(1)

In [7]:
class Schemas(object):

    def __init__(self, filepath):
        with open(filepath) as f:
            self.index = {}
            for schema in json.load(f):
                service_name = schema["service_name"]
                self.index[service_name] = schema

    @lru_cache(maxsize=None)
    def get(self, service):
        result = dict(
            # service
            name=service,
            desc=self.index[service]["description"],
            
            # slots
            slot_name=[],
            slot_desc=[],
            slot_iscat=[], 
            slot_vals=[], # collected only for cat slots.. not sure if that makes sense

            # intents
            intent_name=[],
            intent_desc=[],
            intent_istrans=[],
            intent_reqslots=[],
            intent_optslots=[],
            intent_optvals=[],
        )

        for slot in self.index[service]["slots"]:
            result["slot_name"].append(slot["name"])
            result["slot_desc"].append(slot["description"])
            result["slot_iscat"].append(slot["is_categorical"])
            result["slot_vals"].append(slot["possible_values"])
        
        for intent in self.index[service]["intents"]:
            result["intent_name"].append(intent["name"])
            result["intent_desc"].append(intent["description"])
            result["intent_istrans"].append(intent["is_transactional"])
            result["intent_reqslots"].append(intent["required_slots"])
            result["intent_optslots"].append(list(intent["optional_slots"].keys()))
            result["intent_optvals"].append(list(intent["optional_slots"].values()))

        return result    

In [231]:
# convert dial
def init_memory(schema, dial):
    # memory for cat slots: serv,slot=>[] or for noncat slots: noncat => []
    memory = defaultdict(list)
    index = defaultdict(set) # serv,slot->val, noncat->val
    
    for serv in dial["services"]:
        # add possible values from slot
        sch = schema.get(serv)
        for slot, iscat, slotvals in zip(sch["slot_name"], sch["slot_iscat"], sch["slot_vals"]):
            key = (serv, slot) if iscat else "noncat"
            for val in ["NONE", "dontcare"] + slotvals:
                if val not in index[key]:
                    index[key].add(val)
                    memory[key].append(val)
                        
        # add optional slot vals
        for optslots, optvals in zip(sch["intent_optslots"], sch["intent_optvals"]):
            for slot, val in zip(optslots, optvals):
                slotid = sch["slot_name"].index(slot)
                iscat = sch["slot_iscat"][slotid]
                assert slotid != -1
                key = (serv, slot) if iscat else "noncat"
                if val not in index[key]:
                    index[key].add(val)
                    memory[key].append(val)

    return memory, index


def update_memory(schema, dial_turn, memory, memory_index):
    # update only noncat values..
    utter = dial_turn["utterance"]
    for frame in dial_turn["frames"]:
        sch = schema.get(frame["service"])
        slot_names = sch["slot_name"]
        slot_iscat = sch["slot_iscat"]
        
        for tag in frame["slots"]:
            slot, st, en = tag["slot"], tag["start"], tag["exclusive_end"]
            slotid = slot_names.index(slot)
            iscat = slot_iscat[slotid]
            assert slotid != -1
            
            if not iscat:
                value = utter[st:en]
                key = "noncat"
                if value not in memory_index[key]:
                    value = re.sub("\u2013", "-", value) # dial 59_00125 turn 14
                    memory_index[key].add(value)
                    memory[key].append(value)


def memory_to_str(memory, key="noncat"):
    value = " . ".join(memory[key]) 
    return value


def query_to_str(items):
    return " . ".join(items)


def dial_to_samples(schema, dial):
    dial_id = dial["dialogue_id"]
    memory, memory_index = init_memory(schema, dial)    
    paragraphs = []
    
    for turnid, turn in enumerate(dial["turns"]):
        # memory comes from slot tagger
        update_memory(schema, turn, memory, memory_index)
        
        if turn["speaker"] == "USER":
            assert dial["turns"][turnid-1]["speaker"] == "SYSTEM"
            
            # utter info
            usr_utter = turn["utterance"]
            sys_utter = dial["turns"][turnid-1]["utterance"] if turnid > 0 else "dialogue started"
            turn_formatted = "this is turn number {}".format(turnid)
            
            # per turn QAS
            cat_qas = defaultdict(dict) # serv,slot->context,question,ans
            noncat_qas = dict(context=memory_to_str(memory), qas=[])
            
            for frame in turn["frames"]:
                serv = frame["service"]
                serv_desc = schema.get(serv)["desc"]
                
                # intent info
                intent = frame["state"]["active_intent"]
                if intent == "NONE":
                    intent_desc = "NONE"
                    intent_formatted = "there is no intent"
                else:
                    idx = schema.get(serv)["intent_name"].index(intent)
                    intent_desc = schema.get(serv)["intent_desc"][idx]
                    intent_trans = schema.get(serv)["intent_istrans"][idx]
                    if intent_trans:
                        intent_formatted = "intent is transactional"
                    else:
                        intent_formatted = "intent is not transactional"
                    assert idx != -1
                
                all_slots = schema.get(serv)["slot_name"]
                all_slots_desc = schema.get(serv)["slot_desc"]
                all_slots_iscat = schema.get(serv)["slot_iscat"] 
                active_slots = frame["state"]["slot_values"] # have values
                
                # fill active cat slots
                for slot, values in active_slots.items():
                    slot_idx = all_slots.index(slot)
                    slot_desc = all_slots_desc[slot_idx]
                    slot_iscat = all_slots_iscat[slot_idx]
                    
                    if slot_iscat:
                        context = memory_to_str(memory, key=(serv,slot))
                        
                        value = re.sub("\u2013", "-", values[0])
                        val_st = context.find(value)
                        val_en = val_st + len(value)
                        assert val_st != -1
                        assert context[val_st:val_en] == value
                        
                        formatted_info = "slot is categorical"
                        question_items = [turn_formatted, sys_utter, usr_utter, serv_desc, slot_desc, intent_desc, intent_formatted, formatted_info]
                        
                        cat_qas[serv, slot] = dict(context=context, qas=[])
                        cat_qas[serv, slot]["qas"].append(dict(
                            id="{}/{}/{}/{}".format(dial_id, turnid, serv, slot),
                            question=query_to_str(question_items),
                            answers=[dict(text=value, answer_start=val_st, answer_end=val_en)],
                            is_impossible=False,
                            is_cat=True,
                        ))
                        
                    else:
                        value = re.sub("\u2013", "-", values[0])
                        val_st = noncat_qas["context"].find(value)
                        val_en = val_st + len(value)
                        assert val_st != -1
                        assert noncat_qas["context"][val_st:val_en] == value
                        
                        formatted_info = "slot is not categorical"
                        question_items = [turn_formatted, sys_utter, usr_utter, serv_desc, slot_desc, intent_desc, intent_formatted, formatted_info]
                        
                        noncat_qas["qas"].append(dict(
                            id="{}/{}/{}/{}".format(dial_id, turnid, serv, slot),
                            question=query_to_str(question_items),
                            answers=[dict(text=value, answer_start=val_st, answer_end=val_en)],
                            is_impossible=False,
                            is_cat=False,
                        ))
                        
                
                # fill NONE valued slots, downsampled;
                for slot in set(all_slots) - set(active_slots):
                    slot_idx = all_slots.index(slot)
                    slot_desc = all_slots_desc[slot_idx]
                    slot_iscat = all_slots_iscat[slot_idx]
                    
                    num_cat_questions = 0
                    for x in cat_qas.values():
                        num_cat_questions += len(x["qas"])
                    
                    num_noncat_questions = len(noncat_qas["qas"])
                    
                    if slot_iscat and np.random.randn() > 0.5 and num_cat_questions < 2:
                        context = memory_to_str(memory, key=(serv,slot))
                        value = "NONE"
                        val_st = context.find(value)
                        val_en = val_st + len(value)
                        assert val_st != -1
                        assert context[val_st:val_en] == value
                        
                        formatted_info = "slot is categorical"
                        question_items = [turn_formatted, sys_utter, usr_utter, serv_desc, slot_desc, intent_desc, intent_formatted, formatted_info]
                        
                        cat_qas[serv, slot] = dict(context=context, qas=[])
                        cat_qas[serv, slot]["qas"].append(dict(
                            id="{}/{}/{}/{}".format(dial_id, turnid, serv, slot),
                            question=query_to_str(question_items),
                            answers=[dict(text=value, answer_start=val_st, answer_end=val_en)],
                            is_impossible=False,
                            is_cat=True,
                        ))
                        
                    elif not slot_iscat and np.random.randn() > 0.5 and num_noncat_questions < 2:
                        value = "NONE"
                        val_st = noncat_qas["context"].find(value)
                        val_en = val_st + len(value)
                        assert val_st != -1
                        assert noncat_qas["context"][val_st:val_en] == value
                        
                        formatted_info = "slot is not categorical"
                        question_items = [turn_formatted, sys_utter, usr_utter, serv_desc, slot_desc, intent_desc, intent_formatted, formatted_info]
                        
                        noncat_qas["qas"].append(dict(
                            id="{}/{}/{}/{}".format(dial_id, turnid, serv, slot),
                            question=query_to_str(question_items),
                            answers=[dict(text=value, answer_start=val_st, answer_end=val_en)],
                            is_impossible=False,
                            is_cat=False,
                        ))
            
            # push all turn paragraphs. ALL CAT and Non CAT questions are locally grouped
            paragraphs.append(dict(
                context=noncat_qas["context"], qas=noncat_qas["qas"],
            ))
            for x in cat_qas.values():
                paragraphs.append(dict(
                    context=x["context"], qas=x["qas"],
                ))

    return dict(
        title=dial_id,
        paragraphs=paragraphs,
    )

In [232]:
def create_dataset(mode):
    files = sorted(glob.glob(f"../data/{mode}/dialogues_*.json"))
    schema = Schemas(f"../data/{mode}/schema.json")
    dataset = {"version": "v0", "data": []}
    for filename in tqdm(files):
        with open(filename) as f:
            dial_list = json.load(f)
        for dial in dial_list:
            samples = dial_to_samples(schema, dial)
            dataset["data"].append(samples)
    return dataset


train_dataset = create_dataset("train")
dev_dataset = create_dataset("dev")

HBox(children=(IntProgress(value=0, max=127), HTML(value='')))

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

In [233]:
print(json.dumps(train_dataset["data"][0], indent=2))

{
  "title": "1_00000",
  "paragraphs": [
    {
      "context": "NONE . dontcare . Mexican . Chinese . Indian . American . Italian . 2019-03-01",
      "qas": [
        {
          "id": "1_00000/0/Restaurants_1/date",
          "question": "this is turn number 0 . dialogue started . I am feeling hungry so I would like to find a place to eat. . A leading provider for restaurant search and reservations . Date for the reservation or to find availability . Find a restaurant of a particular cuisine in a city . intent is not transactional . slot is not categorical",
          "answers": [
            {
              "text": "NONE",
              "answer_start": 0,
              "answer_end": 4
            }
          ],
          "is_impossible": false,
          "is_cat": false
        },
        {
          "id": "1_00000/0/Restaurants_1/cuisine",
          "question": "this is turn number 0 . dialogue started . I am feeling hungry so I would like to find a place to eat. . A leading prov

In [237]:
# all data
with open("../data/bert-all/train.json", "w") as f:
    json.dump(train_dataset, f, indent=2)
    
with open("../data/bert-all/dev.json", "w") as f:
    json.dump(dev_dataset, f, indent=2)

In [234]:
# smaller versions
if not os.path.exists("../data/bert-litmus/"):
    os.makedirs("../data/bert-litmus/")

with open("../data/bert-litmus/train.json", "w") as f:
    small = train_dataset["data"][:2]
    json.dump({"version": "v0-small", "data": small}, f, indent=2)
    
with open("../data/bert-litmus/dev.json", "w") as f:
    small = dev_dataset["data"][:2]
    json.dump({"version": "v0-small", "data": small}, f, indent=2)

In [235]:
# 10 dialogs
if not os.path.exists("../data/bert-10/"):
    os.makedirs("../data/bert-10/")

with open("../data/bert-10/train.json", "w") as f:
    json.dump({"version": "v0-10", "data": train_dataset["data"][:10]}, f, indent=2)
    
with open("../data/bert-10/dev.json", "w") as f:
    json.dump({"version": "v0-10", "data": dev_dataset["data"][:10]}, f, indent=2)

In [236]:
# 100 dialogs
if not os.path.exists("../data/bert-100/"):
    os.makedirs("../data/bert-100/")
    
with open("../data/bert-100/train.json", "w") as f:
    json.dump({"version": "v0-10", "data": train_dataset["data"][:100]}, f, indent=2)
    
with open("../data/bert-100/dev.json", "w") as f:
    json.dump({"version": "v0-10", "data": dev_dataset["data"][:100]}, f, indent=2)