In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [1]:
import glob
import os
import json
import numpy as np
import re

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from collections import OrderedDict, defaultdict
from functools import lru_cache
from ipdb import set_trace

from allennlp.data import Instance, Token
from allennlp.data.fields import TextField, ListField, MetadataField, ArrayField, IndexField, Field, AdjacencyField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.token_indexers import PretrainedBertIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers.word_splitter import BertBasicWordSplitter
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.iterators import BasicIterator
from allennlp.models import Model, SimpleSeq2Seq
from allennlp.modules.token_embedders import PretrainedBertEmbedder
from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper
from allennlp.training import Trainer
from allennlp.training.moving_average import ExponentialMovingAverage
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy
from allennlp.nn.util import get_text_field_mask, move_to_device

In [2]:
np.random.seed(1)

In [3]:
class Schema(object):

    def __init__(self, filepath):
        with open(filepath) as f:
            self.index = {}
            for schema in json.load(f):
                service_name = schema["service_name"]
                self.index[service_name] = schema

    @lru_cache(maxsize=None)
    def get(self, service):
        result = dict(
            # service
            name=service,
            desc=self.index[service]["description"],
            
            # slots
            slot_name=[],
            slot_desc=[],
            slot_iscat=[], 
            slot_vals=[], # collected only for cat slots.. not sure if that makes sense

            # intents
            intent_name=[],
            intent_desc=[],
            intent_istrans=[],
            intent_reqslots=[],
            intent_optslots=[],
            intent_optvals=[],
        )

        for slot in self.index[service]["slots"]:
            result["slot_name"].append(slot["name"])
            result["slot_desc"].append(slot["description"])
            result["slot_iscat"].append(slot["is_categorical"])
            result["slot_vals"].append(slot["possible_values"])
        
        for intent in self.index[service]["intents"]:
            result["intent_name"].append(intent["name"])
            result["intent_desc"].append(intent["description"])
            result["intent_istrans"].append(intent["is_transactional"])
            result["intent_reqslots"].append(intent["required_slots"])
            result["intent_optslots"].append(list(intent["optional_slots"].keys()))
            result["intent_optvals"].append(list(intent["optional_slots"].values()))

        return result

In [4]:
class Memory(object):
    
    def __init__(self, schema, services):
        self.schema = schema
        
        # memory for cat slots: serv,slot=>[] or for noncat slots: noncat => []
        self.memory = defaultdict(list)
        self.index = defaultdict(set) # serv,slot->val, noncat->val

        for serv in services:
            # add possible values from slot
            sch = schema.get(serv)
            for slot, iscat, slotvals in zip(sch["slot_name"], sch["slot_iscat"], sch["slot_vals"]):
                key = (serv, slot) if iscat else "noncat"
                for val in ["NONE", "dontcare"] + slotvals:
                    if val not in self.index[key]:
                        self.index[key].add(val)
                        self.memory[key].append(val)

            # add optional slot vals
            for optslots, optvals in zip(sch["intent_optslots"], sch["intent_optvals"]):
                for slot, val in zip(optslots, optvals):
                    slotid = sch["slot_name"].index(slot)
                    iscat = sch["slot_iscat"][slotid]
                    assert slotid != -1
                    key = (serv, slot) if iscat else "noncat"
                    if val not in self.index[key]:
                        self.index[key].add(val)
                        self.memory[key].append(val)
                        
    def update(self, dial_turn):
        # update only noncat values..
        utter = dial_turn["utterance"]
        for frame in dial_turn["frames"]:
            sch = self.schema.get(frame["service"])
            slot_names = sch["slot_name"]
            slot_iscat = sch["slot_iscat"]

            for tag in frame["slots"]:
                slot, st, en = tag["slot"], tag["start"], tag["exclusive_end"]
                slotid = slot_names.index(slot)
                iscat = slot_iscat[slotid]
                assert slotid != -1

                if not iscat:
                    value = utter[st:en]
                    key = "noncat"
                    if value not in self.index[key]:
                        value = re.sub("\u2013", "-", value) # dial 59_00125 turn 14
                        self.index[key].add(value)
                        self.memory[key].append(value)
                        
    def get(self, key="noncat"):
        return self.memory[key]
    
    
class DialogReader(DatasetReader):

    def __init__(self, schema, limit, lazy=False):
        super().__init__(lazy)
        self.token_indexers = {"tokens": PretrainedBertIndexer("bert-base-uncased")}
        self.tokenizer = BertBasicWordSplitter()
        self.schema = schema
        self.limit = limit

    def _read(self, path):
        # get a set of dialogs
        count = 0
        dialogs = []
        for filename in glob.glob(path):
            if count > self.limit:
                break
            with open(filename) as f:
                for d in json.load(f):
                    dialogs.append(d)
                    count += 1
                    if count > self.limit:
                        break
        
        # prepare instances
        for dial in dialogs:
            memory = Memory(self.schema, dial["services"])
            for turnid, turn in enumerate(dial["turns"]):
                memory.update(turn)
                if turn["speaker"] == "USER":
                    usr_utter = turn["utterance"]
                    sys_utter = dial["turns"][turnid-1]["utterance"] if turnid > 0 else "dialog started"
                    num_none_questions  = 0
                    
                    for frame in turn["frames"]:
                        # get schema info
                        serv = frame["service"]
                        sch = self.schema.get(serv)
                        
                        # intent
                        intent = frame["state"]["active_intent"]
                        all_intents = {s: i for i, s in enumerate(sch["intent_name"])}
                        intent_istrans = False
                        intent_desc = "No intent"
                        if intent != "NONE":
                            intentid = all_intents[intent]
                            assert intentid != -1
                            intent_desc = sch["intent_desc"][intentid]
                            intent_istrans = sch["intent_istrans"][intentid]
                        
                        # slots
                        all_slots = {s: i for i, s in enumerate(sch["slot_name"])}
                        all_slots_iscat = sch["slot_iscat"]
                        all_slots_desc = sch["slot_desc"]
                        active_slots = frame["state"]["slot_values"]
                        none_slots = set(all_slots) - set(active_slots)
                        
                        # active slots
                        for slot, values in active_slots.items():
                            slotid = all_slots[slot]
                            assert slotid != -1
                            key = (serv, slot) if all_slots_iscat[slotid] else "noncat"
                            target_value = re.sub("\u2013", "-", values[0])
                            
                            item = dict(
                                dialid=dial["dialogue_id"],
                                turnid=turnid,
                                usr_utter=usr_utter,
                                sys_utter=sys_utter,
                                serv=serv,
                                serv_desc=sch["desc"],
                                slot=slot,
                                slot_desc=all_slots_desc[slotid],
                                slot_iscat=all_slots_iscat[slotid],
                                slot_val=target_value,
                                intent=intent,
                                intent_desc=intent_desc,
                                intent_istrans=intent_istrans,
                                memory=memory.get(key),
                            )
                            yield self.text_to_instance(item)
                            
                        # none valued slots
                        for slot in none_slots:
                            if np.random.randn() > 0.5 and num_none_questions < 3:
                                num_none_questions += 1
                                slotid = all_slots[slot]
                                assert slotid != -1
                                key = (serv, slot) if all_slots_iscat[slotid] else "noncat"
                                target_value = "NONE"
                                item = dict(
                                    dialid=dial["dialogue_id"],
                                    turnid=turnid,
                                    usr_utter=usr_utter,
                                    sys_utter=sys_utter,
                                    serv=serv,
                                    serv_desc=sch["desc"],
                                    slot=slot,
                                    slot_desc=all_slots_desc[slotid],
                                    slot_iscat=all_slots_iscat[slotid],
                                    slot_val=target_value,
                                    intent=intent,
                                    intent_desc=intent_desc,
                                    intent_istrans=intent_istrans,
                                    memory=memory.get(key),
                                )
                                yield self.text_to_instance(item)
                        
            
    def text_to_instance(self, item):
        fields = {}
        
        # featurize query
        q_fields = ("sys_utter", "usr_utter", "serv_desc", "slot_desc", "intent_desc")
        query = " . ".join(item[f] for f in q_fields)
        query_tokens = self.tokenizer.split_words(query)
        fields["query"] = TextField(query_tokens, self.token_indexers)
        
        # featurize memory
        mem_values = item["memory"]
        mem_tokens = []
        mem_target = []
        for mem_val in mem_values:
            tokens = self.tokenizer.split_words(mem_val)
            target = int(mem_val == item["slot_val"])
            for t in tokens:
                mem_tokens.append(t)
                mem_target.append(target)
        
        fields["memory"] = TextField(mem_tokens, self.token_indexers)
        fields["memory_target"] = ArrayField(np.array(mem_target))
        
        # positional fields
        fields["turnid"] =  ArrayField(np.array(item["turnid"]))
        
        # meta fields
        fields["id"] = MetadataField("{}/{}/{}/{}".format(item["dialid"], item["turnid"], item["serv"], item["slot"]))
        fields["slot"] = MetadataField(item["slot"])
        fields["serv"] = MetadataField(item["serv"])
        fields["intent"] = MetadataField(item["intent"])
        fields["dialid"] = MetadataField(item["dialid"])
        
        return Instance(fields)

In [5]:
train_schema = Schema("../data/train/schema.json")
dev_schema = Schema("../data/dev/schema.json")

In [6]:
reader = DialogReader(train_schema, limit=10)
train_ds = reader.read("../data/train/dialogues*.json")

reader = DialogReader(dev_schema, limit=10)
dev_ds = reader.read("../data/dev/dialogues*.json")

vocab = Vocabulary.from_instances(train_ds + dev_ds)


it = BasicIterator(batch_size=32)
it.index_with(vocab)
batch = next(iter(it(train_ds)))
batch.keys()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




HBox(children=(IntProgress(value=0, max=1966), HTML(value='')))




dict_keys(['query', 'memory', 'memory_target', 'turnid', 'id', 'slot', 'serv', 'intent', 'dialid'])

In [20]:
class CandidateSelector(Model):
    
    def __init__(self, vocab):
        super().__init__(vocab)
        self.emb = PretrainedBertEmbedder("bert-base-uncased", requires_grad=True)
        emb_dim = self.emb.get_output_dim()
        
        self.register_buffer("query_state", None)
        self.query_enc = nn.GRU(emb_dim, emb_dim, batch_first=True)
        
        self.memory_dec = nn.GRU(2 * emb_dim + 1, emb_dim, batch_first=True)
        
        self.attn_w = nn.Linear(emb_dim * 2, emb_dim)
        self.attn_v = nn.Linear(emb_dim, 1, bias=False)
        
        self.final = nn.Linear(emb_dim, 1)
        
        self.accuracy = BooleanAccuracy()
        
    def get_metrics(self, reset=False):
        return {"acc": self.accuracy.get_metric(reset)}
    
    def encoder(self, inputs, reset=False):
        hidden = self.query_state
        if reset:
            hidden = None
        
        # if hidden size is uneven
        if hidden is not None:
            hidden_bs = hidden.shape[1]
            inputs_bs = inputs.shape[0]
            if hidden_bs > inputs_bs:
                hidden = hidden[:,:inputs_bs,:]
            elif hidden_bs < inputs_bs:
                padding = torch.zeros(
                    hidden.shape[0], inputs_bs-hidden_bs, hidden.shape[2],
                    device=hidden.device, dtype=hidden.dtype)
                hidden = torch.cat([hidden, padding], dim=1)
        
        # forward pass
        outputs, hidden = self.query_enc(inputs, hidden)
        self.query_state = hidden.detach()
        return outputs, hidden
    
    def attention(self, mem_h, query_o):
        # mem_h: l*d,b,e  query_o: bse
        energies = []
        
        sh = mem_h.shape
        mem_h = mem_h.view(sh[1], 1, -1) # b,1, l*d*e
        mem_h = mem_h.repeat(1, query_o.shape[1], 1) # b,s,lde
        
        x = torch.cat((query_o, mem_h), -1) # b,s,e+lde
        energy = torch.tanh(self.attn_w(x)) # bse
        energy = self.attn_v(energy).squeeze(-1) # bs1
        
        attn = F.softmax(energy, -1) # bs
        
        # context
        context = torch.bmm(attn.unsqueeze(1), query_o) # b1e
        
        return context # b1e
    
    def decoder(self, query_o, query_h, mem_inp, mem_out):
        bs = mem_inp.shape[0]

        predicted = []
        prev_out = torch.zeros(bs, 1, 1, device=mem_inp.device, dtype=mem_inp.dtype) # b11
        dec_h = query_h
        
        for m in range(mem_inp.shape[1]):
            inp = mem_inp[:,m:m+1,:] # b1e
            context = self.attention(dec_h, query_o) # bwe
            
            combined_inp = torch.cat((inp, context, prev_out), dim=-1) # b,1,e2+1
            dec_o, dec_h = self.memory_dec(combined_inp, dec_h)
            
            predicted.append(dec_o)
            prev_loc = mem_out[:,m:m+1,None] # teacher forcing
        
        predicted = torch.cat(predicted, dim=1) # bme
        
        return predicted
    
    def forward(self, **batch):
        query = self.emb(batch["query"]["tokens"], batch["query"]["tokens-offsets"]) # bse
        memory = self.emb(batch["memory"]["tokens"], batch["memory"]["tokens-offsets"]) # bme
        memory_loc_oh = batch["memory_target"] # bm
        
        # encode query
        is_new_dial = bool((batch["turnid"] == 0).all())
        query_o, query_h = self.encoder(query, reset=is_new_dial) # be
        
        # decode values
        mem_o = self.decoder(query_o, query_h, memory, memory_loc_oh) # bme
        
        # final logits
        predicted = self.final(mem_o).squeeze(-1) # bm
        
        # loss
        mask = batch["memory"]["mask"].float()
        loss = F.binary_cross_entropy_with_logits(predicted, memory_loc_oh, mask).unsqueeze(0)
        
        # metric
        predicted_loc = predicted.argmax(-1) # b
        memory_loc = memory_loc_oh.argmax(-1) # b
        mask = batch["memory"]["mask"][torch.arange(memory.shape[0]), memory_loc] # b
        self.accuracy(predicted_loc, memory_loc, mask)
        
        output = dict(
            logits=predicted,
            loss=loss,
            pred=predicted_loc,
            target=memory_loc,
        )
        
        return output

In [None]:
!rm -rf ../results/4

allen_device=1
torch_device=1

model = CandidateSelector(vocab).to(torch_device)
optimizer = optim.Adam(model.parameters(), lr=3e-5)
ema = ExponentialMovingAverage(model.named_parameters())

iterator = BasicIterator(batch_size=32)
iterator.index_with(vocab)


trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    num_epochs=10,
    cuda_device=allen_device,
    #serialization_dir="../results/4",
    #should_log_learning_rate=True,
    #histogram_interval=30,
    num_serialized_models_to_keep=1,
    #moving_average=ema,
)

trainer.train()

HBox(children=(IntProgress(value=0, max=27), HTML(value='')))

HBox(children=(IntProgress(value=0, max=27), HTML(value='')))

HBox(children=(IntProgress(value=0, max=27), HTML(value='')))

HBox(children=(IntProgress(value=0, max=27), HTML(value='')))

HBox(children=(IntProgress(value=0, max=27), HTML(value='')))

HBox(children=(IntProgress(value=0, max=27), HTML(value='')))