In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [9]:
import glob
import os
import json
import numpy as np
import re
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from collections import OrderedDict, defaultdict
from functools import lru_cache
from ipdb import set_trace
from tqdm import tqdm_notebook as tqdm
from tabulate import tabulate

from allennlp.data import Instance, Token
from allennlp.data.fields import TextField, ListField, MetadataField, ArrayField, IndexField, Field, AdjacencyField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.tokenizers import Token
from allennlp.data.token_indexers import PretrainedBertIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers.word_splitter import BertBasicWordSplitter
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.iterators import BasicIterator
from allennlp.models import Model, SimpleSeq2Seq
from allennlp.modules.token_embedders import PretrainedBertEmbedder
from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertModel
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper
from allennlp.training import Trainer
from allennlp.training.learning_rate_schedulers import SlantedTriangular, NoamLR, CosineWithRestarts
from allennlp.training.moving_average import ExponentialMovingAverage
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy
from allennlp.nn.util import get_text_field_mask, move_to_device

In [2]:
np.random.seed(1)

In [3]:
class Schema(object):

    def __init__(self, filepath):
        with open(filepath) as f:
            self.index = {}
            for schema in json.load(f):
                service_name = schema["service_name"]
                self.index[service_name] = schema

    @lru_cache(maxsize=None)
    def get(self, service):
        result = dict(
            # service
            name=service,
            desc=self.index[service]["description"],
            
            # slots
            slot_name=[],
            slot_desc=[],
            slot_iscat=[], 
            slot_vals=[], # collected only for cat slots.. not sure if that makes sense

            # intents
            intent_name=[],
            intent_desc=[],
            intent_istrans=[],
            intent_reqslots=[],
            intent_optslots=[],
            intent_optvals=[],
        )

        for slot in self.index[service]["slots"]:
            result["slot_name"].append(slot["name"])
            result["slot_desc"].append(slot["description"])
            result["slot_iscat"].append(slot["is_categorical"])
            result["slot_vals"].append(slot["possible_values"])
        
        for intent in self.index[service]["intents"]:
            result["intent_name"].append(intent["name"])
            result["intent_desc"].append(intent["description"])
            result["intent_istrans"].append(intent["is_transactional"])
            result["intent_reqslots"].append(intent["required_slots"])
            result["intent_optslots"].append(list(intent["optional_slots"].keys()))
            result["intent_optvals"].append(list(intent["optional_slots"].values()))

        return result

In [8]:
t=BertBasicWordSplitter()
type(t.split_words("hi [SEP] hi"))

list

In [13]:
class Memory(object):
    
    def __init__(self, schema, services):
        self.schema = schema
        
        # memory for cat slots: serv,slot=>[] or for noncat slots: noncat => []
        self.memory = defaultdict(list)
        self.index = defaultdict(set) # serv,slot->val, noncat->val

        for serv in services:
            # add possible values from slot
            sch = schema.get(serv)
            for slot, iscat, slotvals in zip(sch["slot_name"], sch["slot_iscat"], sch["slot_vals"]):
                key = (serv, slot) if iscat else "noncat"
                for val in ["NONE", "dontcare"] + slotvals:
                    if val not in self.index[key]:
                        self.index[key].add(val)
                        self.memory[key].append(val)

            # add optional slot vals
            for optslots, optvals in zip(sch["intent_optslots"], sch["intent_optvals"]):
                for slot, val in zip(optslots, optvals):
                    slotid = sch["slot_name"].index(slot)
                    iscat = sch["slot_iscat"][slotid]
                    assert slotid != -1
                    key = (serv, slot) if iscat else "noncat"
                    if val not in self.index[key]:
                        self.index[key].add(val)
                        self.memory[key].append(val)
                        
    def update(self, dial_turn):
        # update only noncat values..
        utter = dial_turn["utterance"]
        for frame in dial_turn["frames"]:
            sch = self.schema.get(frame["service"])
            slot_names = sch["slot_name"]
            slot_iscat = sch["slot_iscat"]

            for tag in frame["slots"]:
                slot, st, en = tag["slot"], tag["start"], tag["exclusive_end"]
                slotid = slot_names.index(slot)
                iscat = slot_iscat[slotid]
                assert slotid != -1

                if not iscat:
                    value = utter[st:en]
                    key = "noncat"
                    if value not in self.index[key]:
                        value = re.sub("\u2013", "-", value) # dial 59_00125 turn 14
                        self.index[key].add(value)
                        self.memory[key].append(value)
                        
    def get(self, key="noncat"):
        return self.memory[key]
    
    
class DialogReader(DatasetReader):

    def __init__(self, schema, limit, lazy=False):
        super().__init__(lazy)
        self.token_indexers = {"tokens": PretrainedBertIndexer("bert-base-uncased")}
        self.tokenizer = BertBasicWordSplitter()
        self.schema = schema
        self.limit = limit

    def _read(self, path):
        # get a set of dialogs
        count = 0
        dialogs = []
        for filename in sorted(glob.glob(path)):
            if count > self.limit:
                break
            with open(filename) as f:
                for d in json.load(f):
                    dialogs.append(d)
                    count += 1
                    if count > self.limit:
                        break
        
        # prepare instances
        for dial in dialogs:
            memory = Memory(self.schema, dial["services"])
            for turnid, turn in enumerate(dial["turns"]):
                memory.update(turn)
                if turn["speaker"] == "USER":
                    usr_utter = turn["utterance"]
                    sys_utter = dial["turns"][turnid-1]["utterance"] if turnid > 0 else "dialog started"
                    num_none_questions  = 0
                    
                    for frame in turn["frames"]:
                        # get schema info
                        serv = frame["service"]
                        sch = self.schema.get(serv)
                        
                        # intent
                        intent = frame["state"]["active_intent"]
                        all_intents = {s: i for i, s in enumerate(sch["intent_name"])}
                        intent_istrans = False
                        intent_desc = "No intent"
                        if intent != "NONE":
                            intentid = all_intents[intent]
                            assert intentid != -1
                            intent_desc = sch["intent_desc"][intentid]
                            intent_istrans = sch["intent_istrans"][intentid]
                        
                        # slots
                        all_slots = {s: i for i, s in enumerate(sch["slot_name"])}
                        all_slots_iscat = sch["slot_iscat"]
                        all_slots_desc = sch["slot_desc"]
                        active_slots = frame["state"]["slot_values"]
                        none_slots = set(all_slots) - set(active_slots)
                        
                        for slot, values in active_slots.items():
                            slotid = all_slots[slot]
                            assert slotid != -1
                            key = (serv, slot) if all_slots_iscat[slotid] else "noncat"
                            target_value = re.sub("\u2013", "-", values[0])
                            
                            item = dict(
                                dialid=dial["dialogue_id"],
                                turnid=turnid,
                                usr_utter=usr_utter,
                                sys_utter=sys_utter,
                                serv=serv,
                                serv_desc=sch["desc"],
                                slot=slot,
                                slot_desc=all_slots_desc[slotid],
                                slot_iscat=all_slots_iscat[slotid],
                                slot_val=target_value,
                                intent=intent,
                                intent_desc=intent_desc,
                                intent_istrans=intent_istrans,
                                memory=memory.get(key),
                            )
                            yield self.text_to_instance(item)
                            
#                         # none valued slots
#                         for slot in none_slots:
#                             if np.random.randn() > 0.5 and num_none_questions < 3:
#                                 num_none_questions += 1
#                                 slotid = all_slots[slot]
#                                 assert slotid != -1
#                                 key = (serv, slot) if all_slots_iscat[slotid] else "noncat"
#                                 target_value = "NONE"
#                                 item = dict(
#                                     dialid=dial["dialogue_id"],
#                                     turnid=turnid,
#                                     usr_utter=usr_utter,
#                                     sys_utter=sys_utter,
#                                     serv=serv,
#                                     serv_desc=sch["desc"],
#                                     slot=slot,
#                                     slot_desc=all_slots_desc[slotid],
#                                     slot_iscat=all_slots_iscat[slotid],
#                                     slot_val=target_value,
#                                     intent=intent,
#                                     intent_desc=intent_desc,
#                                     intent_istrans=intent_istrans,
#                                     memory=memory.get(key),
#                                 )
#                                 yield self.text_to_instance(item)
                        
            
    def text_to_instance(self, item):
        fields = {}
        SEP_token = Token("[SEP]")
        
        # query part tokens
        sys_utter_tokens = self.tokenizer.split_words(item["sys_utter"])
        usr_utter_tokens = self.tokenizer.split_words(item["usr_utter"])
        serv_desc_tokens = self.tokenizer.split_words(item["serv_desc"])
        slot_desc_tokens = self.tokenizer.split_words(item["slot_desc"])
        
        # query part features
        sys_utter = [] # [utter ; candidate-value-1 , ... ]
        usr_utter = []
        serv_desc = []
        slot_desc = []
        
        # memory features
        mem_loc = []
        
        for index, mem_val in enumerate(item["memory"]):
            mem_tokens = self.tokenizer.split_words(mem_val)
            mem_loc.append(int(mem_val == item["slot_val"]))
            sys_utter.append(TextField(sys_utter_tokens + [SEP_token] + mem_tokens, self.token_indexers))
            usr_utter.append(TextField(usr_utter_tokens + [SEP_token] + mem_tokens, self.token_indexers))
            serv_desc.append(TextField(serv_desc_tokens + [SEP_token] + mem_tokens, self.token_indexers))
            slot_desc.append(TextField(slot_desc_tokens + [SEP_token] + mem_tokens, self.token_indexers))
        
        fields["sys_utter"] = ListField(sys_utter)
        fields["usr_utter"] = ListField(usr_utter)
        fields["serv_desc"] = ListField(serv_desc)
        fields["slot_desc"] = ListField(slot_desc)
        fields["memory_loc"] = ArrayField(np.array(mem_loc), padding_value=-1)

        # positional fields
        fields["turnid"] =  ArrayField(np.array(item["turnid"]), padding_value=-1)
        
        # meta fields
        fields["id"] = MetadataField("{}/{}/{}/{}".format(item["dialid"], item["turnid"], item["serv"], item["slot"]))
        fields["slot"] = MetadataField(item["slot"])
        fields["serv"] = MetadataField(item["serv"])
        fields["intent"] = MetadataField(item["intent"])
        fields["dialid"] = MetadataField(item["dialid"])
        fields["memory_values"] = MetadataField(item["memory"])
        
        return Instance(fields)

In [14]:
train_schema = Schema("../data/train/schema.json")
dev_schema = Schema("../data/dev/schema.json")

In [15]:
# read full dataset
reader = DialogReader(train_schema, limit=1000)
train_ds = reader.read("../data/train/dialogues*.json")

reader = DialogReader(dev_schema, limit=50)
dev_ds = reader.read("../data/dev/dialogues*.json")

vocab = Vocabulary.from_instances(train_ds + dev_ds)


it = BasicIterator(batch_size=32)
it.index_with(vocab)
batch = next(iter(it(train_ds)))
batch.keys()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=25316), HTML(value='')))




dict_keys(['sys_utter', 'usr_utter', 'serv_desc', 'slot_desc', 'memory_loc', 'turnid', 'id', 'slot', 'serv', 'intent', 'dialid', 'memory_values'])

In [19]:
batch["sys_utter"]["tokens-type-ids"][0]

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0,
         0, 0

In [21]:
# shapes
for f in batch:
    if type(batch[f]) is torch.Tensor:
        print(f, "->", batch[f].shape)
    elif type(batch[f]) is dict and type(batch[f]["tokens"]) is torch.Tensor:
        print(f, "->", batch[f]["tokens"].shape)

sys_utter -> torch.Size([32, 20, 44])
usr_utter -> torch.Size([32, 20, 27])
serv_desc -> torch.Size([32, 20, 24])
slot_desc -> torch.Size([32, 20, 19])
memory_loc -> torch.Size([32, 20])
turnid -> torch.Size([32])


In [9]:
# clear bert cache
keys = list(PretrainedBertModel._cache.keys())
for k in keys:
    del PretrainedBertModel._cache[k]

In [10]:
class CandidateSelector(Model):
    
    def __init__(self, vocab):
        super().__init__(vocab)
        # query encoder
        self.emb = PretrainedBertEmbedder("bert-base-uncased", requires_grad=True)
        emb_dim = self.emb.get_output_dim()
        
        # metrics
        self.accuracy = BooleanAccuracy()
        
    def get_metrics(self, reset=False):
        return {"acc": self.accuracy.get_metric(reset)}
    
    def encoder(self, batch):
        q = batch["query"]
        m = batch["memory"]
        x_tokens = torch.cat((q["tokens"], m["tokens"].flatten(1, 2)), -1) # b,s1+s2
        x_offsets = torch.cat((q["tokens-offsets"], m["tokens-offsets"].flatten(1,2)), -1)
        x_types = torch.cat((q["tokens-type-ids"], (m["tokens-type-ids"]+1).flatten(1,2)), -1)
        
        enc = self.emb(x_tokens, x_offsets, x_types)
        
        q_len = q["mask"].shape[-1]
        tot = enc.shape[1]

        query_enc, mem_enc = enc.split([q_len, tot - q_len], 1)
        
        mem_sh = list(m["mask"].shape)
        mem_enc = mem_enc.view(mem_sh + [-1])
        
        return query_enc, mem_enc
    
    
    def decoder(self, query, memory):
        # query: encoder output -- batch, seq, emb
        # memory: decoder input -- batch, mem, seq, emb
        memory = memory.sum(2)
        
        query = query.permute(1,0,2) # seq, batch, emb
        memory = memory.permute(1,0,2)
        
        x = self.dec(memory, query)
        x = x.permute(1,0,2)
        
        return x # bme
    
    def mse_loss(self, decoded, target):
        # decoded: batch, mem, emb 
        # tgt: batch, mem
        predicted = F.softmax(self.final(decoded).squeeze(-1), -1) # batch, mem
        
        # loss
        mask = (target != -1).float()
        loss = F.mse_loss(predicted * mask, target * mask).unsqueeze(0)
        
        # metric
        predicted_loc = predicted.argmax(-1)
        target_loc = target.argmax(-1)

        mask_loc = mask[torch.arange(mask.shape[0]), target_loc]
        self.accuracy(predicted_loc, target_loc, mask_loc.long())
    
        return loss, predicted_loc, target_loc
    
    def bce_loss(self, decoded, target, query_iscat):
        # decoded: batch, mem, emb
        # tgt: batch, mem
        predicted = self.cat_final(decoded).squeeze(-1)
        
        # loss
        loss = F.cross_entropy(predicted, target.argmax(-1), ignore_index=-1).unsqueeze(0)
        
        # metric
        predicted_loc = predicted.argmax(-1)
        target_loc = target.argmax(-1)

        mask = (target != -1).float()
        mask_loc = mask[torch.arange(mask.shape[0]), target_loc]
        self.accuracy(predicted_loc, target_loc, mask_loc.long())
        
        return loss, predicted_loc, target_loc
    
    def bce_loss_1(self, decoded, target, query_iscat):
        # decoded: batch, mem, emb
        # tgt: batch, mem
        cat = (query_iscat == 1).nonzero().squeeze(-1)
        noncat = (query_iscat == 0).nonzero().squeeze(-1)
        
        cat_decoded = decoded[cat]
        cat_out = self.cat_final(cat_decoded).squeeze(-1) # bm
        cat_target = target.argmax(-1)[cat]
        cat_loss = 0
        if len(cat_out) > 0 and len(cat_target) > 0:
            cat_loss = F.cross_entropy(cat_out, cat_target, ignore_index=-1).unsqueeze(0)
        
        noncat_decoded = decoded[noncat]
        noncat_out = self.noncat_final(noncat_decoded).squeeze(-1)
        noncat_target = target.argmax(-1)[noncat]
        noncat_loss = 0
        if len(noncat_out) > 0 and len(noncat_target) > 0:
            noncat_loss = F.cross_entropy(noncat_out, noncat_target, ignore_index=-1).unsqueeze(0)
        
        loss = cat_loss + noncat_loss
        
        predicted = torch.zeros(decoded.shape[:2], device=target.device).float() # bm
        predicted[cat] = cat_out
        predicted[noncat] = noncat_out
        
        # metric
        predicted_loc = predicted.argmax(-1)
        target_loc = target.argmax(-1)

        mask = (target != -1).float()
        mask_loc = mask[torch.arange(mask.shape[0]), target_loc]
        self.accuracy(predicted_loc, target_loc, mask_loc.long())
        
        return loss, predicted_loc, target_loc
    
    def bce_loss_2(self, decoded, target, query_iscat):
        # decoded: batch, mem, emb
        # tgt: batch, mem
        cat_mask = (query_iscat == 1).float() # b
        noncat_mask = (query_iscat == 0).float()
        
        cat = self.cat_final(decoded).squeeze(-1) # bm
        noncat = self.noncat_final(decoded).squeeze(-1)
        
        predicted =  cat * cat_mask[:,None] + noncat * noncat_mask[:,None]

        # loss
        cat_loss = F.cross_entropy(predicted * cat_mask[:,None], target.argmax(-1) * cat_mask.long(), ignore_index=-1)
        noncat_loss = F.cross_entropy(predicted * noncat_mask[:,None], target.argmax(-1) * noncat_mask.long(), ignore_index=-1)
        loss = cat_loss + noncat_loss
        
        # metric
        predicted_loc = predicted.argmax(-1)
        target_loc = target.argmax(-1)

        mask = (target != -1).float()
        mask_loc = mask[torch.arange(mask.shape[0]), target_loc]
        self.accuracy(predicted_loc, target_loc, mask_loc.long())
        
        return loss, predicted_loc, target_loc
    
    def forward(self, **batch):
        query, memory = self.encoder(batch)
        
        decoded = self.decoder(query, memory) # batch, mem, emb

        target = batch["memory_loc"] # [batch, mem]
        query_iscat = batch["query_iscat"]
        loss, predicted_loc, target_loc = self.bce_loss(decoded, target, query_iscat)
        
        output = dict(
            loss=loss,
            pred=predicted_loc,
            target=target_loc,
        )
        
        return output

In [30]:
# litmus test on CPU
model = CandidateSelector(vocab).to("cpu")
batch = move_to_device(batch, -1)
model(**batch)

{'loss': tensor(5.8197, grad_fn=<AddBackward0>),
 'pred': tensor([ 4,  3,  5,  7,  5, 13, 10,  5,  8, 17,  1,  2,  4,  7,  6,  9,  6,  1,
          1, 16,  5,  6,  8,  9, 14, 14,  0, 12,  9,  5,  1,  6]),
 'target': tensor([ 9,  3,  3,  3,  9,  1,  8,  9,  2,  9, 13,  9,  2,  2,  3, 15,  9,  3,
         11,  9,  8,  9,  8,  4,  6,  2, 10,  8,  3, 17,  3, 11])}

In [11]:
allen_device=2
torch_device=2

model = CandidateSelector(vocab).to(torch_device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

iterator = BasicIterator(batch_size=16)
iterator.index_with(vocab)

num_steps = iterator.get_num_batches(train_ds)
lr_scheduler = None #SlantedTriangular(optimizer, 3, num_steps)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    num_epochs=3,
    cuda_device=allen_device,
    #learning_rate_scheduler=lr_scheduler,
    #serialization_dir="../results/4",
    #should_log_learning_rate=True,
    #histogram_interval=50,
    #num_serialized_models_to_keep=1,
    grad_norm=5,
    shuffle=False,
)

trainer.train()

HBox(children=(IntProgress(value=0, max=1516), HTML(value='')))

KeyboardInterrupt: 

In [65]:
# inference
results = {}
test_iterator = BasicIterator(batch_size=1)
test_iterator.index_with(vocab)

sample = next(iter(test_iterator(dev_ds, shuffle=False)))
sample = move_to_device(sample, allen_device)

key = sample["id"][0]
output = model(**sample)

t_loc = int(output["target"].item())
p_loc = int(output["pred"].item())

target_val = sample["memory_values"][0][t_loc]
pred_val = sample["memory_values"][0][p_loc]

target_val, pred_val

('2', '2')

In [66]:
def predictor(model, test_ds, device):
    results = defaultdict(OrderedDict)
    test_iterator = BasicIterator(batch_size=32)
    test_iterator.index_with(vocab)
    
    model = move_to_device(model, device)
    model = model.eval()
    
    for sample in tqdm(test_iterator(test_ds, shuffle=False, num_epochs=1)):
        sample = move_to_device(sample, device)
        with torch.no_grad():
            output = model(**sample)
        
        num_samples = output["target"].shape[0]
        for i in range(num_samples):
            key = sample["id"][i]
            t_loc = int(output["target"][i].item())
            p_loc = int(output["pred"][i].item())

            t_val = "UNK"
            p_val = "UNK"
            if t_loc < len(sample["memory_values"][i]):
                t_val = sample["memory_values"][i][t_loc]
            if p_loc < len(sample["memory_values"][i]):
                p_val = sample["memory_values"][i][p_loc]
        
            results[key] = (t_val, p_val, t_loc == p_loc)
    
    return results

In [67]:
tr_results = predictor(model, dev_ds, allen_device)
de_results = predictor(model, train_ds, allen_device)

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [68]:
def results_to_dial(dev_ds, results):
    # index..
    index = {}
    for d in dev_ds:
        index[d["dialogue_id"]] = d
    
    def query_turn(dial, turnid):
        return index[dial]["turns"][turnid]
    
    def query_serv(dial):
        return index[dial]["services"]
    
    dialogs = OrderedDict()
    
    # init dataset
    for res in results:
        dial, turn, serv, slot = res.split("/")
        # GOLD: set services..
        dialogs[dial] = dict(dialogue_id=dial, services=query_serv(dial), turns=[])
            
    # fill dataset
    turn_exists = set()
    slot_exists = set()
    
    for key, res in tqdm(results.items()):
        dial, turn, serv, slot = key.split("/")
        turn = int(turn)
        target_value, pred_value, iscorrect = res
        
        # create turn. turns with no slots need to be created explicitly.
        if (dial, turn) not in turn_exists:
            for t in range(0, turn + 1, 2):
                if (dial, t) not in turn_exists:
                    # GOLD: set utterances. script gives error otherwise
                    sys_gold_text = query_turn(dial, t-1)["utterance"]
                    usr_gold_text = query_turn(dial, t)["utterance"]
                    
                    sys = dict(speaker="SYSTEM", utterance=sys_gold_text, frames=[])
                    usr = dict(speaker="USER", utterance=usr_gold_text, frames=[])
                    
                    # create empty frames
                    for s in dialogs[dial]["services"]:
                        state = dict(active_intent="", slot_values={}, requested_slots=[])
                        frame = dict(service=s, state=state, slots=[])
                        usr["frames"].append(frame)
                        
                    dialogs[dial]["turns"].append(usr)
                    dialogs[dial]["turns"].append(sys)
                    turn_exists.add((dial, t))
            
        # fill slots
        state = dialogs[dial]["turns"][turn]
        for frame in state["frames"]:
            if frame["service"] == serv:
                if (dial, turn, serv, slot) not in slot_exists:
                    frame["state"]["slot_values"][slot] = [pred_value]
                    slot_exists.add((dial, turn, serv, slot))
    
    return list(dialogs.values())

In [74]:
raw_dev_ds = []
for fname in sorted(glob.glob("../data/dev/dialogues*.json")):
    with open(fname) as f:
        ds_list = json.load(f)
    raw_dev_ds.extend(ds_list)
    
raw_train_ds = []
for fname in sorted(glob.glob("../data/train/dialogues*.json")):
    with open(fname) as f:
        ds_list = json.load(f)
    raw_train_ds.extend(ds_list)

In [78]:
dial_results = results_to_dial(raw_dev_ds, tr_results)

with open("../results/4/out/dialogues.json", "w") as f:
    json.dump(dial_results, f, indent=2)

HBox(children=(IntProgress(value=0, max=132), HTML(value='')))




In [79]:
tr_dial_results = results_to_dial(raw_train_ds, de_results)

with open("../results/4/out/dialogues.json", "w") as f:
    json.dump(tr_dial_results, f, indent=2)

HBox(children=(IntProgress(value=0, max=24253), HTML(value='')))




In [59]:
# eval command
#
# python -m schema_guided_dst.evaluate --dstc8_data_dir data --prediction_dir results/4/out --eval_set dev --output_metric_file results/4/eval.json

/home/suryak/dstc8/venv/bin/python: Error while finding module specification for '/home/suryak/dstc8/schema_guided_dst.evaluate' (ModuleNotFoundError: No module named '/home/suryak/dstc8/schema_guided_dst')


In [72]:
tr_dial_results

[]

In [17]:
pd.set_option('display.max_rows', 500)
pd.DataFrame(results.values(), index=results.keys(), columns=["target", "pred", "correct"])

Unnamed: 0,target,pred,correct
1_00000/0/Restaurants_2/number_of_seats,2,2,True
1_00000/0/Restaurants_2/time,half past 11 in the morning,half past 11 in the morning,True
1_00000/2/Restaurants_2/location,San Jose,San Jose,True
1_00000/2/Restaurants_2/number_of_seats,2,2,True
1_00000/2/Restaurants_2/restaurant_name,Sino,Sino,True
1_00000/2/Restaurants_2/time,half past 11 in the morning,San Jose,False
1_00000/4/Restaurants_2/date,today,San Jose,False
1_00000/4/Restaurants_2/location,San Jose,San Jose,True
1_00000/4/Restaurants_2/number_of_seats,2,2,True
1_00000/4/Restaurants_2/restaurant_name,Sino,11:30 am,False
