In [1]:
import sys
sys.path.append("../")

import numpy as np
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim

import torchdiffeq

from tensorboard_utils import Tensorboard
from tensorboard_utils import tensorboard_event_accumulator

import transformer.Constants as Constants
from transformer.Layers import EncoderLayer, DecoderLayer
from transformer.Modules import ScaledDotProductAttention
from transformer.Models import Decoder, get_attn_key_pad_mask, get_non_pad_mask, get_sinusoid_encoding_table
from transformer.SubLayers import PositionwiseFeedForward

import dataset

import model_process
import checkpoints
from node_transformer import NodeTransformer

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
#%matplotlib notebook  
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

print("Torch Version", torch.__version__)

%load_ext autoreload
%autoreload 2

Torch Version 1.1.0


In [2]:
seed = 1
torch.manual_seed(seed)
device = torch.device("cuda")
print("device", device)

device cuda


In [3]:
data = torch.load("/home/mandubian/datasets/multi30k/multi30k.atok.low.pt")

In [4]:
max_token_seq_len = data['settings'].max_token_seq_len
print(max_token_seq_len)

52


In [5]:
train_loader, val_loader = dataset.prepare_dataloaders(data, batch_size=128)

### Create an experiment with a name and a unique ID

In [6]:
exp_name = "transformer_6_layers"
unique_id = "2019-06-20_2330"


### Create Model

In [7]:
model = None

In [8]:
from odeint_ext_adams import *

src_vocab_sz = train_loader.dataset.src_vocab_size
print("src_vocab_sz", src_vocab_sz)
tgt_vocab_sz = train_loader.dataset.tgt_vocab_size
print("tgt_vocab_sz", tgt_vocab_sz)

if model:
    del model

RTOL = 0.01
LR=1e-4
N_HEAD=8
N_LAYERS=6
METHOD='dopri5-ext'
HAS_NODE_ENCODER=False
HAS_NODE_DECODER=False
HAS_SEPARATED_NODE_DECODER=False
ADD_TIME=False

checkpoint_desc = {
    #"rtol":RTOL,
    "lr":LR,
    "n_layers":N_LAYERS,
    "n_head":N_HEAD,
    #"method":METHOD,
    "has_node_encoder":HAS_NODE_ENCODER,
    "has_node_decoder":HAS_NODE_DECODER,
    "has_separated_node_encoder":HAS_SEPARATED_NODE_DECODER,
    "add_time": ADD_TIME,
}

model = NodeTransformer(
    n_src_vocab=max(src_vocab_sz, tgt_vocab_sz),
    n_tgt_vocab=max(src_vocab_sz, tgt_vocab_sz),
    len_max_seq=max_token_seq_len,
    #emb_src_tgt_weight_sharing=False,
    #d_word_vec=256, d_model=256, d_inner=1024,
    n_layers=N_LAYERS,
    n_head=N_HEAD, method=METHOD, rtol=RTOL, atol=RTOL,
    has_node_encoder=HAS_NODE_ENCODER, has_node_decoder=HAS_NODE_DECODER,
    has_separated_node_decoder=HAS_SEPARATED_NODE_DECODER, add_time=ADD_TIME,
)

model = model.to(device)

src_vocab_sz 9795
tgt_vocab_sz 17989


### Create Tensorboard metrics logger

In [9]:
tb = Tensorboard(exp_name, unique_name=unique_id, output_dir="../runs")

Writing TensorBoard events locally to ../runs/transformer_6_layers_2019-06-20_2330


### Create basic optimizer

In [10]:
optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.995), eps=1e-9)

#optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)


### Train

In [11]:
# Continuous space discretization
timesteps = np.linspace(0., 1., num=2)
timesteps = torch.from_numpy(timesteps).float()

EPOCHS = 100
LOG_INTERVAL = 5

#from torch import autograd
#with autograd.detect_anomaly():
model_process.train(
    exp_name, unique_id,
    model, 
    train_loader, val_loader, timesteps,
    optimizer, device,
    epochs=EPOCHS, tb=tb, log_interval=LOG_INTERVAL,
    start_epoch=0,
    checkpoint_desc= checkpoint_desc,
    #best_valid_accu=state["acc"],
)

Loaded model and timesteps to cuda


HBox(children=(IntProgress(value=0, description='Training', style=ProgressStyle(description_width='initial')),…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 0 ]
Adding group train to writers (dict_keys([]))
[Training]  loss: 7.2235593511588565, ppl:  1371.36154, accuracy: 16.599 %, elapse: 61566.379ms
Adding group eval to writers (dict_keys(['train']))
[Validation]  loss: 5.975024863551293,  ppl:  393.47788, accuracy: 27.932 %, elapse: 886.642ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 1', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 1 ]
[Training]  loss: 6.131624574610461, ppl:  460.18316, accuracy: 30.963 %, elapse: 61741.634ms
[Validation]  loss: 5.2839826120245,  ppl:  197.15350, accuracy: 33.047 %, elapse: 886.225ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 2', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 2 ]
[Training]  loss: 5.611737301748947, ppl:  273.61918, accuracy: 35.105 %, elapse: 61829.323ms
[Validation]  loss: 4.799503687417168,  ppl:  121.45013, accuracy: 36.256 %, elapse: 889.352ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 3', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 3 ]
[Training]  loss: 5.249176728696815, ppl:  190.40945, accuracy: 38.443 %, elapse: 61881.548ms
[Validation]  loss: 4.434336158571531,  ppl:  84.29615, accuracy: 39.802 %, elapse: 885.400ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 4', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 4 ]
[Training]  loss: 4.970103780322496, ppl:  144.04184, accuracy: 41.545 %, elapse: 61515.407ms
[Validation]  loss: 4.145467690815513,  ppl:  63.14715, accuracy: 42.919 %, elapse: 887.919ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 5', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 5 ]
[Training]  loss: 4.741854278772327, ppl:  114.64659, accuracy: 44.069 %, elapse: 61628.993ms
[Validation]  loss: 3.9454950589113653,  ppl:  51.70193, accuracy: 44.588 %, elapse: 901.415ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 6', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 6 ]
[Training]  loss: 4.5556971131717985, ppl:  95.17308, accuracy: 46.124 %, elapse: 61227.118ms
[Validation]  loss: 3.804431938511892,  ppl:  44.89974, accuracy: 46.056 %, elapse: 888.103ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 7', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 7 ]
[Training]  loss: 4.403000811119757, ppl:  81.69565, accuracy: 47.906 %, elapse: 61592.751ms
[Validation]  loss: 3.6607101880954938,  ppl:  38.88895, accuracy: 47.181 %, elapse: 892.708ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 8', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 8 ]
[Training]  loss: 4.269500272654117, ppl:  71.48590, accuracy: 49.554 %, elapse: 61653.779ms
[Validation]  loss: 3.5257453888396375,  ppl:  33.97909, accuracy: 48.241 %, elapse: 899.850ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 9', max=227, style=ProgressStyle(description_width='ini…

[ Epoch 9 ]
[Training]  loss: 4.153875382076213, ppl:  63.68031, accuracy: 51.029 %, elapse: 61861.103ms
[Validation]  loss: 3.4440650549152876,  ppl:  31.31399, accuracy: 48.986 %, elapse: 888.685ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 10', max=227, style=ProgressStyle(description_width='in…

[ Epoch 10 ]
[Training]  loss: 4.049625650546544, ppl:  57.37597, accuracy: 52.437 %, elapse: 61895.532ms
[Validation]  loss: 3.3790054523515294,  ppl:  29.34158, accuracy: 49.953 %, elapse: 887.341ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 11', max=227, style=ProgressStyle(description_width='in…

[ Epoch 11 ]
[Training]  loss: 3.954736531867953, ppl:  52.18194, accuracy: 53.670 %, elapse: 62236.721ms
[Validation]  loss: 3.3172995258995095,  ppl:  27.58576, accuracy: 50.340 %, elapse: 909.275ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 12', max=227, style=ProgressStyle(description_width='in…

[ Epoch 12 ]
[Training]  loss: 3.8644316215776335, ppl:  47.67617, accuracy: 54.984 %, elapse: 61240.938ms
[Validation]  loss: 3.268442962425675,  ppl:  26.27040, accuracy: 51.035 %, elapse: 902.026ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 13', max=227, style=ProgressStyle(description_width='in…

[ Epoch 13 ]
[Training]  loss: 3.7836560495446614, ppl:  43.97653, accuracy: 56.116 %, elapse: 61776.531ms
[Validation]  loss: 3.221307188966348,  ppl:  25.06086, accuracy: 51.408 %, elapse: 886.904ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 14', max=227, style=ProgressStyle(description_width='in…

[ Epoch 14 ]
[Training]  loss: 3.7020620794527606, ppl:  40.53080, accuracy: 57.324 %, elapse: 61467.325ms
[Validation]  loss: 3.178625622078498,  ppl:  24.01373, accuracy: 51.873 %, elapse: 889.173ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 15', max=227, style=ProgressStyle(description_width='in…

[ Epoch 15 ]
[Training]  loss: 3.626175195168419, ppl:  37.56885, accuracy: 58.479 %, elapse: 61280.264ms
[Validation]  loss: 3.181361767049932,  ppl:  24.07952, accuracy: 52.031 %, elapse: 889.200ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 16', max=227, style=ProgressStyle(description_width='in…

[ Epoch 16 ]
[Training]  loss: 3.5552920700777126, ppl:  34.99804, accuracy: 59.516 %, elapse: 62018.015ms
[Validation]  loss: 3.129823591914267,  ppl:  22.86994, accuracy: 52.583 %, elapse: 886.944ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 17', max=227, style=ProgressStyle(description_width='in…

[ Epoch 17 ]
[Training]  loss: 3.485296185512063, ppl:  32.63209, accuracy: 60.653 %, elapse: 61636.351ms
[Validation]  loss: 3.1065540643021525,  ppl:  22.34392, accuracy: 52.797 %, elapse: 888.376ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 18', max=227, style=ProgressStyle(description_width='in…

[ Epoch 18 ]
[Training]  loss: 3.417452145880192, ppl:  30.49163, accuracy: 61.772 %, elapse: 61733.647ms
[Validation]  loss: 3.0798077321307042,  ppl:  21.75422, accuracy: 53.299 %, elapse: 888.237ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 19', max=227, style=ProgressStyle(description_width='in…

[ Epoch 19 ]
[Training]  loss: 3.353798248410696, ppl:  28.61120, accuracy: 62.747 %, elapse: 61543.150ms
[Validation]  loss: 3.0865037509402535,  ppl:  21.90037, accuracy: 53.414 %, elapse: 887.147ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 20', max=227, style=ProgressStyle(description_width='in…

[ Epoch 20 ]
[Training]  loss: 3.291516870416551, ppl:  26.88361, accuracy: 63.770 %, elapse: 61766.417ms
[Validation]  loss: 3.057862569903378,  ppl:  21.28202, accuracy: 53.156 %, elapse: 901.946ms


HBox(children=(IntProgress(value=0, description='Epoch 21', max=227, style=ProgressStyle(description_width='in…

[ Epoch 21 ]
[Training]  loss: 3.2308629888999056, ppl:  25.30148, accuracy: 64.776 %, elapse: 61564.438ms
[Validation]  loss: 3.025019910429651,  ppl:  20.59441, accuracy: 53.521 %, elapse: 907.845ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 22', max=227, style=ProgressStyle(description_width='in…

[ Epoch 22 ]
[Training]  loss: 3.171966132266974, ppl:  23.85434, accuracy: 65.763 %, elapse: 61511.059ms
[Validation]  loss: 3.057421511211405,  ppl:  21.27264, accuracy: 53.679 %, elapse: 900.449ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 23', max=227, style=ProgressStyle(description_width='in…

[ Epoch 23 ]
[Training]  loss: 3.1163941959925805, ppl:  22.56487, accuracy: 66.723 %, elapse: 61124.784ms
[Validation]  loss: 3.0611816860985743,  ppl:  21.35277, accuracy: 53.693 %, elapse: 886.142ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 24', max=227, style=ProgressStyle(description_width='in…

[ Epoch 24 ]
[Training]  loss: 3.0595907375453293, ppl:  21.31883, accuracy: 67.700 %, elapse: 61677.086ms
[Validation]  loss: 3.033094198299484,  ppl:  20.76137, accuracy: 54.366 %, elapse: 887.021ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 25', max=227, style=ProgressStyle(description_width='in…

[ Epoch 25 ]
[Training]  loss: 3.0079640168011736, ppl:  20.24614, accuracy: 68.573 %, elapse: 61956.717ms
[Validation]  loss: 3.0568231139520203,  ppl:  21.25991, accuracy: 54.144 %, elapse: 887.208ms


HBox(children=(IntProgress(value=0, description='Epoch 26', max=227, style=ProgressStyle(description_width='in…

[ Epoch 26 ]
[Training]  loss: 2.9573720794054266, ppl:  19.24732, accuracy: 69.459 %, elapse: 61530.807ms
[Validation]  loss: 3.051384579393223,  ppl:  21.14460, accuracy: 53.765 %, elapse: 887.437ms


HBox(children=(IntProgress(value=0, description='Epoch 27', max=227, style=ProgressStyle(description_width='in…

[ Epoch 27 ]
[Training]  loss: 2.907264680751677, ppl:  18.30666, accuracy: 70.346 %, elapse: 61767.996ms
[Validation]  loss: 3.023755570163067,  ppl:  20.56839, accuracy: 54.144 %, elapse: 888.344ms


HBox(children=(IntProgress(value=0, description='Epoch 28', max=227, style=ProgressStyle(description_width='in…

[ Epoch 28 ]
[Training]  loss: 2.8561339292737955, ppl:  17.39415, accuracy: 71.267 %, elapse: 61525.572ms
[Validation]  loss: 3.040335550205065,  ppl:  20.91226, accuracy: 54.266 %, elapse: 888.879ms


HBox(children=(IntProgress(value=0, description='Epoch 29', max=227, style=ProgressStyle(description_width='in…

[ Epoch 29 ]
[Training]  loss: 2.811232707694427, ppl:  16.63041, accuracy: 72.050 %, elapse: 61771.679ms
[Validation]  loss: 3.0446501878828176,  ppl:  21.00268, accuracy: 54.417 %, elapse: 886.931ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 30', max=227, style=ProgressStyle(description_width='in…

[ Epoch 30 ]
[Training]  loss: 2.766295042615599, ppl:  15.89962, accuracy: 72.803 %, elapse: 61931.606ms
[Validation]  loss: 3.048543635322641,  ppl:  21.08462, accuracy: 54.660 %, elapse: 903.055ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 31', max=227, style=ProgressStyle(description_width='in…

[ Epoch 31 ]
[Training]  loss: 2.7213144256506467, ppl:  15.20029, accuracy: 73.659 %, elapse: 61374.191ms
[Validation]  loss: 3.078574803610126,  ppl:  21.72741, accuracy: 54.696 %, elapse: 894.747ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 32', max=227, style=ProgressStyle(description_width='in…

[ Epoch 32 ]
[Training]  loss: 2.676646051568525, ppl:  14.53626, accuracy: 74.516 %, elapse: 61708.758ms
[Validation]  loss: 3.091830735323089,  ppl:  22.01735, accuracy: 54.273 %, elapse: 889.571ms


HBox(children=(IntProgress(value=0, description='Epoch 33', max=227, style=ProgressStyle(description_width='in…

[ Epoch 33 ]
[Training]  loss: 2.635380696009296, ppl:  13.94862, accuracy: 75.226 %, elapse: 61801.278ms
[Validation]  loss: 3.065088425813991,  ppl:  21.43636, accuracy: 54.667 %, elapse: 890.943ms


HBox(children=(IntProgress(value=0, description='Epoch 34', max=227, style=ProgressStyle(description_width='in…

[ Epoch 34 ]
[Training]  loss: 2.5955309207385153, ppl:  13.40370, accuracy: 75.868 %, elapse: 61728.124ms
[Validation]  loss: 3.075884901765886,  ppl:  21.66905, accuracy: 54.359 %, elapse: 905.097ms


HBox(children=(IntProgress(value=0, description='Epoch 35', max=227, style=ProgressStyle(description_width='in…

[ Epoch 35 ]
[Training]  loss: 2.555540369113705, ppl:  12.87826, accuracy: 76.631 %, elapse: 61893.173ms
[Validation]  loss: 3.069434929973494,  ppl:  21.52973, accuracy: 54.782 %, elapse: 894.484ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 36', max=227, style=ProgressStyle(description_width='in…

[ Epoch 36 ]
[Training]  loss: 2.5160290404708068, ppl:  12.37934, accuracy: 77.334 %, elapse: 61692.509ms
[Validation]  loss: 3.0907347522207895,  ppl:  21.99323, accuracy: 55.147 %, elapse: 888.752ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 37', max=227, style=ProgressStyle(description_width='in…

[ Epoch 37 ]
[Training]  loss: 2.480366587020254, ppl:  11.94564, accuracy: 77.976 %, elapse: 61651.067ms
[Validation]  loss: 3.1059780190535675,  ppl:  22.33105, accuracy: 54.653 %, elapse: 885.491ms


HBox(children=(IntProgress(value=0, description='Epoch 38', max=227, style=ProgressStyle(description_width='in…

[ Epoch 38 ]
[Training]  loss: 2.4440151729382067, ppl:  11.51920, accuracy: 78.676 %, elapse: 61409.153ms
[Validation]  loss: 3.130931957820671,  ppl:  22.89531, accuracy: 54.603 %, elapse: 890.444ms


HBox(children=(IntProgress(value=0, description='Epoch 39', max=227, style=ProgressStyle(description_width='in…

[ Epoch 39 ]
[Training]  loss: 2.4066043645102333, ppl:  11.09622, accuracy: 79.378 %, elapse: 61310.510ms
[Validation]  loss: 3.1300716677524356,  ppl:  22.87562, accuracy: 54.653 %, elapse: 887.732ms


HBox(children=(IntProgress(value=0, description='Epoch 40', max=227, style=ProgressStyle(description_width='in…

[ Epoch 40 ]
[Training]  loss: 2.37166661335434, ppl:  10.71524, accuracy: 79.958 %, elapse: 61641.054ms
[Validation]  loss: 3.128873928782506,  ppl:  22.84824, accuracy: 54.904 %, elapse: 903.072ms


HBox(children=(IntProgress(value=0, description='Epoch 41', max=227, style=ProgressStyle(description_width='in…

[ Epoch 41 ]
[Training]  loss: 2.339608342409426, ppl:  10.37717, accuracy: 80.619 %, elapse: 61573.939ms
[Validation]  loss: 3.126169720252167,  ppl:  22.78653, accuracy: 55.298 %, elapse: 887.470ms
Checkpointing Validation Model...


HBox(children=(IntProgress(value=0, description='Epoch 42', max=227, style=ProgressStyle(description_width='in…

[ Epoch 42 ]
[Training]  loss: 2.3057704804396466, ppl:  10.03190, accuracy: 81.219 %, elapse: 61534.669ms
[Validation]  loss: 3.154124844970091,  ppl:  23.43252, accuracy: 54.925 %, elapse: 900.562ms


HBox(children=(IntProgress(value=0, description='Epoch 43', max=227, style=ProgressStyle(description_width='in…

[ Epoch 43 ]
[Training]  loss: 2.2770802726710504, ppl:  9.74818, accuracy: 81.706 %, elapse: 61848.229ms
[Validation]  loss: 3.152339972195358,  ppl:  23.39073, accuracy: 55.033 %, elapse: 885.443ms


HBox(children=(IntProgress(value=0, description='Epoch 44', max=227, style=ProgressStyle(description_width='in…

[ Epoch 44 ]


KeyboardInterrupt: 

In [19]:
model.encoder.encoder.rtol = 0.001
model.encoder.encoder.atol = 0.001
#model.decoder.decoder.rtol = 0.001
#model.decoder.decoder.atol = 0.001

### Restore best checkpoint (to restart past training)

In [11]:
state = checkpoints.restore_best_checkpoint(
    exp_name, unique_id, "validation", model, optimizer)

print("accuracy", state["acc"])
print("loss", state["loss"])
model = model.to(device)

Extracting state from checkpoints/node_transformer_separated_dopri5_multi30k_encoder_only_add_time_2019-06-20_1230_validation_best.pth
Loading model state_dict from state found in checkpoints/node_transformer_separated_dopri5_multi30k_encoder_only_add_time_2019-06-20_1230_validation_best.pth
Loading optimizer state_dict from state found in checkpoints/node_transformer_separated_dopri5_multi30k_encoder_only_add_time_2019-06-20_1230_validation_best.pth
accuracy 0.5886524822695035
loss 2.558568739534082
