In [1]:
import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim

import torchdiffeq

from tensorboard_utils import Tensorboard
from tensorboard_utils import tensorboard_event_accumulator

import transformer.Constants as Constants
from transformer.Layers import EncoderLayer, DecoderLayer
from transformer.Modules import ScaledDotProductAttention
from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table
from transformer.SubLayers import PositionwiseFeedForward

import dataset

import model_process
import checkpoints
from node_transformer import NodeTransformer

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
#%matplotlib notebook  
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

print("Torch Version", torch.__version__)

%load_ext autoreload
%autoreload 2

Torch Version 1.1.0


In [2]:
seed = 1
torch.manual_seed(seed)
device = torch.device("cuda")
print("device", device)

device cuda


In [3]:
data = torch.load("/home/mandubian/datasets/multi30k/multi30k.atok.low.pt")

In [4]:
max_token_seq_len = data['settings'].max_token_seq_len
print(max_token_seq_len)

52


In [5]:
train_loader, val_loader = dataset.prepare_dataloaders(data, batch_size=128)

### Create an experiment with a name and a unique ID

In [6]:
exp_name = "transformer_multi30k"
unique_id = "2019-06-07_1000"


### Create Model

In [7]:
model = None

In [8]:
src_vocab_sz = train_loader.dataset.src_vocab_size
print("src_vocab_sz", src_vocab_sz)
tgt_vocab_sz = train_loader.dataset.tgt_vocab_size
print("tgt_vocab_sz", tgt_vocab_sz)

if model:
    del model
    
model = NodeTransformer(
    n_src_vocab=max(src_vocab_sz, tgt_vocab_sz),
    n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz),
    len_max_seq=max_token_seq_len,
    #emb_src_tgt_weight_sharing=False,
    #d_word_vec=128, d_model=128, d_inner=512,
    n_head=8, method='dopri5-ext', rtol=1e-3, atol=1e-3,
    has_node_encoder=False, has_node_decoder=False)

model = model.to(device)

src_vocab_sz 9795
tgt_vocab_sz 17989


### Create Tensorboard metrics logger

In [9]:
tb = Tensorboard(exp_name, unique_name=unique_id)

Writing TensorBoard events locally to runs/transformer_multi30k_2019-06-07_1000


### Create basic optimizer

In [11]:
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.995), eps=1e-9)


### Train

In [12]:
# Continuous space discretization
timesteps = np.linspace(0., 1, num=6)
timesteps = torch.from_numpy(timesteps).float()

EPOCHS = 50
LOG_INTERVAL = 5

#from torch import autograd
#with autograd.detect_anomaly():
model_process.train(
    exp_name, unique_id,
    model, 
    train_loader, val_loader, timesteps,
    optimizer, device,
    epochs=EPOCHS, tb=tb, log_interval=LOG_INTERVAL,
    #start_epoch=0, best_valid_accu=state["acc"]
)

Loaded model and timesteps to cuda


HBox(children=(IntProgress(value=0, description='Training', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 0 ]
Adding group train to writers (dict_keys([]))
[Training]  loss: 6.8149000610367345, ppl:  911.32543, accuracy: 27.076 %, elapse: 19205.983ms
Adding group eval to writers (dict_keys(['train']))
[Validation]  loss: 5.35197686920177,  ppl:  211.02505, accuracy: 34.644 %, elapse: 285.239ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 1', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 1 ]
[Training]  loss: 5.709207316331694, ppl:  301.63187, accuracy: 37.813 %, elapse: 19029.006ms
[Validation]  loss: 4.641788004155026,  ppl:  103.72965, accuracy: 41.214 %, elapse: 285.800ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 2', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 2 ]
[Training]  loss: 5.221279466281199, ppl:  185.17095, accuracy: 42.742 %, elapse: 19098.102ms
[Validation]  loss: 4.201269888042661,  ppl:  66.77107, accuracy: 45.018 %, elapse: 287.158ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 3', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 3 ]
[Training]  loss: 4.882612129451231, ppl:  131.97495, accuracy: 46.122 %, elapse: 19079.331ms
[Validation]  loss: 3.87762375528333,  ppl:  48.30928, accuracy: 47.912 %, elapse: 286.661ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 4', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 4 ]
[Training]  loss: 4.625005148037302, ppl:  102.00330, accuracy: 48.558 %, elapse: 19195.698ms
[Validation]  loss: 3.6569998242621247,  ppl:  38.74493, accuracy: 49.846 %, elapse: 285.892ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 5', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 5 ]
[Training]  loss: 4.422383956332246, ppl:  83.29462, accuracy: 50.618 %, elapse: 19190.359ms
[Validation]  loss: 3.4661356542486033,  ppl:  32.01279, accuracy: 51.307 %, elapse: 287.055ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 6', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 6 ]
[Training]  loss: 4.258283854245886, ppl:  70.68857, accuracy: 52.266 %, elapse: 19292.726ms
[Validation]  loss: 3.3301772224125115,  ppl:  27.94329, accuracy: 52.611 %, elapse: 287.839ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 7', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 7 ]
[Training]  loss: 4.123006861047643, ppl:  61.74462, accuracy: 53.653 %, elapse: 19341.239ms
[Validation]  loss: 3.205529012979977,  ppl:  24.66855, accuracy: 53.786 %, elapse: 286.993ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 8', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 8 ]
[Training]  loss: 4.006088463278855, ppl:  54.93158, accuracy: 54.971 %, elapse: 19157.291ms
[Validation]  loss: 3.1169744388141165,  ppl:  22.57797, accuracy: 54.510 %, elapse: 286.757ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 9', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 9 ]
[Training]  loss: 3.903140753192194, ppl:  49.55785, accuracy: 56.118 %, elapse: 19145.078ms
[Validation]  loss: 3.0255612558430047,  ppl:  20.60557, accuracy: 55.570 %, elapse: 286.950ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 10', max=227, style=ProgressStyle(description_width='in…

[ Epoch 10 ]
[Training]  loss: 3.8134498209280165, ppl:  45.30647, accuracy: 57.225 %, elapse: 19227.834ms
[Validation]  loss: 2.963360856738932,  ppl:  19.36294, accuracy: 56.114 %, elapse: 286.996ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 11', max=227, style=ProgressStyle(description_width='in…

[ Epoch 11 ]
[Training]  loss: 3.730016917819589, ppl:  41.67981, accuracy: 58.213 %, elapse: 19257.791ms
[Validation]  loss: 2.9031291136094994,  ppl:  18.23110, accuracy: 57.038 %, elapse: 287.244ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 12', max=227, style=ProgressStyle(description_width='in…

[ Epoch 12 ]
[Training]  loss: 3.6553272047205625, ppl:  38.68018, accuracy: 59.161 %, elapse: 19173.209ms
[Validation]  loss: 2.8594095249368685,  ppl:  17.45122, accuracy: 57.339 %, elapse: 287.387ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 13', max=227, style=ProgressStyle(description_width='in…

[ Epoch 13 ]
[Training]  loss: 3.5854325269291243, ppl:  36.06896, accuracy: 60.070 %, elapse: 19194.095ms
[Validation]  loss: 2.8131415971550613,  ppl:  16.66218, accuracy: 57.948 %, elapse: 287.422ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 14', max=227, style=ProgressStyle(description_width='in…

[ Epoch 14 ]
[Training]  loss: 3.521586382016125, ppl:  33.83807, accuracy: 61.031 %, elapse: 19267.316ms
[Validation]  loss: 2.7764724413348914,  ppl:  16.06226, accuracy: 58.221 %, elapse: 287.343ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 15', max=227, style=ProgressStyle(description_width='in…

[ Epoch 15 ]
[Training]  loss: 3.4651449033742487, ppl:  31.98109, accuracy: 61.772 %, elapse: 19171.697ms
[Validation]  loss: 2.7467189416886058,  ppl:  15.59139, accuracy: 58.643 %, elapse: 289.780ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 16', max=227, style=ProgressStyle(description_width='in…

[ Epoch 16 ]
[Training]  loss: 3.408083768268836, ppl:  30.20730, accuracy: 62.515 %, elapse: 19294.364ms
[Validation]  loss: 2.7245766690061073,  ppl:  15.24996, accuracy: 58.786 %, elapse: 287.600ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 17', max=227, style=ProgressStyle(description_width='in…

[ Epoch 17 ]
[Training]  loss: 3.3537013417763224, ppl:  28.60843, accuracy: 63.296 %, elapse: 19263.549ms
[Validation]  loss: 2.714487761871821,  ppl:  15.09687, accuracy: 59.123 %, elapse: 287.532ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 18', max=227, style=ProgressStyle(description_width='in…

[ Epoch 18 ]
[Training]  loss: 3.305841905589071, ppl:  27.27149, accuracy: 63.963 %, elapse: 19235.801ms
[Validation]  loss: 2.681176654622466,  ppl:  14.60227, accuracy: 59.274 %, elapse: 287.500ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 19', max=227, style=ProgressStyle(description_width='in…

[ Epoch 19 ]
[Training]  loss: 3.254038400699319, ppl:  25.89470, accuracy: 64.741 %, elapse: 19289.994ms
[Validation]  loss: 2.6499454911761675,  ppl:  14.15327, accuracy: 59.711 %, elapse: 287.331ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 20', max=227, style=ProgressStyle(description_width='in…

[ Epoch 20 ]
[Training]  loss: 3.2098861677973685, ppl:  24.77627, accuracy: 65.378 %, elapse: 19367.139ms
[Validation]  loss: 2.6287706687889534,  ppl:  13.85672, accuracy: 59.668 %, elapse: 288.026ms


HBox(children=(IntProgress(value=0, description='Epoch 21', max=227, style=ProgressStyle(description_width='in…

[ Epoch 21 ]
[Training]  loss: 3.1689672080594713, ppl:  23.78291, accuracy: 65.973 %, elapse: 19315.423ms
[Validation]  loss: 2.6248423641053713,  ppl:  13.80240, accuracy: 60.019 %, elapse: 287.227ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 22', max=227, style=ProgressStyle(description_width='in…

[ Epoch 22 ]
[Training]  loss: 3.1232531022303016, ppl:  22.72017, accuracy: 66.678 %, elapse: 19318.844ms
[Validation]  loss: 2.62673569135504,  ppl:  13.82856, accuracy: 60.090 %, elapse: 287.155ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 23', max=227, style=ProgressStyle(description_width='in…

[ Epoch 23 ]
[Training]  loss: 3.082037394471548, ppl:  21.80278, accuracy: 67.251 %, elapse: 19258.919ms
[Validation]  loss: 2.591683523370899,  ppl:  13.35223, accuracy: 60.305 %, elapse: 300.765ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 24', max=227, style=ProgressStyle(description_width='in…

[ Epoch 24 ]
[Training]  loss: 3.045673388802711, ppl:  21.02418, accuracy: 67.771 %, elapse: 19209.310ms
[Validation]  loss: 2.583979128049108,  ppl:  13.24976, accuracy: 60.671 %, elapse: 287.338ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 25', max=227, style=ProgressStyle(description_width='in…

[ Epoch 25 ]
[Training]  loss: 3.007106275681019, ppl:  20.22878, accuracy: 68.449 %, elapse: 19297.897ms
[Validation]  loss: 2.578282076219867,  ppl:  13.17449, accuracy: 60.807 %, elapse: 287.760ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 26', max=227, style=ProgressStyle(description_width='in…

[ Epoch 26 ]
[Training]  loss: 2.970538389044133, ppl:  19.50242, accuracy: 69.022 %, elapse: 19232.895ms
[Validation]  loss: 2.5717140480324074,  ppl:  13.08824, accuracy: 60.585 %, elapse: 288.249ms


HBox(children=(IntProgress(value=0, description='Epoch 27', max=227, style=ProgressStyle(description_width='in…

[ Epoch 27 ]
[Training]  loss: 2.932887938489493, ppl:  18.78179, accuracy: 69.570 %, elapse: 19022.968ms
[Validation]  loss: 2.583282560338536,  ppl:  13.24053, accuracy: 60.563 %, elapse: 291.326ms


HBox(children=(IntProgress(value=0, description='Epoch 28', max=227, style=ProgressStyle(description_width='in…

[ Epoch 28 ]
[Training]  loss: 2.9000793092342505, ppl:  18.17559, accuracy: 70.081 %, elapse: 19208.324ms
[Validation]  loss: 2.5588046949156906,  ppl:  12.92036, accuracy: 60.921 %, elapse: 288.368ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 29', max=227, style=ProgressStyle(description_width='in…

[ Epoch 29 ]
[Training]  loss: 2.8655112512750396, ppl:  17.55803, accuracy: 70.664 %, elapse: 19232.314ms
[Validation]  loss: 2.5717961628138655,  ppl:  13.08931, accuracy: 60.928 %, elapse: 288.946ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 30', max=227, style=ProgressStyle(description_width='in…

[ Epoch 30 ]
[Training]  loss: 2.8326299293565005, ppl:  16.99008, accuracy: 71.201 %, elapse: 19191.897ms
[Validation]  loss: 2.5638964060205334,  ppl:  12.98632, accuracy: 60.542 %, elapse: 286.521ms


HBox(children=(IntProgress(value=0, description='Epoch 31', max=227, style=ProgressStyle(description_width='in…

[ Epoch 31 ]
[Training]  loss: 2.8010262737999496, ppl:  16.46153, accuracy: 71.651 %, elapse: 19151.319ms
[Validation]  loss: 2.553544679116072,  ppl:  12.85258, accuracy: 60.936 %, elapse: 287.549ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 32', max=227, style=ProgressStyle(description_width='in…

[ Epoch 32 ]
[Training]  loss: 2.7683642373515585, ppl:  15.93255, accuracy: 72.264 %, elapse: 19249.395ms
[Validation]  loss: 2.558455457864326,  ppl:  12.91585, accuracy: 61.022 %, elapse: 286.093ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 33', max=227, style=ProgressStyle(description_width='in…

[ Epoch 33 ]
[Training]  loss: 2.738777834456113, ppl:  15.46807, accuracy: 72.763 %, elapse: 19324.383ms
[Validation]  loss: 2.5463302265787306,  ppl:  12.76019, accuracy: 61.029 %, elapse: 287.187ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 34', max=227, style=ProgressStyle(description_width='in…

[ Epoch 34 ]
[Training]  loss: 2.7079431309204782, ppl:  14.99839, accuracy: 73.224 %, elapse: 19169.953ms
[Validation]  loss: 2.5505116966428467,  ppl:  12.81366, accuracy: 61.065 %, elapse: 287.598ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 35', max=227, style=ProgressStyle(description_width='in…

[ Epoch 35 ]
[Training]  loss: 2.6801784852383355, ppl:  14.58770, accuracy: 73.786 %, elapse: 19254.931ms
[Validation]  loss: 2.5525227130010206,  ppl:  12.83945, accuracy: 61.151 %, elapse: 286.926ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 36', max=227, style=ProgressStyle(description_width='in…

[ Epoch 36 ]
[Training]  loss: 2.6538176551126043, ppl:  14.20818, accuracy: 74.162 %, elapse: 19206.389ms
[Validation]  loss: 2.5462488113907247,  ppl:  12.75915, accuracy: 61.165 %, elapse: 292.544ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 37', max=227, style=ProgressStyle(description_width='in…

[ Epoch 37 ]
[Training]  loss: 2.626255959568011, ppl:  13.82192, accuracy: 74.706 %, elapse: 19279.195ms
[Validation]  loss: 2.5454947370983776,  ppl:  12.74953, accuracy: 61.365 %, elapse: 288.078ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 38', max=227, style=ProgressStyle(description_width='in…

[ Epoch 38 ]
[Training]  loss: 2.597747584029436, ppl:  13.43345, accuracy: 75.161 %, elapse: 19240.255ms
[Validation]  loss: 2.554861646300102,  ppl:  12.86952, accuracy: 61.545 %, elapse: 287.290ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 39', max=227, style=ProgressStyle(description_width='in…

[ Epoch 39 ]
[Training]  loss: 2.571709822023426, ppl:  13.08818, accuracy: 75.665 %, elapse: 19283.040ms
[Validation]  loss: 2.5563447319493697,  ppl:  12.88862, accuracy: 61.129 %, elapse: 287.303ms


HBox(children=(IntProgress(value=0, description='Epoch 40', max=227, style=ProgressStyle(description_width='in…

[ Epoch 40 ]
[Training]  loss: 2.54691290225444, ppl:  12.76763, accuracy: 76.105 %, elapse: 19251.555ms
[Validation]  loss: 2.5683553348478134,  ppl:  13.04435, accuracy: 61.086 %, elapse: 289.211ms


HBox(children=(IntProgress(value=0, description='Epoch 41', max=227, style=ProgressStyle(description_width='in…

[ Epoch 41 ]
[Training]  loss: 2.5214596745709623, ppl:  12.44675, accuracy: 76.567 %, elapse: 19161.765ms
[Validation]  loss: 2.5596232192548265,  ppl:  12.93094, accuracy: 61.086 %, elapse: 287.415ms


HBox(children=(IntProgress(value=0, description='Epoch 42', max=227, style=ProgressStyle(description_width='in…

[ Epoch 42 ]
[Training]  loss: 2.498970646289876, ppl:  12.16996, accuracy: 76.955 %, elapse: 19297.633ms
[Validation]  loss: 2.565563572196925,  ppl:  13.00799, accuracy: 61.666 %, elapse: 287.214ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 43', max=227, style=ProgressStyle(description_width='in…

[ Epoch 43 ]
[Training]  loss: 2.476282393978307, ppl:  11.89695, accuracy: 77.352 %, elapse: 19248.771ms
[Validation]  loss: 2.5854675655882855,  ppl:  13.26949, accuracy: 61.265 %, elapse: 287.247ms


HBox(children=(IntProgress(value=0, description='Epoch 44', max=227, style=ProgressStyle(description_width='in…

[ Epoch 44 ]
[Training]  loss: 2.450184462506979, ppl:  11.59048, accuracy: 77.775 %, elapse: 19342.824ms
[Validation]  loss: 2.5886967840248674,  ppl:  13.31241, accuracy: 61.208 %, elapse: 287.669ms


HBox(children=(IntProgress(value=0, description='Epoch 45', max=227, style=ProgressStyle(description_width='in…

[ Epoch 45 ]
[Training]  loss: 2.4313040332251825, ppl:  11.37370, accuracy: 78.162 %, elapse: 19217.146ms
[Validation]  loss: 2.589681234441042,  ppl:  13.32552, accuracy: 61.272 %, elapse: 286.766ms


HBox(children=(IntProgress(value=0, description='Epoch 46', max=227, style=ProgressStyle(description_width='in…

[ Epoch 46 ]
[Training]  loss: 2.407496511081418, ppl:  11.10612, accuracy: 78.595 %, elapse: 19273.083ms
[Validation]  loss: 2.5865377686718696,  ppl:  13.28370, accuracy: 61.172 %, elapse: 287.521ms


HBox(children=(IntProgress(value=0, description='Epoch 47', max=227, style=ProgressStyle(description_width='in…

[ Epoch 47 ]
[Training]  loss: 2.387569536504241, ppl:  10.88700, accuracy: 78.908 %, elapse: 19187.665ms
[Validation]  loss: 2.588779196133543,  ppl:  13.31351, accuracy: 61.537 %, elapse: 287.255ms


HBox(children=(IntProgress(value=0, description='Epoch 48', max=227, style=ProgressStyle(description_width='in…

[ Epoch 48 ]
[Training]  loss: 2.3634617543730814, ppl:  10.62768, accuracy: 79.446 %, elapse: 19346.606ms
[Validation]  loss: 2.586259278008364,  ppl:  13.28000, accuracy: 61.244 %, elapse: 287.271ms


HBox(children=(IntProgress(value=0, description='Epoch 49', max=227, style=ProgressStyle(description_width='in…

[ Epoch 49 ]
[Training]  loss: 2.3470888432062367, ppl:  10.45509, accuracy: 79.760 %, elapse: 19267.039ms
[Validation]  loss: 2.578025202994036,  ppl:  13.17110, accuracy: 61.459 %, elapse: 287.232ms



### Restore best checkpoint (to restart past training)

In [None]:
state = checkpoints.restore_best_checkpoint(
    "node_transformer_multi30k", "2019-05-29_1200", "validation", model, optimizer)

print("accuracy", state["acc"])
print("loss", state["loss"])
model = model.to(device)