In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [1]:
import torch.utils.data as D
import torch
import numpy as np
import json
import csv
import glob
import re
import concurrent.futures
import os

from copy import deepcopy
from functools import lru_cache
from collections import OrderedDict
from types import SimpleNamespace
from collections.abc import Iterable
from tqdm import tqdm_notebook as tqdm
from pytorch_transformers import BertTokenizer
from ipdb import set_trace

# Data Reader

## Utils

In [2]:
class Tokenizer:
    
    def __init__(self, bert):
        self.bert = bert
    
    def __call__(self, text, include_sep=True):
        tokens = self.bert.tokenize(text)
        if include_sep:
            tokens.insert(0, "[CLS]")
            tokens.append("[SEP]")
        return tokens
    
    
class TokenIndexer:
    
    def __init__(self, bert):
        self.bert = bert
        
    def inv(self, *args, **kw):
        # tokens MUST be list, tensor doesn't work.
        return self.bert.convert_ids_to_tokens(*args, **kw)
        
    def __call__(self, *args, **kw):
        return self.bert.convert_tokens_to_ids(*args, **kw)

In [3]:
def label_binarize(labels, classes):
    # labels: np.array or tensor [batch, 1]
    # classes: [..] list of classes
    # weirdly,`sklearn.preprocessing.label_binarize` returns [1] or [0]
    # instead of onehot ONLY when executing in this script!
    vectors = [np.zeros(len(classes)) for _ in labels]
    for i, label in enumerate(labels):
        for j, c in enumerate(classes):
            if c == label:
                vectors[i][j] = 1
    return np.array(vectors)
    

def label_inv_binarize(vectors, classes):
    # labels: np.array or tensor [batch, classes]
    # classes: [..] list of classes
    # follows sklearn LabelBinarizer.inverse_transform()
    # given all zeros, predicts label at index 0, instead of returning none!
    # sklearn doesn't have functional API of inverse transform
    labels = []
    for each in vectors:
        index = np.argmax(each)
        labels.append(classes[index])
    return labels

In [4]:
def padded_array(array, value=0):
    # TODO: this does not do type checking; and wow it can be slow on strings.
    # expects array to have fixed _number_ of dimensions
    
    # resolve the shape of padded array
    shape_index = {}
    queue = [(array, 0)]
    while queue:
        subarr, dim = queue.pop(0)
        shape_index[dim] = max(shape_index.get(dim, -1), len(subarr))
        for x in subarr:
            if isinstance(x, Iterable) and not isinstance(x, str):
                queue.append((x, dim+1))
    shape = [shape_index[k] for k in range(max(shape_index) + 1)]
    
    # fill the values 
    padded = np.ones(shape) * value
    queue = [(array, [])]
    while queue:
        subarr, index = queue.pop(0)
        for j, x in enumerate(subarr):
            if isinstance(x, Iterable):
                queue.append((x, index + [j]))
            else:
                padded[tuple(index + [j])] = x
    return padded

## Schema Reader

In [5]:
class Schemas(object):

    def __init__(self, filepath):
        with open(filepath) as f:
            self.index = {}
            for schema in json.load(f):
                service_name = schema["service_name"]
                self.index[service_name] = schema

    @lru_cache(maxsize=None)
    def get(self, service):
        result = dict(
            # service
            name=service,
            desc=self.index[service]["description"],
            
            # slots
            slot_name=[],
            slot_desc=[],
            slot_iscat=[], 
            slot_vals=[], # collected only for cat slots.. not sure if that makes sense

            # intents
            intent_name=[],
            intent_desc=[],
            intent_istrans=[],
            intent_reqslots=[],
            intent_optslots=[],
            intent_optvals=[],
        )

        for slot in self.index[service]["slots"]:
            result["slot_name"].append(slot["name"])
            result["slot_desc"].append(slot["description"])
            result["slot_iscat"].append(slot["is_categorical"])
            result["slot_vals"].append(slot["possible_values"])
        
        for intent in self.index[service]["intents"]:
            result["intent_name"].append(intent["name"])
            result["intent_desc"].append(intent["description"])
            result["intent_istrans"].append(intent["is_transactional"])
            result["intent_reqslots"].append(intent["required_slots"])
            result["intent_optslots"].append(list(intent["optional_slots"].keys()))
            result["intent_optvals"].append(list(intent["optional_slots"].values()))

        return result    

## Dialogue Reader

In [6]:
class DialogueDataset(D.Dataset):
    def __init__(self, filename, schemas, tokenizer, token_indexer):
        with open(filename) as f:
            self.ds = json.load(f)
        self.schemas = schemas
        self.tokenizer = tokenizer
        self.token_indexer = token_indexer
        self.dialogues = []
        for dial in self.ds:
            fields = self.dial_to_fields(dial)
            self.dialogues.append(fields)
        self.schemas = None
        self.tokenizer = None
        self.token_indexer = None
            
    def __getitem__(self, idx):
        return self.dialogues[idx]
    
    def __len__(self):
        return len(self.dialogues)
    
    def fd_serv_name(self, dial, fields):
        resp = dict(value=[])
        for service in dial["services"]:
            resp["value"].append(service)
        return resp
        
    def fd_serv_desc(self, dial, fields):
        resp = dict(value=[], tokens=[], ids=[], mask=[])
        for service in dial["services"]:
            desc = self.schemas.get(service)["desc"]
            resp["value"].append(desc)
            resp["tokens"].append(self.tokenizer(desc))
            resp["ids"].append(self.token_indexer(resp["tokens"][-1]))
            resp["mask"].append([1] * len(resp["tokens"][-1]))
        return resp
    
    def fd_slot_name(self, dial, fields):
        resp = {"value": []}
        for serv in dial["services"]:
            schema = self.schemas.get(serv)
            resp["value"].append(schema["slot_name"])
        return resp
    
    def fd_slot_desc(self, dial, fields):
        resp = dict(value=[], tokens=[], ids=[], mask=[])
        for serv in dial["services"]:
            s_desc = self.schemas.get(serv)["slot_desc"]
            s_tokens = [self.tokenizer(d) for d in s_desc]
            s_ids = [self.token_indexer(d) for d in s_tokens]
            s_mask = [[1] * len(d) for d in s_tokens]
            resp["value"].append(s_desc)
            resp["tokens"].append(s_tokens)
            resp["ids"].append(s_ids)
            resp["mask"].append(s_mask)
        return resp

    def fd_slot_memory(self, dial, fields):
        # memory is a sequence of slot tagger values across all frames in a dial.. that grows per turn
        # maintain snapshot at each turn
        resp = dict(
            value=[], # unfeaturized memory
            tokens=[], # tokenized memory [turn, mem-size, tokens]
            ids=[], # indexed memory [turn mem-size, tokens]
            mask=[], # mask on memory [turn, mem-size, tokens]
            ids_memsize=[] # mem sizes [turn, 1]
        )
        
        # memory is sequential, initialized by vals from schema, only has values
        memory = ["NONE", "dontcare"] # keep these at index 0, 1
        memory_index = set()
        for serv in dial["services"]:
            schema = self.schemas.get(serv)
            for values in schema["slot_vals"]:
                for val in values:
                    if val not in memory_index:
                        memory.append(val)
                        memory_index.add(val)
        
        # at each user turn create a memory snapshot..
        for turn in dial["turns"]:
            memory = deepcopy(memory)
            # pick slot tagger's ground truth ; categorical slot vals are already initialized!
            utter = turn["utterance"]
            for frame in turn["frames"]:
                for tag in frame["slots"]:
                    st, en = tag["start"], tag["exclusive_end"]
                    value = utter[st:en]
                    if value not in memory_index:
                        memory.append(value)
                        memory_index.add(value)
            if turn["speaker"] == "USER":
                resp["value"].append(memory)
                resp["ids_memsize"].append(len(memory))

        # tokenize and index the memory values
        # value: [turn, values], ids/tokens: [turn, values, tokens]
        for mem_snapshot in resp["value"]:
            mem_tokens = []
            mem_ids = []
            mem_mask = []
            for val in mem_snapshot:
                tokens = self.tokenizer(val)
                mem_tokens.append(tokens)
                mem_ids.append(self.token_indexer(tokens))
                mem_mask.append([1] * len(tokens))
            resp["tokens"].append(mem_tokens)
            resp["ids"].append(mem_ids)
            resp["mask"].append(mem_mask)
        
        return resp
    
    def fd_slot_memory_loc(self, dial, fields):
        resp = dict(
            value=[],  # string value
            ids=[], # memory loc
            ids_onehot=[], # onehot of memory loc
            mask=[],
            mask_onehot=[],
            mask_none=[], # 1 -> non NONE values
            mask_none_onehot=[],
        )
        
        # query dialog memory snapshots per turn. NOTE: fd_slot_memory should exec first.
        memory = fields["slot_memory"]["value"]
        
        # init snapshot: service, slot -> memory loc, val
        memory_loc = []
        memory_val = []
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                loc = OrderedDict()
                val = OrderedDict()
                for serv in dial["services"]:
                    loc[serv] = OrderedDict()
                    val[serv] = OrderedDict()
                    for slot in self.schemas.get(serv)["slot_name"]:
                        loc[serv][slot] = None
                        val[serv][slot] = None
                memory_loc.append(loc)
                memory_val.append(val)
        
        # fill the memory locations 
        snapshot_id = 0
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                turn_memory = memory[snapshot_id]
                turn_memory_loc = memory_loc[snapshot_id]
                turn_memory_val = memory_val[snapshot_id]
                for frame in turn["frames"]:
                    service = frame["service"]
                    for slot, values in frame["state"]["slot_values"].items():
                        val = re.sub("\u2013", "-", values[0]) # dial 59_00125 turn 14
                        turn_memory_loc[service][slot] = turn_memory.index(val)
                        turn_memory_val[service][slot] = val
                    # add locations to UNKNOWN slots
                    for slot, val in turn_memory_loc[service].items():
                        if val is None:
                            turn_memory_loc[service][slot] = turn_memory.index("NONE")
                            turn_memory_val[service][slot] = "NONE"
                snapshot_id += 1
        
        # featurize
        snapshot_id = 0
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                turn_memory = memory[snapshot_id]
                turn_memory_loc = memory_loc[snapshot_id]
                turn_memory_val = memory_val[snapshot_id]
                none_loc = memory[snapshot_id].index("NONE")
                turn_fields = dict(
                    value=[], ids=[], ids_onehot=[], mask=[], mask_onehot=[],
                    mask_none=[], mask_none_onehot=[],
                )

                for serv in turn_memory_loc:
                    mem_size = len(turn_memory)
                    vals = list(turn_memory_val[serv].values())
                    ids = list(turn_memory_loc[serv].values())
                    ids_onehots = label_binarize(ids, list(range(mem_size)))
                    mask = [1] * len(ids)
                    mask_onehots = [[1] * mem_size for _ in ids]

                    mask_none = [int(v!="NONE") for v in vals]
                    mask_none_onehot = [[1] * mem_size for _ in ids]
                    for v, loc, onehot in zip(vals, ids, mask_none_onehot):
                        if v == "NONE":
                            onehot[loc] = 0
                                
                    turn_fields["value"].append(vals)
                    turn_fields["ids"].append(ids)
                    turn_fields["ids_onehot"].append(ids_onehots)
                    turn_fields["mask"].append(mask)
                    turn_fields["mask_onehot"].append(mask_onehots)
                    turn_fields["mask_none"].append(mask_none)
                    turn_fields["mask_none_onehot"].append(mask_none_onehot)

                # update turn
                for k, v in turn_fields.items():
                    resp[k].append(v)
                
                snapshot_id += 1
            
        return resp
    
    def fd_num_turns(self, dial, fields):
        return {"ids": len(dial["turns"])}
    
    def fd_num_frames(self, dial, fields):
        return {"ids": [len(t["frames"]) for t in dial["turns"]]}
    
    def fd_usr_utter(self, dial, fields):
        resp = dict(value=[], ids=[], mask=[], tokens=[])
        for turn in dial["turns"]:
            if turn["speaker"] == "USER":
                utter = turn["utterance"]
                tokens = self.tokenizer(utter)
                ids = self.token_indexer(tokens)
                resp["value"].append(utter)
                resp["ids"].append(ids)
                resp["tokens"].append(tokens)
                resp["mask"].append([1] * len(tokens))
        return resp
    
    def fd_sys_utter(self, dial, fields):
        resp = dict(value=[], ids=[], mask=[], tokens=[])
        for turn in dial["turns"]:
            if turn["speaker"] == "SYSTEM":
                utter = turn["utterance"]
                tokens = self.tokenizer(utter)
                ids = self.token_indexer(tokens)
                resp["value"].append(utter)
                resp["ids"].append(ids)
                resp["tokens"].append(tokens)
                resp["mask"].append([1] * len(tokens))
        return resp
    
    def fd_dial_id(self, dial, fields):
        return {"value": dial["dialogue_id"]}
    
    def dial_to_fields(self, dial):
        fields = {}
        ordered_funcs = [
            "fd_dial_id",
            "fd_num_turns", "fd_num_frames",
            "fd_serv_name", "fd_serv_desc", 
            "fd_slot_name", "fd_slot_desc",
            "fd_slot_memory", "fd_slot_memory_loc",
            "fd_usr_utter", "fd_sys_utter"]
        for func in ordered_funcs:
            name = func.split("fd_", maxsplit=1)[-1]
            value = getattr(self, func)(dial, fields)
            if value is not None:
                fields[name] = value
        return fields

## Load Data

In [7]:
bert_ = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer = Tokenizer(bert_)
token_indexer = TokenIndexer(bert_)

train_schemas = Schemas("../data/train/schema.json")
test_schemas = Schemas("../data/dev/schema.json")

In [8]:
# load training dataset
train_dial_sets = []
train_dial_files = sorted(glob.glob("../data/train/dialogues*.json"))
num_workers = min(20, len(train_dial_files))

def worker(filename):
    return DialogueDataset(filename, train_schemas, tokenizer, token_indexer)

with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
    for ds in tqdm(executor.map(worker, train_dial_files), total=len(train_dial_files)):
        train_dial_sets.append(ds)

train_ds = D.ConcatDataset(train_dial_sets)

HBox(children=(IntProgress(value=0, max=127), HTML(value='')))




In [9]:
# load test dataset
test_dial_sets = []
test_dial_files = sorted(glob.glob("../data/dev/dialogues*.json"))
num_workers = min(20, len(train_dial_files))

def worker(filename):
    return DialogueDataset(filename, test_schemas, tokenizer, token_indexer)

with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
    for ds in tqdm(executor.map(worker, test_dial_files), total=len(test_dial_files)):
        test_dial_sets.append(ds)

test_ds = D.ConcatDataset(test_dial_sets)

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




# Training Utils

In [10]:
import torch.nn as nn

In [11]:
def move_to_device(obj, device):
    if type(obj) is list:
        return [move_to_device(o, device) for o in obj]
    elif type(obj) is dict:
        return {k: move_to_device(v, device) for k, v in obj.items()}
    elif type(obj) is torch.Tensor or isinstance(obj, nn.Module):
        return obj.to(device)
    return obj

In [12]:
def dialogue_mini_batcher(dialogues):
    default_padding = 0
    batch = {}
    for dial in dialogues:
        # populate the batch
        for field, data in dial.items():
            if field not in batch:
                batch[field] = {}
            for attr, val in data.items():
                if attr == "padding":
                    batch[field][attr] = val
                else:
                    batch[field][attr] = batch[field].get(attr, [])
                    batch[field][attr].append(val)

    # padding on field attributes
    for field_name, data in batch.items():
        for attr in data:
            if attr.startswith("ids") or attr.startswith("mask"):
                if type(data[attr]) is not torch.Tensor: # don't move b/w CPU and GPU unnecessarily
                    data[attr] = padded_array(data[attr], default_padding)
                    data[attr] = torch.tensor(data[attr], device="cpu")

    return batch


class DialogIterator(object):
    """A simple wrapper on DataLoader"""
    
    def __init__(self, dataset, batch_size, *args, **kw):
        self.length = None
        self.dataset = dataset
        self.batch_size = batch_size
        self.iterator = D.DataLoader(dataset, batch_size, *args, **kw)
        
    def __len__(self):
        if self.length is None:
            self.length = 0
            for dial in self.dataset:
                self.length += dial["num_turns"]["ids"] + len(dial["serv_name"]["value"])
            self.length = int(self.length / self.batch_size)
        return self.length
    
    def __iter__(self):
        for batch in self.iterator:
            num_turns = batch["usr_utter"]["ids"].shape[1]
            num_services = batch["serv_desc"]["ids"].shape[1]
            for turnid in range(num_turns):
                for sid in range(num_services):
                    inputs = dict(turnid=turnid, serviceid=sid)
                    inputs.update(batch)
                    yield inputs
                    
#next(iter(DialogIterator(train_ds, 1, collate_fn=dialogue_mini_batcher)))

In [13]:
def print_shapes(ds):
    it = next(iter(DialogIterator(ds, 1, collate_fn=dialogue_mini_batcher)))
    for field, val in it.items():
        if type(val) is dict:
            for attr in val:
                if attr.startswith("ids") or attr.startswith("mask"):
                    print(field, attr, "-->", val[attr].shape)
                    

# TODO: why mem ids, and mem loc size is diff by 1?
print_shapes(train_ds)

num_turns ids --> torch.Size([1])
num_frames ids --> torch.Size([1, 24])
serv_desc ids --> torch.Size([1, 1, 10])
serv_desc mask --> torch.Size([1, 1, 10])
slot_desc ids --> torch.Size([1, 1, 11, 12])
slot_desc mask --> torch.Size([1, 1, 11, 12])
slot_memory ids --> torch.Size([1, 12, 29, 10])
slot_memory mask --> torch.Size([1, 12, 29, 10])
slot_memory ids_memsize --> torch.Size([1, 12])
slot_memory_loc ids --> torch.Size([1, 12, 1, 11])
slot_memory_loc ids_onehot --> torch.Size([1, 12, 1, 11, 29])
slot_memory_loc mask --> torch.Size([1, 12, 1, 11])
slot_memory_loc mask_onehot --> torch.Size([1, 12, 1, 11, 29])
slot_memory_loc mask_none --> torch.Size([1, 12, 1, 11])
slot_memory_loc mask_none_onehot --> torch.Size([1, 12, 1, 11, 29])
usr_utter ids --> torch.Size([1, 12, 25])
usr_utter mask --> torch.Size([1, 12, 25])
sys_utter ids --> torch.Size([1, 12, 31])
sys_utter mask --> torch.Size([1, 12, 31])


# Model

In [100]:
from pytorch_transformers import BertModel, BertConfig
from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy
from collections import OrderedDict

import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch

In [101]:
class BertDownstream(nn.Module):
    # https://huggingface.co/pytorch-transformers/model_doc/bert.html
    
    def __init__(self, btype, requires_grad=False):
        super().__init__()
        self.emb = BertModel.from_pretrained(btype)
        for param in self.emb.parameters():
            param.requires_grad = requires_grad
    
    def forward(self, *args, **kw):
        outputs = self.emb(*args, **kw)
        return outputs[0] # pooler [b,s]->[b,s,e]

In [112]:
class CandidateSelector(nn.Module):

    def __init__(self):
        super().__init__()
        self.emb = BertDownstream("bert-base-uncased")
        self.emb_dim = 768
        
        # encode utter and desc tokens
        self.l0 = nn.GRU(self.emb_dim, self.emb_dim, batch_first=True, num_layers=1)
        
        # encode memory cells and rank 
        self.l1 = nn.GRU(self.emb_dim, self.emb_dim, batch_first=True, num_layers=1)

        # encode slot desc
        self.l2 = nn.GRU(self.emb_dim, self.emb_dim, batch_first=True, num_layers=1)
       
        # attentions
        self.l4 = nn.Linear(2*self.emb_dim, self.emb_dim)
        self.l6 = nn.Bilinear(self.emb_dim, self.emb_dim, 1, bias=False)
         
        # metrics
        self.acc = CategoricalAccuracy()
        self.goal_acc = CategoricalAccuracy()

        # init weights: classifier's performance changes heavily on these
        for name, param in self.named_parameters():
            if name.startswith(("l0.", "11.", "l2.")):
                print("Initializing bias/weights of ", name)
                if "weight" in name:
                    nn.init.xavier_uniform_(param)
                else:
                    param.data.fill_(0.)

    def get_metrics(self, reset=False, goal_reset=False):
        return dict(
            acc=self.acc.get_metric(reset), # avg per service
            goal_acc=self.goal_acc.get_metric(goal_reset), # avg per turn
        )
    
    def compute_score1(self, memory, slot, utter):
        # memory: bme
        # slot: bse
        # utter: b,e*dir
        s = slot.shape[1]
        m = memory.shape[1]
        
        utter = utter[:,None,:].expand(-1,s,-1) #be->bse

        key = self.l4(torch.cat([utter, slot], dim=-1)) #bs2e
        key = torch.tanh(key)
        key = key[:,:,None,:].expand(-1,-1,m,-1).contiguous() # bsme
        
        memory = memory[:,None,:,:].expand(-1,s,-1,-1).contiguous()
        
        energy = self.l6(key, memory) #bsme->bsm1
        energy = F.softmax(energy.squeeze(-1), dim=-1)
        return energy
    
    def encode_utter(self, batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        
        usr_utter, usr_mask = batch["usr_utter"]["ids"][:,turnid,:], batch["usr_utter"]["mask"][:,turnid,:]
        sys_utter, sys_mask = batch["sys_utter"]["ids"][:,turnid,:], batch["sys_utter"]["mask"][:,turnid,:]
        
        utter = torch.cat([sys_utter, usr_utter], dim=-1).long()
        utter_mask = torch.cat([sys_mask, usr_mask], dim=-1).long()
        
        utter = self.emb(utter, attention_mask=utter_mask) # bs->bse
        utter, utter_h = self.l0(utter)
        
        return utter_h[-1] # be

    def encode_memory(self, batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        
        # get fixed size encoding of memory cells
        memory = batch["slot_memory"]["ids"][:,turnid,:]
        memory_mask = batch["slot_memory"]["mask"][:,turnid,:]
        sh = memory.shape
        
        # EMB can conditionally represent across memories!
        memory = memory.contiguous().view(sh[0], -1).long() # b,m,s -> b,m*s
        memory_mask = memory_mask.view(sh[0], -1).contiguous().long()
        
        memory = self.emb(memory, attention_mask=memory_mask) # b,ms -> b,ms,e
        memory, memory_h = self.l1(memory) # b,ms,e
        memory = memory.view(sh[0], sh[1], sh[2], -1)
        memory = memory[:,:,-1,:]
        
        return memory # bme
    
    def encode_slot_desc(self, batch):
        serviceid = batch["serviceid"]
        
        desc = batch["slot_desc"]["ids"][:,serviceid,:].long() # [batch, slots, tokens]
        desc_mask = batch["slot_desc"]["mask"][:,serviceid,:].long()
        sh = desc.shape
        
        # bst -> b,st -> b,st,e
        desc = desc.contiguous().view(sh[0], -1)
        desc_mask = desc_mask.contiguous().view(sh[0], -1)
        desc = self.emb(desc, attention_mask=desc_mask)
        desc, desc_h = self.l2(desc)
        desc = desc.view(sh[0], sh[1], sh[2], -1)
        desc = desc[:,:,-1,:]
        return desc # bse
        
        
    def forward(self, **batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]
        
        # doc: GRU outputs [batch, seq, emb * dir] [layers * dir, batch, emb]

        # get fixed size encodings
        utter = self.encode_utter(batch) # be
        memory = self.encode_memory(batch) # bme
        slot_desc = self.encode_slot_desc(batch) #bse
        
        # map: slot desc -> memory
        score = self.compute_score1(memory, slot_desc, utter) # bsm
        output = {"score": score}

        if "slot_memory_loc" in batch:
            # calc loss
            target_score = batch["slot_memory_loc"]["ids_onehot"][:,turnid,serviceid,:].contiguous().view(-1, 1).float() # [batch, slots, memory]
            target_mask = batch["slot_memory_loc"]["mask_onehot"][:,turnid,serviceid,:].contiguous().view(-1, 1).float()
            output["loss"] = F.binary_cross_entropy(score.contiguous().view(-1,1), target_score, target_mask)
            output["loss"] = output["loss"].unsqueeze(0) # don't return scalar.
            
            # calc acc. don't calc acc on slots that are predicted NONE.
            target_score = batch["slot_memory_loc"]["ids"][:,turnid,serviceid,:].float() # [batch, slots]
            target_mask = (batch["slot_memory_loc"]["mask"][:,turnid,serviceid,:] * \
                           batch["slot_memory_loc"]["mask_none"][:,turnid,serviceid,:]).float()
            
            self.acc(score, target_score, target_mask)
            self.goal_acc(score, target_score, target_mask)
            
            output["target"] = target_score
            output["target_ids"] = torch.argmax(target_score, dim=-1)
            output["pred_ids"] = torch.argmax(score, dim=-1)

        return output

# Trainer

In [16]:
from allennlp.training.tensorboard_writer import TensorboardWriter

In [17]:
def module(m):
    if type(m) is nn.DataParallel:
        return m.module
    return m

def train(model, optimizer, batch_size, num_epochs, train_ds, test_ds, device):
    current_batch = 1 # tensorboard doesn't like 0
    tensorboard = TensorboardWriter(
        get_batch_num_total=lambda: current_batch,
        summary_interval=10,
        serialization_dir="../data/tensorboard/"
    )
    for epoch in range(num_epochs):
        train_iter = DialogIterator(train_ds, batch_size, collate_fn=dialogue_mini_batcher)
        num_batches = len(train_iter) 
        if test_ds:
            test_iter = DialogIterator(test_ds, batch_size, collate_fn=dialogue_mini_batcher)
            num_batches += len(test_iter)
        
        with tqdm(total=num_batches) as pbar:
            # train
            pbar.set_description("Train {}".format(epoch))
            model = model.train()
            train_iter = DialogIterator(train_ds, batch_size, collate_fn=dialogue_mini_batcher)
            metrics = OrderedDict()
            
            # to know when the dialog and service changes
            turnid = -1
            
            for i, batch in enumerate(train_iter):
                current_batch += 1
                batch = move_to_device(batch, device)
                optimizer.zero_grad()
                output = model(**batch)
                output["loss"] = output["loss"].mean()
                output["loss"].backward()
                optimizer.step()
                
                # at each new turn
                if turnid != batch["turnid"]:
                    metrics.update(module(model).get_metrics(reset=True, goal_reset=True))
                else:
                    curr_met = module(model).get_metrics(reset=True)
                    metrics.update(acc=curr_met["acc"])
                    
                metrics["loss"] = output["loss"].item()
                pbar.update(1)
                pbar.set_postfix(metrics)
                
                metrics["turnid"] = batch["turnid"]
                metrics["servid"] = batch["serviceid"]
                turnid = batch["turnid"]

#                 # update tensorboard logs
#                 tensorboard.add_train_histogram("target", output["target"])
#                 tensorboard.add_train_histogram("pred", output["score"])
#                 tensorboard.add_train_histogram("target_ids", output["target_ids"])
#                 tensorboard.add_train_histogram("pred_ids", output["pred_ids"])
                tensorboard.add_train_scalar("loss", metrics["loss"], timestep=current_batch)
                tensorboard.log_metrics(train_metrics=metrics, epoch=current_batch)
                tensorboard.log_parameter_and_gradient_statistics(model, None)

            # test
            if test_ds:
                pbar.set_description("Test {}".format(epoch))
                metrics = OrderedDict(epoch=epoch)
                turnid = -1
                with torch.no_grad():
                    model = model.eval()
                    for i, batch in enumerate(test_iter):
                        current_batch += 1
                        batch = move_to_device(batch, device)
                        output = model(**batch)
                        output["loss"] = output["loss"].mean()

                        # at each new turn
                        if turnid != batch["turnid"]:
                            metrics.update(module(model).get_metrics(reset=True, goal_reset=True))
                        else:
                            curr_met = module(model).get_metrics(reset=True)
                            metrics.update(acc=curr_met["acc"])

                        metrics["loss"] = output["loss"].item()
                        pbar.update(1)
                        pbar.set_postfix(metrics)
                        turnid = batch["turnid"]

In [58]:
#torch.save(model, "../data/model0.pkl")
try:
    del model
except NameError:
    torch.cuda.empty_cache()

In [113]:
print("Remove tensorboard logs")
!rm -rf ../data/tensorboard/*

print("set number of devices") # not sure we can in jupyter once program already kicked in the first time
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

print("loading model")
model = CandidateSelector()
model = nn.DataParallel(model)
model = move_to_device(model, "cuda")

optim = torch.optim.Adam(model.parameters(), lr=3e-4)
train_samples = [train_ds[i] for i in range(1000)]
test_samples = [test_ds[i] for i in range(100)]

print("started training")
train(
    model=model,
    optimizer=optim,
    train_ds=train_samples,
    test_ds=test_samples,
    device="cuda",
    num_epochs=20,
    batch_size=32,
)

Remove tensorboard logs
set number of devices
loading model
Initializing bias/weights of  l0.weight_ih_l0
Initializing bias/weights of  l0.weight_hh_l0
Initializing bias/weights of  l0.bias_ih_l0
Initializing bias/weights of  l0.bias_hh_l0
Initializing bias/weights of  l2.weight_ih_l0
Initializing bias/weights of  l2.weight_hh_l0
Initializing bias/weights of  l2.bias_ih_l0
Initializing bias/weights of  l2.bias_hh_l0
started training


HBox(children=(IntProgress(value=0, max=545), HTML(value='')))

> [0;32m<ipython-input-112-8768bb56e119>[0m(108)[0;36mencode_slot_desc[0;34m()[0m
[0;32m    107 [0;31m[0;34m[0m[0m
[0m[0;32m--> 108 [0;31m        [0mdesc[0m [0;34m=[0m [0mdesc[0m[0;34m.[0m[0mview[0m[0;34m([0m[0msh[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m,[0m [0msh[0m[0;34m[[0m[0;36m1[0m[0;34m][0m[0;34m,[0m [0msh[0m[0;34m[[0m[0;36m2[0m[0;34m][0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    109 [0;31m        [0mdesc[0m [0;34m=[0m [0mdesc[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m[0;34m:[0m[0;34m,[0m[0;34m-[0m[0;36m1[0m[0;34m,[0m[0;34m:[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m> [0;32m<ipython-input-112-8768bb56e119>[0m(108)[0;36mencode_slot_desc[0;34m()[0m
[0;32m    107 [0;31m[0;34m[0m[0m
[0m[0;32m--> 108 [0;31m        [0mdesc[0m [0;34m=[0m [0mdesc[0m[0;34m.[0m[0mview[0m[0;34m([0m[0msh[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m,[0m [0msh[0m[

KeyboardInterrupt: 

In [92]:
# def run_infer(model, test_ds, batch_size, device):
#     with torch.no_grad():
#         model = move_to_device(model, device)
#         model = model.eval()
#         test_iter = DialogIterator(test_ds, batch_size, collate_fn=dialogue_mini_batcher)
#         with tqdm(test_iter, leave=False) as test_pbar:
#             for i, batch in enumerate(test_pbar):
#                 batch = move_to_device(batch, device)
#                 output = model(**batch)
                
#                 # results..
#                 dial_ids = batch["dial_id"]
#                 turn_id = batch["turnid"]
#                 serv_id = batch["serviceid"]
                
#                 mem_loc = torch.argmax(output["score"], dim=-1)
                
                
                
#                 outputs.append((batch,out))
#         return outputs
    
# results = run_infer(model, [test_ds[0]], 1, "cuda")

In [None]:
#b=next(iter(DialogIterator(train_ds, 1, collate_fn=dialogue_mini_batcher)))
#b["slot_memory_loc"]["ids"]