# Experiment 1

In [None]:
from transformer import Transformer
from dataset.scan_dataset import ScanDataset, ExperimentType
import torch

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
# device = "cpu"
max_len = 128
train_dataset = ScanDataset(
    ExperimentType.E1_TRAIN,
    in_seq_len=max_len,
    out_seq_len=max_len + 20,
    device=device,
)
test_dataset = ScanDataset(
    ExperimentType.E1_TEST,
    vocab=train_dataset.vocab,
    in_seq_len=max_len,
    out_seq_len=max_len + 20,
    device=device,
)

model = Transformer(
    src_vocab_size=len(train_dataset.vocab),
    tgt_vocab_size=len(train_dataset.vocab),
    src_pad_idx=train_dataset.vocab.pad_idx,
    tgt_pad_idx=train_dataset.vocab.pad_idx,
    dropout=0.05,
    emb_dim=128,
    num_layers=1,
    num_heads=8,
    forward_dim=512,
    max_len=max_len + 20,
)

In [2]:
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.nn import utils
import torch

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)
grad_clip = 1.0
criterion = CrossEntropyLoss(ignore_index=train_dataset.vocab.pad_idx)
optimizer = Adam(
    model.parameters(),
    lr=7e-4,
    weight_decay=0.00001,
)

from tqdm import tqdm

model.to(device)

for epoch in range(70):
    losses = []
    for step, batch in enumerate(tqdm(train_loader)):
        inputs, decoder_inputs, target_label_indices = batch

        optimizer.zero_grad()
        out = model(inputs, decoder_inputs)
        loss = criterion(out.permute(0, 2, 1), target_label_indices)
        loss.backward()
        utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        losses.append(loss.item())
        if step % 100 == 0:
            print(f"Epoch {epoch} Loss: {sum(losses) / len(losses)}")

  6%|▌         | 16/262 [00:00<00:03, 70.25it/s]

Epoch 0 Loss: 3.1776909828186035


 48%|████▊     | 127/262 [00:00<00:00, 149.98it/s]

Epoch 0 Loss: 0.9928240339354714


 85%|████████▌ | 223/262 [00:01<00:00, 156.77it/s]

Epoch 0 Loss: 0.7550483783382681


100%|██████████| 262/262 [00:01<00:00, 141.48it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.96it/s]

Epoch 1 Loss: 0.2918003499507904


 49%|████▉     | 128/262 [00:00<00:00, 158.20it/s]

Epoch 1 Loss: 0.24642197492689188


 85%|████████▌ | 224/262 [00:01<00:00, 158.33it/s]

Epoch 1 Loss: 0.20120967135055742


100%|██████████| 262/262 [00:01<00:00, 158.51it/s]
  6%|▋         | 17/262 [00:00<00:01, 160.17it/s]

Epoch 2 Loss: 0.11123456060886383


 50%|█████     | 131/262 [00:00<00:00, 158.38it/s]

Epoch 2 Loss: 0.0942174742127409


 87%|████████▋ | 227/262 [00:01<00:00, 156.19it/s]

Epoch 2 Loss: 0.0819613243550507


100%|██████████| 262/262 [00:01<00:00, 157.71it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.10it/s]

Epoch 3 Loss: 0.0511094368994236


 49%|████▉     | 129/262 [00:00<00:00, 159.29it/s]

Epoch 3 Loss: 0.05033732878763487


 86%|████████▌ | 225/262 [00:01<00:00, 156.19it/s]

Epoch 3 Loss: 0.047814774637420975


100%|██████████| 262/262 [00:01<00:00, 157.73it/s]
 12%|█▏        | 32/262 [00:00<00:01, 156.61it/s]

Epoch 4 Loss: 0.03816741704940796


 49%|████▉     | 128/262 [00:00<00:00, 157.78it/s]

Epoch 4 Loss: 0.036526938107344184


 85%|████████▌ | 224/262 [00:01<00:00, 153.12it/s]

Epoch 4 Loss: 0.033386934354011695


100%|██████████| 262/262 [00:01<00:00, 155.26it/s]
 12%|█▏        | 32/262 [00:00<00:01, 157.82it/s]

Epoch 5 Loss: 0.029422592371702194


 49%|████▉     | 128/262 [00:00<00:00, 151.46it/s]

Epoch 5 Loss: 0.025381847423049483


 87%|████████▋ | 229/262 [00:01<00:00, 161.18it/s]

Epoch 5 Loss: 0.02456951811353662


100%|██████████| 262/262 [00:01<00:00, 157.35it/s]
  6%|▋         | 17/262 [00:00<00:01, 160.29it/s]

Epoch 6 Loss: 0.03295540064573288


 45%|████▌     | 119/262 [00:00<00:00, 161.37it/s]

Epoch 6 Loss: 0.021659116140834177


 84%|████████▍ | 221/262 [00:01<00:00, 160.29it/s]

Epoch 6 Loss: 0.020194289028941104


100%|██████████| 262/262 [00:01<00:00, 161.19it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.06it/s]

Epoch 7 Loss: 0.017161671072244644


 45%|████▍     | 117/262 [00:00<00:00, 157.38it/s]

Epoch 7 Loss: 0.018374166160932567


 89%|████████▉ | 233/262 [00:01<00:00, 159.79it/s]

Epoch 7 Loss: 0.017439411531448068


100%|██████████| 262/262 [00:01<00:00, 158.91it/s]
 12%|█▏        | 32/262 [00:00<00:01, 159.45it/s]

Epoch 8 Loss: 0.013162399642169476


 45%|████▍     | 117/262 [00:00<00:00, 161.12it/s]

Epoch 8 Loss: 0.013465047913077887


 84%|████████▎ | 219/262 [00:01<00:00, 160.36it/s]

Epoch 8 Loss: 0.013874198162965291


100%|██████████| 262/262 [00:01<00:00, 160.41it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.60it/s]

Epoch 9 Loss: 0.011165250092744827


 45%|████▌     | 118/262 [00:00<00:00, 159.94it/s]

Epoch 9 Loss: 0.012741375587751517


 84%|████████▍ | 220/262 [00:01<00:00, 161.90it/s]

Epoch 9 Loss: 0.01269314523012521


100%|██████████| 262/262 [00:01<00:00, 160.70it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.30it/s]

Epoch 10 Loss: 0.009522500447928905


 45%|████▌     | 119/262 [00:00<00:00, 160.35it/s]

Epoch 10 Loss: 0.010640291317301516


 84%|████████▍ | 221/262 [00:01<00:00, 161.97it/s]

Epoch 10 Loss: 0.010927883101933038


100%|██████████| 262/262 [00:01<00:00, 161.66it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.61it/s]

Epoch 11 Loss: 0.011910147033631802


 45%|████▌     | 119/262 [00:00<00:00, 161.64it/s]

Epoch 11 Loss: 0.008972342162955515


 84%|████████▍ | 220/262 [00:01<00:00, 157.92it/s]

Epoch 11 Loss: 0.010079883239394174


100%|██████████| 262/262 [00:01<00:00, 160.72it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.50it/s]

Epoch 12 Loss: 0.005803688894957304


 45%|████▌     | 118/262 [00:00<00:00, 159.24it/s]

Epoch 12 Loss: 0.009557264951509561


 84%|████████▍ | 220/262 [00:01<00:00, 161.52it/s]

Epoch 12 Loss: 0.00939211741481347


100%|██████████| 262/262 [00:01<00:00, 160.80it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.73it/s]

Epoch 13 Loss: 0.013139878399670124


 45%|████▌     | 118/262 [00:00<00:00, 159.61it/s]

Epoch 13 Loss: 0.008997436042294128


 84%|████████▎ | 219/262 [00:01<00:00, 159.09it/s]

Epoch 13 Loss: 0.008227958898778209


100%|██████████| 262/262 [00:01<00:00, 159.67it/s]
 13%|█▎        | 33/262 [00:00<00:01, 160.43it/s]

Epoch 14 Loss: 0.009270712733268738


 45%|████▌     | 118/262 [00:00<00:00, 162.03it/s]

Epoch 14 Loss: 0.008340563623944648


 84%|████████▎ | 219/262 [00:01<00:00, 160.33it/s]

Epoch 14 Loss: 0.008289134627977266


100%|██████████| 262/262 [00:01<00:00, 160.98it/s]
 12%|█▏        | 32/262 [00:00<00:01, 157.70it/s]

Epoch 15 Loss: 0.01944763958454132


 45%|████▍     | 117/262 [00:00<00:00, 159.76it/s]

Epoch 15 Loss: 0.009149754212449977


 88%|████████▊ | 230/262 [00:01<00:00, 159.46it/s]

Epoch 15 Loss: 0.008342175718410802


100%|██████████| 262/262 [00:01<00:00, 159.81it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.26it/s]

Epoch 16 Loss: 0.007401658222079277


 45%|████▌     | 119/262 [00:00<00:00, 160.83it/s]

Epoch 16 Loss: 0.006477800446866763


 84%|████████▍ | 221/262 [00:01<00:00, 161.57it/s]

Epoch 16 Loss: 0.00665150652755404


100%|██████████| 262/262 [00:01<00:00, 161.00it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.57it/s]

Epoch 17 Loss: 0.004865958355367184


 50%|█████     | 131/262 [00:00<00:00, 156.67it/s]

Epoch 17 Loss: 0.0075306091088764735


 89%|████████▉ | 233/262 [00:01<00:00, 160.47it/s]

Epoch 17 Loss: 0.008054935633997196


100%|██████████| 262/262 [00:01<00:00, 159.56it/s]
  6%|▌         | 16/262 [00:00<00:01, 153.23it/s]

Epoch 18 Loss: 0.006317215971648693


 49%|████▉     | 129/262 [00:00<00:00, 155.68it/s]

Epoch 18 Loss: 0.006534178010510779


 88%|████████▊ | 230/262 [00:01<00:00, 161.38it/s]

Epoch 18 Loss: 0.007173433126962804


100%|██████████| 262/262 [00:01<00:00, 158.82it/s]
  6%|▋         | 17/262 [00:00<00:01, 160.86it/s]

Epoch 19 Loss: 0.004148106090724468


 45%|████▍     | 117/262 [00:00<00:00, 157.47it/s]

Epoch 19 Loss: 0.008714709687968808


 84%|████████▎ | 219/262 [00:01<00:00, 161.17it/s]

Epoch 19 Loss: 0.007972555507243198


100%|██████████| 262/262 [00:01<00:00, 159.99it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.40it/s]

Epoch 20 Loss: 0.0028498955070972443


 45%|████▌     | 119/262 [00:00<00:00, 160.95it/s]

Epoch 20 Loss: 0.0036071762442588806


 84%|████████▍ | 221/262 [00:01<00:00, 161.47it/s]

Epoch 20 Loss: 0.004079795079230474


100%|██████████| 262/262 [00:01<00:00, 161.39it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.49it/s]

Epoch 21 Loss: 0.0034863559994846582


 45%|████▌     | 119/262 [00:00<00:00, 162.67it/s]

Epoch 21 Loss: 0.005079039004815761


 84%|████████▍ | 221/262 [00:01<00:00, 162.22it/s]

Epoch 21 Loss: 0.005340582294394239


100%|██████████| 262/262 [00:01<00:00, 161.82it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.69it/s]

Epoch 22 Loss: 0.0032699638977646828


 45%|████▌     | 119/262 [00:00<00:00, 161.82it/s]

Epoch 22 Loss: 0.006126383526969959


 84%|████████▍ | 220/262 [00:01<00:00, 159.06it/s]

Epoch 22 Loss: 0.005980226761761089


100%|██████████| 262/262 [00:01<00:00, 160.37it/s]
  6%|▌         | 16/262 [00:00<00:01, 153.21it/s]

Epoch 23 Loss: 0.0016275182133540511


 49%|████▉     | 129/262 [00:00<00:00, 157.63it/s]

Epoch 23 Loss: 0.0044659754551836465


 86%|████████▋ | 226/262 [00:01<00:00, 156.83it/s]

Epoch 23 Loss: 0.006399235134163584


100%|██████████| 262/262 [00:01<00:00, 155.36it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.06it/s]

Epoch 24 Loss: 0.003418310545384884


 45%|████▌     | 119/262 [00:00<00:00, 161.79it/s]

Epoch 24 Loss: 0.004587520923613027


 84%|████████▍ | 221/262 [00:01<00:00, 161.64it/s]

Epoch 24 Loss: 0.004164060900353859


100%|██████████| 262/262 [00:01<00:00, 161.12it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.72it/s]

Epoch 25 Loss: 0.006975054275244474


 50%|█████     | 132/262 [00:00<00:00, 159.32it/s]

Epoch 25 Loss: 0.0038444642654716934


 88%|████████▊ | 230/262 [00:01<00:00, 155.63it/s]

Epoch 25 Loss: 0.004786801763227561


100%|██████████| 262/262 [00:01<00:00, 157.38it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.12it/s]

Epoch 26 Loss: 0.006058776751160622


 49%|████▉     | 129/262 [00:00<00:00, 158.60it/s]

Epoch 26 Loss: 0.006157381450273281


 87%|████████▋ | 227/262 [00:01<00:00, 159.45it/s]

Epoch 26 Loss: 0.005453165617959101


100%|██████████| 262/262 [00:01<00:00, 158.58it/s]
  6%|▋         | 17/262 [00:00<00:01, 160.23it/s]

Epoch 27 Loss: 0.00034862736356444657


 50%|█████     | 131/262 [00:00<00:00, 153.66it/s]

Epoch 27 Loss: 0.002691184437665443


 87%|████████▋ | 228/262 [00:01<00:00, 153.57it/s]

Epoch 27 Loss: 0.0036122026675599697


100%|██████████| 262/262 [00:01<00:00, 155.88it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.24it/s]

Epoch 28 Loss: 0.014517286792397499


 50%|█████     | 131/262 [00:00<00:00, 154.74it/s]

Epoch 28 Loss: 0.0072189184138551354


 87%|████████▋ | 227/262 [00:01<00:00, 157.67it/s]

Epoch 28 Loss: 0.0068306684875119456


100%|██████████| 262/262 [00:01<00:00, 157.03it/s]
 12%|█▏        | 32/262 [00:00<00:01, 159.50it/s]

Epoch 29 Loss: 0.00260677607730031


 49%|████▉     | 129/262 [00:00<00:00, 159.32it/s]

Epoch 29 Loss: 0.004102239459015355


 86%|████████▌ | 225/262 [00:01<00:00, 158.82it/s]

Epoch 29 Loss: 0.004562020792529701


100%|██████████| 262/262 [00:01<00:00, 159.38it/s]
 12%|█▏        | 32/262 [00:00<00:01, 157.21it/s]

Epoch 30 Loss: 0.0008589741191826761


 49%|████▉     | 128/262 [00:00<00:00, 158.69it/s]

Epoch 30 Loss: 0.003930763715225742


 85%|████████▌ | 224/262 [00:01<00:00, 156.65it/s]

Epoch 30 Loss: 0.0043977280167645


100%|██████████| 262/262 [00:01<00:00, 157.90it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.15it/s]

Epoch 31 Loss: 0.001119435066357255


 49%|████▉     | 129/262 [00:00<00:00, 156.68it/s]

Epoch 31 Loss: 0.0038794862624548267


 86%|████████▌ | 225/262 [00:01<00:00, 157.94it/s]

Epoch 31 Loss: 0.004504667760529764


100%|██████████| 262/262 [00:01<00:00, 157.34it/s]
 12%|█▏        | 32/262 [00:00<00:01, 156.41it/s]

Epoch 32 Loss: 0.006417559459805489


 49%|████▉     | 128/262 [00:00<00:00, 156.67it/s]

Epoch 32 Loss: 0.0033797655690583576


 85%|████████▌ | 224/262 [00:01<00:00, 156.94it/s]

Epoch 32 Loss: 0.0031980129482400655


100%|██████████| 262/262 [00:01<00:00, 156.93it/s]
 11%|█▏        | 30/262 [00:00<00:01, 147.91it/s]

Epoch 33 Loss: 0.00339919188991189


 48%|████▊     | 125/262 [00:00<00:00, 155.23it/s]

Epoch 33 Loss: 0.006841665482503537


 84%|████████▍ | 221/262 [00:01<00:00, 157.97it/s]

Epoch 33 Loss: 0.00527067088635826


100%|██████████| 262/262 [00:01<00:00, 155.43it/s]
 12%|█▏        | 32/262 [00:00<00:01, 157.76it/s]

Epoch 34 Loss: 0.004556750878691673


 49%|████▉     | 128/262 [00:00<00:00, 157.92it/s]

Epoch 34 Loss: 0.004035952791460965


 85%|████████▌ | 224/262 [00:01<00:00, 157.59it/s]

Epoch 34 Loss: 0.003916663556653353


100%|██████████| 262/262 [00:01<00:00, 157.73it/s]
  6%|▋         | 17/262 [00:00<00:01, 159.97it/s]

Epoch 35 Loss: 0.001883780350908637


 51%|█████     | 133/262 [00:00<00:00, 159.67it/s]

Epoch 35 Loss: 0.002667911616839828


 87%|████████▋ | 229/262 [00:01<00:00, 158.09it/s]

Epoch 35 Loss: 0.0026622303436187084


100%|██████████| 262/262 [00:01<00:00, 159.24it/s]
  6%|▋         | 17/262 [00:00<00:01, 160.38it/s]

Epoch 36 Loss: 0.0008481974946334958


 50%|████▉     | 130/262 [00:00<00:00, 157.21it/s]

Epoch 36 Loss: 0.0056015719958733175


 86%|████████▋ | 226/262 [00:01<00:00, 154.78it/s]

Epoch 36 Loss: 0.004569919491175155


100%|██████████| 262/262 [00:01<00:00, 157.04it/s]
 12%|█▏        | 32/262 [00:00<00:01, 156.79it/s]

Epoch 37 Loss: 0.00037809181958436966


 49%|████▉     | 129/262 [00:00<00:00, 156.33it/s]

Epoch 37 Loss: 0.004917292010804436


 86%|████████▌ | 225/262 [00:01<00:00, 156.88it/s]

Epoch 37 Loss: 0.004046006649020197


100%|██████████| 262/262 [00:01<00:00, 156.38it/s]
  6%|▌         | 16/262 [00:00<00:01, 154.01it/s]

Epoch 38 Loss: 0.0024894950911402702


 49%|████▉     | 129/262 [00:00<00:00, 158.32it/s]

Epoch 38 Loss: 0.005212896883154592


 87%|████████▋ | 227/262 [00:01<00:00, 157.98it/s]

Epoch 38 Loss: 0.004708919360314551


100%|██████████| 262/262 [00:01<00:00, 157.33it/s]
 12%|█▏        | 32/262 [00:00<00:01, 156.84it/s]

Epoch 39 Loss: 0.004270167555660009


 51%|█████     | 133/262 [00:00<00:00, 160.41it/s]

Epoch 39 Loss: 0.003925454999910354


 83%|████████▎ | 218/262 [00:01<00:00, 160.99it/s]

Epoch 39 Loss: 0.0027821698652890823


100%|██████████| 262/262 [00:01<00:00, 160.31it/s]
 12%|█▏        | 32/262 [00:00<00:01, 156.56it/s]

Epoch 40 Loss: 0.0030823235865682364


 49%|████▉     | 128/262 [00:00<00:00, 158.04it/s]

Epoch 40 Loss: 0.0028235508262417232


 88%|████████▊ | 230/262 [00:01<00:00, 159.57it/s]

Epoch 40 Loss: 0.0021186229516650707


100%|██████████| 262/262 [00:01<00:00, 159.27it/s]
 13%|█▎        | 33/262 [00:00<00:01, 161.09it/s]

Epoch 41 Loss: 0.004591302014887333


 45%|████▍     | 117/262 [00:00<00:00, 159.67it/s]

Epoch 41 Loss: 0.005183895568059075


 84%|████████▎ | 219/262 [00:01<00:00, 159.94it/s]

Epoch 41 Loss: 0.005270398898248269


100%|██████████| 262/262 [00:01<00:00, 160.30it/s]
 12%|█▏        | 32/262 [00:00<00:01, 154.19it/s]

Epoch 42 Loss: 0.008115112781524658


 49%|████▉     | 128/262 [00:00<00:00, 155.02it/s]

Epoch 42 Loss: 0.002246657763475248


 85%|████████▌ | 224/262 [00:01<00:00, 158.18it/s]

Epoch 42 Loss: 0.002214063617122239


100%|██████████| 262/262 [00:01<00:00, 156.80it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.58it/s]

Epoch 43 Loss: 0.008706697262823582


 49%|████▉     | 128/262 [00:00<00:00, 158.56it/s]

Epoch 43 Loss: 0.005502275363277035


 85%|████████▌ | 224/262 [00:01<00:00, 158.79it/s]

Epoch 43 Loss: 0.004857929801366834


100%|██████████| 262/262 [00:01<00:00, 158.72it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.67it/s]

Epoch 44 Loss: 0.002917452482506633


 49%|████▉     | 128/262 [00:00<00:00, 157.11it/s]

Epoch 44 Loss: 0.00389252879007346


 85%|████████▌ | 224/262 [00:01<00:00, 157.17it/s]

Epoch 44 Loss: 0.002973043794997977


100%|██████████| 262/262 [00:01<00:00, 157.23it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.65it/s]

Epoch 45 Loss: 0.0007208812749013305


 49%|████▉     | 129/262 [00:00<00:00, 158.03it/s]

Epoch 45 Loss: 0.0016140796053672837


 87%|████████▋ | 229/262 [00:01<00:00, 159.86it/s]

Epoch 45 Loss: 0.0016278788510438365


100%|██████████| 262/262 [00:01<00:00, 158.34it/s]
  6%|▌         | 16/262 [00:00<00:01, 152.34it/s]

Epoch 46 Loss: 0.0006411890499293804


 48%|████▊     | 126/262 [00:00<00:00, 151.95it/s]

Epoch 46 Loss: 0.0035713745509032254


 85%|████████▍ | 222/262 [00:01<00:00, 150.75it/s]

Epoch 46 Loss: 0.0030443185506868107


100%|██████████| 262/262 [00:01<00:00, 151.24it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.16it/s]

Epoch 47 Loss: 0.0005346708931028843


 49%|████▉     | 128/262 [00:00<00:00, 152.16it/s]

Epoch 47 Loss: 0.0033486865957654864


 87%|████████▋ | 228/262 [00:01<00:00, 159.72it/s]

Epoch 47 Loss: 0.0025515689684928915


100%|██████████| 262/262 [00:01<00:00, 157.47it/s]
  6%|▋         | 17/262 [00:00<00:01, 160.47it/s]

Epoch 48 Loss: 0.00019869717652909458


 45%|████▌     | 119/262 [00:00<00:00, 162.28it/s]

Epoch 48 Loss: 0.0067095975681164275


 84%|████████▍ | 221/262 [00:01<00:00, 162.11it/s]

Epoch 48 Loss: 0.004909327673967302


100%|██████████| 262/262 [00:01<00:00, 162.01it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.50it/s]

Epoch 49 Loss: 0.002208410995081067


 45%|████▌     | 119/262 [00:00<00:00, 160.54it/s]

Epoch 49 Loss: 0.002727041414611521


 84%|████████▍ | 221/262 [00:01<00:00, 162.22it/s]

Epoch 49 Loss: 0.0021063491623673184


100%|██████████| 262/262 [00:01<00:00, 161.48it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.95it/s]

Epoch 50 Loss: 0.0014953955542296171


 45%|████▌     | 119/262 [00:00<00:00, 161.04it/s]

Epoch 50 Loss: 0.0014352801604263282


 84%|████████▍ | 221/262 [00:01<00:00, 161.60it/s]

Epoch 50 Loss: 0.0017071571167059179


100%|██████████| 262/262 [00:01<00:00, 161.75it/s]
  6%|▌         | 16/262 [00:00<00:01, 152.20it/s]

Epoch 51 Loss: 0.00041407288517802954


 49%|████▉     | 129/262 [00:00<00:00, 156.92it/s]

Epoch 51 Loss: 0.0017850598583163896


 88%|████████▊ | 231/262 [00:01<00:00, 161.96it/s]

Epoch 51 Loss: 0.002581557929872035


100%|██████████| 262/262 [00:01<00:00, 159.14it/s]
  6%|▋         | 17/262 [00:00<00:01, 160.44it/s]

Epoch 52 Loss: 0.005993106868118048


 45%|████▌     | 119/262 [00:00<00:00, 161.55it/s]

Epoch 52 Loss: 0.005401434269065203


 84%|████████▍ | 221/262 [00:01<00:00, 161.29it/s]

Epoch 52 Loss: 0.003940890183346809


100%|██████████| 262/262 [00:01<00:00, 159.34it/s]
 12%|█▏        | 32/262 [00:00<00:01, 157.63it/s]

Epoch 53 Loss: 0.0005646736244671047


 50%|█████     | 131/262 [00:00<00:00, 159.61it/s]

Epoch 53 Loss: 0.0019395710019530765


 89%|████████▊ | 232/262 [00:01<00:00, 161.19it/s]

Epoch 53 Loss: 0.0017309610613746875


100%|██████████| 262/262 [00:01<00:00, 159.79it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.14it/s]

Epoch 54 Loss: 0.0004747233761008829


 45%|████▌     | 119/262 [00:00<00:00, 162.46it/s]

Epoch 54 Loss: 0.002578104346155507


 84%|████████▍ | 221/262 [00:01<00:00, 161.12it/s]

Epoch 54 Loss: 0.003530075259537689


100%|██████████| 262/262 [00:01<00:00, 161.71it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.68it/s]

Epoch 55 Loss: 0.0007134710322134197


 45%|████▌     | 119/262 [00:00<00:00, 161.96it/s]

Epoch 55 Loss: 0.0010554131627887359


 84%|████████▍ | 221/262 [00:01<00:00, 159.58it/s]

Epoch 55 Loss: 0.001485183374942929


100%|██████████| 262/262 [00:01<00:00, 161.42it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.98it/s]

Epoch 56 Loss: 0.0002699150936678052


 45%|████▌     | 119/262 [00:00<00:00, 161.30it/s]

Epoch 56 Loss: 0.0033876674982694917


 84%|████████▍ | 221/262 [00:01<00:00, 160.97it/s]

Epoch 56 Loss: 0.002949161422275301


100%|██████████| 262/262 [00:01<00:00, 161.77it/s]
  6%|▋         | 17/262 [00:00<00:01, 160.54it/s]

Epoch 57 Loss: 0.000522589311003685


 50%|█████     | 132/262 [00:00<00:00, 157.92it/s]

Epoch 57 Loss: 0.0032116069183523517


 88%|████████▊ | 231/262 [00:01<00:00, 156.43it/s]

Epoch 57 Loss: 0.004309989760869613


100%|██████████| 262/262 [00:01<00:00, 156.64it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.58it/s]

Epoch 58 Loss: 0.0027493308298289776


 49%|████▉     | 129/262 [00:00<00:00, 158.09it/s]

Epoch 58 Loss: 0.0028503705885164475


 86%|████████▌ | 225/262 [00:01<00:00, 157.94it/s]

Epoch 58 Loss: 0.002919723317348537


100%|██████████| 262/262 [00:01<00:00, 158.77it/s]
  6%|▋         | 17/262 [00:00<00:01, 162.02it/s]

Epoch 59 Loss: 0.0004586885042954236


 45%|████▌     | 119/262 [00:00<00:00, 160.96it/s]

Epoch 59 Loss: 0.001876777371303478


 84%|████████▍ | 221/262 [00:01<00:00, 161.61it/s]

Epoch 59 Loss: 0.00195557218671525


100%|██████████| 262/262 [00:01<00:00, 161.46it/s]
  6%|▌         | 16/262 [00:00<00:01, 154.79it/s]

Epoch 60 Loss: 0.003375139320269227


 49%|████▉     | 128/262 [00:00<00:00, 155.14it/s]

Epoch 60 Loss: 0.0021547755463645875


 86%|████████▌ | 225/262 [00:01<00:00, 150.19it/s]

Epoch 60 Loss: 0.002059714319601433


100%|██████████| 262/262 [00:01<00:00, 152.03it/s]
 12%|█▏        | 32/262 [00:00<00:01, 157.68it/s]

Epoch 61 Loss: 0.0023854486644268036


 49%|████▉     | 128/262 [00:00<00:00, 153.52it/s]

Epoch 61 Loss: 0.0018850041473846846


 85%|████████▌ | 224/262 [00:01<00:00, 151.14it/s]

Epoch 61 Loss: 0.002037408909045417


100%|██████████| 262/262 [00:01<00:00, 153.11it/s]
  6%|▌         | 16/262 [00:00<00:01, 156.89it/s]

Epoch 62 Loss: 0.005339410156011581


 49%|████▉     | 128/262 [00:00<00:00, 157.31it/s]

Epoch 62 Loss: 0.003961238685198875


 86%|████████▌ | 225/262 [00:01<00:00, 149.85it/s]

Epoch 62 Loss: 0.0023904346694575565


100%|██████████| 262/262 [00:01<00:00, 155.53it/s]
  6%|▋         | 17/262 [00:00<00:01, 161.71it/s]

Epoch 63 Loss: 0.00022741044813301414


 50%|████▉     | 130/262 [00:00<00:00, 155.37it/s]

Epoch 63 Loss: 0.0012207556792058277


 86%|████████▋ | 226/262 [00:01<00:00, 154.61it/s]

Epoch 63 Loss: 0.002262327480318855


100%|██████████| 262/262 [00:01<00:00, 155.50it/s]
 12%|█▏        | 32/262 [00:00<00:01, 155.45it/s]

Epoch 64 Loss: 0.014762593433260918


 50%|█████     | 131/262 [00:00<00:00, 159.33it/s]

Epoch 64 Loss: 0.0035066572192302527


 88%|████████▊ | 230/262 [00:01<00:00, 158.58it/s]

Epoch 64 Loss: 0.0027856924877700685


100%|██████████| 262/262 [00:01<00:00, 157.81it/s]
 12%|█▏        | 31/262 [00:00<00:01, 152.93it/s]

Epoch 65 Loss: 0.007827476598322392


 48%|████▊     | 127/262 [00:00<00:00, 158.44it/s]

Epoch 65 Loss: 0.0015908035253448928


 87%|████████▋ | 228/262 [00:01<00:00, 161.06it/s]

Epoch 65 Loss: 0.001288596487151452


100%|██████████| 262/262 [00:01<00:00, 159.22it/s]
 12%|█▏        | 32/262 [00:00<00:01, 159.61it/s]

Epoch 66 Loss: 0.00014813581947237253


 49%|████▉     | 128/262 [00:00<00:00, 154.82it/s]

Epoch 66 Loss: 0.000579412707473707


 87%|████████▋ | 227/262 [00:01<00:00, 159.46it/s]

Epoch 66 Loss: 0.0020422924861493775


100%|██████████| 262/262 [00:01<00:00, 157.88it/s]
 12%|█▏        | 32/262 [00:00<00:01, 157.98it/s]

Epoch 67 Loss: 0.00023588872863911092


 45%|████▍     | 117/262 [00:00<00:00, 158.73it/s]

Epoch 67 Loss: 0.002296016346818841


 89%|████████▊ | 232/262 [00:01<00:00, 157.41it/s]

Epoch 67 Loss: 0.002136761553203849


100%|██████████| 262/262 [00:01<00:00, 159.06it/s]
  6%|▌         | 16/262 [00:00<00:01, 150.62it/s]

Epoch 68 Loss: 0.00206825346685946


 50%|█████     | 132/262 [00:00<00:00, 158.87it/s]

Epoch 68 Loss: 0.0026010752996743314


 87%|████████▋ | 228/262 [00:01<00:00, 157.37it/s]

Epoch 68 Loss: 0.0036222568891371427


100%|██████████| 262/262 [00:01<00:00, 157.78it/s]
 12%|█▏        | 32/262 [00:00<00:01, 158.12it/s]

Epoch 69 Loss: 0.0017086296575143933


 45%|████▍     | 117/262 [00:00<00:00, 160.82it/s]

Epoch 69 Loss: 0.003602335573719252


 84%|████████▎ | 219/262 [00:01<00:00, 160.87it/s]

Epoch 69 Loss: 0.0022714917191949476


100%|██████████| 262/262 [00:01<00:00, 161.25it/s]


In [None]:
from evaluate import evaluate_model_batchwise
evals = evaluate_model_batchwise(model, test_loader, train_dataset.vocab, device=device)
evals

cuda
Generated: ['I_TURN_RIGHT', 'I_TURN_RIGHT', 'I_TURN_RIGHT', 'I_TURN_RIGHT', 'I_TURN_RIGHT']
Target: ['I_TURN_RIGHT', 'I_TURN_RIGHT', 'I_TURN_RIGHT', 'I_TURN_RIGHT', 'I_TURN_RIGHT', 'I_TURN_RIGHT', 'I_TURN_RIGHT']
4181 4182


0.9997608799617408