## Annotated Transformer

* http://nlp.seas.harvard.edu/2018/04/03/attention.html

In [1]:
import sys, os

sys.path.append("../python/")
from transformer import *

ctx = mx.gpu()

### Synthetic data

In [2]:
V = 11
batch = 30
n_batch = 20
in_seq_len = 10
dat = data_gen(V, batch, n_batch, in_seq_len, ctx = ctx)

### Architecture

In [3]:
# Task: copy 10 input integers
out_seq_len = 10
dropout = .1
data = data_gen(V, batch, n_batch, in_seq_len, ctx = ctx)
model = make_model(V, V, in_seq_len, out_seq_len, N = 2, dropout = .1, d_model = 128, ctx = ctx)
model.collect_params().initialize(mx.init.Xavier(), ctx = ctx)
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': 1e-4, 'beta1': 0.9, 'beta2': 0.98 , 'epsilon': 1e-9})
loss = gluon.loss.KLDivLoss(from_logits = False)

In [4]:
for i,d in enumerate(dat):
    dd =  d
    if i == 0:
        break

In [5]:
print('src :')
print('{}'.format(dd.src[:2].asnumpy()))
print('trg :')
print('{}'.format(dd.trg[:2].asnumpy()))
print('trg_y :')
print('{}'.format(dd.trg_y[:2].asnumpy()))

src :
[[ 1.  9.  4.  1.  9.  3.  6.  4.  2. 10.]
 [ 1.  9.  9. 10.  3.  6.  3.  9.  9.  6.]]
trg :
[[ 1.  9.  4.  1.  9.  3.  6.  4.  2.]
 [ 1.  9.  9. 10.  3.  6.  3.  9.  9.]]
trg_y :
[[ 9.  4.  1.  9.  3.  6.  4.  2. 10.]
 [ 9.  9. 10.  3.  6.  3.  9.  9.  6.]]


In [15]:
for epoch in range(20):
    run_epoch(epoch, data_gen(V, batch, n_batch, in_seq_len, ctx = ctx), model, trainer, loss, ctx = ctx)

2019-05-23 00:12:48,051 - transformer - INFO - Epoch Step: 0 Loss: 0.162606 Tokens per Sec: 2094.781793
2019-05-23 00:12:49,477 - transformer - INFO - Epoch Step: 1 Loss: 0.160456 Tokens per Sec: 3997.028441
2019-05-23 00:12:50,905 - transformer - INFO - Epoch Step: 2 Loss: 0.161431 Tokens per Sec: 2328.966069
2019-05-23 00:12:52,236 - transformer - INFO - Epoch Step: 3 Loss: 0.158512 Tokens per Sec: 4158.784896
2019-05-23 00:12:53,450 - transformer - INFO - Epoch Step: 4 Loss: 0.154697 Tokens per Sec: 4445.752466
2019-05-23 00:12:54,689 - transformer - INFO - Epoch Step: 5 Loss: 0.153123 Tokens per Sec: 4031.247504
2019-05-23 00:12:56,059 - transformer - INFO - Epoch Step: 6 Loss: 0.153150 Tokens per Sec: 4701.958821
2019-05-23 00:12:57,404 - transformer - INFO - Epoch Step: 7 Loss: 0.151267 Tokens per Sec: 3911.732066
2019-05-23 00:12:58,774 - transformer - INFO - Epoch Step: 8 Loss: 0.151548 Tokens per Sec: 3980.590518
2019-05-23 00:13:00,166 - transformer - INFO - Epoch Step: 9 Los

In [17]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = nd.array([[start_symbol]], ctx = ctx)
    for i in range(max_len):
        tgt_mask = subsequent_mask(ys.shape[1])
        out = model.decode(memory, src_mask, ys, tgt_mask.as_in_context(ctx))
        next_word = nd.argmax(out, axis = 2)
        ys = nd.concat(ys, next_word[:,-1].expand_dims(axis = 1), dim = 1)
    return ys

src = nd.array([[1,5,2,3,2,5,7,8,9,10]], ctx = ctx)
print('src = {}'.format(src))
src_mask = nd.ones_like(src, ctx = ctx)
with autograd.predict_mode():
    res = greedy_decode(model, src, src_mask, max_len=9, start_symbol=1)
print('tgt = {}'.format(res))

src = 
[[ 1.  5.  2.  3.  2.  5.  7.  8.  9. 10.]]
<NDArray 1x10 @gpu(0)>
tgt = 
[[ 1.  1.  2.  3.  2.  5.  7.  8.  9. 10.]]
<NDArray 1x10 @gpu(0)>
