In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [1]:
import torch.utils.data as D
import torch
import numpy as np
import json
import csv
import glob
import re
import concurrent.futures
import os

from copy import deepcopy
from functools import lru_cache
from collections import OrderedDict
from types import SimpleNamespace
from collections.abc import Iterable
from tqdm import tqdm_notebook as tqdm
from pytorch_transformers import BertTokenizer
from ipdb import set_trace

# Data Reader

## Utils

In [2]:
class Tokenizer:
    
    def __init__(self, bert):
        self.bert = bert
    
    def __call__(self, text, include_sep=True):
        tokens = self.bert.tokenize(text)
        if include_sep:
            tokens.insert(0, "[CLS]")
            tokens.append("[SEP]")
        return tokens
    
    
class TokenIndexer:
    
    def __init__(self, bert):
        self.bert = bert
        
    def inv(self, *args, **kw):
        # tokens MUST be list, tensor doesn't work.
        return self.bert.convert_ids_to_tokens(*args, **kw)
        
    def __call__(self, *args, **kw):
        return self.bert.convert_tokens_to_ids(*args, **kw)

In [3]:
def label_binarize(labels, classes):
    # labels: np.array or tensor [batch, 1]
    # classes: [..] list of classes
    # weirdly,`sklearn.preprocessing.label_binarize` returns [1] or [0]
    # instead of onehot ONLY when executing in this script!
    vectors = [np.zeros(len(classes)) for _ in labels]
    for i, label in enumerate(labels):
        for j, c in enumerate(classes):
            if c == label:
                vectors[i][j] = 1
    return np.array(vectors)
    

def label_inv_binarize(vectors, classes):
    # labels: np.array or tensor [batch, classes]
    # classes: [..] list of classes
    # follows sklearn LabelBinarizer.inverse_transform()
    # given all zeros, predicts label at index 0, instead of returning none!
    # sklearn doesn't have functional API of inverse transform
    labels = []
    for each in vectors:
        index = np.argmax(each)
        labels.append(classes[index])
    return labels

In [4]:
def padded_array(array, value=0):
    # TODO: this does not do type checking; and wow it can be slow on strings.
    # expects array to have fixed _number_ of dimensions
    
    # resolve the shape of padded array
    shape_index = {}
    queue = [(array, 0)]
    while queue:
        subarr, dim = queue.pop(0)
        shape_index[dim] = max(shape_index.get(dim, -1), len(subarr))
        for x in subarr:
            if isinstance(x, Iterable) and not isinstance(x, str):
                queue.append((x, dim+1))
    shape = [shape_index[k] for k in range(max(shape_index) + 1)]
    
    # fill the values 
    padded = np.ones(shape) * value
    queue = [(array, [])]
    while queue:
        subarr, index = queue.pop(0)
        for j, x in enumerate(subarr):
            if isinstance(x, Iterable):
                queue.append((x, index + [j]))
            else:
                padded[tuple(index + [j])] = x
    return padded

## Schema Reader

In [5]:
class Schemas(object):

    def __init__(self, filepath):
        with open(filepath) as f:
            self.index = {}
            for schema in json.load(f):
                service_name = schema["service_name"]
                self.index[service_name] = schema

    @lru_cache(maxsize=None)
    def get(self, service):
        result = dict(
            # service
            name=service,
            desc=self.index[service]["description"],
            
            # slots
            slot_name=[],
            slot_desc=[],
            slot_iscat=[], 
            slot_vals=[], # collected only for cat slots.. not sure if that makes sense

            # intents
            intent_name=[],
            intent_desc=[],
            intent_istrans=[],
            intent_reqslots=[],
            intent_optslots=[],
            intent_optvals=[],
        )

        for slot in self.index[service]["slots"]:
            result["slot_name"].append(slot["name"])
            result["slot_desc"].append(slot["description"])
            result["slot_iscat"].append(slot["is_categorical"])
            result["slot_vals"].append(slot["possible_values"])
        
        for intent in self.index[service]["intents"]:
            result["intent_name"].append(intent["name"])
            result["intent_desc"].append(intent["description"])
            result["intent_istrans"].append(intent["is_transactional"])
            result["intent_reqslots"].append(intent["required_slots"])
            result["intent_optslots"].append(list(intent["optional_slots"].keys()))
            result["intent_optvals"].append(list(intent["optional_slots"].values()))

        return result    

## Dialogue Reader

In [77]:
class DialogueDataset(D.Dataset):
    def __init__(self, filename, schemas, tokenizer, token_indexer):
        with open(filename) as f:
            self.ds = json.load(f)
        self.schemas = schemas
        self.tokenizer = tokenizer
        self.token_indexer = token_indexer
        self.dialogues = []
        for dial in self.ds:
            fields = self.dial_to_fields(dial)
            self.dialogues.append(fields)
        self.schemas = None
        self.tokenizer = None
        self.token_indexer = None
            
    def __getitem__(self, idx):
        return self.dialogues[idx]
    
    def __len__(self):
        return len(self.dialogues)
    
    def fd_serv_name(self, dial, fields):
        resp = dict(value=[])
        for service in dial["services"]:
            resp["value"].append(service)
        return resp
        
    def fd_serv_desc(self, dial, fields):
        resp = dict(value=[], tokens=[], ids=[], ids_pos=[], mask=[])
        for service in dial["services"]:
            desc = self.schemas.get(service)["desc"]
            tokens = self.tokenizer(desc)
            ids = self.token_indexer(tokens)
            resp["value"].append(desc)
            resp["tokens"].append(tokens)
            resp["ids"].append(ids)
            resp["ids_pos"].append(list(range(1, len(ids) + 1)))
            resp["mask"].append([1] * len(ids))
        return resp
    
    def fd_slot_name(self, dial, fields):
        resp = {"value": []}
        for serv in dial["services"]:
            schema = self.schemas.get(serv)
            resp["value"].append(schema["slot_name"])
        return resp
    
    def fd_slot_desc(self, dial, fields):
        resp = dict(value=[], tokens=[], ids=[], ids_pos=[], mask=[])
        for serv in dial["services"]:
            s_desc = self.schemas.get(serv)["slot_desc"]
            s_tokens = [self.tokenizer(d) for d in s_desc]
            s_ids = [self.token_indexer(d) for d in s_tokens]
            s_mask = [[1] * len(d) for d in s_tokens]
            resp["value"].append(s_desc)
            resp["tokens"].append(s_tokens)
            resp["ids"].append(s_ids)
            resp["ids_pos"].append([list(range(1, len(i) + 1)) for i in s_ids])
            resp["mask"].append(s_mask)
        return resp

    def fd_slot_memory(self, dial, fields):
        # TODO: iscat feature..
        # memory is a sequence of slot tagger values across all frames in a dial.. that grows per turn
        # maintain snapshot at each turn
        resp = dict(
            value=[], # unfeaturized memory
            value_desc = [],
            value_slot = [], # value -> slots that use it..
            tokens=[], # tokenized memory [turn, mem-size, tokens]
            tokens_desc = [],
            ids=[], # indexed memory [turn mem-size, tokens]
            ids_desc = [], # append desc if slots have same desc!
            ids_pos=[],
            ids_memsize=[], # mem sizes [turn, 1]
            mask=[], # mask on memory [turn, mem-size, tokens]
            mask_desc = [],
        )
        
        # memory is sequential, initialized by vals from schema, only has values
        memory = ["NONE", "dontcare"] # keep these at index 0, 1
        memory_slot = [["*"], ["*"]]
        memory_desc = ["NONE", "DONTCARE"]
        memory_index = {}
        for serv in dial["services"]:
            schema = self.schemas.get(serv)
            servdesc = schema["desc"]
            for slotname, slotdesc, values in zip(schema["slot_name"], schema["slot_desc"], schema["slot_vals"]):
                for val in values:
                    if val not in memory_index:
                        memory.append(val)
                        memory_slot.append([slotname])
                        memory_desc.append(slotdesc + "[SEP]" + servdesc)
                        memory_index[val] = len(memory)
                    else:
                        idx = memory_index[val] - 1
                        if slotname not in memory_slot[idx]:
                            memory_desc[idx] = memory_desc[idx] + "[SEP]" + slotdesc + "[SEP]" + servdesc
                            memory_slot[idx].append(slotname)
        
        # at each user turn create a memory snapshot..
        for turn in dial["turns"]:
            memory = deepcopy(memory)
            memory_desc = deepcopy(memory_desc)
            memory_slot = deepcopy(memory_slot)
            
            # pick slot tagger's ground truth ; categorical slot vals are already initialized!
            utter = turn["utterance"]
            for frame in turn["frames"]:
                schema = self.schemas.get(frame["service"])
                servdesc = schema["desc"]
                for tag in frame["slots"]:
                    slotname = tag["slot"]
                    slotdesc = schema["slot_desc"][schema["slot_name"].index(slotname)] # shit! O(N*2)
                    st, en = tag["start"], tag["exclusive_end"]
                    value = utter[st:en]
                    if value not in memory_index:
                        memory.append(value)
                        memory_slot.append([slotname])
                        memory_desc.append(slotdesc)
                        memory_index[value] = len(memory)
                    else:
                        idx = memory_index[value] - 1
                        if slotname not in memory_slot[idx]:
                            memory_slot[idx].append(slotname)
                            memory_desc[idx] = memory_desc[idx] + "[SEP]" + slotdesc + "[SEP]" + servdesc
                        
            if turn["speaker"] == "USER":
                resp["value"].append(memory)
                resp["value_desc"].append(memory_desc)
                resp["value_slot"].append(memory_slot)
                resp["ids_memsize"].append(len(memory))

        # tokenize and index the memory values
        # value: [turn, values], ids/tokens: [turn, values, tokens]
        for mem_desc_snapshot, mem_snapshot in zip(resp["value_desc"], resp["value"]):
            mem_tokens = []
            mem_tokens_desc = []
            mem_ids = []
            mem_ids_desc = []
            mem_pos = list(range(1, len(mem_snapshot) + 1))
            mem_mask = []
            mem_mask_desc = []
            for desc, val in zip(mem_desc_snapshot, mem_snapshot):
                # featurize memory
                tokens = self.tokenizer(val)
                mem_tokens.append(tokens)
                mem_ids.append(self.token_indexer(tokens))
                mem_mask.append([1] * len(tokens))
                
                # featurize memory desc
                tokens_desc = self.tokenizer(desc)
                mem_tokens_desc.append(tokens_desc)
                mem_ids_desc.append(self.token_indexer(tokens_desc))
                mem_mask_desc.append([1] * len(tokens_desc))
                
            resp["tokens"].append(mem_tokens)
            resp["tokens_desc"].append(mem_tokens_desc)
            resp["ids"].append(mem_ids)
            resp["ids_desc"].append(mem_ids_desc)
            resp["ids_pos"].append(mem_pos)
            resp["mask"].append(mem_mask)
            resp["mask_desc"].append(mem_mask_desc)
        
        return resp
    
    def fd_slot_memory_loc(self, dial, fields):
        resp = dict(
            value=[],  # string value
            ids=[], # memory loc
            ids_onehot=[], # onehot of memory loc
            mask=[],
            mask_onehot=[],
            mask_none=[], # 1 -> non NONE values
            mask_none_onehot=[],
        )
        
        # query dialog memory snapshots per turn. NOTE: fd_slot_memory should exec first.
        memory = fields["slot_memory"]["value"]
        
        # init snapshot: service, slot -> memory loc, val
        memory_loc = []
        memory_val = []
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                loc = OrderedDict()
                val = OrderedDict()
                for serv in dial["services"]:
                    loc[serv] = OrderedDict()
                    val[serv] = OrderedDict()
                    for slot in self.schemas.get(serv)["slot_name"]:
                        loc[serv][slot] = None
                        val[serv][slot] = None
                memory_loc.append(loc)
                memory_val.append(val)
        
        # fill the memory locations 
        snapshot_id = 0
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                turn_memory = memory[snapshot_id]
                turn_memory_loc = memory_loc[snapshot_id]
                turn_memory_val = memory_val[snapshot_id]
                for frame in turn["frames"]:
                    service = frame["service"]
                    for slot, values in frame["state"]["slot_values"].items():
                        val = re.sub("\u2013", "-", values[0]) # dial 59_00125 turn 14
                        turn_memory_loc[service][slot] = turn_memory.index(val)
                        turn_memory_val[service][slot] = val
                    # add locations to UNKNOWN slots
                    for slot, val in turn_memory_loc[service].items():
                        if val is None:
                            turn_memory_loc[service][slot] = turn_memory.index("NONE")
                            turn_memory_val[service][slot] = "NONE"
                snapshot_id += 1
        
        # featurize
        snapshot_id = 0
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                turn_memory = memory[snapshot_id]
                turn_memory_loc = memory_loc[snapshot_id]
                turn_memory_val = memory_val[snapshot_id]
                none_loc = memory[snapshot_id].index("NONE")
                turn_fields = dict(
                    value=[], ids=[], ids_onehot=[], mask=[], mask_onehot=[],
                    mask_none=[], mask_none_onehot=[],
                )

                for serv in turn_memory_loc:
                    mem_size = len(turn_memory)
                    vals = list(turn_memory_val[serv].values())
                    ids = list(turn_memory_loc[serv].values())
                    ids_onehots = label_binarize(ids, list(range(mem_size)))
                    mask = [1] * len(ids)
                    mask_onehots = [[1] * mem_size for _ in ids]

                    mask_none = [int(v!="NONE") for v in vals]
                    mask_none_onehot = [[1] * mem_size for _ in ids]
                    for v, loc, onehot in zip(vals, ids, mask_none_onehot):
                        if v == "NONE":
                            onehot[loc] = 0
                                
                    turn_fields["value"].append(vals)
                    turn_fields["ids"].append(ids)
                    turn_fields["ids_onehot"].append(ids_onehots)
                    turn_fields["mask"].append(mask)
                    turn_fields["mask_onehot"].append(mask_onehots)
                    turn_fields["mask_none"].append(mask_none)
                    turn_fields["mask_none_onehot"].append(mask_none_onehot)

                # update turn
                for k, v in turn_fields.items():
                    resp[k].append(v)
                
                snapshot_id += 1
            
        return resp
    
    def fd_num_turns(self, dial, fields):
        return {"ids": len(dial["turns"])}
    
    def fd_num_frames(self, dial, fields):
        return {"ids": [len(t["frames"]) for t in dial["turns"]]}
    
    def fd_usr_utter(self, dial, fields):
        resp = dict(value=[], ids=[], ids_pos=[], mask=[], tokens=[])
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                utter = turn["utterance"]
                tokens = self.tokenizer(utter)
                ids = self.token_indexer(tokens)
                resp["value"].append(utter)
                resp["ids"].append(ids)
                resp["ids_pos"].append(list(range(1, len(ids) + 1)))
                resp["tokens"].append(tokens)
                resp["mask"].append([1] * len(tokens))
        return resp
    
    def fd_sys_utter(self, dial, fields):
        resp = dict(value=[], ids=[], ids_pos=[], mask=[], tokens=[])
        for turn in dial["turns"]:
            if turn["speaker"] == "SYSTEM":
                utter = turn["utterance"]
                tokens = self.tokenizer(utter)
                ids = self.token_indexer(tokens)
                resp["value"].append(utter)
                resp["ids"].append(ids)
                resp["ids_pos"].append(list(range(1, len(ids) + 1)))
                resp["tokens"].append(tokens)
                resp["mask"].append([1] * len(tokens))
        return resp
    
    def fd_dial_id(self, dial, fields):
        return {"value": dial["dialogue_id"]}
    
    def dial_to_fields(self, dial):
        fields = {}
        ordered_funcs = [
            "fd_dial_id",
            "fd_num_turns", "fd_num_frames",
            "fd_serv_name", "fd_serv_desc", 
            "fd_slot_name", "fd_slot_desc",
            "fd_slot_memory", "fd_slot_memory_loc",
            "fd_usr_utter", "fd_sys_utter"]
        for func in ordered_funcs:
            name = func.split("fd_", maxsplit=1)[-1]
            value = getattr(self, func)(dial, fields)
            if value is not None:
                fields[name] = value
        return fields

## Load Data

In [7]:
bert_ = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer = Tokenizer(bert_)
token_indexer = TokenIndexer(bert_)

train_schemas = Schemas("../data/train/schema.json")
test_schemas = Schemas("../data/dev/schema.json")

In [78]:
ds = DialogueDataset("../data/train/dialogues_001.json", train_schemas, tokenizer, token_indexer)

In [81]:
# load training dataset
train_dial_sets = []
train_dial_files = sorted(glob.glob("../data/train/dialogues*.json"))
num_workers = min(20, len(train_dial_files))

def worker(filename):
    return DialogueDataset(filename, train_schemas, tokenizer, token_indexer)

with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
    for ds in tqdm(executor.map(worker, train_dial_files), total=len(train_dial_files)):
        train_dial_sets.append(ds)

train_ds = D.ConcatDataset(train_dial_sets)

HBox(children=(IntProgress(value=0, max=127), HTML(value='')))




In [82]:
# load test dataset
test_dial_sets = []
test_dial_files = sorted(glob.glob("../data/dev/dialogues*.json"))
num_workers = min(20, len(train_dial_files))

def worker(filename):
    return DialogueDataset(filename, test_schemas, tokenizer, token_indexer)

with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
    for ds in tqdm(executor.map(worker, test_dial_files), total=len(test_dial_files)):
        test_dial_sets.append(ds)

test_ds = D.ConcatDataset(test_dial_sets)

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




# Training Utils

In [11]:
import torch.nn as nn

In [51]:
def move_to_device(obj, device):
    if type(obj) is list:
        return [move_to_device(o, device) for o in obj]
    elif type(obj) is dict:
        return {k: move_to_device(v, device) for k, v in obj.items()}
    elif type(obj) is torch.Tensor or isinstance(obj, nn.Module):
        return obj.to(device)
    return obj

In [83]:
def dialogue_mini_batcher(dialogues):
    default_padding = 0
    batch = {}
    for dial in dialogues:
        # populate the batch
        for field, data in dial.items():
            if field not in batch:
                batch[field] = {}
            for attr, val in data.items():
                if attr == "padding":
                    batch[field][attr] = val
                else:
                    batch[field][attr] = batch[field].get(attr, [])
                    batch[field][attr].append(val)

    # padding on field attributes
    for field_name, data in batch.items():
        for attr in data:
            if attr.startswith(("ids", "mask")):
                data[attr] = padded_array(data[attr], default_padding)
                data[attr] = torch.tensor(data[attr], device="cpu") # whatif its in device0 in epoch0, then at epoch1, sent to device1

    return batch


class DialogIterator(object):
    """A simple wrapper on DataLoader"""
    
    def __init__(self, dataset, batch_size, *args, **kw):
        self.length = None
        self.dataset = dataset
        self.batch_size = batch_size
        self.iterator = D.DataLoader(dataset, batch_size, *args, **kw)
        
    def __len__(self):
        if self.length is None:
            self.length = 0
            for dial in self.dataset:
                self.length += (dial["num_turns"]["ids"] * len(dial["serv_name"]["value"]))
            self.length = int(self.length / self.batch_size)
        return self.length
    
    def __iter__(self):
        for batch in self.iterator:
            num_turns = batch["usr_utter"]["ids"].shape[1]
            num_services = batch["serv_desc"]["ids"].shape[1]
            for turnid in range(num_turns):
                for sid in range(num_services):
                    inputs = dict(turnid=turnid, serviceid=sid)
                    inputs.update(batch)
                    yield inputs

In [84]:
def get_sample(ds, size=1):
    return next(iter(DialogIterator(ds, size, collate_fn=dialogue_mini_batcher)))

def print_shapes(ds):
    it = get_sample(ds, size=1)
    for field, val in it.items():
        if type(val) is dict:
            for attr in val:
                if attr.startswith("ids") or attr.startswith("mask"):
                    print(field, attr, "-->", val[attr].shape)
                    

# TODO: why mem ids, and mem loc size is diff by 1?
print_shapes(train_ds)

num_turns ids --> torch.Size([1])
num_frames ids --> torch.Size([1, 24])
serv_desc ids --> torch.Size([1, 1, 10])
serv_desc ids_pos --> torch.Size([1, 1, 10])
serv_desc mask --> torch.Size([1, 1, 10])
slot_desc ids --> torch.Size([1, 1, 11, 12])
slot_desc ids_pos --> torch.Size([1, 1, 11, 12])
slot_desc mask --> torch.Size([1, 1, 11, 12])
slot_memory ids --> torch.Size([1, 12, 29, 10])
slot_memory ids_desc --> torch.Size([1, 12, 29, 40])
slot_memory ids_pos --> torch.Size([1, 12, 29])
slot_memory ids_memsize --> torch.Size([1, 12])
slot_memory mask --> torch.Size([1, 12, 29, 10])
slot_memory mask_desc --> torch.Size([1, 12, 29, 40])
slot_memory_loc ids --> torch.Size([1, 12, 1, 11])
slot_memory_loc ids_onehot --> torch.Size([1, 12, 1, 11, 29])
slot_memory_loc mask --> torch.Size([1, 12, 1, 11])
slot_memory_loc mask_onehot --> torch.Size([1, 12, 1, 11, 29])
slot_memory_loc mask_none --> torch.Size([1, 12, 1, 11])
slot_memory_loc mask_none_onehot --> torch.Size([1, 12, 1, 11, 29])
usr_ut

# Model

In [55]:
from pytorch_transformers import BertModel, BertConfig
from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy
from collections import OrderedDict

import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch

In [86]:
class BertDownstream(nn.Module):
    # https://huggingface.co/pytorch-transformers/model_doc/bert.html
    # https://github.com/hanxiao/bert-as-service/blob/master/docs/section/faq.rst#id6
    
    def __init__(self, btype, requires_grad=False):
        super().__init__()
        self.emb = BertModel.from_pretrained(btype, output_hidden_states=True)
        for name, param in self.emb.named_parameters():
            param.requires_grad = requires_grad

    @property
    def output_dim(self):
        return 768
    
    def forward(self, input_ids, position_ids=None, attention_mask=None, flat=None):
        sh = list(input_ids.shape)
        
        if flat:
            st, en = flat
        elif len(sh) == 3:
            st, en = 0, 1
        elif len(sh) in (0,1,2):
            st, en = 0, 0
        else:
            st, en = 1, -2
        
        input_ids = torch.flatten(input_ids, st, en).long()
        if position_ids is not None:
            position_ids = torch.flatten(position_ids, st, en).long()
        if attention_mask is not None:
            attention_mask = torch.flatten(attention_mask, st, en).long()
        
        outputs = self.emb(input_ids, position_ids=position_ids, attention_mask=attention_mask)
        pooled = outputs[2][4]
        pooled = pooled.view(sh + [-1])
        
        return pooled #*->*e


class CandidateSelector(nn.Module):
    """
    memory -- represents temporal hash table of text values, that monotonically grows
    memory_desc -- represents sequential description to every value in hash table
    slot, utter -- represent query/key descriptions
    
    The idea is that we model a hash function 
        F: slot, utter, time -> memory_location
    
    This is a mix of ideas from
        - Memory Networks by Jason Weston, FAIR
        - E2E Memory Networks by Sukhbaatar, NYU/FAIR
        - Pointer Networks by Oriol Vinayls, Google Brain
        - Alignment and Translation by Bahdanau, MILA
    
        - Weston and Sukhbaatar demonstrate memory models on Facebook bAbI tasks. that capture some temporal challenges.
            while Vinayls tweaks the idea of Bahdanau's attention, to classification/regress targets represented in latent space.
            (NOTE: their motivation is subtly different, but this explanation still stands valid)
    
    Unique contribution:
        - Explore the problem where memory grows and has temporal information.
        - Model a hash function that understands temporal (just like Vinayls) but but 
            new challenges are:
            - applied to dialog state (no one else did before), where the hash function has additional challenges of dealing with 
                guessing user behavior dist. and language subtlities (eg: coreference and ellipsis resolutions.)
            - the hash function is learning from real world data, unlike previous works where toy data are used. 
                the hash key is can change just because user explicitly/implicitly says something..
            - the model is setup in a way it supports growth, and scale across domains. Sukhbaatar's approach is fixed (verify again!)
    """
        
    def __init__(self):
        super().__init__()
        self.emb = BertDownstream("bert-base-uncased")
        emb_dim = self.emb.output_dim
        
        self.l0 = nn.GRU(emb_dim, emb_dim, batch_first=True, num_layers=1) # encode utter
        self.l1 = nn.GRU(emb_dim, emb_dim, batch_first=True, num_layers=1) # encode memory
        self.l2 = nn.GRU(emb_dim, emb_dim, batch_first=True, num_layers=1) # encode query slot
       
        # attentions
        self.l3 = nn.Linear(2*emb_dim, emb_dim)
        self.l4 = nn.Bilinear(emb_dim, emb_dim, 1)
         
        # metrics
        self.acc = CategoricalAccuracy()
        self.goal_acc = CategoricalAccuracy()
        
        self.register_buffer("utter_state", None)

        # init weights: classifier's performance changes heavily on these
        for name, param in self.named_parameters():
            if name.startswith(("l0.", "11.", "l2.")):
                print("Initializing bias/weights of ", name)
                if "weight" in name:
                    nn.init.xavier_uniform_(param)
                else:
                    param.data.fill_(0.)

    def get_metrics(self, reset=False, goal_reset=False):
        return dict(
            acc=self.acc.get_metric(reset), # avg per service
            goal_acc=self.goal_acc.get_metric(goal_reset), # avg per turn
        )
    
    def encode_utter(self, batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        
        usr_utter, usr_mask, usr_pos = batch["usr_utter"]["ids"][:,turnid,:], batch["usr_utter"]["mask"][:,turnid,:], batch["usr_utter"]["ids_pos"][:,turnid,:]
        sys_utter, sys_mask, sys_pos = batch["sys_utter"]["ids"][:,turnid,:], batch["sys_utter"]["mask"][:,turnid,:], batch["sys_utter"]["ids_pos"][:,turnid,:]
        
        usr_utter = self.emb(usr_utter, position_ids=usr_pos, attention_mask=usr_mask)
        sys_utter = self.emb(sys_utter, position_ids=sys_pos, attention_mask=sys_mask)
        
        utter = torch.cat([sys_utter, usr_utter], dim=1) # b2se

        utter_h = self.utter_state.detach() if turnid != 0 else None
        utter, utter_h = self.l0(utter, utter_h)
        self.utter_state = utter_h
        
        return utter_h[-1] # be

    def encode_memory(self, batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        
        # Encode memory
        memory = batch["slot_memory"]["ids"][:,turnid,:]
        memory_mask = batch["slot_memory"]["mask"][:,turnid,:]
        memory_pos = batch["slot_memory"]["ids_pos"][:,turnid,:]
        sh = memory.shape
        
        memory_pos = memory_pos[:,:,None].repeat(1,1,sh[-1])# bm->bms
        memory = self.emb(memory, position_ids=memory_pos, attention_mask=memory_mask) # bmse
        
        # Encode memory desc
        memory_desc = batch["slot_memory"]["ids_desc"][:,turnid,:]
        memory_desc_mask = batch["slot_memory"]["mask_desc"][:,turnid,:]
        sh2 = memory_desc.shape
        
        memory_pos = batch["slot_memory"]["ids_pos"][:,turnid,:]
        memory_desc_pos = memory_pos[:,:,None].repeat(1,1,sh2[-1]) # bm->bms
        memory_desc = self.emb(memory_desc, position_ids=memory_desc_pos, attention_mask=memory_desc_mask) # bmse
        
        # across text
        memory = torch.cat([memory, memory_desc], dim=2) # bm2se
        memory = torch.flatten(memory, 0, 1) # bm,2s,e
        memory, memory_h = self.l1(memory)
        
        memory_h = memory_h[-1].view(sh[0], sh[1], -1) # pool two layers in bidirectional.
        return memory_h # bme
    
    def encode_slot_desc(self, batch):
        serviceid = batch["serviceid"]
        
        desc = batch["slot_desc"]["ids"][:,serviceid,:] # [batch, slots, tokens]
        desc_pos = batch["slot_desc"]["ids_pos"][:,serviceid,:]
        desc_mask = batch["slot_desc"]["mask"][:,serviceid,:]
        sh = desc.shape
        
        desc = self.emb(desc, position_ids=desc_pos, attention_mask=desc_mask) # bste
        desc = torch.flatten(desc, 0, 1) # bs,t,e
        desc, desc_h = self.l2(desc)
    
        desc_h = desc_h[-1].view(sh[0], sh[1], -1)
    
        return desc_h # bse
        
    def compute_score1(self, memory, slot, utter):
        # memory: bme
        # slot: bse
        # utter: be
        s = slot.shape[1]
        m = memory.shape[1]
        
        utter = utter[:,None,:].repeat(1,s,1) #be->bse

        key = self.l3(torch.cat([utter, slot], dim=-1)) # bs2e->bse
        key = torch.tanh(key)
        key = key[:,:,None,:].repeat(1,1,m,1) # bsme
        
        memory = memory[:,None,:,:].repeat(1,s,1,1) # bsme
        
        energy = self.l4(key, memory) #bsme->bsm1
        energy = F.softmax(energy.squeeze(-1), dim=-1)
        return energy
        
    def forward(self, **batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        
        # doc: GRU outputs [batch, seq, emb * dir] [layers * dir, batch, emb]

        # get fixed size encodings
        utter = self.encode_utter(batch) # be
        memory = self.encode_memory(batch) # bme
        slot_desc = self.encode_slot_desc(batch) # bse
        
        # map: slot desc -> memory
        score = self.compute_score1(memory, slot_desc, utter) # bsm
        output = {"score": score}

        if "slot_memory_loc" in batch:
            target = batch["slot_memory_loc"]
            memsize = target["ids_onehot"].shape[-1]
            pred_score = score.view(-1, memsize) # bsm->bs,m 
            
            # onehot targets
            target_score_oh = target["ids_onehot"][:,turnid,serviceid,:].contiguous().view(-1, memsize).float() # bsm->bs,m
            target_mask_oh = (target["mask_onehot"][:,turnid,serviceid,:] * target["mask_none_onehot"][:,turnid,serviceid,:]).contiguous().view(-1, memsize).float()
            
            # target
            target_score = target["ids"][:,turnid,serviceid,:].float() # [batch, slots]
            target_mask = (target["mask"][:,turnid,serviceid,:] * target["mask_none"][:,turnid,serviceid,:]).float()
            
            # loss
            output["loss"] = F.binary_cross_entropy(pred_score, target_score_oh, target_mask_oh).unsqueeze(0) # don't return scalar
            
            # log metrics
            self.acc(score, target_score, target_mask)
            self.goal_acc(score, target_score, target_mask)
            
            output["target_ids"] = target_score
            output["pred_ids"] = torch.argmax(score, dim=-1)
            output["mask"] = target["mask"][:,turnid,serviceid,:].float()
            output["mask_none"] = target["mask_none"][:,turnid,serviceid,:].float()

        return output

# Trainer

In [17]:
from allennlp.training.tensorboard_writer import TensorboardWriter
import time

In [57]:
def module(m):
    if type(m) is nn.DataParallel:
        return m.module
    return m

def train(model, optimizer, batch_size, num_epochs, train_ds, test_ds, device):
    model = model.train()
    current_batch = 1 # tensorboard doesn't like 0
    tensorboard = TensorboardWriter(
        get_batch_num_total=lambda: current_batch,
        summary_interval=10,
        serialization_dir="../data/tensorboard/"
    )
    
    histogram_weights = [name for name, p in model.named_parameters() if not ".emb." in name]
    
    for epoch in range(num_epochs):
        train_iter = DialogIterator(train_ds, batch_size, collate_fn=dialogue_mini_batcher)
        num_batches = len(train_iter) 
        if test_ds:
            test_iter = DialogIterator(test_ds, batch_size, collate_fn=dialogue_mini_batcher)
            num_batches += len(test_iter)
        
        with tqdm(total=num_batches) as pbar:
            # train
            pbar.set_description("Train {}".format(epoch))
            model = model.train()
            train_iter = DialogIterator(train_ds, batch_size, collate_fn=dialogue_mini_batcher)
            metrics = OrderedDict()
            
            # to know when the dialog and service changes
            turnid = -1
            
            for i, batch in enumerate(train_iter):
                current_batch += 1
                batch = move_to_device(batch, device)
                optimizer.zero_grad()
                output = model(**batch)
                output["loss"] = output["loss"].mean()
                output["loss"].backward()
                optimizer.step()
                
                # at each new turn
                if turnid != batch["turnid"]:
                    metrics.update(module(model).get_metrics(reset=True, goal_reset=True))
                else:
                    curr_met = module(model).get_metrics(reset=True)
                    metrics.update(acc=curr_met["acc"])
                    
                metrics["loss"] = output["loss"].item()
                pbar.update(1)
                pbar.set_postfix(metrics)
                
                metrics["turnid"] = batch["turnid"]
                metrics["servid"] = batch["serviceid"]

                # update tensorboard logs
                tensorboard.add_train_scalar("loss", metrics["loss"], timestep=current_batch)
                tensorboard.log_metrics(train_metrics=metrics, epoch=current_batch)
                tensorboard.log_parameter_and_gradient_statistics(model, None)
                
                # histograms, I guess take longer..
                if turnid == 0:
                    st_time = time.time()
                    tensorboard.log_histograms(model, histogram_weights)
                    tensorboard.add_train_histogram("target_ids", output["target_ids"])
                    tensorboard.add_train_histogram("pred_ids", output["pred_ids"])
                    metrics["time"] = time.time() - st_time
                
                # at the end of the loop, update prev turnid
                turnid = batch["turnid"]

            # test
            if test_ds:
                pbar.set_description("Test {}".format(epoch))
                metrics = OrderedDict(epoch=epoch)
                turnid = -1
                with torch.no_grad():
                    model = model.eval()
                    for i, batch in enumerate(test_iter):
                        current_batch += 1
                        batch = move_to_device(batch, device)
                        output = model(**batch)
                        output["loss"] = output["loss"].mean()

                        # at each new turn
                        if turnid != batch["turnid"]:
                            metrics.update(module(model).get_metrics(reset=True, goal_reset=True))
                        else:
                            curr_met = module(model).get_metrics(reset=True)
                            metrics.update(acc=curr_met["acc"])

                        metrics["loss"] = output["loss"].item()
                        pbar.update(1)
                        pbar.set_postfix(metrics)
                        turnid = batch["turnid"]

In [19]:
#torch.save(model, "../data/model0.pkl")
try:
    del model
except NameError:
    torch.cuda.empty_cache()

In [None]:
print("Remove tensorboard logs")
!rm -rf ../data/tensorboard/*

print("set number of devices") # not sure we can in jupyter once program already kicked in the first time
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = "cuda"

print("loading model")
model = CandidateSelector()
#model = nn.DataParallel(model)
model = move_to_device(model, device)

optim = torch.optim.Adam(model.parameters(), lr=3e-4)
train_samples = [train_ds[i] for i in range(1000)]
test_samples = [test_ds[i] for i in range(10)]

print("started training")
train(
    model=model,
    optimizer=optim,
    train_ds=train_samples,
    test_ds=test_samples,
    device=device,
    num_epochs=20,
    batch_size=16,
)

Remove tensorboard logs
set number of devices
loading model
Initializing bias/weights of  l0.weight_ih_l0
Initializing bias/weights of  l0.weight_hh_l0
Initializing bias/weights of  l0.bias_ih_l0
Initializing bias/weights of  l0.bias_hh_l0
Initializing bias/weights of  l2.weight_ih_l0
Initializing bias/weights of  l2.weight_hh_l0
Initializing bias/weights of  l2.bias_ih_l0
Initializing bias/weights of  l2.bias_hh_l0
started training


HBox(children=(IntProgress(value=0, max=953), HTML(value='')))

# Inference

In [24]:
from tabulate import tabulate

In [35]:
def infer_batch(model, test_ds, batch_size, device):
    result_table = []

    goal_acc = []
    joint_acc = []

    total_turn_acc = None
    total_joint_acc = None
    num_turns = 0
    num_slots = 0

    model = model.eval()
    
    for batch in DialogIterator(test_ds, batch_size, collate_fn=dialogue_mini_batcher):
        batch = move_to_device(batch, device)
        with torch.no_grad():
            result = model(**batch)
        
        turnid = batch["turnid"]
        servid = batch["serviceid"]
        num_batches = batch["usr_utter"]["ids"].shape[0]
        
        if turnid == 0:
            # update previous turn's.
            if total_turn_acc is not None:
                goal_acc.append(total_turn_acc / num_turns)
                joint_acc.append(total_joint_acc / num_turns)
            
            # reset counter
            total_turn_acc = 0
            total_joint_acc = 1 # its binary 1 or 0
            num_turns += 1

        batch_fields = dict(
            prd_mem_loc = result["pred_ids"], # B,slot
            tgt_mem_loc = result["target_ids"],
            mask = result["mask"],
            mask_none = result["mask_none"], # 1 -> target is not None
            dial_id = batch["dial_id"]["value"],
            serv_name = batch["serv_name"]["value"],
            slot_name = batch["slot_name"]["value"],
            num_turns = batch["num_turns"]["ids"],
            memory = batch["slot_memory"]["value"], # B,turn,memory
        )

        for b in range(num_batches):
            num_slots = 0

            fields = {k: v[b] for k, v in batch_fields.items()}
            memory = fields["memory"][turnid]
            
            for slotid, (prd_loc, tgt_loc) in enumerate(zip(fields["prd_mem_loc"], fields["tgt_mem_loc"])):
                # clip the locations.. they somehow predict in padding regions!! Make it NONE
                prd_loc = int(prd_loc.item()) if prd_loc < len(memory) else 0
                tgt_loc = int(tgt_loc.item()) if tgt_loc < len(memory) else 0
                
                prd_val = memory[prd_loc]
                tgt_val = memory[tgt_loc]
                mask = fields["mask"][slotid].item()
                mask_none = fields["mask_none"][slotid].item()
                dialid = fields["dial_id"]
                servname = fields["serv_name"][servid]
                slotname = fields["slot_name"][servid][slotid]
                num_turns = fields["num_turns"].item() / 2 # usr-sys turn pairs
                
                # update acc counter
                if mask * mask_none == 1:
                    total_turn_acc += int(prd_loc == tgt_loc)
                    total_joint_acc *= int(prd_loc == tgt_loc)
                    num_slots += 1

                item = OrderedDict(
                    dialid = fields["dial_id"],
                    num_turns = num_turns,
                    turnid = turnid,
                    servid = servid,
                    servname = fields["serv_name"][servid],
                    memsize = len(memory),
                    slotname = slotname,
                    tgt_val = tgt_val,
                    tgt_loc = tgt_loc,
                    prd_val = prd_val,
                    prd_loc = prd_loc,
                    correct = 100*int(tgt_loc == prd_loc),
                    mask = (mask * mask_none),
                )

                result_table.append(item)
            
            # avg acc across slots per turn
            total_turn_acc /= num_slots
    
    # the last batch
    goal_acc.append(total_turn_acc / num_turns)
    joint_acc.append(total_joint_acc / num_turns)
        
    print("Avg Goal Acc: ", sum(goal_acc) / len(goal_acc))
    print("Avg Joint Goal Acc: ", sum(joint_acc) / len(joint_acc))
    print()
    
    print(tabulate(
        [list(item.values()) for item in result_table],
        list(result_table[0].keys()),
        "fancy_grid",
    ))
    
    
        
infer_batch(model, [test_ds[0]], batch_size=2, device="cuda")

Avg Goal Acc:  0.12333333333333334
Avg Joint Goal Acc:  0.0

╒══════════╤═════════════╤══════════╤══════════╤═══════════════╤═══════════╤════════════════════════╤═════════════════════════════╤═══════════╤═══════════╤═══════════╤═══════════╤════════╕
│   dialid │   num_turns │   turnid │   servid │ servname      │   memsize │ slotname               │ tgt_val                     │   tgt_loc │ prd_val   │   prd_loc │   correct │   mask │
╞══════════╪═════════════╪══════════╪══════════╪═══════════════╪═══════════╪════════════════════════╪═════════════════════════════╪═══════════╪═══════════╪═══════════╪═══════════╪════════╡
│  1_00000 │           6 │        0 │        0 │ Restaurants_2 │        15 │ restaurant_name        │ NONE                        │         0 │ NONE      │         0 │       100 │      0 │
├──────────┼─────────────┼──────────┼──────────┼───────────────┼───────────┼────────────────────────┼─────────────────────────────┼───────────┼───────────┼───────────┼───────────┼────