In [1]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim

import torchdiffeq

from tensorboard_utils import Tensorboard
from tensorboard_utils import tensorboard_event_accumulator

import transformer.Constants as Constants
from transformer.Layers import EncoderLayer, DecoderLayer
from transformer.Modules import ScaledDotProductAttention
from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table
from transformer.SubLayers import PositionwiseFeedForward

from dataset import TranslationDataset, paired_collate_fn

import model_process
import checkpoints
from node_transformer import NodeTransformer

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
#%matplotlib notebook  
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

print("Torch Version", torch.__version__)

%load_ext autoreload
%autoreload 2



Torch Version 1.1.0


In [None]:
#!pip install --user --force https://github.com/chengs/tqdm/archive/colab.zip

In [2]:
def prepare_dataloaders(data, batch_size=64, num_workers=2):
    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['train']['src'],
            tgt_insts=data['train']['tgt']),
        num_workers=num_workers,
        batch_size=batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['valid']['src'],
            tgt_insts=data['valid']['tgt']),
        num_workers=num_workers,
        batch_size=batch_size,
        collate_fn=paired_collate_fn,
        shuffle=False)
    return train_loader, valid_loader

In [3]:
seed = 1
torch.manual_seed(seed)
device = torch.device("cuda")
print("device", device)

device cuda


In [4]:
data = torch.load("/home/mandubian/datasets/multi30k/multi30k.atok.low.pt")

In [5]:
max_token_seq_len = data['settings'].max_token_seq_len
print(max_token_seq_len)

52


In [6]:
train_loader, val_loader = prepare_dataloaders(data, batch_size=64)

In [None]:
fst = next(iter(train_loader))
print(fst)
en = ' '.join([train_loader.dataset.src_idx2word[idx] for idx in fst[0][0].numpy()])
ge = ' '.join([train_loader.dataset.tgt_idx2word[idx] for idx in fst[2][0].numpy()])
print(en)
print(ge)

In [None]:
node_trans.eval()

src, src_idx, tgt, tgt_idx = fst

print("fst", fst)
print("ts", ts)
node_trans(src.to(device), src_idx.to(device), tgt.to(device), tgt_idx.to(device), ts.to(device))

### Create an experiment with a name and a unique ID

In [12]:
exp_name = "node_transformer_multi30k"
unique_id = "2019-06-04_0030"


### Create Model

In [25]:

src_vocab_sz = train_loader.dataset.src_vocab_size
print("src_vocab_sz", src_vocab_sz)
tgt_vocab_sz = train_loader.dataset.tgt_vocab_size
print("tgt_vocab_sz", tgt_vocab_sz)

model = NodeTransformer(
    n_src_vocab=max(src_vocab_sz, tgt_vocab_sz),
    n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz),
    len_max_seq=max_token_seq_len,
    #emb_src_tgt_weight_sharing=False,
    #d_word_vec=128, d_model=128, d_inner=512,
    n_layers=1, n_head=8, method='dopri5-ext', rtol=1e-3, atol=1e-3)

model = model.to(device)

src_vocab_sz 9795
tgt_vocab_sz 17989


### Create Tensorboard metrics logger

In [24]:
tb = Tensorboard(exp_name, unique_name=unique_id)

Writing TensorBoard events locally to runs/node_transformer_multi30k_2019-06-04_0030


### Create basic optimizer

In [26]:
optimizer = optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.995), eps=1e-9)


### Restore best checkpoint (to restart past training)

In [None]:
state = checkpoints.restore_best_checkpoint(
    "node_transformer_multi30k", "2019-05-29_1200", "validation", model, optimizer)

print("accuracy", state["acc"])
print("loss", state["loss"])
model = model.to(device)

In [27]:
# Continuous space discretization
timesteps = np.linspace(0., 1, num=6)
timesteps = torch.from_numpy(timesteps).float()

EPOCHS = 50
LOG_INTERVAL = 25

model_process.train(
    exp_name, unique_id,
    model, 
    train_loader, val_loader, timesteps,
    optimizer, device,
    epochs=EPOCHS, tb=tb, log_interval=LOG_INTERVAL,
    #start_epoch=0, best_valid_accu=state["acc"]
)

Loaded model and timesteps to cuda


HBox(children=(IntProgress(value=0, description='Training', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=453, style=ProgressStyle(description_width='ini…

[ Epoch 0 ]
Adding group train to writers (dict_keys([]))
[Training]  loss: 7.673015976626173, ppl:  2149.55469, accuracy: 15.012 %, elapse: 935993.343ms
Adding group eval to writers (dict_keys(['train']))
[Validation]  loss: 6.588887055954044,  ppl:  726.97134, accuracy: 23.884 %, elapse: 1111.448ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 1', max=453, style=ProgressStyle(description_width='ini…

[ Epoch 1 ]
[Training]  loss: 6.704420181324197, ppl:  816.00475, accuracy: 28.164 %, elapse: 909428.343ms
[Validation]  loss: 5.997417363321826,  ppl:  402.38823, accuracy: 29.687 %, elapse: 1372.342ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 2', max=453, style=ProgressStyle(description_width='ini…

[ Epoch 2 ]
[Training]  loss: 6.36515670140328, ppl:  581.23590, accuracy: 31.653 %, elapse: 903455.974ms
[Validation]  loss: 5.68437115573398,  ppl:  294.23276, accuracy: 32.674 %, elapse: 1374.231ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 3', max=453, style=ProgressStyle(description_width='ini…

[ Epoch 3 ]


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[Training]  loss: 5.323315736725406, ppl:  205.06269, accuracy: 41.436 %, elapse: 930754.356ms
[Validation]  loss: 4.569684212610368,  ppl:  96.51363, accuracy: 42.374 %, elapse: 1377.306ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 12', max=453, style=ProgressStyle(description_width='in…

[ Epoch 12 ]
[Training]  loss: 5.257532256891236, ppl:  192.00708, accuracy: 42.056 %, elapse: 932487.923ms
[Validation]  loss: 4.496541437339933,  ppl:  89.70634, accuracy: 43.169 %, elapse: 1391.434ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 13', max=453, style=ProgressStyle(description_width='in…

[ Epoch 13 ]


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[Training]  loss: 4.784259053178228, ppl:  119.61270, accuracy: 46.848 %, elapse: 966451.126ms
[Validation]  loss: 4.000704927852326,  ppl:  54.63665, accuracy: 47.647 %, elapse: 1389.359ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 22', max=453, style=ProgressStyle(description_width='in…

[ Epoch 22 ]
[Training]  loss: 4.7428512830280445, ppl:  114.76095, accuracy: 47.191 %, elapse: 978156.048ms
[Validation]  loss: 3.950902076785139,  ppl:  51.98224, accuracy: 48.198 %, elapse: 1398.504ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 23', max=453, style=ProgressStyle(description_width='in…

[ Epoch 23 ]


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[Training]  loss: 4.49069979210693, ppl:  89.18383, accuracy: 49.743 %, elapse: 1001095.063ms
[Validation]  loss: 3.6967943438708897,  ppl:  40.31785, accuracy: 50.419 %, elapse: 1372.895ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 30', max=453, style=ProgressStyle(description_width='in…

[ Epoch 30 ]
[Training]  loss: 4.459186214825103, ppl:  86.41716, accuracy: 50.061 %, elapse: 1000625.380ms
[Validation]  loss: 3.669616467478419,  ppl:  39.23685, accuracy: 50.519 %, elapse: 1375.330ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 31', max=453, style=ProgressStyle(description_width='in…

[ Epoch 31 ]
[Training]  loss: 4.429238127960041, ppl:  83.86750, accuracy: 50.366 %, elapse: 1006574.517ms
[Validation]  loss: 3.6451685621451664,  ppl:  38.28923, accuracy: 50.713 %, elapse: 1388.275ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 32', max=453, style=ProgressStyle(description_width='in…

[ Epoch 32 ]
[Training]  loss: 4.4001569067058695, ppl:  81.46365, accuracy: 50.698 %, elapse: 1006354.985ms
[Validation]  loss: 3.613294664704447,  ppl:  37.08804, accuracy: 51.007 %, elapse: 1376.540ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 33', max=453, style=ProgressStyle(description_width='in…

[ Epoch 33 ]
[Training]  loss: 4.373145470002462, ppl:  79.29265, accuracy: 50.941 %, elapse: 1014332.490ms
[Validation]  loss: 3.5885333589944213,  ppl:  36.18097, accuracy: 51.064 %, elapse: 1376.040ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 34', max=453, style=ProgressStyle(description_width='in…

[ Epoch 34 ]


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[Training]  loss: 4.319252679539181, ppl:  75.13246, accuracy: 51.556 %, elapse: 1011321.536ms
[Validation]  loss: 3.5412620218138837,  ppl:  34.51044, accuracy: 51.444 %, elapse: 1375.584ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 36', max=453, style=ProgressStyle(description_width='in…

[ Epoch 36 ]
[Training]  loss: 4.295056058796522, ppl:  73.33633, accuracy: 51.783 %, elapse: 1021624.530ms
[Validation]  loss: 3.5125179655598537,  ppl:  33.53260, accuracy: 51.544 %, elapse: 1377.074ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 37', max=453, style=ProgressStyle(description_width='in…

[ Epoch 37 ]
[Training]  loss: 4.270003449667737, ppl:  71.52188, accuracy: 52.045 %, elapse: 1017845.837ms
[Validation]  loss: 3.4886752260526364,  ppl:  32.74254, accuracy: 51.938 %, elapse: 1395.465ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 38', max=453, style=ProgressStyle(description_width='in…

[ Epoch 38 ]
[Training]  loss: 4.2460895012291315, ppl:  69.83180, accuracy: 52.344 %, elapse: 1023487.059ms
[Validation]  loss: 3.464681356869896,  ppl:  31.96627, accuracy: 52.045 %, elapse: 1393.364ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 39', max=453, style=ProgressStyle(description_width='in…

[ Epoch 39 ]
[Training]  loss: 4.2243275487367, ppl:  68.32854, accuracy: 52.541 %, elapse: 1029272.552ms
[Validation]  loss: 3.445078241132513,  ppl:  31.34574, accuracy: 52.325 %, elapse: 1381.492ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 40', max=453, style=ProgressStyle(description_width='in…

[ Epoch 40 ]
[Training]  loss: 4.202677105282119, ppl:  66.86510, accuracy: 52.787 %, elapse: 1034462.481ms
[Validation]  loss: 3.4257358114052763,  ppl:  30.74526, accuracy: 52.353 %, elapse: 1379.078ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 41', max=453, style=ProgressStyle(description_width='in…

[ Epoch 41 ]
[Training]  loss: 4.180017964632561, ppl:  65.36703, accuracy: 53.058 %, elapse: 1032145.901ms
[Validation]  loss: 3.398947835933582,  ppl:  29.93259, accuracy: 52.676 %, elapse: 1394.871ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 42', max=453, style=ProgressStyle(description_width='in…

[ Epoch 42 ]
[Training]  loss: 4.160548458897218, ppl:  64.10667, accuracy: 53.294 %, elapse: 1036668.451ms
[Validation]  loss: 3.3789322748764237,  ppl:  29.33943, accuracy: 52.776 %, elapse: 1381.764ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 43', max=453, style=ProgressStyle(description_width='in…

[ Epoch 43 ]
[Training]  loss: 4.138735805230966, ppl:  62.72348, accuracy: 53.517 %, elapse: 1046645.635ms
[Validation]  loss: 3.3645643393711047,  ppl:  28.92089, accuracy: 52.826 %, elapse: 1399.013ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 44', max=453, style=ProgressStyle(description_width='in…

[ Epoch 44 ]
[Training]  loss: 4.1195876148813175, ppl:  61.53386, accuracy: 53.752 %, elapse: 1046423.693ms
[Validation]  loss: 3.3419581138009082,  ppl:  28.27444, accuracy: 53.163 %, elapse: 1380.700ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 45', max=453, style=ProgressStyle(description_width='in…

[ Epoch 45 ]
[Training]  loss: 4.100825609024572, ppl:  60.39013, accuracy: 53.946 %, elapse: 1058879.628ms
[Validation]  loss: 3.329712604911034,  ppl:  27.93031, accuracy: 53.163 %, elapse: 1381.177ms


HBox(children=(IntProgress(value=0, description='Epoch 46', max=453, style=ProgressStyle(description_width='in…

[ Epoch 46 ]
[Training]  loss: 4.081136880947215, ppl:  59.21275, accuracy: 54.215 %, elapse: 1063959.605ms
[Validation]  loss: 3.308474959031807,  ppl:  27.34339, accuracy: 53.349 %, elapse: 1379.291ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 47', max=453, style=ProgressStyle(description_width='in…

[ Epoch 47 ]
[Training]  loss: 4.063845857741888, ppl:  58.19770, accuracy: 54.346 %, elapse: 1065927.727ms
[Validation]  loss: 3.2919823931916685,  ppl:  26.89613, accuracy: 53.428 %, elapse: 1393.012ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 48', max=453, style=ProgressStyle(description_width='in…

[ Epoch 48 ]


Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, b

KeyboardInterrupt: 