In [1]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim

import torchdiffeq

from tensorboard_utils import Tensorboard
from tensorboard_utils import tensorboard_event_accumulator

import transformer.Constants as Constants
from transformer.Layers import EncoderLayer, DecoderLayer
from transformer.Modules import ScaledDotProductAttention
from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table
from transformer.SubLayers import PositionwiseFeedForward

from dataset import TranslationDataset, paired_collate_fn

import model_process
import checkpoints

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib notebook  

print("Torch Version", torch.__version__)

%load_ext autoreload
%autoreload 2

Torch Version 1.1.0


In [37]:
!pip install --user --force https://github.com/chengs/tqdm/archive/colab.zip

Collecting https://github.com/chengs/tqdm/archive/colab.zip
  Downloading https://github.com/chengs/tqdm/archive/colab.zip
[K     | 808kB 1.3MB/s
Building wheels for collected packages: tqdm
  Building wheel for tqdm (setup.py) ... [?25ldone
[?25h  Stored in directory: /tmp/pip-ephem-wheel-cache-2u43pw1d/wheels/41/18/ee/d5dd158441b27965855b1bbae03fa2d8a91fe645c01b419896
Successfully built tqdm
Installing collected packages: tqdm
  Found existing installation: tqdm 4.28.1
    Uninstalling tqdm-4.28.1:
      Successfully uninstalled tqdm-4.28.1
[33m  The script tqdm is installed in '/home/mandubian/.local/bin' which is not on PATH.
Successfully installed tqdm-4.28.1
[33mYou are using pip version 19.0.3, however version 19.1.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [2]:
def prepare_dataloaders(data, batch_size=64, num_workers=2):
    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['train']['src'],
            tgt_insts=data['train']['tgt']),
        num_workers=num_workers,
        batch_size=batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['valid']['src'],
            tgt_insts=data['valid']['tgt']),
        num_workers=num_workers,
        batch_size=batch_size,
        collate_fn=paired_collate_fn,
        shuffle=False)
    return train_loader, valid_loader

In [3]:
seed = 1
torch.manual_seed(seed)
device = torch.device("cuda")
print("device", device)

device cuda


In [4]:
data = torch.load("/home/mandubian/datasets/multi30k/multi30k.atok.low.pt")

In [5]:
max_token_seq_len = data['settings'].max_token_seq_len
print(max_token_seq_len)

52


In [6]:
train_loader, val_loader = prepare_dataloaders(data, batch_size=64)

In [7]:
fst = next(iter(train_loader))
print(fst)
en = ' '.join([train_loader.dataset.src_idx2word[idx] for idx in fst[0][0].numpy()])
ge = ' '.join([train_loader.dataset.tgt_idx2word[idx] for idx in fst[2][0].numpy()])
print(en)
print(ge)

(tensor([[   2, 2428, 5671,  ...,    0,    0,    0],
        [   2,  113, 1523,  ...,    0,    0,    0],
        [   2, 5572,  963,  ...,    0,    0,    0],
        ...,
        [   2, 1611, 6919,  ...,    0,    0,    0],
        [   2,  113, 2785,  ...,    0,    0,    0],
        [   2, 5572, 7377,  ...,    0,    0,    0]]), tensor([[1, 2, 3,  ..., 0, 0, 0],
        [1, 2, 3,  ..., 0, 0, 0],
        [1, 2, 3,  ..., 0, 0, 0],
        ...,
        [1, 2, 3,  ..., 0, 0, 0],
        [1, 2, 3,  ..., 0, 0, 0],
        [1, 2, 3,  ..., 0, 0, 0]]), tensor([[    2,  8935,  7643,  ...,     0,     0,     0],
        [    2,  1156,  8610,  ...,     0,     0,     0],
        [    2,   594,  4565,  ...,     0,     0,     0],
        ...,
        [    2, 12709,  4724,  ...,     0,     0,     0],
        [    2,  1156,  1298,  ...,     0,     0,     0],
        [    2, 12709,  5824,  ...,     0,     0,     0]]), tensor([[1, 2, 3,  ..., 0, 0, 0],
        [1, 2, 3,  ..., 0, 0, 0],
        [1, 2, 3,  ...

In [11]:
node_trans.eval()

src, src_idx, tgt, tgt_idx = fst

print("fst", fst)
print("ts", ts)
node_trans(src.to(device), src_idx.to(device), tgt.to(device), tgt_idx.to(device), ts.to(device))

fst (tensor([[   2, 5572, 4113, 8034, 1523, 3968,  995, 7521, 9726, 5572, 8218,    3]]), tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12]]), tensor([[    2, 12709, 12710,  9165,  4869,   762,  4378,  7724,  6533, 14015,
             3]]), tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]]))
ts tensor([0.0000, 0.2000, 0.4000, 0.6000, 0.8000, 1.0000], device='cuda:0')


tensor([[ 0.0000, -0.1402, 15.3975,  ..., -2.3901, -1.0970,  2.5601],
        [ 0.0000, -0.5810,  1.1288,  ..., -0.9210, -1.8688,  0.7077],
        [ 0.0000, -0.8570,  0.4723,  ...,  0.4867, -0.8943,  2.2571],
        ...,
        [ 0.0000, -1.8532,  1.4003,  ..., -0.5003,  0.8397,  0.7987],
        [ 0.0000, -0.8690, -1.3879,  ...,  0.3047,  0.8819,  0.6424],
        [ 0.0000, -2.3078,  1.0934,  ..., -0.0613,  0.4002,  0.9100]],
       device='cuda:0', grad_fn=<ViewBackward>)

### Create an experiment with a name and a unique ID

In [8]:
exp_name = "node_transformer_multi30k"
unique_id = "2019-05-29_1200"


### Create Model

In [9]:
from node_transformer import NodeTransformer

src_vocab_sz = train_loader.dataset.src_vocab_size
tgt_vocab_sz = train_loader.dataset.tgt_vocab_size

model = NodeTransformer(
    n_src_vocab=src_vocab_sz,
    n_tgt_vocab=tgt_vocab_sz,
    len_max_seq=max_token_seq_len,
    emb_src_tgt_weight_sharing=False,
    #d_word_vec=128, d_model=128, d_inner=512,
    n_layers=1, n_head=8, method='dopri5-ext', rtol=1e-3, atol=1e-3)


### Create Tensorboard metrics logger

In [10]:
tb = Tensorboard(exp_name, unique_name=unique_id)

Writing TensorBoard events locally to runs/node_transformer_multi30k_2019-05-29_1200


### Create basic optimizer

In [11]:
optimizer = optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.995), eps=1e-9)


In [None]:
timesteps = np.linspace(0., 1, num=6)
timesteps = torch.from_numpy(timesteps).float()

model_process.train(
    exp_name, unique_id,
    model, 
    train_loader, val_loader, timesteps,
    optimizer, device,
    epochs=20, tb=tb, log_interval=100,
)

  0%|          | 0/453 [00:00<?, ?it/s]

Loaded model and timesteps to cuda
[ Epoch 0 ]
Adding group train to writers (dict_keys([]))


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 7.772877168508285, ppl:  2375.29559, accuracy: 14.472 %, elapse: 932072.848ms




Adding group eval to writers (dict_keys(['train']))
[Validation]  loss: 6.674679250394011,  ppl:  792.09335, accuracy: 24.493 %, elapse: 1127.259ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 1 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 6.754271093470751, ppl:  857.71433, accuracy: 28.150 %, elapse: 894210.277ms




[Validation]  loss: 6.0787044557666645,  ppl:  436.46337, accuracy: 30.074 %, elapse: 1126.177ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 2 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 6.400615447797599, ppl:  602.21556, accuracy: 31.743 %, elapse: 885503.449ms




[Validation]  loss: 5.772487434602004,  ppl:  321.33604, accuracy: 32.323 %, elapse: 1126.302ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 3 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 6.191441062070472, ppl:  488.54963, accuracy: 33.489 %, elapse: 872408.148ms




[Validation]  loss: 5.553710185437039,  ppl:  258.19373, accuracy: 33.820 %, elapse: 1126.762ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 4 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 6.031144473281045, ppl:  416.19108, accuracy: 34.885 %, elapse: 876850.305ms




[Validation]  loss: 5.373627625014551,  ppl:  215.64373, accuracy: 35.597 %, elapse: 1127.067ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 5 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.897737418914761, ppl:  364.21247, accuracy: 36.099 %, elapse: 887652.689ms




[Validation]  loss: 5.216899487869609,  ppl:  184.36168, accuracy: 36.965 %, elapse: 1123.003ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 6 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.780387622675538, ppl:  323.88471, accuracy: 37.093 %, elapse: 892534.662ms




[Validation]  loss: 5.079440900304015,  ppl:  160.68419, accuracy: 38.140 %, elapse: 1122.408ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 7 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.67492754227238, ppl:  291.46722, accuracy: 38.132 %, elapse: 889743.139ms




[Validation]  loss: 4.966722333317662,  ppl:  143.55559, accuracy: 38.964 %, elapse: 1136.502ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 8 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.580504402045974, ppl:  265.20534, accuracy: 38.956 %, elapse: 894299.257ms




[Validation]  loss: 4.86604137353701,  ppl:  129.80604, accuracy: 39.960 %, elapse: 1124.189ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 9 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.493336958695328, ppl:  243.06696, accuracy: 39.799 %, elapse: 898622.197ms




[Validation]  loss: 4.7674782930143635,  ppl:  117.62226, accuracy: 40.884 %, elapse: 1137.485ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 10 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.411757381198697, ppl:  224.02494, accuracy: 40.487 %, elapse: 902221.318ms




[Validation]  loss: 4.670859522614219,  ppl:  106.78949, accuracy: 41.672 %, elapse: 1123.696ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 11 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.334392323844883, ppl:  207.34671, accuracy: 41.252 %, elapse: 908835.289ms




[Validation]  loss: 4.59576444436609,  ppl:  99.06384, accuracy: 42.360 %, elapse: 1126.570ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 12 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.261967273697875, ppl:  192.86053, accuracy: 41.949 %, elapse: 917785.429ms




[Validation]  loss: 4.510300918929992,  ppl:  90.94918, accuracy: 42.854 %, elapse: 1122.581ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 13 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.193884961262795, ppl:  180.16714, accuracy: 42.716 %, elapse: 929965.462ms




[Validation]  loss: 4.439625889603034,  ppl:  84.74323, accuracy: 43.499 %, elapse: 1126.473ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 14 ]


  0%|          | 0/16 [00:00<?, ?it/s]

[Training]  loss: 5.131747499937813, ppl:  169.31273, accuracy: 43.332 %, elapse: 933687.749ms




[Validation]  loss: 4.372634272271697,  ppl:  79.25213, accuracy: 44.136 %, elapse: 1124.963ms
Checkpointing Validation Model...


  0%|          | 0/453 [00:00<?, ?it/s]

[ Epoch 15 ]


 25%|██▍       | 113/453 [03:52<11:22,  2.01s/it]