In [1]:
import torch
import torch.nn.functional as F

from modules.model import TransformerLSDecoder

decoder = TransformerLSDecoder(vocab_size=10, d_model=256, n_head=8, d_inner=16, n_layer=12, dropout=0.2, emb_dropout=0.2, chunk_rank=1, chunk_size=16, mem_len=16, window_len=4, grad_chk=False, pre_ln=True, use_gelu=True, use_bias=False, clamp_len=-1, cpos_clamp_len=-1, probing=False)


In [2]:
_mems = None
num_epochs = 10
optimizer = torch.optim.Adam(decoder.parameters(), lr=0.001)


In [6]:
src_tokens = torch.randint(0, 10, (1, 16))
trg_tokens = src_tokens.roll(shifts=-1, dims=1)
trg_tokens

tensor([[2, 4, 3, 5, 5, 4, 3, 3, 8, 4, 6, 8, 4, 3, 7, 7]])

In [7]:
from nltk.metrics import edit_distance

for epoch in range(1, num_epochs+1):
    bsz = src_tokens.size(0)
    mems = _mems

    if mems is None:
        # first time init
        mems = decoder.init_hid_cache(bsz)

    output, mems, _ = decoder(
        x=src_tokens, 
        h_cache=mems,
    )

    _mems = mems
    loss = F.cross_entropy(output.view(-1, output.size(-1)), trg_tokens.view(-1))
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    if (epoch+1) % 1 == 0:
        logits = torch.argmax(output, dim=-1)
        pred_str = ''.join(map(str, trg_tokens.squeeze().tolist()))
        trg_str = ''.join(map(str, trg_tokens.squeeze().tolist()))
        edit_dist = 1 - edit_distance(pred_str, trg_str) / len(trg_str.split())
        print("Prediction:", pred_str)
        print("Target:", trg_str)
        print(f"Epoch {epoch}, Loss: {loss.item()}, Edit Distance: {edit_dist}")
        print("=============")


Prediction: 2435543384684377
Target: 2435543384684377
Epoch 1, Loss: 3.156334400177002, Edit Distance: 1.0
Prediction: 2435543384684377
Target: 2435543384684377
Epoch 2, Loss: 3.019670248031616, Edit Distance: 1.0
Prediction: 2435543384684377
Target: 2435543384684377
Epoch 3, Loss: 2.6530404090881348, Edit Distance: 1.0
Prediction: 2435543384684377
Target: 2435543384684377
Epoch 4, Loss: 2.5919203758239746, Edit Distance: 1.0
Prediction: 2435543384684377
Target: 2435543384684377
Epoch 5, Loss: 2.25030779838562, Edit Distance: 1.0
Prediction: 2435543384684377
Target: 2435543384684377
Epoch 6, Loss: 2.0488860607147217, Edit Distance: 1.0
Prediction: 2435543384684377
Target: 2435543384684377
Epoch 7, Loss: 2.147839069366455, Edit Distance: 1.0
Prediction: 2435543384684377
Target: 2435543384684377
Epoch 8, Loss: 2.0518670082092285, Edit Distance: 1.0
Prediction: 2435543384684377
Target: 2435543384684377
Epoch 9, Loss: 1.9159176349639893, Edit Distance: 1.0
Prediction: 2435543384684377
Targ