In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import copy
import time
from datetime import timedelta

import spacy # 一个自然语言文本处理库

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F

from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
from torchtext.data.functional import to_map_style_dataset  # 将迭代器转化为 Dataset 类型，可直接索引


from transformer_package import make_model, scheduler, LabelSmoothingKL

In [27]:
class ComputeLoss:
    def __init__(self, criterion):
        self.criterion = criterion

    def __call__(self, x, y):
        x = F.log_softmax(x, dim=-1)  # if KL divergence is used
        loss = (
            self.criterion(    # x: [b, len, vocab_size] ---> [b*len, vocab_size];  y: [b, len] ---> [b*len, ]
                x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)
            )
        )
        return loss

In [28]:
def train_epoch(epoch, data_loader, model, loss_compute, optimizer, scheduler, padding_idx):
    """Train a single epoch."""
    start = time.time()
    total_loss = 0
    total_tokens = 0
    model.train()
    for i, (src, tgt) in enumerate(data_loader):
        # src, tgt shape: [batch_size, max_len]
        tgt_y = copy.deepcopy(tgt[:, 1:])  # 真实的序列，用来构建loss，第一位往往是起始符
        tgt_seq = copy.deepcopy(tgt[:, :-1])   # 输入decoder的序列，最后一位的token是用不到的。因为按照decoder的工作原理，最后一个token的生成是不会依赖到最后一个token的信息的。
        ntokens = (tgt_y != padding_idx).data.sum()
        
        # get padding mask and sequence mask
        src_mask = model.padding_mask(src, padding_idx)
        tgt_mask = model.padding_mask(tgt_seq, padding_idx) & model.sequence_mask(tgt_seq.size(-1))
        
        # train
        logit = model(src, tgt_seq, src_mask, tgt_mask)
        loss = loss_compute(logit, tgt_y)
        
        
        total_loss += loss
        total_tokens += ntokens
        
        loss /= ntokens # mean loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True) # 优化内存使用
        scheduler.step()

        lr = optimizer.param_groups[0]["lr"]
        elapsed = time.time() - start
        print(
            (
                "| Epoch {:3d} | {:5d}/{:5d} batches | Loss: {:6.2f} "
                + "| Tokens: {:5d} | Learning Rate: {:6.1e} | Time: {} |"
            ).format(epoch, i, len(data_loader), loss, ntokens, lr, timedelta(seconds=elapsed))
        )
        
        del loss

    return total_loss / total_tokens


def eval_epoch(epoch, data_loader, model, loss_compute, padding_idx):
    """Eval a single epoch."""
    start = time.time()
    total_loss = 0
    total_tokens = 0
    model.eval()
    with torch.no_grad():
        for i, (src, tgt) in enumerate(data_loader):
            # src, tgt shape: [batch_size, max_len]
            tgt_y = copy.deepcopy(tgt[:, 1:])  # 真实的序列，用来构建loss，第一位往往是起始符
            tgt_seq = copy.deepcopy(tgt[:, :-1])   # 输入decoder的序列，最后一位的token是用不到的。因为按照decoder的工作原理，最后一个token的生成是不会依赖到最后一个token的信息的。
            ntokens = (tgt_y != padding_idx).data.sum()
            
            # get padding mask and sequence mask
            src_mask = model.padding_mask(src, padding_idx)
            tgt_mask = model.padding_mask(tgt_seq, padding_idx) & model.sequence_mask(tgt_seq.size(-1))
            
            # train
            logit = model(src, tgt_seq, src_mask, tgt_mask)
            loss = loss_compute(logit, tgt_y)
        

            total_loss += loss
            total_tokens += ntokens
            
            
            loss /= ntokens # mean loss
    
            elapsed = time.time() - start
            print(
                (
                    "| Epoch {:3d} | {:5d}/{:5d} batches | Loss: {:6.2f} "
                    + "| Tokens: {:5d} | Time: {} |"
                ).format(epoch, i, len(data_loader), loss, ntokens, timedelta(seconds=elapsed))
            )
            
            del loss
            
    return total_loss / total_tokens

# Multi30k German-English Translation task

## Tokenization

In [29]:
def load_tokenizers():

    try:
        spacy_de = spacy.load("de_core_news_sm")
    except IOError:
        os.system("python -m spacy download de_core_news_sm")  # 大概率要翻墙才行
        spacy_de = spacy.load("de_core_news_sm")

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_de, spacy_en

# spacy_de, spacy_en = load_tokenizers()
# doc = spacy_en.tokenizer("This is a sentence.")
# print([(w.text, w.pos_) for w in doc])


def tokenize(text, tokenizer):
    return [tok.text for tok in tokenizer.tokenizer(text)]


def yield_tokens(data_iter, tokenizer, index):
    for from_to_tuple in data_iter:
        yield tokenizer(from_to_tuple[index])

## Build vocabulary

In [30]:
def build_vocabulary(spacy_de, spacy_en):
    def tokenize_de(text):
        return tokenize(text, spacy_de)

    def tokenize_en(text):
        return tokenize(text, spacy_en)

    print("Building German Vocabulary ...")
    train, val, test = datasets.Multi30k(language_pair=("de", "en"))
    vocab_src = build_vocab_from_iterator(
        yield_tokens(train + val + test, tokenize_de, index=0),  # 'de'
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],   # 分别代表起始符、终止符、padding字符、未知字符
    )

    print("Building English Vocabulary ...")
    train, val, test = datasets.Multi30k(language_pair=("de", "en"))
    vocab_tgt = build_vocab_from_iterator(
        yield_tokens(train + val + test, tokenize_en, index=1),  # 'en'
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )

    vocab_src.set_default_index(vocab_src["<unk>"])  # This index will be returned when OOV token is queried.
    vocab_tgt.set_default_index(vocab_tgt["<unk>"])

    return vocab_src, vocab_tgt



def load_vocab(spacy_de, spacy_en):
    if not os.path.exists("vocab.pt"):
        vocab_src, vocab_tgt = build_vocabulary(spacy_de, spacy_en)
        torch.save((vocab_src, vocab_tgt), "vocab.pt")
    else:
        vocab_src, vocab_tgt = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes:")
    print(len(vocab_src))
    print(len(vocab_tgt))
    return vocab_src, vocab_tgt



# vocab_src, vocab_tgt = load_vocab(spacy_de, spacy_en)
# print(vocab_src.get_stoi())

## Data Loader

In [31]:
def collate_batch(batch, src_pipline, tgt_pipline, src_vocab, tgt_vocab, max_padding=128, pad_id=2):  # <blank> token id
    '''
    负责在 DataLoad 提取一个 batch 的样本时，完成一系列预处理工作。
    所以，我们将 collate_batch 函数通过参数 collate_fn 传入 DataLoader，
    即可实现对变长数据的处理。
    '''
    bs_id = torch.tensor([0])  # <s> token id
    eos_id = torch.tensor([1])  # </s> token id
    src_list, tgt_list = [], []
    
    for (src, tgt) in batch:
        # 为每一句话添加起始符和结束符
        processed_src = torch.cat([bs_id, 
                                  torch.as_tensor(src_vocab(src_pipline(src)), dtype=torch.int64),
                                  eos_id],
                                  dim=0)
        processed_tgt = torch.cat([bs_id, 
                                  torch.as_tensor(tgt_vocab(tgt_pipline(tgt)), dtype=torch.int64),
                                  eos_id],
                                  dim=0)
        
        # 给长度不足max_padding的句子打padding
        processed_src = F.pad(processed_src, (0, max_padding - len(processed_src)), value=pad_id)  # # warning - overwrites values for negative values of padding - len
        src_list.append(processed_src)
        processed_tgt = F.pad(processed_tgt, (0, max_padding - len(processed_tgt)), value=pad_id)
        tgt_list.append(processed_tgt)
        
       
    src_batch = torch.stack(src_list)
    tgt_batch = torch.stack(tgt_list)
    
    return (src_batch, tgt_batch)



def create_dataloaders(vocab_src, vocab_tgt, spacy_de, spacy_en, batch_size=512, max_padding=128):
    
    def tokenize_de(text):    # src_pipline
        return tokenize(text, spacy_de)

    def tokenize_en(text):   # tgt_pipline
        return tokenize(text, spacy_en)
        
    def collate_fn(batch):
        return collate_batch(batch, 
                             tokenize_de, 
                             tokenize_en, 
                             vocab_src, 
                             vocab_tgt,
                             max_padding=max_padding,
                             pad_id=vocab_src.get_stoi()['<blank>'])
    
    train_iter, valid_iter, test_iter = datasets.Multi30k(language_pair=("de", "en"))
    
    train_iter_map = to_map_style_dataset(train_iter)
    valid_iter_map = to_map_style_dataset(valid_iter)
    
    train_dataloader = DataLoader(train_iter_map,
                                  batch_size,
                                  collate_fn=collate_fn)

    valid_dataloader = DataLoader(valid_iter_map,
                                  batch_size,
                                  collate_fn=collate_fn)

    return train_dataloader, valid_dataloader


# train_dataloader, valid_dataloader = create_dataloaders(vocab_src, vocab_tgt, spacy_de, spacy_en, batch_size=512, max_padding=128)

## Training

In [None]:
spacy_de, spacy_en = load_tokenizers()
vocab_src, vocab_tgt = load_vocab(spacy_de, spacy_en)
train_dataloader, valid_dataloader = create_dataloaders(vocab_src, vocab_tgt, spacy_de, spacy_en, batch_size=128, max_padding=72)

src_vocab_size = len(vocab_src)
tgt_vocab_size = len(vocab_tgt)
pad_idx = vocab_src['<blank>']
d_model = 512
init_lr = 1.0
warmup = 3000
num_epochs = 8

model = make_model(src_vocab_size, tgt_vocab_size, d_model)

criterion = LabelSmoothingKL(vocab_size=tgt_vocab_size, padding_idx=pad_idx, smoothing=0.1)

optimizer = optim.Adam(model.parameters(), lr=init_lr, betas=(0.9, 0.98), eps=1e-9)
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                           lr_lambda=lambda step: scheduler(step, d_model, factor=1, warmup=warmup))

train_losses = []
val_losses = []
for epoch in range(num_epochs):
    start = time.time()
    model.train()
    train_loss = train_epoch(epoch,
        train_dataloader,
        model,
        ComputeLoss(criterion),
        optimizer,
        lr_scheduler,
        pad_idx
    )
    train_losses.append(train_loss.item())
    
    model.eval()
    val_loss = eval_epoch(epoch,
        valid_dataloader,
        model,
        ComputeLoss(criterion),
        pad_idx
    )

    val_losses.append(val_loss.item())
    print('-' * 59)
    print('| End of epoch {:3d} | Train Loss: {:8.3f} | Val Loss: {:8.3f} | time: {} |'
          .format(epoch, train_loss, val_loss, timedelta(seconds=time.time()-start)))
    print('-' * 59)

Finished.
Vocabulary sizes:
8185
6291
| Epoch   0 |     0/  227 batches | Loss:   7.61 | Tokens:  1792 | Learning Rate: 2.7e-07 | Time: 0:01:03.702847 |
| Epoch   0 |     1/  227 batches | Loss:   7.61 | Tokens:  1795 | Learning Rate: 5.4e-07 | Time: 0:01:49.522915 |
| Epoch   0 |     2/  227 batches | Loss:   7.60 | Tokens:  1791 | Learning Rate: 8.1e-07 | Time: 0:02:35.399620 |
| Epoch   0 |     3/  227 batches | Loss:   7.59 | Tokens:  1813 | Learning Rate: 1.1e-06 | Time: 0:03:24.495143 |
| Epoch   0 |     4/  227 batches | Loss:   7.58 | Tokens:  1807 | Learning Rate: 1.3e-06 | Time: 0:04:13.745857 |
| Epoch   0 |     5/  227 batches | Loss:   7.57 | Tokens:  1824 | Learning Rate: 1.6e-06 | Time: 0:05:10.148724 |
| Epoch   0 |     6/  227 batches | Loss:   7.55 | Tokens:  1794 | Learning Rate: 1.9e-06 | Time: 0:06:00.251952 |
| Epoch   0 |     7/  227 batches | Loss:   7.52 | Tokens:  1796 | Learning Rate: 2.2e-06 | Time: 0:06:50.500950 |
| Epoch   0 |     8/  227 batches | Loss: 

| Epoch   0 |    71/  227 batches | Loss:   6.29 | Tokens:  1755 | Learning Rate: 1.9e-05 | Time: 0:57:12.522978 |
| Epoch   0 |    72/  227 batches | Loss:   6.23 | Tokens:  1710 | Learning Rate: 2.0e-05 | Time: 0:57:58.643895 |
| Epoch   0 |    73/  227 batches | Loss:   6.26 | Tokens:  1669 | Learning Rate: 2.0e-05 | Time: 0:58:44.575481 |
| Epoch   0 |    74/  227 batches | Loss:   6.25 | Tokens:  1764 | Learning Rate: 2.0e-05 | Time: 0:59:29.325651 |
| Epoch   0 |    75/  227 batches | Loss:   6.20 | Tokens:  1747 | Learning Rate: 2.0e-05 | Time: 1:00:17.824932 |
| Epoch   0 |    76/  227 batches | Loss:   6.21 | Tokens:  1781 | Learning Rate: 2.1e-05 | Time: 1:01:03.923181 |
| Epoch   0 |    77/  227 batches | Loss:   6.15 | Tokens:  1721 | Learning Rate: 2.1e-05 | Time: 1:01:48.339299 |
| Epoch   0 |    78/  227 batches | Loss:   6.18 | Tokens:  1723 | Learning Rate: 2.1e-05 | Time: 1:02:33.121294 |
| Epoch   0 |    79/  227 batches | Loss:   6.20 | Tokens:  1710 | Learning Rate

| Epoch   0 |   143/  227 batches | Loss:   5.02 | Tokens:  1860 | Learning Rate: 3.9e-05 | Time: 1:54:38.563361 |
| Epoch   0 |   144/  227 batches | Loss:   5.01 | Tokens:  1867 | Learning Rate: 3.9e-05 | Time: 1:55:21.806534 |
| Epoch   0 |   145/  227 batches | Loss:   5.06 | Tokens:  1832 | Learning Rate: 3.9e-05 | Time: 1:56:07.363344 |
| Epoch   0 |   146/  227 batches | Loss:   5.03 | Tokens:  1888 | Learning Rate: 4.0e-05 | Time: 1:56:55.639392 |
| Epoch   0 |   147/  227 batches | Loss:   4.96 | Tokens:  1829 | Learning Rate: 4.0e-05 | Time: 1:57:39.127458 |
| Epoch   0 |   148/  227 batches | Loss:   4.91 | Tokens:  1764 | Learning Rate: 4.0e-05 | Time: 1:58:23.846070 |
| Epoch   0 |   149/  227 batches | Loss:   4.91 | Tokens:  1790 | Learning Rate: 4.0e-05 | Time: 1:59:08.282601 |
| Epoch   0 |   150/  227 batches | Loss:   4.97 | Tokens:  1889 | Learning Rate: 4.1e-05 | Time: 1:59:50.764932 |
| Epoch   0 |   151/  227 batches | Loss:   4.89 | Tokens:  1870 | Learning Rate

| Epoch   0 |   215/  227 batches | Loss:   4.44 | Tokens:  1911 | Learning Rate: 5.8e-05 | Time: 2:50:04.846523 |
| Epoch   0 |   216/  227 batches | Loss:   4.33 | Tokens:  1840 | Learning Rate: 5.8e-05 | Time: 2:50:48.618488 |
| Epoch   0 |   217/  227 batches | Loss:   4.53 | Tokens:  1967 | Learning Rate: 5.9e-05 | Time: 2:51:35.303857 |
| Epoch   0 |   218/  227 batches | Loss:   4.52 | Tokens:  1920 | Learning Rate: 5.9e-05 | Time: 2:52:25.899691 |
| Epoch   0 |   219/  227 batches | Loss:   4.50 | Tokens:  1933 | Learning Rate: 5.9e-05 | Time: 2:53:16.116080 |
| Epoch   0 |   220/  227 batches | Loss:   4.57 | Tokens:  2004 | Learning Rate: 5.9e-05 | Time: 2:54:04.743613 |
| Epoch   0 |   221/  227 batches | Loss:   4.43 | Tokens:  1910 | Learning Rate: 6.0e-05 | Time: 2:54:54.010412 |
| Epoch   0 |   222/  227 batches | Loss:   4.38 | Tokens:  1899 | Learning Rate: 6.0e-05 | Time: 2:55:39.853488 |
| Epoch   0 |   223/  227 batches | Loss:   4.31 | Tokens:  1864 | Learning Rate

| Epoch   1 |    52/  227 batches | Loss:   4.10 | Tokens:  1794 | Learning Rate: 7.5e-05 | Time: 0:41:27.650749 |
| Epoch   1 |    53/  227 batches | Loss:   4.02 | Tokens:  1767 | Learning Rate: 7.6e-05 | Time: 0:42:13.001234 |
| Epoch   1 |    54/  227 batches | Loss:   4.01 | Tokens:  1792 | Learning Rate: 7.6e-05 | Time: 0:43:01.840516 |
| Epoch   1 |    55/  227 batches | Loss:   4.02 | Tokens:  1734 | Learning Rate: 7.6e-05 | Time: 0:43:49.752062 |
| Epoch   1 |    56/  227 batches | Loss:   3.99 | Tokens:  1780 | Learning Rate: 7.6e-05 | Time: 0:44:38.497504 |
| Epoch   1 |    57/  227 batches | Loss:   4.02 | Tokens:  1833 | Learning Rate: 7.7e-05 | Time: 0:45:32.900157 |
| Epoch   1 |    58/  227 batches | Loss:   3.97 | Tokens:  1793 | Learning Rate: 7.7e-05 | Time: 0:46:19.623019 |
| Epoch   1 |    59/  227 batches | Loss:   3.93 | Tokens:  1720 | Learning Rate: 7.7e-05 | Time: 0:47:05.898271 |
| Epoch   1 |    60/  227 batches | Loss:   3.97 | Tokens:  1798 | Learning Rate

| Epoch   1 |   124/  227 batches | Loss:   3.80 | Tokens:  1835 | Learning Rate: 9.5e-05 | Time: 1:42:29.223696 |
| Epoch   1 |   125/  227 batches | Loss:   3.80 | Tokens:  1915 | Learning Rate: 9.5e-05 | Time: 1:43:21.192478 |
| Epoch   1 |   126/  227 batches | Loss:   3.68 | Tokens:  1790 | Learning Rate: 9.5e-05 | Time: 1:44:13.072019 |
| Epoch   1 |   127/  227 batches | Loss:   3.81 | Tokens:  1831 | Learning Rate: 9.5e-05 | Time: 1:45:07.947999 |
| Epoch   1 |   128/  227 batches | Loss:   3.75 | Tokens:  1776 | Learning Rate: 9.6e-05 | Time: 1:46:04.544400 |
| Epoch   1 |   129/  227 batches | Loss:   3.86 | Tokens:  1890 | Learning Rate: 9.6e-05 | Time: 1:46:56.144266 |
| Epoch   1 |   130/  227 batches | Loss:   3.72 | Tokens:  1800 | Learning Rate: 9.6e-05 | Time: 1:47:52.231985 |
| Epoch   1 |   131/  227 batches | Loss:   3.78 | Tokens:  1787 | Learning Rate: 9.7e-05 | Time: 1:48:42.702585 |
| Epoch   1 |   132/  227 batches | Loss:   3.62 | Tokens:  1809 | Learning Rate

| Epoch   1 |   196/  227 batches | Loss:   3.65 | Tokens:  2014 | Learning Rate: 1.1e-04 | Time: 2:43:46.575538 |
| Epoch   1 |   197/  227 batches | Loss:   3.61 | Tokens:  1950 | Learning Rate: 1.1e-04 | Time: 2:44:34.883324 |
| Epoch   1 |   198/  227 batches | Loss:   3.73 | Tokens:  1901 | Learning Rate: 1.1e-04 | Time: 2:45:24.653200 |
| Epoch   1 |   199/  227 batches | Loss:   3.70 | Tokens:  2041 | Learning Rate: 1.1e-04 | Time: 2:46:08.933757 |
| Epoch   1 |   200/  227 batches | Loss:   3.73 | Tokens:  1937 | Learning Rate: 1.2e-04 | Time: 2:46:57.235560 |
| Epoch   1 |   201/  227 batches | Loss:   3.83 | Tokens:  1934 | Learning Rate: 1.2e-04 | Time: 2:47:44.926133 |
| Epoch   1 |   202/  227 batches | Loss:   3.75 | Tokens:  1939 | Learning Rate: 1.2e-04 | Time: 2:48:32.102944 |
| Epoch   1 |   203/  227 batches | Loss:   3.92 | Tokens:  2015 | Learning Rate: 1.2e-04 | Time: 2:49:21.270428 |
| Epoch   1 |   204/  227 batches | Loss:   3.62 | Tokens:  1894 | Learning Rate

| Epoch   2 |    33/  227 batches | Loss:   3.16 | Tokens:  1670 | Learning Rate: 1.3e-04 | Time: 0:37:05.142927 |
| Epoch   2 |    34/  227 batches | Loss:   3.14 | Tokens:  1771 | Learning Rate: 1.3e-04 | Time: 0:37:50.819362 |
| Epoch   2 |    35/  227 batches | Loss:   3.28 | Tokens:  1737 | Learning Rate: 1.3e-04 | Time: 0:38:35.562682 |
| Epoch   2 |    36/  227 batches | Loss:   3.06 | Tokens:  1693 | Learning Rate: 1.3e-04 | Time: 0:39:23.465552 |
| Epoch   2 |    37/  227 batches | Loss:   3.22 | Tokens:  1748 | Learning Rate: 1.3e-04 | Time: 0:40:10.093829 |
| Epoch   2 |    38/  227 batches | Loss:   3.12 | Tokens:  1842 | Learning Rate: 1.3e-04 | Time: 0:40:55.311882 |
| Epoch   2 |    39/  227 batches | Loss:   3.22 | Tokens:  1784 | Learning Rate: 1.3e-04 | Time: 0:41:40.684777 |
| Epoch   2 |    40/  227 batches | Loss:   3.37 | Tokens:  1747 | Learning Rate: 1.3e-04 | Time: 0:42:25.861506 |
| Epoch   2 |    41/  227 batches | Loss:   3.14 | Tokens:  1806 | Learning Rate

| Epoch   2 |   105/  227 batches | Loss:   3.06 | Tokens:  1740 | Learning Rate: 1.5e-04 | Time: 1:31:57.850025 |
| Epoch   2 |   106/  227 batches | Loss:   2.88 | Tokens:  1695 | Learning Rate: 1.5e-04 | Time: 1:32:43.753243 |
| Epoch   2 |   107/  227 batches | Loss:   3.00 | Tokens:  1766 | Learning Rate: 1.5e-04 | Time: 1:33:31.979247 |
| Epoch   2 |   108/  227 batches | Loss:   3.08 | Tokens:  1750 | Learning Rate: 1.5e-04 | Time: 1:34:19.004465 |
| Epoch   2 |   109/  227 batches | Loss:   2.98 | Tokens:  1725 | Learning Rate: 1.5e-04 | Time: 1:35:03.579237 |
| Epoch   2 |   110/  227 batches | Loss:   2.98 | Tokens:  1700 | Learning Rate: 1.5e-04 | Time: 1:35:48.570527 |
| Epoch   2 |   111/  227 batches | Loss:   2.93 | Tokens:  1678 | Learning Rate: 1.5e-04 | Time: 1:36:33.997020 |
| Epoch   2 |   112/  227 batches | Loss:   3.03 | Tokens:  1741 | Learning Rate: 1.5e-04 | Time: 1:37:19.456103 |
| Epoch   2 |   113/  227 batches | Loss:   2.92 | Tokens:  1659 | Learning Rate

| Epoch   2 |   177/  227 batches | Loss:   2.92 | Tokens:  1848 | Learning Rate: 1.7e-04 | Time: 2:26:55.260376 |
| Epoch   2 |   178/  227 batches | Loss:   3.01 | Tokens:  1800 | Learning Rate: 1.7e-04 | Time: 2:27:41.042916 |
| Epoch   2 |   179/  227 batches | Loss:   2.79 | Tokens:  1850 | Learning Rate: 1.7e-04 | Time: 2:28:23.346765 |
| Epoch   2 |   180/  227 batches | Loss:   2.87 | Tokens:  1816 | Learning Rate: 1.7e-04 | Time: 2:29:09.629965 |
| Epoch   2 |   181/  227 batches | Loss:   2.83 | Tokens:  1889 | Learning Rate: 1.7e-04 | Time: 2:29:52.875020 |
| Epoch   2 |   182/  227 batches | Loss:   2.98 | Tokens:  1778 | Learning Rate: 1.7e-04 | Time: 2:30:36.982042 |
| Epoch   2 |   183/  227 batches | Loss:   2.80 | Tokens:  1814 | Learning Rate: 1.7e-04 | Time: 2:31:24.277908 |
| Epoch   2 |   184/  227 batches | Loss:   3.05 | Tokens:  1888 | Learning Rate: 1.7e-04 | Time: 2:32:09.854997 |
| Epoch   2 |   185/  227 batches | Loss:   3.12 | Tokens:  1855 | Learning Rate

| Epoch   3 |    14/  227 batches | Loss:   2.87 | Tokens:  1782 | Learning Rate: 1.9e-04 | Time: 0:11:11.657789 |
| Epoch   3 |    15/  227 batches | Loss:   2.85 | Tokens:  1825 | Learning Rate: 1.9e-04 | Time: 0:11:55.000855 |
| Epoch   3 |    16/  227 batches | Loss:   2.75 | Tokens:  1788 | Learning Rate: 1.9e-04 | Time: 0:12:40.571961 |
| Epoch   3 |    17/  227 batches | Loss:   2.82 | Tokens:  1793 | Learning Rate: 1.9e-04 | Time: 0:13:27.777149 |
| Epoch   3 |    18/  227 batches | Loss:   2.94 | Tokens:  1935 | Learning Rate: 1.9e-04 | Time: 0:14:11.391490 |
| Epoch   3 |    19/  227 batches | Loss:   2.81 | Tokens:  1785 | Learning Rate: 1.9e-04 | Time: 0:14:55.296054 |
| Epoch   3 |    20/  227 batches | Loss:   2.88 | Tokens:  1748 | Learning Rate: 1.9e-04 | Time: 0:15:42.419009 |
| Epoch   3 |    21/  227 batches | Loss:   2.81 | Tokens:  1762 | Learning Rate: 1.9e-04 | Time: 0:16:23.383437 |
| Epoch   3 |    22/  227 batches | Loss:   2.84 | Tokens:  1734 | Learning Rate

| Epoch   3 |    86/  227 batches | Loss:   2.71 | Tokens:  1714 | Learning Rate: 2.1e-04 | Time: 1:05:25.284606 |
| Epoch   3 |    87/  227 batches | Loss:   2.60 | Tokens:  1705 | Learning Rate: 2.1e-04 | Time: 1:06:10.676192 |
| Epoch   3 |    88/  227 batches | Loss:   2.70 | Tokens:  1684 | Learning Rate: 2.1e-04 | Time: 1:06:56.464717 |
| Epoch   3 |    89/  227 batches | Loss:   2.55 | Tokens:  1671 | Learning Rate: 2.1e-04 | Time: 1:07:38.974014 |
| Epoch   3 |    90/  227 batches | Loss:   2.85 | Tokens:  1713 | Learning Rate: 2.1e-04 | Time: 1:08:22.211361 |
| Epoch   3 |    91/  227 batches | Loss:   2.64 | Tokens:  1691 | Learning Rate: 2.1e-04 | Time: 1:09:10.931047 |
| Epoch   3 |    92/  227 batches | Loss:   2.78 | Tokens:  1741 | Learning Rate: 2.1e-04 | Time: 1:09:55.163733 |
| Epoch   3 |    93/  227 batches | Loss:   2.57 | Tokens:  1663 | Learning Rate: 2.1e-04 | Time: 1:10:40.005788 |
| Epoch   3 |    94/  227 batches | Loss:   2.60 | Tokens:  1686 | Learning Rate

| Epoch   3 |   158/  227 batches | Loss:   2.57 | Tokens:  1810 | Learning Rate: 2.3e-04 | Time: 1:59:46.647369 |
| Epoch   3 |   159/  227 batches | Loss:   2.78 | Tokens:  1859 | Learning Rate: 2.3e-04 | Time: 2:00:31.097170 |
| Epoch   3 |   160/  227 batches | Loss:   2.64 | Tokens:  1835 | Learning Rate: 2.3e-04 | Time: 2:01:15.170282 |
| Epoch   3 |   161/  227 batches | Loss:   2.64 | Tokens:  1820 | Learning Rate: 2.3e-04 | Time: 2:02:03.145849 |
| Epoch   3 |   162/  227 batches | Loss:   2.63 | Tokens:  1789 | Learning Rate: 2.3e-04 | Time: 2:02:47.502204 |
| Epoch   3 |   163/  227 batches | Loss:   2.81 | Tokens:  1895 | Learning Rate: 2.3e-04 | Time: 2:03:33.069322 |
| Epoch   3 |   164/  227 batches | Loss:   2.71 | Tokens:  1853 | Learning Rate: 2.3e-04 | Time: 2:04:19.061302 |
| Epoch   3 |   165/  227 batches | Loss:   2.63 | Tokens:  1799 | Learning Rate: 2.3e-04 | Time: 2:05:00.910363 |
| Epoch   3 |   166/  227 batches | Loss:   2.61 | Tokens:  1819 | Learning Rate

In [None]:
plt.plot(train_losses, label='train_loss')
plt.plot(val_losses, label='val_loss')
plt.legend()