In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
import torch.utils.data as D
import torch
import numpy as np
import json
import csv
import glob
import concurrent.futures

from functools import lru_cache
from collections import OrderedDict
from types import SimpleNamespace
from collections.abc import Iterable
from tqdm import tqdm_notebook as tqdm
from pytorch_transformers import BertTokenizer
from ipdb import set_trace

# Load dataset

In [3]:
class Schemas(object):

    def __init__(self, filepath):
        with open(filepath) as f:
            self.index = {}
            for schema in json.load(f):
                service_name = schema["service_name"]
                self.index[service_name] = schema

    def get_service_desc(self, service):
        return self.index[service]["description"]

    @lru_cache(maxsize=None)
    def get_slot_desc(self, service, slot):
        for item in self.index[service]["slots"]:
            if item["name"] == slot:
                return item["description"]

    @lru_cache(maxsize=None)
    def get_intent_desc(self, service, intent):
        for item in self.index[service]["intents"]:
            if item["name"] == intent:
                return item["description"]

    @lru_cache(maxsize=None)
    def get(self, service):
        result = dict(
            # service
            service_name=service,
            service_desc=self.index[service]["description"],
            
            # slots
            slot_name=[],
            slot_desc=[],
            slot_iscat=[], 
            slot_vals=[], # collected only for cat slots.. not sure if that makes sense

            # intents
            intent_name=[],
            intent_desc=[],
            intent_istrans=[],
            intent_reqslots=[],
            intent_optslots=[],
            intent_optvals=[],
        )

        for slot in self.index[service]["slots"]:
            result["slot_name"].append(slot["name"])
            result["slot_desc"].append(slot["description"])
            result["slot_iscat"].append(slot["is_categorical"])
            result["slot_vals"].append(slot["possible_values"])
        
        for intent in self.index[service]["intents"]:
            result["intent_name"].append(intent["name"])
            result["intent_desc"].append(intent["description"])
            result["intent_istrans"].append(intent["is_transactional"])
            result["intent_reqslots"].append(intent["required_slots"])
            result["intent_optslots"].append(list(intent["optional_slots"].keys()))
            result["intent_optvals"].append(list(intent["optional_slots"].values()))

        return result    

In [4]:
class Tokenizer:
    
    def __init__(self, bert):
        self.bert = bert
        
    def __call__(self, text, include_sos=True):
        tokens = self.bert.tokenize(text)
        if include_sos:
            tokens.insert(0, "[CLS]")
            tokens.append("[SEP]")
        return tokens
    
    
class TokenIndexer:
    
    def __init__(self, bert):
        self.bert = bert
        
    def __call__(self, *args, **kw):
        return self.bert.convert_tokens_to_ids(*args, **kw)

In [5]:
def label_binarize(labels, classes):
    # labels: np.array or tensor [batch, classes]
    # classes: [..] list of classes
    # weirdly,`sklearn.preprocessing.label_binarize` returns [1] or [0]
    # instead of onehot ONLY when executing in this script!
    vectors = [np.zeros(len(classes)) for _ in classes]
    for i, label in enumerate(labels):
        for j, c in enumerate(classes):
            if c == label:
                vectors[i][j] = 1
    return np.array(vectors)
    

def label_inv_binarize(vectors, classes):
    # labels: np.array or tensor [batch, classes]
    # classes: [..] list of classes
    # follows sklearn LabelBinarizer.inverse_transform()
    # given all zeros, predicts label at index 0, instead of returning none!
    # sklearn doesn't have functional API of inverse transform
    labels = []
    for each in vectors:
        index = np.argmax(each)
        labels.append(classes[index])
    return labels

In [6]:
def padded_array(array, value=0):
    # TODO: this does not do type checking.
    # expects array to have fixed _number_ of dimensions
    # resolve the shape of padded array
    shape_index = {}
    queue = [(array, 0)]
    while queue:
        subarr, dim = queue.pop(0)
        shape_index[dim] = max(shape_index.get(dim, -1), len(subarr))
        for x in subarr:
            if isinstance(x, Iterable):
                queue.append((x, dim+1))
    shape = [shape_index[k] for k in range(max(shape_index) + 1)]

    padded = np.ones(shape) * value
    queue = [(array, [])]
    while queue:
        subarr, index = queue.pop(0)
        for j, x in enumerate(subarr):
            if isinstance(x, Iterable):
                queue.append((x, index + [j]))
            else:
                padded[tuple(index + [j])] = x
    return padded

In [7]:
class DialogueDataset(D.Dataset):

    def __init__(self, filename, schemas, tokenizer, token_indexer):
        with open(filename) as f:
            self.ds = json.load(f)
        self.schemas = schemas
        self.tokenizer = tokenizer
        self.token_indexer = token_indexer
        self.dialogues = []
        self.default_padding = 0
        for dial in self.ds:
            fields = self.text_to_fields(dial)
            self.dialogues.append(fields)
        # cant' pickle these, and not required too
        self.tokenizer = None
        self.token_indexer = None
        self.schemas = None
    
    def __getitem__(self, idx):
        return self.dialogues[idx]
    
    def __len__(self):
        return len(self.dialogues)

    def field_dialogue_id(self, dialogue):
        return {"value": dialogue["dialogue_id"]}

    def field_turn_speaker(self, turnid, dialogue):
        return {"value": dialogue["turns"][turnid]["speaker"]}
    
    def field_turn_utter(self, turnid, dialogue):
        text = dialogue["turns"][turnid]["utterance"]
        tokens = self.tokenizer(text)
        token_indices = self.token_indexer(tokens)
        token_mask = [1] * len(tokens)
        return {"value": text, "tokens": tokens, "ids": token_indices, "mask": token_mask}
    
    def field_turn_sys_utter(self, turnid, dialogue):
        turn = dialogue["turns"][turnid]
        if turn["speaker"] == "SYSTEM":
            return self.field_turn_utter(turnid, dialogue)
        
    def field_turn_usr_utter(self, turnid, dialogue):
        turn = dialogue["turns"][turnid]
        if turn["speaker"] == "USER":
            return self.field_turn_utter(turnid, dialogue)

    def field_service(self, dialogue):
        return {"value": dialogue["services"]}

    def field_service_desc(self, dialogue):
        resp = dict(
            value=[],
            tokens=[],
            ids=[],
            mask=[]
        )
        for service in dialogue["services"]:
            desc = self.schemas.get_service_desc(service)
            resp["value"].append(desc)
            resp["tokens"].append(self.tokenizer(desc))
            resp["ids"].append(self.token_indexer(resp["tokens"][-1]))
            resp["mask"].append([1] * len(resp["tokens"][-1]))
        return resp

    def field_turn_service_exist(self, turnid, dialogue):
        turn = dialogue["turns"][turnid]
        services = dialogue["services"]
        # order frames by dialog.services list, to establish one to one mappings across fields
        sorted_frames = sorted(turn["frames"], key=lambda x: services.index(x["service"]))
        exists_onehot = label_binarize([f["service"] for f in sorted_frames], classes=services)
        exists = np.sum(exists_onehot, axis=0) # eg: [1, 0, 1, 0]
        return {"ids": exists, "padding": -1}
        
    def field_intent(self, dialogue):
        return {"value": [self.schemas.get(s)["intent_name"] for s in dialogue["services"]]}

    def field_intent_desc(self, dialogue):
        resp = dict(
            value=[],
            tokens=[],
            ids=[],
            mask=[]
        )
        for service in dialogue["services"]:
            s_desc = [d for d in self.schemas.get(service)["intent_desc"]]
            s_tokens = [self.tokenizer(d) for d in s_desc]
            s_ids = [self.token_indexer(d) for d in s_tokens]
            s_mask = [[1] * len(d) for d in s_tokens]
            resp["value"].append(s_desc)
            resp["tokens"].append(s_tokens)
            resp["ids"].append(s_ids)
            resp["mask"].append(s_mask)
        return resp

    def field_turn_intent_exist(self, turnid, dialogue):
        turn = dialogue["turns"][turnid]
        if turn["speaker"] == "USER":
            # maintain order of services; onehot per service
            exists_onehot = OrderedDict()
            for service in dialogue["services"]:
                exists_onehot[service] = None
            
            # fill encodings of existing services
            # this _will_ be onehot assuming each service has only one intent!
            for frame in turn["frames"]:
                service = frame["service"]
                all_intents = self.schemas.get(service)["intent_name"]
                intent = frame["state"]["active_intent"]
                encoding = label_binarize([intent], classes=all_intents)[0]
                exists_onehot[service] = encoding
            
            # fill with empty encodings for remaining
            for service in exists_onehot:
                if exists_onehot[service] is None:
                    all_intents = self.schemas.get(service)["intent_name"]
                    encoding = np.array([0] * len(all_intents))
                    exists_onehot[service] = encoding
                    
            return {"ids": list(exists_onehot.values()), "padding": -1}

    def field_turn_intent_changed(self, turnid, dialogue):
        turn = dialogue["turns"][turnid]
        if turn["speaker"] == "USER":
            # one hot encoding of which intent exists.. in current and prev user turn
            intent_exist = self.field_turn_intent_exist(turnid, dialogue)["ids"]
            prev_intent_exist = self.field_turn_intent_exist(max(0, turnid-2), dialogue)["ids"]
            assert len(intent_exist) == len(prev_intent_exist)
            
            # its an array of onehots.. one per service
            changed = []
            for curr, prev in zip(intent_exist, prev_intent_exist):
                # at the first turn, assume intent-exist => intent-changed 
                if turnid == 0:
                    changed_at_service = curr
                else:
                    changed_at_service = (curr != prev) * 1
                changed.append(changed_at_service)
            
            return {"ids": changed, "padding": -1}

    def field_slots(self, dialogue):
        slot_list = []
        for service in dialogue["services"]:
            slots = self.schemas.get(service)["slot_name"]
            slot_list.append(slots)
        return {"value": slot_list}

    def field_slots_desc(self, dialogue):
        resp = dict(
            value=[],
            tokens=[],
            ids=[],
            mask=[]
        )
        for service in dialogue["services"]:
            s_desc = [d for d in self.schemas.get(service)["slot_desc"]]
            s_tokens = [self.tokenizer(d) for d in s_desc]
            s_ids = [self.token_indexer(d) for d in s_tokens]
            s_mask = [[1] * len(d) for d in s_tokens]
            resp["value"].append(s_desc)
            resp["tokens"].append(s_tokens)
            resp["ids"].append(s_ids)
            resp["mask"].append(s_mask)
        return resp

    def field_slots_iscat(self, dialogue):
        iscat_list = []
        for service in dialogue["services"]:
            iscat = [int(i) for i in self.schemas.get(service)["slot_iscat"]]
            iscat_list.append(iscat)
        return {"ids": iscat_list, "padding": -1}

    def field_num_turns(self, dialogue):
        return {"value": len(dialogue["turns"])}

    def field_turn_num_frames(self, turnid, dialogue):
        return {"value": len(dialogue["turns"][turnid]["frames"])}

    def text_to_fields(self, dialogue):
        """
        fields = dict(
            dialogue_id=None, # [Batch,]
            num_turns=None, # [Batch,]
            num_frames=[], # [Batch, Turn] equals to number of services per turn

            # messages
            speaker=[], # [Batch, Turn]
            utter=[], # [Batch, Turn, Tokens]
            sys_utter=[], # [Batch, Turn, Tokens] only system utters
            usr_utter=[], # [Batch, Turn, Tokens] only user utters

            # services
            service=None, # [Batch, Service] all dialog services
            service_desc=None, # [Batch, Service, Tokens] service descriptions
            service_exist=[], # [Batch, Turn, Service] binarized
            
            # intents
            intent=None, # [Batch, Service, Intent]
            intent_desc=None, # [Batch, Service, Intent, Tokens]
            intent_exist=[], # [Batch, Turn, Service, Intent]
            intent_changed=[], # [Batch, Turn, Service]

            # state slots
            slots=None, # [Batch, Service, Slot]
            slots_desc=None, # [Batch, Service, Slot, Tokens]
            slots_iscat=None, # [Batch, Service, Slot]
        )
        """
        fields = {}
        
        # filter the field names in the instance
        dial_field_funcs = []
        turn_field_funcs = []
        for attr in dir(self):
            if attr.startswith("field_turn_"):
                turn_field_funcs.append(attr)
            elif attr.startswith("field_"):
                dial_field_funcs.append(attr)
                        
        # fill dialogue level fields
        for func in dial_field_funcs:
            name = func.split("field_", maxsplit=1)[-1]
            resp = getattr(self, func)(dialogue)
            resp["padding"] = resp.get("padding", self.default_padding)
            fields[name] = resp

        # fill turn level fields
        for turnid in range(len(dialogue["turns"])):
            for func in turn_field_funcs:
                name = func.split("field_turn_", maxsplit=1)[-1]
                resp = getattr(self, func)(turnid, dialogue) or {}
                if name not in fields:
                    fields[name] = {"padding": resp.get("padding", self.default_padding)}
                for k, v in resp.items():
                    if k != "padding":
                        fields[name][k] = fields[name].get(k, [])
                        fields[name][k].append(v)
                        
        # combine the turn field ids and mask.. with default padding or the one given by func resp
        for name, data in fields.items():
            padding_value = data["padding"]
            for attr in ["ids", "mask"]:
                if attr in data:
                    data[attr] = padded_array(data[attr], padding_value)
            
        return fields

In [8]:
bert_ = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer = Tokenizer(bert_)
token_indexer = TokenIndexer(bert_)

train_schemas = Schemas("../data/train/schema.json")
test_schemas = Schemas("../data/dev/schema.json")

In [9]:
# load training dataset
train_dial_sets = []
train_dial_files = sorted(glob.glob("../data/train/dialogues*.json"))

def worker(filename):
    return DialogueDataset(filename, train_schemas, tokenizer, token_indexer)

with concurrent.futures.ProcessPoolExecutor(max_workers=20) as executor:
    for ds in tqdm(executor.map(worker, train_dial_files), total=len(train_dial_files)):
        train_dial_sets.append(ds)

train_ds = D.ConcatDataset(train_dial_sets)

HBox(children=(IntProgress(value=0, max=127), HTML(value='')))




In [10]:
# load test dataset
test_dial_sets = []
test_dial_files = sorted(glob.glob("../data/dev/dialogues*.json"))

def worker(filename):
    return DialogueDataset(filename, test_schemas, tokenizer, token_indexer)

with concurrent.futures.ProcessPoolExecutor(max_workers=20) as executor:
    for ds in tqdm(executor.map(worker, test_dial_files), total=len(test_dial_files)):
        test_dial_sets.append(ds)

test_ds = D.ConcatDataset(test_dial_sets)

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




In [11]:
def dialogue_mini_batcher(dialogues):
    batch = {}
    for dial in dialogues:
        # populate the batch
        for field, data in dial.items():
            if field not in batch:
                batch[field] = {}
            for attr, val in data.items():
                if attr == "padding":
                    batch[field][attr] = val
                else:
                    batch[field][attr] = batch[field].get(attr, [])
                    batch[field][attr].append(val)

    # padding on field attributes
    for field_name, data in batch.items():
        for attr in ["ids", "mask"]:
            if attr in data:
                data[attr] = padded_array(data[attr], data["padding"])
                data[attr] = torch.tensor(data[attr], device="cpu")
    
    return batch

In [14]:
def move_to_device(obj, device):
    if type(obj) is list:
        return [move_to_device(o, device) for o in obj]
    elif type(obj) is dict:
        return {k: move_to_device(v, device) for k, v in obj.items()}
    elif type(obj) is torch.Tensor or isinstance(obj, nn.Module):
        return obj.to(device)
    return obj

# Model

In [12]:
from allennlp.modules.token_embedders import PretrainedBertEmbedder
from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy
from collections import OrderedDict

import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torch

In [81]:
class DialogState(nn.Module):

    def __init__(self, emb_size):
        super().__init__()
        self.l0 = nn.GRU(emb_size, emb_size, num_layers=2, batch_first=True)
        self.register_buffer("state_l0", None)

    def forward(self, usr_utter, reset=False):
        # get the state: reset or detach
        if reset:
            self.state_l0 = None
        if self.state_l0 is not None:
            self.state_l0 = self.state_l0.detach()

        # encode utter [batch, seq, dir * embed], [layers * dir, batch, embed]
        ht_usr, hw = self.l0(usr_utter, self.state_l0)

        # update the state
        self.state_l0 = hw
        return hw


class UserIntentPredictor(nn.Module):

    def __init__(self):
        super().__init__()
        # embedding layer for encoding text
        self.emb = PretrainedBertEmbedder("bert-base-uncased", requires_grad=False, top_layer_only=True)
        
        # encode utter and desc tokens
        self.l0 = nn.GRU(self.emb.output_dim, self.emb.output_dim, batch_first=True, num_layers=2, bidirectional=True)
        self.l1 = nn.GRU(self.emb.output_dim, self.emb.output_dim, batch_first=True, num_layers=2, bidirectional=True)
        
        # maintain dialog state
        self.l2 = DialogState(2 * self.emb.output_dim)

        # final FC
        self.l3 = nn.Linear(2 * self.emb.output_dim, 1)
        
        # metrics
        self.accuracy = CategoricalAccuracy()

        # init weights: classifier's performance changes heavily on these
        for name, param in self.named_parameters():
            if name.startswith("l0.") or name.startswith("l1."):
                print("Initializing bias/weights of ", name)
                if "weight" in name:
                    nn.init.xavier_uniform_(param)
                else:
                    param.data.fill_(0.)

    def get_metrics(self, reset):
        return {"acc": self.accuracy.get_metric(reset)}

    def get_score(self, usr_utter, usr_utter_h, intent_desc, intent_desc_h, context):
        mat = torch.einsum("be,bitx,bx->bix", usr_utter_h, intent_desc, context)
        mat = self.l3(mat).squeeze(-1)
        mat = torch.sigmoid(mat)
        return mat

    def forward(self, **batch):
        turnid = batch["turnid"]
        serviceid = batch["serviceid"]

        usr_utter = batch["usr_utter"]["ids"][:,turnid,:].long() # [batch, tokens]
        usr_utter = self.emb(usr_utter) # [batch, tokens, emb]
        usr_utter, usr_utter_h = self.l0(usr_utter) # [batch, seq, emb * dir] [layers * dir, batch, emb]
        usr_utter_h = usr_utter_h[-1] # [batch, emb]

        #  slice after embedding.. not working otherwise
        intent_desc = batch["intent_desc"]["ids"][:,serviceid,:].long() # [batch, intents, tokens]
        shape = list(intent_desc.shape)

        intent_desc = self.emb(intent_desc)  # [batch, intents, tokens, emb]
        intent_desc, intent_desc_h = self.l1(torch.flatten(intent_desc, 1, 2)) # [batch, intents * tokens, emb * dir] [layers * dir, batch, emb]
        intent_desc = intent_desc.reshape(shape + [-1]) # [batch, intents, tokens, emb * dir]
        intent_desc_h = intent_desc_h[-1] # [batch, emb]
        
        # context
        context = self.l2(usr_utter, reset=turnid==0) # [layers * dim, batch, emb * dir]
        context = context[-1] # [batch, emb * dir]

        
        # compute prediction onehot
        score = self.get_score(
            usr_utter,
            usr_utter_h,
            intent_desc,
            intent_desc_h,
            context,
        )

        output = {"score": score}

        if "intent_exist" in batch:
            # target -- onehot. squared to keep padding=-1 unchanged after mult
            target_score = (batch["intent_exist"]["ids"][:,turnid,serviceid,:] * 
                            batch["intent_changed"]["ids"][:,turnid,serviceid,:] ** 2) # [Batch, Intent]
            target_values, target_labels = torch.max(target_score, -1) # [Batch,]

            # update the accuracy counters. reset at every dialogue
            mask = (target_values != -1).float()
            self.accuracy(
                score.float(),
                target_labels.float(), 
                mask=mask,
            )
            
#             with open("switch.txt") as f:
#                 switch = f.read()
#                 if switch == "1":
#                     # print("Target", target_score.tolist())
#                     # print("Pred: ", score.tolist())
#                     print("Target", torch.argmax(target_score, -1).tolist())
#                     print("Pred: ", torch.argmax(score, -1).tolist())
#                     print("Mask: ", mask.long().tolist())
#                     print("Acc: ", self.accuracy.get_metric())
#                     print("\n\n")

            # calculate loss
            mask = (target_score != -1).float()
            output["loss"] = F.binary_cross_entropy(
                score.float(),
                target_score.float(),
                weight=mask,
            )
#             mask = (target_score != -1).float()
#             output["loss"] = F.mse_loss(
#                 score * mask,
#                 target_score * mask,
#                 reduction="sum"
#             )
#             output["loss"] /= target_score.shape[0]

        return output

In [85]:
class DialogIterator(object):
    """A simple wrapper on DataLoader"""
    
    def __init__(self, dataset, batch_size, *args, **kw):
        self.length = None
        self.dataset = dataset
        self.batch_size = batch_size
        self.iterator = D.DataLoader(dataset, batch_size, *args, **kw)
        
    def __len__(self):
        if self.length is None:
            self.length = 0
            for dial in self.dataset:
                self.length += dial["num_turns"]["value"] + len(dial["service"]["value"])
            self.length = int(self.length / self.batch_size)
        return self.length
    
    def __iter__(self):
        for batch in self.iterator:
            num_turns = batch["usr_utter"]["ids"].shape[1]
            num_services = batch["service_desc"]["ids"].shape[1]
            for turnid in range(num_turns):
                for sid in range(num_services):
                    inputs = dict(turnid=turnid, serviceid=sid)
                    inputs.update(batch)
                    yield inputs

In [53]:
class Metrics(OrderedDict):
    
    def __str__(self):
        return self.__repr__()
    
    def __repr__(self):
        formatted = []
        for k, v in self.items():
            if type(v) is float:
                v = round(v, 4)
            formatted.append((k, v))
        return ", ".join("{}: {}".format(k, v) for k, v in formatted)

In [76]:
class MyDataParallel(nn.DataParallel):
    
    def __init__(self, model):
        super(MyDataParallel, self).__init__(model)
    
    def __getattr__(self, name):
        try:
            super(MyDataParallel, self).__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

In [88]:
def train(model, optimizer, batch_size, num_epochs, train_ds, test_ds, device):
    model = move_to_device(model, device)
    model = nn.DataParallel(model)
    for epoch in range(num_epochs):
        # train
        metrics = Metrics(run="train")
        model = model.train()
        train_iter = DialogIterator(train_ds, batch_size, collate_fn=dialogue_mini_batcher)
        train_pbar = tqdm(train_iter)
        for batch in train_pbar:
            batch = move_to_device(batch, device)
            optimizer.zero_grad()
            output = model(**batch)
            output["loss"] = output["loss"].mean()
            output["loss"].backward()
            optimizer.step()
            metrics["loss"] = output["loss"].item()
            metrics["turn"] = batch["turnid"]
            metrics.update(model.module.get_metrics(reset=True))
            train_pbar.set_description(str(metrics))
        
        # test
        metrics = Metrics(run="test")
        if test_ds:
            test_iter = DialogIterator(test_ds, batch_size, collate_fn=dialogue_mini_batcher)
            test_pbar = tqdm(test_iter)
            with torch.no_grad():
                model = model.eval()
                for batch in test_pbar:
                    batch = move_to_device(batch, device)
                    output = model(**batch)
                    metrics["loss"] = output["loss"].mean().item()
                    metrics["turn"] = batch["turnid"]
                    metrics.update(model.module.get_metrics(reset=True))
                    test_pbar.set_description(str(metrics))
            
            
print("loading model")
model = UserIntentPredictor()
optim = torch.optim.Adam(model.parameters(), lr=1e-6)
train_samples = train_ds #[train_ds[i] for i in range(100)]
test_samples = test_ds

print("started training")
train(
    model=model,
    optimizer=optim,
    train_ds=train_samples,
    test_ds=test_samples,
    device="cuda",
    num_epochs=1,
    batch_size=128,
)

loading model
Initializing bias/weights of  l0.weight_ih_l0
Initializing bias/weights of  l0.weight_hh_l0
Initializing bias/weights of  l0.bias_ih_l0
Initializing bias/weights of  l0.bias_hh_l0
Initializing bias/weights of  l0.weight_ih_l0_reverse
Initializing bias/weights of  l0.weight_hh_l0_reverse
Initializing bias/weights of  l0.bias_ih_l0_reverse
Initializing bias/weights of  l0.bias_hh_l0_reverse
Initializing bias/weights of  l0.weight_ih_l1
Initializing bias/weights of  l0.weight_hh_l1
Initializing bias/weights of  l0.bias_ih_l1
Initializing bias/weights of  l0.bias_hh_l1
Initializing bias/weights of  l0.weight_ih_l1_reverse
Initializing bias/weights of  l0.weight_hh_l1_reverse
Initializing bias/weights of  l0.bias_ih_l1_reverse
Initializing bias/weights of  l0.bias_hh_l1_reverse
Initializing bias/weights of  l1.weight_ih_l0
Initializing bias/weights of  l1.weight_hh_l0
Initializing bias/weights of  l1.bias_ih_l0
Initializing bias/weights of  l1.bias_hh_l0
Initializing bias/weig

HBox(children=(IntProgress(value=0, max=2809), HTML(value='')))

HBox(children=(IntProgress(value=0, max=416), HTML(value='')))

In [91]:
torch.save(model, "../data/model-00.pt")