In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [8]:
import torch.utils.data as D
import torch
import numpy as np
import json
import csv
import glob
import re
import concurrent.futures
import os

from copy import deepcopy
from functools import lru_cache
from collections import OrderedDict
from types import SimpleNamespace
from collections.abc import Iterable
from tqdm import tqdm_notebook as tqdm
from pytorch_transformers import BertTokenizer
from ipdb import set_trace

# Data Reader

## Utils

In [9]:
class Tokenizer:
    
    def __init__(self, bert):
        self.bert = bert
    
    def __call__(self, text, include_sep=True):
        tokens = self.bert.tokenize(text)
        if include_sep:
            tokens.insert(0, "[CLS]")
            tokens.append("[SEP]")
        return tokens
    
    
class TokenIndexer:
    
    def __init__(self, bert):
        self.bert = bert
        
    def inv(self, *args, **kw):
        # tokens MUST be list, tensor doesn't work.
        return self.bert.convert_ids_to_tokens(*args, **kw)
        
    def __call__(self, *args, **kw):
        return self.bert.convert_tokens_to_ids(*args, **kw)

In [10]:
def label_binarize(labels, classes):
    # labels: np.array or tensor [batch, 1]
    # classes: [..] list of classes
    # weirdly,`sklearn.preprocessing.label_binarize` returns [1] or [0]
    # instead of onehot ONLY when executing in this script!
    vectors = [np.zeros(len(classes)) for _ in labels]
    for i, label in enumerate(labels):
        for j, c in enumerate(classes):
            if c == label:
                vectors[i][j] = 1
    return np.array(vectors)
    

def label_inv_binarize(vectors, classes):
    # labels: np.array or tensor [batch, classes]
    # classes: [..] list of classes
    # follows sklearn LabelBinarizer.inverse_transform()
    # given all zeros, predicts label at index 0, instead of returning none!
    # sklearn doesn't have functional API of inverse transform
    labels = []
    for each in vectors:
        index = np.argmax(each)
        labels.append(classes[index])
    return labels

In [11]:
def padded_array(array, value=0):
    # TODO: this does not do type checking; and wow it can be slow on strings.
    # expects array to have fixed _number_ of dimensions
    
    # resolve the shape of padded array
    shape_index = {}
    queue = [(array, 0)]
    while queue:
        subarr, dim = queue.pop(0)
        shape_index[dim] = max(shape_index.get(dim, -1), len(subarr))
        for x in subarr:
            if isinstance(x, Iterable) and not isinstance(x, str):
                queue.append((x, dim+1))
    shape = [shape_index[k] for k in range(max(shape_index) + 1)]
    
    # fill the values 
    padded = np.ones(shape) * value
    queue = [(array, [])]
    while queue:
        subarr, index = queue.pop(0)
        for j, x in enumerate(subarr):
            if isinstance(x, Iterable):
                queue.append((x, index + [j]))
            else:
                padded[tuple(index + [j])] = x
    return padded

## Schema Reader

In [12]:
class Schemas(object):

    def __init__(self, filepath):
        with open(filepath) as f:
            self.index = {}
            for schema in json.load(f):
                service_name = schema["service_name"]
                self.index[service_name] = schema

    @lru_cache(maxsize=None)
    def get(self, service):
        result = dict(
            # service
            name=service,
            desc=self.index[service]["description"],
            
            # slots
            slot_name=[],
            slot_desc=[],
            slot_iscat=[], 
            slot_vals=[], # collected only for cat slots.. not sure if that makes sense

            # intents
            intent_name=[],
            intent_desc=[],
            intent_istrans=[],
            intent_reqslots=[],
            intent_optslots=[],
            intent_optvals=[],
        )

        for slot in self.index[service]["slots"]:
            result["slot_name"].append(slot["name"])
            result["slot_desc"].append(slot["description"])
            result["slot_iscat"].append(slot["is_categorical"])
            result["slot_vals"].append(slot["possible_values"])
        
        for intent in self.index[service]["intents"]:
            result["intent_name"].append(intent["name"])
            result["intent_desc"].append(intent["description"])
            result["intent_istrans"].append(intent["is_transactional"])
            result["intent_reqslots"].append(intent["required_slots"])
            result["intent_optslots"].append(list(intent["optional_slots"].keys()))
            result["intent_optvals"].append(list(intent["optional_slots"].values()))

        return result    

## Dialogue Reader

In [14]:
class DialogueDataset(D.Dataset):
    def __init__(self, filename, schemas, tokenizer, token_indexer):
        with open(filename) as f:
            self.ds = json.load(f)
        self.schemas = schemas
        self.tokenizer = tokenizer
        self.token_indexer = token_indexer
        self.dialogues = []
        for dial in self.ds:
            fields = self.dial_to_fields(dial)
            self.dialogues.append(fields)
        self.schemas = None
        self.tokenizer = None
        self.token_indexer = None
            
    def __getitem__(self, idx):
        return self.dialogues[idx]
    
    def __len__(self):
        return len(self.dialogues)
    
    # Turn,Service,Value
    def fd_serv_name(self, dial, fields):
        resp = dict(value=[])
        for service in dial["services"]:
            resp["value"].append(service)
        return resp
    
    # Turn,Service,Tokens
    def fd_serv_desc(self, dial, fields):
        resp = dict(value=[], tokens=[], ids=[], ids_pos=[], mask=[])
        for service in dial["services"]:
            desc = self.schemas.get(service)["desc"]
            tokens = self.tokenizer(desc)
            ids = self.token_indexer(tokens)
            resp["value"].append(desc)
            resp["tokens"].append(tokens)
            resp["ids"].append(ids)
            resp["ids_pos"].append(list(range(1, len(ids) + 1)))
            resp["mask"].append([1] * len(ids))
        return resp
    
    # Turn,Service,Slots,Value
    def fd_slot_name(self, dial, fields):
        resp = {"value": []}
        for serv in dial["services"]:
            schema = self.schemas.get(serv)
            resp["value"].append(schema["slot_name"])
        return resp
    
    # Turn,Service,Slots,Tokens
    def fd_slot_desc(self, dial, fields):
        resp = dict(value=[], tokens=[], ids=[], ids_pos=[], mask=[])
        for serv in dial["services"]:
            s_desc = self.schemas.get(serv)["slot_desc"]
            s_tokens = [self.tokenizer(d) for d in s_desc]
            s_ids = [self.token_indexer(d) for d in s_tokens]
            s_mask = [[1] * len(d) for d in s_tokens]
            resp["value"].append(s_desc)
            resp["tokens"].append(s_tokens)
            resp["ids"].append(s_ids)
            resp["ids_pos"].append([list(range(1, len(i) + 1)) for i in s_ids])
            resp["mask"].append(s_mask)
        return resp
    
    # Turn,Service,Slot,1
    def fd_slot_iscat(self, dial, fields):
        resp = dict(value=[], ids=[], mask=[])
        for serv in dial["services"]:
            schema = self.schemas.get(serv)
            resp["value"].append(schema["slot_iscat"])
            resp["ids"].append([int(i) for i in schema["slot_iscat"]])
            resp["mask"].append([1] * len(schema["slot_iscat"]))
        return resp

    # Turn,Memory,Tokens
    def fd_slot_memory(self, dial, fields):
        # TODO: iscat feature..
        # memory is a sequence of slot tagger values across all frames in a dial.. that grows per turn
        # maintain snapshot at each turn
        resp = dict(
            value=[], # unfeaturized memory
            value_desc = [],
            value_slot = [], # slots that use this value
            tokens=[], # [turn, memory, tokens]
            tokens_desc = [],
            ids=[], # [turn, memory, tokens]
            ids_desc = [], # [turn, memory, tokens]
            ids_iscat = [],  # [turn, memory, 2] 0/1 onehot
            ids_pos=[], # [turn, memory, 1] positional info for memory cells
            #ids_tokens_pos = [], # [turn, memory, tokens] position info for memory values
            #ids_tokens_desc_pos = [], # [turn, memory, tokens] positional info for memory desc values
            ids_memsize=[], # mem sizes [turn, 1]
            mask=[], # mask on memory [turn, mem-size, tokens]
            mask_desc = [], # mask on mem desc [turn, mem, tokens]
        )
        
        # memory is sequential, initialized by vals from schema, only has values
        memory = ["NONE", "dontcare"] # keep these at index 0, 1
        memory_slot = [["*"], ["*"]]
        memory_desc = ["NONE", "DONTCARE"]
        memory_iscat = [[1,1], [1,1]] # [not used by cat slot, used by cat slot]
        memory_index = {}
        for serv in dial["services"]:
            schema = self.schemas.get(serv)
            servdesc = schema["desc"]
            for slotname, slotdesc, iscat, values in zip(schema["slot_name"], schema["slot_desc"], schema["slot_iscat"], schema["slot_vals"]):
                for val in values:
                    if val not in memory_index:
                        memory.append(val)
                        memory_slot.append([slotname])
                        memory_desc.append(slotdesc + "[SEP]" + servdesc)
                        memory_iscat.append([int(iscat==False), int(iscat==True)]) # LOL; I just can't believe I am not using label_binarize!!
                        memory_index[val] = len(memory)
                    else:
                        idx = memory_index[val] - 1
                        if slotname not in memory_slot[idx]:
                            memory_desc[idx] = memory_desc[idx] + "[SEP]" + slotdesc + "[SEP]" + servdesc
                            memory_slot[idx].append(slotname)
                            memory_iscat[idx][int(iscat)] = 1
        
        # at each user turn create a memory snapshot..
        for turn in dial["turns"]:
            memory = deepcopy(memory)
            memory_desc = deepcopy(memory_desc)
            memory_slot = deepcopy(memory_slot)
            memory_iscat = deepcopy(memory_iscat)
            
            # pick slot tagger's ground truth ; categorical slot vals are already initialized!
            utter = turn["utterance"]
            for frame in turn["frames"]:
                schema = self.schemas.get(frame["service"])
                servdesc = schema["desc"]
                for tag in frame["slots"]:
                    slotname = tag["slot"]
                    iscat = schema["slot_iscat"][schema["slot_name"].index(slotname)]
                    slotdesc = schema["slot_desc"][schema["slot_name"].index(slotname)] # shit! O(N*2)
                    st, en = tag["start"], tag["exclusive_end"]
                    value = utter[st:en]
                    if value not in memory_index:
                        memory.append(value)
                        memory_slot.append([slotname])
                        memory_desc.append(slotdesc)
                        memory_iscat.append([int(iscat==False), int(iscat==True)])
                        memory_index[value] = len(memory)
                    else:
                        idx = memory_index[value] - 1
                        if slotname not in memory_slot[idx]:
                            memory_slot[idx].append(slotname)
                            memory_desc[idx] = memory_desc[idx] + "[SEP]" + slotdesc + "[SEP]" + servdesc
                            memory_iscat[idx][int(iscat)] = 1
                        
            if turn["speaker"] == "USER":
                resp["value"].append(memory)
                resp["value_desc"].append(memory_desc)
                resp["value_slot"].append(memory_slot)
                resp["ids_iscat"].append(memory_iscat)
                resp["ids_memsize"].append(len(memory))

        # tokenize and index the memory values
        # value: [turn, values], ids/tokens: [turn, values, tokens]
        for mem_desc_snapshot, mem_snapshot in zip(resp["value_desc"], resp["value"]):
            mem_tokens = []
            mem_tokens_desc = []
            mem_ids = []
            mem_ids_desc = []
            mem_pos = list(range(1, len(mem_snapshot) + 1))
            mem_mask = []
            mem_mask_desc = []
            for desc, val in zip(mem_desc_snapshot, mem_snapshot):
                # featurize memory
                tokens = self.tokenizer(val)
                mem_tokens.append(tokens)
                mem_ids.append(self.token_indexer(tokens))
                mem_mask.append([1] * len(tokens))
                
                # featurize memory desc
                tokens_desc = self.tokenizer(desc)
                mem_tokens_desc.append(tokens_desc)
                mem_ids_desc.append(self.token_indexer(tokens_desc))
                mem_mask_desc.append([1] * len(tokens_desc))
                
            resp["tokens"].append(mem_tokens)
            resp["tokens_desc"].append(mem_tokens_desc)
            resp["ids"].append(mem_ids)
            resp["ids_desc"].append(mem_ids_desc)
            resp["ids_pos"].append(mem_pos)
            resp["mask"].append(mem_mask)
            resp["mask_desc"].append(mem_mask_desc)
        
        return resp
    
    def fd_slot_memory_loc(self, dial, fields):
        resp = dict(
            value=[],  # string value
            ids=[], # memory loc
            ids_onehot=[], # onehot of memory loc
            mask=[],
            mask_onehot=[],
            mask_none=[], # 1 -> non NONE values
            mask_none_onehot=[],
        )
        
        # query dialog memory snapshots per turn. NOTE: fd_slot_memory should exec first.
        memory = fields["slot_memory"]["value"]
        
        # init snapshot: service, slot -> memory loc, val
        memory_loc = []
        memory_val = []
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                loc = OrderedDict()
                val = OrderedDict()
                for serv in dial["services"]:
                    loc[serv] = OrderedDict()
                    val[serv] = OrderedDict()
                    for slot in self.schemas.get(serv)["slot_name"]:
                        loc[serv][slot] = None
                        val[serv][slot] = None
                memory_loc.append(loc)
                memory_val.append(val)
        
        # fill the memory locations 
        snapshot_id = 0
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                turn_memory = memory[snapshot_id]
                turn_memory_loc = memory_loc[snapshot_id]
                turn_memory_val = memory_val[snapshot_id]
                for frame in turn["frames"]:
                    service = frame["service"]
                    for slot, values in frame["state"]["slot_values"].items():
                        val = re.sub("\u2013", "-", values[0]) # dial 59_00125 turn 14
                        turn_memory_loc[service][slot] = turn_memory.index(val)
                        turn_memory_val[service][slot] = val
                    # add locations to UNKNOWN slots
                    for slot, val in turn_memory_loc[service].items():
                        if val is None:
                            turn_memory_loc[service][slot] = turn_memory.index("NONE")
                            turn_memory_val[service][slot] = "NONE"
                snapshot_id += 1
        
        # featurize
        snapshot_id = 0
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                turn_memory = memory[snapshot_id]
                turn_memory_loc = memory_loc[snapshot_id]
                turn_memory_val = memory_val[snapshot_id]
                none_loc = memory[snapshot_id].index("NONE")
                turn_fields = dict(
                    value=[], ids=[], ids_onehot=[], mask=[], mask_onehot=[],
                    mask_none=[], mask_none_onehot=[],
                )

                for serv in turn_memory_loc:
                    mem_size = len(turn_memory)
                    vals = list(turn_memory_val[serv].values())
                    ids = list(turn_memory_loc[serv].values())
                    ids_onehots = label_binarize(ids, list(range(mem_size)))
                    mask = [1] * len(ids)
                    mask_onehots = [[1] * mem_size for _ in ids]

                    mask_none = [int(v!="NONE") for v in vals]
                    mask_none_onehot = [[1] * mem_size for _ in ids]
                    for v, loc, onehot in zip(vals, ids, mask_none_onehot):
                        if v == "NONE":
                            onehot[loc] = 0
                                
                    turn_fields["value"].append(vals)
                    turn_fields["ids"].append(ids)
                    turn_fields["ids_onehot"].append(ids_onehots)
                    turn_fields["mask"].append(mask)
                    turn_fields["mask_onehot"].append(mask_onehots)
                    turn_fields["mask_none"].append(mask_none)
                    turn_fields["mask_none_onehot"].append(mask_none_onehot)

                # update turn
                for k, v in turn_fields.items():
                    resp[k].append(v)
                
                snapshot_id += 1
            
        return resp
    
    def fd_num_turns(self, dial, fields):
        return {"ids": len(dial["turns"])}
    
    def fd_num_frames(self, dial, fields):
        return {"ids": [len(t["frames"]) for t in dial["turns"]]}
    
    def fd_usr_utter(self, dial, fields):
        resp = dict(value=[], ids=[], ids_pos=[], mask=[], tokens=[])
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                utter = turn["utterance"]
                tokens = self.tokenizer(utter)
                ids = self.token_indexer(tokens)
                resp["value"].append(utter)
                resp["ids"].append(ids)
                resp["ids_pos"].append(list(range(1, len(ids) + 1)))
                resp["tokens"].append(tokens)
                resp["mask"].append([1] * len(tokens))
        return resp
    
    def fd_sys_utter(self, dial, fields):
        resp = dict(value=[], ids=[], ids_pos=[], mask=[], tokens=[])
        for turn in dial["turns"]:
            if turn["speaker"] == "SYSTEM":
                utter = turn["utterance"]
                tokens = self.tokenizer(utter)
                ids = self.token_indexer(tokens)
                resp["value"].append(utter)
                resp["ids"].append(ids)
                resp["ids_pos"].append(list(range(1, len(ids) + 1)))
                resp["tokens"].append(tokens)
                resp["mask"].append([1] * len(tokens))
        return resp
    
    def fd_dial_id(self, dial, fields):
        return {"value": dial["dialogue_id"]}
    
    def dial_to_fields(self, dial):
        fields = {}
        ordered_funcs = [
            "fd_dial_id",
            "fd_num_turns", "fd_num_frames",
            "fd_serv_name", "fd_serv_desc", 
            "fd_slot_name", "fd_slot_desc", "fd_slot_iscat",
            "fd_slot_memory", "fd_slot_memory_loc",
            "fd_usr_utter", "fd_sys_utter"]
        for func in ordered_funcs:
            name = func.split("fd_", maxsplit=1)[-1]
            value = getattr(self, func)(dial, fields)
            if value is not None:
                fields[name] = value
        return fields

## Load Data

In [15]:
bert_ = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer = Tokenizer(bert_)
token_indexer = TokenIndexer(bert_)

train_schemas = Schemas("../data/train/schema.json")
test_schemas = Schemas("../data/dev/schema.json")

In [16]:
ds = DialogueDataset("../data/train/dialogues_001.json", train_schemas, tokenizer, token_indexer)

In [17]:
# load training dataset
train_dial_sets = []
train_dial_files = sorted(glob.glob("../data/train/dialogues*.json"))
num_workers = min(20, len(train_dial_files))

def worker(filename):
    return DialogueDataset(filename, train_schemas, tokenizer, token_indexer)

with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
    for ds in tqdm(executor.map(worker, train_dial_files), total=len(train_dial_files)):
        train_dial_sets.append(ds)

train_ds = D.ConcatDataset(train_dial_sets)

HBox(children=(IntProgress(value=0, max=127), HTML(value='')))




In [18]:
# load test dataset
test_dial_sets = []
test_dial_files = sorted(glob.glob("../data/dev/dialogues*.json"))
num_workers = min(20, len(train_dial_files))

def worker(filename):
    return DialogueDataset(filename, test_schemas, tokenizer, token_indexer)

with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
    for ds in tqdm(executor.map(worker, test_dial_files), total=len(test_dial_files)):
        test_dial_sets.append(ds)

test_ds = D.ConcatDataset(test_dial_sets)

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




In [142]:
train_ds[0]["slot_name"]["value"]

[['restaurant_name',
  'date',
  'time',
  'serves_alcohol',
  'has_live_music',
  'phone_number',
  'street_address',
  'party_size',
  'price_range',
  'city',
  'cuisine']]

# Training Utils

In [19]:
import torch.nn as nn

In [21]:
def move_to_device(obj, device):
    if type(obj) is list:
        return [move_to_device(o, device) for o in obj]
    elif type(obj) is dict:
        return {k: move_to_device(v, device) for k, v in obj.items()}
    elif type(obj) is torch.Tensor or isinstance(obj, nn.Module):
        return obj.to(device)
    return obj

In [145]:
def dialogue_mini_batcher(dialogues):
    default_padding = 0
    batch = {}
    for dial in dialogues:
        # populate the batch
        for field, data in dial.items():
            if field not in batch:
                batch[field] = {}
            for attr, val in data.items():
                if attr == "padding":
                    batch[field][attr] = val
                else:
                    batch[field][attr] = batch[field].get(attr, [])
                    batch[field][attr].append(val)

    # padding on field attributes
    for field_name, data in batch.items():
        for attr in data:
            if attr.startswith(("ids", "mask")):
                data[attr] = padded_array(data[attr], default_padding)
                data[attr] = torch.tensor(data[attr], device="cpu") # whatif its in device0 in epoch0, then at epoch1, sent to device1

    return batch


class DialogIterator(object):
    """A simple wrapper on DataLoader"""
    
    def __init__(self, dataset, batch_size, *args, **kw):
        self.length = None
        self.dataset = dataset
        self.batch_size = batch_size
        self.iterator = D.DataLoader(dataset, batch_size, *args, **kw)
        
    def __len__(self):
        if self.length is None:
            self.length = 0
            for dial in self.dataset:
                num_turns = dial["num_turns"]["ids"] # num slo
                num_services = len(dial["serv_name"]["value"])
                num_slots = sum([len(x) for x in dial["slot_name"]["value"]])
                self.length += num_turns * num_services * num_slots
            self.length = int(self.length / self.batch_size)
        return self.length
    
    def __iter__(self):
        for batch in self.iterator:
            num_turns = batch["usr_utter"]["ids"].shape[1]
            num_services = batch["serv_desc"]["ids"].shape[1]
            num_slots = batch["slot_desc"]["ids"].shape[2]
            for turnid in range(num_turns):
                for serviceid in range(num_services):
                    for slotid in range(num_slots):
                        inputs = dict(turnid=turnid, serviceid=serviceid, slotid=slotid)
                        inputs.update(batch)
                        yield inputs

In [152]:
def get_sample(ds, size=1):
    return next(iter(DialogIterator(ds, size, collate_fn=dialogue_mini_batcher)))

def print_shapes(ds):
    it = get_sample(ds, size=1)
    for field, val in it.items():
        if type(val) is dict:
            for attr in val:
                if attr.startswith("ids") or attr.startswith("mask"):
                    print(field, attr, "-->", val[attr].shape)
                    

# TODO: why mem ids, and mem loc size is diff by 1?
print_shapes(train_ds)

num_turns ids --> torch.Size([1])
num_frames ids --> torch.Size([1, 24])
serv_desc ids --> torch.Size([1, 1, 10])
serv_desc ids_pos --> torch.Size([1, 1, 10])
serv_desc mask --> torch.Size([1, 1, 10])
slot_desc ids --> torch.Size([1, 1, 11, 12])
slot_desc ids_pos --> torch.Size([1, 1, 11, 12])
slot_desc mask --> torch.Size([1, 1, 11, 12])
slot_iscat ids --> torch.Size([1, 1, 11])
slot_iscat mask --> torch.Size([1, 1, 11])
slot_memory ids --> torch.Size([1, 12, 29, 10])
slot_memory ids_desc --> torch.Size([1, 12, 29, 40])
slot_memory ids_iscat --> torch.Size([1, 12, 29, 2])
slot_memory ids_pos --> torch.Size([1, 12, 29])
slot_memory ids_memsize --> torch.Size([1, 12])
slot_memory mask --> torch.Size([1, 12, 29, 10])
slot_memory mask_desc --> torch.Size([1, 12, 29, 40])
slot_memory_loc ids --> torch.Size([1, 12, 1, 11])
slot_memory_loc ids_onehot --> torch.Size([1, 12, 1, 11, 29])
slot_memory_loc mask --> torch.Size([1, 12, 1, 11])
slot_memory_loc mask_onehot --> torch.Size([1, 12, 1, 11

# Model

In [24]:
from pytorch_transformers import BertModel, BertConfig
from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy
from collections import OrderedDict

import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch

In [182]:
class BertDownstream(nn.Module):
    # https://huggingface.co/pytorch-transformers/model_doc/bert.html
    # https://github.com/hanxiao/bert-as-service/blob/master/docs/section/faq.rst#id6
    
    def __init__(self, btype, requires_grad=False):
        super().__init__()
        self.emb = BertModel.from_pretrained(btype, output_hidden_states=True)
        for name, param in self.emb.named_parameters():
            param.requires_grad = requires_grad

    @property
    def output_dim(self):
        return 768
    
    def forward(self, input_ids, position_ids=None, attention_mask=None, flat=(1,-1)):
        sh = list(input_ids.shape)
        
        input_ids = input_ids.flatten(*flat).long()
        if position_ids is not None:
            position_ids = position_ids.flatten(*flat).long()
        if attention_mask is not None:
            attention_mask = attention_mask.flatten(*flat).long()
        outputs = self.emb(input_ids, position_ids=position_ids, attention_mask=attention_mask)
        
        # average of hidden layers: bs->bse
        pooled = torch.mean(torch.stack([outputs[2][i] for i in range(2, 11)]), dim=0)
        
        # unflatten
        pooled = pooled.view(sh + [-1])
        
        return pooled # (..., e)


class CandidateSelector(nn.Module):
    """
    memory -- represents temporal hash table of text values, that monotonically grows
    memory_desc -- represents sequential description to every value in hash table
    slot, utter -- represent query/key descriptions
    
    The idea is that we model a hash function 
        F: slot, utter, time -> memory_location
    
    This is a mix of ideas from
        - Memory Networks by Jason Weston, FAIR
        - E2E Memory Networks by Sukhbaatar, NYU/FAIR
        - Pointer Networks by Oriol Vinayls, Google Brain
        - Alignment and Translation by Bahdanau, MILA
    
        - Weston and Sukhbaatar demonstrate memory models on Facebook bAbI tasks. that capture some temporal challenges.
            while Vinayls tweaks the idea of Bahdanau's attention, to classification/regress targets represented in latent space.
            (NOTE: their motivation is subtly different, but this explanation still stands valid)
    
    Unique contribution:
        - Explore the problem where memory grows and has temporal information.
        - Model a hash function that understands temporal (just like Vinayls) but but 
            new challenges are:
            - applied to dialog state (no one else did before), where the hash function has additional challenges of dealing with 
                guessing user behavior dist. and language subtlities (eg: coreference and ellipsis resolutions.)
            - the hash function is learning from real world data, unlike previous works where toy data are used. 
                the hash key is can change just because user explicitly/implicitly says something..
            - the model is setup in a way it supports growth, and scale across domains. Sukhbaatar's approach is fixed (verify again!)
            
    Viewpoints/Quesitons
        - Where does meta data belong. such as attr -- iscatategorical. should Hash function figure out everything..
            or the memory may cache this information too as they appear in the history (at inference time, this will have cumulative error)?
        - one way to think is that, the memory representation remains static.. and the hash function figures it out
            then should the slot descriptions/serv desc for resp. memory values be tagged to the representation?
    """
        
    def __init__(self):
        super().__init__()
        self.emb = BertDownstream("bert-base-uncased")
        emb_dim = self.emb.output_dim
        
        self.l0 = nn.LSTM(emb_dim, emb_dim, batch_first=True, num_layers=2) # encode utter
        self.l1 = nn.LSTM(emb_dim, emb_dim, batch_first=True, num_layers=2) # encode memory
        self.l2 = nn.LSTM(emb_dim, emb_dim, batch_first=True, num_layers=2) # encode query slot
        
        self.turn_pos = nn.Linear(1, emb_dim)
        self.mem_pos = nn.Linear(1, emb_dim)
        
        # attentions
        self.l3 = nn.Linear(2*emb_dim, emb_dim)
        self.l4 = nn.Linear(emb_dim, 1)
        self.l5 = nn.Linear(emb_dim, 1)
         
        # metrics
        self.acc = CategoricalAccuracy()
        self.goal_acc = CategoricalAccuracy()
        
        #self.register_buffer("utter_state", None)

        # init weights: classifier's performance changes heavily on these
        for name, param in self.named_parameters():
            if name.startswith(("l0.", "11.", "l2.")):
                print("Initializing bias/weights of ", name)
                if "weight" in name:
                    nn.init.xavier_uniform_(param)
                else:
                    param.data.fill_(0.01)

    def get_metrics(self, reset=False, goal_reset=False):
        return dict(
            acc=self.acc.get_metric(reset), # avg per slot
            goal_acc=self.goal_acc.get_metric(goal_reset), # avg per turn
        )
    
    def encode_utter(self, batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        
        usr_utter, usr_mask, usr_pos = batch["usr_utter"]["ids"][:,turnid,:], batch["usr_utter"]["mask"][:,turnid,:], batch["usr_utter"]["ids_pos"][:,turnid,:]
        sys_utter, sys_mask, sys_pos = batch["sys_utter"]["ids"][:,turnid,:], batch["sys_utter"]["mask"][:,turnid,:], batch["sys_utter"]["ids_pos"][:,turnid,:]
        
        usr_utter = self.emb(usr_utter, position_ids=usr_pos, attention_mask=usr_mask) # bse
        sys_utter = self.emb(sys_utter, position_ids=sys_pos, attention_mask=sys_mask) # bse
        
        utter = torch.cat([sys_utter, usr_utter], dim=1) # b2se

        utter, utter_h = self.l0(utter)
        
        posx = torch.zeros(usr_utter.shape[0], 1, device=usr_utter.device) + turnid
        
        return utter_h[0][-1] + self.turn_pos(posx.float()) # be

    def encode_slot_desc(self, batch):
        serviceid = batch["serviceid"]
        slotid = batch["slotid"]
        
        desc = batch["slot_desc"]["ids"][:,serviceid,slotid,:] # bt
        desc_pos = batch["slot_desc"]["ids_pos"][:,serviceid,slotid,:]
        desc_mask = batch["slot_desc"]["mask"][:,serviceid,slotid,:]
        sh = desc.shape
        desc = self.emb(desc, position_ids=desc_pos, attention_mask=desc_mask) # bte
        
        serv_desc = batch["serv_desc"]["ids"][:,serviceid,:] # bt
        serv_desc_pos = batch["serv_desc"]["ids_pos"][:,serviceid,:]
        serv_desc_mask = batch["serv_desc"]["mask"][:,serviceid,:]
        sh1 = serv_desc.shape
        serv_desc = self.emb(serv_desc, serv_desc_pos, serv_desc_mask) # bte

        desc = torch.cat([desc, serv_desc], dim=1) # d2te
        desc, desc_h = self.l2(desc)

        desc_h = desc_h[0][-1]
    
        return desc_h # be
    
    def encode_memory(self, batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        slotid = batch["slotid"]
        
        # Encode memory
        mem = batch["slot_memory"]["ids"][:,turnid,:] # bms
        mem_mask = batch["slot_memory"]["mask"][:,turnid,:] 
        mem_pos = batch["slot_memory"]["ids_pos"][:,turnid,:] # position of memory cell, not tokens. bm
        sh = mem.shape
        
        mem_pos_r = mem_pos[:,:,None].repeat(1,1,sh[-1])# bm->bms
        mem = self.emb(mem, position_ids=mem_pos_r, attention_mask=mem_mask, flat=(0, 1)) # bmse
        
        mem, mem_h = self.l1(mem.flatten(0,1))
        mem_h = (mem_h[0][-1] + mem_h[1][-1]).view(sh[0], sh[1], -1)
        
        return mem_h + self.mem_pos(mem_pos.unsqueeze(-1).float()) # bme
    
    def get_key(self, utter, slot, attn=None):
        key = self.l3(torch.cat([utter, slot], dim=-1)) # b2e->be
        return key # be
    
    def get_value(self, memory, key, layers=1):
        # memory: bme, key: be
        attn = self.l5(memory * key.unsqueeze(1)).squeeze(-1)
        attn = F.softmax(attn, -1)
        
        output = memory * attn.unsqueeze(-1)
        output = self.l4(output).squeeze(-1) # bm1
        output = F.softmax(output, dim=-1)
        
        return output, attn # bm
        
    def forward(self, **batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        slotid = batch["slotid"]
        
        # doc: GRU outputs [batch, seq, emb * dir] [layers * dir, batch, emb]

        # get fixed size encodings
        utter = self.encode_utter(batch)
        slot_desc = self.encode_slot_desc(batch)
        memory = self.encode_memory(batch)
        
        key = self.get_key(utter, slot_desc)
        score, attn = self.get_value(memory, key)
        
        output = {"score": score, "mem_attn": attn} # bsm

        if "slot_memory_loc" in batch:
            target = batch["slot_memory_loc"]
            memsize = target["ids_onehot"].shape[-1]
            pred_score = score.view(-1, memsize) # bsm->bs,m 
            
            # calc loss
            target_score_oh = target["ids_onehot"][:,turnid,serviceid,slotid,:].float() # bm
            target_mask_oh = target["mask_onehot"][:,turnid,serviceid,slotid,:].float()
            output["loss"] = F.binary_cross_entropy(pred_score, target_score_oh, target_mask_oh).unsqueeze(0) # don't return scalar
            
            # log metrics
            target_score = target["ids"][:,turnid,serviceid,slotid].float() # [batch, slots]
            target_mask = (target["mask"][:,turnid,serviceid,slotid] * target["mask_none"][:,turnid,serviceid,slotid]).float()
            self.acc(score, target_score, target_mask)
            self.goal_acc(score, target_score, target_mask)
            
            output["target_ids"] = target_score
            output["pred_ids"] = torch.argmax(score, dim=-1)
            output["mask"] = target["mask"][:,turnid,serviceid,slotid].float()
            output["mask_none"] = target["mask_none"][:,turnid,serviceid,slotid].float()

        return output

# Trainer

In [31]:
from allennlp.training.tensorboard_writer import TensorboardWriter
import time

In [183]:
def module(m):
    if type(m) is nn.DataParallel:
        return m.module
    return m

def train(model, optimizer, batch_size, num_epochs, train_ds, test_ds, device):
    model = model.train()
    current_batch = 0
    tensorboard = TensorboardWriter(
        get_batch_num_total=lambda: current_batch,
        summary_interval=10,
        serialization_dir="../data/tensorboard/"
    )
    
    histogram_weights = [name for name, p in model.named_parameters() if not ".emb." in name]
    
    for epoch in range(num_epochs):
        train_iter = DialogIterator(train_ds, batch_size, collate_fn=dialogue_mini_batcher)
        num_batches = len(train_iter) 
        if test_ds:
            test_iter = DialogIterator(test_ds, batch_size, collate_fn=dialogue_mini_batcher)
            num_batches += len(test_iter)
        
        with tqdm(total=num_batches) as pbar:
            # train
            pbar.set_description("Train {}".format(epoch))
            model = model.train()
            train_iter = DialogIterator(train_ds, batch_size, collate_fn=dialogue_mini_batcher)
            
            # to know when the dialog and service changes
            prev_turnid = -1
            
            for i, batch in enumerate(train_iter):
                metrics = OrderedDict()
                batch = move_to_device(batch, device)
                current_batch += 1

                optimizer.zero_grad()
                output = model(**batch)
                output["loss"] = output["loss"].mean()
                output["loss"].backward()
                optimizer.step()
                
                if (output["mask"] * output["mask_none"]).sum() > 0:
                    num_turns = batch["usr_utter"]["ids"].shape[1]
                    num_slots = batch["slot_desc"]["ids"].shape[2]
                    goal_reset = batch["turnid"] == num_turns - 1
                    reset = prev_turnid != batch["turnid"]
                    
                    metrics.update(module(model).get_metrics(reset, goal_reset))
                    metrics["loss"] = output["loss"].item()

                    # update tensorboard logs
                    tensorboard.add_train_scalar("loss", metrics["loss"], timestep=current_batch)
                    tensorboard.log_metrics(train_metrics=metrics, epoch=current_batch)
                    tensorboard.add_train_histogram("target_ids", output["target_ids"])
                    tensorboard.add_train_histogram("pred_ids", output["pred_ids"])
                    tensorboard.add_train_histogram("mem_attn", output["mem_attn"])
                    tensorboard.log_parameter_and_gradient_statistics(model, None)

                    metrics.update(turnid=batch["turnid"], servid=batch["serviceid"], slotid=batch["slotid"])
                    pbar.set_postfix(metrics)
                    
                prev_turnid = batch["turnid"]
                pbar.update(1)

            # test
            if test_ds:
                pbar.set_description("Test {}".format(epoch))
                with torch.no_grad():
                    model = model.eval()
                    prev_turnid = -1
                    for i, batch in enumerate(test_iter):
                        metrics = OrderedDict()
                        batch = move_to_device(batch, device)
                        current_batch += 1
                        output = model(**batch)
                        output["loss"] = output["loss"].mean()
                        # at each new turn
                        if prev_turnid != batch["turnid"]:
                            metrics.update(module(model).get_metrics(reset=True, goal_reset=True))
                        else:
                            curr_met = module(model).get_metrics(reset=True)
                            metrics.update(acc=curr_met["acc"])
                        metrics["loss"] = output["loss"].item()
                        metrics.update(turnid=batch["turnid"], servid=batch["serviceid"], slotid=batch["slotid"])
                        pbar.set_postfix(metrics)
                        prev_turnid = batch["turnid"]
                        pbar.update(1)

In [123]:
#torch.save(model, "../data/model0.pkl")
try:
    del model
except NameError:
    torch.cuda.empty_cache()

In [None]:
print("Remove tensorboard logs")
!rm -rf ../data/tensorboard/*

print("set number of devices") # not sure we can in jupyter once program already kicked in the first time
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
device = "cuda"

print("loading model")
model = CandidateSelector()
#model = nn.DataParallel(model)
model = move_to_device(model, device)

optim = torch.optim.Adam(model.parameters(), lr=3e-5)
train_samples = [train_ds[i] for i in range(100)]
test_samples = [test_ds[i] for i in range(10)]

print("started training")
train(
    model=model,
    optimizer=optim,
    train_ds=train_samples,
    test_ds=None, #test_samples,
    device=device,
    num_epochs=20,
    batch_size=16,
)

Remove tensorboard logs
set number of devices
loading model
Initializing bias/weights of  l0.weight_ih_l0
Initializing bias/weights of  l0.weight_hh_l0
Initializing bias/weights of  l0.bias_ih_l0
Initializing bias/weights of  l0.bias_hh_l0
Initializing bias/weights of  l0.weight_ih_l1
Initializing bias/weights of  l0.weight_hh_l1
Initializing bias/weights of  l0.bias_ih_l1
Initializing bias/weights of  l0.bias_hh_l1
Initializing bias/weights of  l2.weight_ih_l0
Initializing bias/weights of  l2.weight_hh_l0
Initializing bias/weights of  l2.bias_ih_l0
Initializing bias/weights of  l2.bias_hh_l0
Initializing bias/weights of  l2.weight_ih_l1
Initializing bias/weights of  l2.weight_hh_l1
Initializing bias/weights of  l2.bias_ih_l1
Initializing bias/weights of  l2.bias_hh_l1
started training


HBox(children=(IntProgress(value=0, max=1322), HTML(value='')))

# Inference

In [118]:
from tabulate import tabulate

In [120]:
def infer_batch(model, test_ds, device):
    result_table = []
    
    # all samples
    avg_goal_acc = 0
    avg_joint_acc = 0

    model = model.eval()
    
    batch_iterator = D.DataLoader(test_ds, 1, collate_fn=dialogue_mini_batcher)
    with tqdm(batch_iterator, total=len(batch_iterator)) as pbar:
        for batch in pbar:
            batch = move_to_device(batch, device)

            # iterate over dialog turn/services
            num_turns = batch["usr_utter"]["ids"].shape[1]
            num_services = batch["serv_desc"]["ids"].shape[1]
            true_batch_size = batch["usr_utter"]["ids"].shape[0]

            # across one batch
            total_goal_acc = {k: 0 for k in range(true_batch_size)} 
            total_joint_acc = {k: 0 for k in range(true_batch_size)}
            unpadded_num_turns = {k: 0 for k in range(true_batch_size)} 

            for turnid in range(num_turns):
                # across one turn
                goal_acc = {k: 0 for k in range(true_batch_size)} # per batch counter
                joint_acc = {k: 1 for k in range(true_batch_size)} # per batch counterr, binary 1/0
                unpadded_num_slots = {k: 0 for k in range(true_batch_size)}

                for sid in range(num_services):
                    inputs = dict(turnid=turnid, serviceid=sid)
                    inputs.update(batch)
                    with torch.no_grad():
                        result = model(**inputs)

                    # iterate step by step now.
                    for b in range(true_batch_size):

                        fields = dict(
                            prd_mem_loc = result["pred_ids"][b], # B,slot
                            tgt_mem_loc = result["target_ids"][b],
                            mask = result["mask"][b],
                            mask_none = result["mask_none"][b], # 1 -> target is not None
                            dial_id = inputs["dial_id"]["value"][b],
                            serv_name = inputs["serv_name"]["value"][b],
                            slot_name = inputs["slot_name"]["value"][b],
                            num_turns = inputs["num_turns"]["ids"][b],
                            memory = inputs["slot_memory"]["value"][b], # B,turn,memory
                        )

                        for slotid, (prd_loc, tgt_loc) in enumerate(zip(fields["prd_mem_loc"], fields["tgt_mem_loc"])):
                            try:
                                memory = fields["memory"][turnid]
                                # clip the locations.. they somehow predict in padding regions!! Make it NONE
                                prd_loc = int(prd_loc.item()) if prd_loc < len(memory) else -1
                                tgt_loc = int(tgt_loc.item()) if tgt_loc < len(memory) else -1

                                prd_val = memory[prd_loc] if prd_loc >= 0 else "<UNK>"
                                tgt_val = memory[tgt_loc] if tgt_loc >= 0 else "<UNK>"
                                mask = fields["mask"][slotid].item()
                                mask_none = fields["mask_none"][slotid].item()
                                dialid = fields["dial_id"]
                                servname = fields["serv_name"][sid]
                                slotname = fields["slot_name"][sid][slotid]
                                num_turns = fields["num_turns"].item() / 2 # usr-sys turn pairs
                            except: # padding issues.
                                continue
                            
                            # update acc counter. Multiply mask_none to ignore acc on NONE slots
                            if mask * mask_none == 1: 
                                goal_acc[b] += int(prd_loc == tgt_loc)
                                joint_acc[b] *= int(prd_loc == tgt_loc)
                                unpadded_num_slots[b] += 1

                            item = OrderedDict(
                                dialid = fields["dial_id"],
                                num_turns = num_turns,
                                turnid = turnid,
                                servid = sid,
                                servname = fields["serv_name"][sid],
                                memsize = len(memory),
                                slotname = slotname,
                                tgt_val = tgt_val,
                                tgt_loc = tgt_loc,
                                prd_val = prd_val,
                                prd_loc = prd_loc,
                                correct = 100*int(tgt_loc == prd_loc),
                                mask = (mask * mask_none),
                            )

                            result_table.append(item)

                # update acc across the batch
                for k in range(true_batch_size):
                    total_goal_acc[k] += goal_acc[k] / max(unpadded_num_slots[k], 1)
                    total_joint_acc[k] += joint_acc[k]
                    unpadded_num_turns[k] += int(unpadded_num_slots[k] > 0)

            # update avg
            for k in range(true_batch_size):
                avg_goal_acc += total_goal_acc[k] / unpadded_num_turns[k]
                avg_joint_acc += total_joint_acc[k] / unpadded_num_turns[k]
        
    print("Avg Goal Acc: ", avg_goal_acc / len(test_ds))
    print("Avg Joint Goal Acc: ", avg_joint_acc / len(test_ds))
    print()
    
    print(tabulate(
        [list(item.values()) for item in result_table],
        list(result_table[0].keys()),
        "fancy_grid",
    ))
    
    
    
litmus_test_samples = [train_ds[0]]
infer_batch(model, litmus_test_samples, device="cuda")

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


Avg Goal Acc:  0.0
Avg Joint Goal Acc:  0.09090909090909091

╒══════════╤═════════════╤══════════╤══════════╤═══════════════╤═══════════╤═════════════════╤═══════════╤═══════════╤═══════════╤═══════════╤═══════════╤════════╕
│   dialid │   num_turns │   turnid │   servid │ servname      │   memsize │ slotname        │ tgt_val   │   tgt_loc │ prd_val   │   prd_loc │   correct │   mask │
╞══════════╪═════════════╪══════════╪══════════╪═══════════════╪═══════════╪═════════════════╪═══════════╪═══════════╪═══════════╪═══════════╪═══════════╪════════╡
│  1_00000 │          12 │        0 │        0 │ Restaurants_1 │        19 │ restaurant_name │ NONE      │         0 │ Indian    │        16 │         0 │      0 │
├──────────┼─────────────┼──────────┼──────────┼───────────────┼───────────┼─────────────────┼───────────┼───────────┼───────────┼───────────┼───────────┼────────┤
│  1_00000 │          12 │        0 │        0 │ Restaurants_1 │        19 │ date            │ NONE      │         0 │