# Part 1 Model Building

In [1]:
import torch
import copy
import math
from torch import nn, Tensor
from torch.nn.functional import log_softmax, pad

In [2]:
def custom_repr(self):
    return "{} {}".format(self.size(), origin_repr(self))

origin_repr = torch.Tensor.__repr__
torch.Tensor.__repr__ = custom_repr

In [3]:
class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
    
    def step(self):
        None

    def zero_grad(self):
        None

class DummyScheduler:
    def step(self):
        None

In [4]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super().__init__()
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator
    
    def forward(self, src, tgt, src_mask, tgt_mask):
        return self.decode(
            self.encode(src, src_mask), src_mask,
            tgt, tgt_mask
        )

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
    
# class EncoderDecoder(nn.Module):
#     """
#     A standard Encoder-Decoder architecture. Base for this and many
#     other models.
#     """

#     def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
#         super(EncoderDecoder, self).__init__()
#         self.encoder = encoder
#         self.decoder = decoder
#         self.src_embed = src_embed
#         self.tgt_embed = tgt_embed
#         self.generator = generator

#     def forward(self, src, tgt, src_mask, tgt_mask):
#         "Take in and process masked src and target sequences."
#         return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

#     def encode(self, src, src_mask):
#         return self.encoder(self.src_embed(src), src_mask)

#     def decode(self, memory, src_mask, tgt, tgt_mask):
#         return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

class Generator(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return log_softmax(self.proj(x), dim=-1)

def clones(layer, N):
    return nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.layernorm = nn.LayerNorm(layer.d_model)
    
    def forward(self, src, src_mask):
        x = src
        for layer in self.layers:
            x = layer(x, src_mask)
        return self.layernorm(x)

class SubLayerConnection(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, layer):
        out = layer(self.layernorm(x))
        return self.dropout(out) + x

class EncoderLayer(nn.Module):
    def __init__(self, d_model, attn, ffn, dropout):
        super().__init__()
        self.attn = attn
        self.ffn = ffn
        self.d_model = d_model
        self.sublayers = clones(SubLayerConnection(d_model, dropout), 2)

    def forward(self, src, src_mask):
        x = self.sublayers[0](src, lambda x: self.attn(x, x, x, src_mask))
        x = self.sublayers[1](x, self.ffn)
        return x

class Decoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.layernorm = nn.LayerNorm(layer.d_model)

    def forward(self, tgt, memory, src_mask, tgt_mask):
        x = tgt
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.layernorm(x)

class DecoderLayer(nn.Module):
    def __init__(self, d_model, self_attn, src_attn, ffn, dropout):
        super().__init__()
        self.src_attn = src_attn
        self.self_attn = self_attn
        self.ffn = ffn
        self.d_model = d_model
        self.sublayers = clones(SubLayerConnection(d_model, dropout), 3)

    def forward(self, tgt, memory, src_mask, tgt_mask):
        x = self.sublayers[0](tgt, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayers[1](x, lambda x: self.src_attn(x, memory, memory, src_mask))
        x = self.sublayers[2](x, self.ffn)
        return x

def subsequent_mask(seq_len):
    mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.int8))
    mask = mask.unsqueeze(0)
    return mask == 1

def attention(q, k, v, mask=None, dropout=None):
    d_k =  q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask==0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    out = torch.matmul(p_attn, v)
    return out

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=None):
        super().__init__()
        assert d_model % h == 0
        self.dk = d_model // h
        self.h = h
        if dropout is not None:
            self.dropout = nn.Dropout(dropout)
        self.linears = clones(nn.Linear(d_model, d_model), 4)

    def forward(self, q, k, v, mask: torch.Tensor):
        batch_size = q.size()[0]
        q, k, v = [
            lin(x).reshape(batch_size, -1, self.h, self.dk).transpose(1, 2) 
            for x, lin in zip([q, k, v], self.linears)
        ]

        if mask is not None:
            mask = mask.unsqueeze(1)

        out = attention(q, k, v, mask)
        
        out = out.transpose(1, 2).reshape(batch_size, -1, self.h * self.dk)
        out = self.linears[-1](out)
        del q, k, v
        return out
    
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.lut = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.max_len = max_len
        self.d_model = d_model
        pos_embed = torch.ones(max_len, d_model)
        div_item = torch.exp(-math.log(10000) * torch.arange(0, d_model, 2) / d_model)
        pos = torch.arange(0, max_len, 1).unsqueeze(1)
        pos_embed[:, 0::2] = torch.sin(pos * div_item)
        pos_embed[:, 1::2] = torch.cos(pos * div_item)
        self.register_buffer('pos_embed', pos_embed.detach().clone())

    def forward(self, x):
        seq_len = x.size(1)
        x = x + self.pos_embed[:seq_len]
        return self.dropout(x)

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ffn, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x).relu()
        return self.linear2(self.dropout(x))


In [5]:
def make_model(
        vocab_src,
        vocab_tgt,
        N=6,
        d_model=512,
        d_ffn=2048,
        dropout=0.1,
        h=8,
        max_len=128
) -> EncoderDecoder:
    from copy import deepcopy as c
    src_embed = nn.Sequential(Embeddings(d_model, vocab_src), PositionalEncoding(d_model, dropout, max_len))
    tgt_embed = nn.Sequential(Embeddings(d_model, vocab_tgt), PositionalEncoding(d_model, dropout, max_len))
    attn = MultiHeadedAttention(h, d_model, dropout)
    ffn = PositionwiseFeedForward(d_model, d_ffn)
    generator = Generator(d_model, vocab_tgt)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ffn), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ffn), dropout), N),
        src_embed,
        tgt_embed,
        generator
    )

    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return model

In [46]:
def greedy_decode(model: EncoderDecoder, src, src_mask, max_len, start_index):
    memory = model.encode(src, src_mask)
    tgt = torch.tensor([[start_index]], dtype=torch.int64).to(src.detach())
    for _ in range(max_len-1):
        out = model.decode(memory, src_mask, tgt, subsequent_mask(tgt.size(1)).to(src.detach()))
        out = model.generator(out)
        last_embed = out[0][-1]
        _, nxt = torch.max(last_embed, dim=-1)
        tgt = torch.concat(
            [tgt, torch.tensor([[nxt.item()]])],
            dim=1
        )
    # print(tgt)
    return tgt

In [7]:
def model_inference():
    V = 11
    model = make_model(V, V, 2)
    src = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=torch.int64)
    src_mask = torch.ones(1, 1, 10)
    model.eval()
    greedy_decode(model, src, src_mask, 10, 1)

for _ in range(10):
    model_inference()

torch.Size([1, 10]) tensor([[1, 5, 7, 8, 5, 7, 8, 5, 7, 8]])
torch.Size([1, 10]) tensor([[ 1,  3, 10, 10, 10, 10, 10, 10, 10, 10]])
torch.Size([1, 10]) tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
torch.Size([1, 10]) tensor([[1, 5, 6, 6, 6, 6, 6, 6, 6, 6]])
torch.Size([1, 10]) tensor([[1, 8, 2, 2, 2, 2, 2, 2, 2, 2]])
torch.Size([1, 10]) tensor([[1, 4, 2, 2, 2, 2, 2, 2, 2, 2]])
torch.Size([1, 10]) tensor([[1, 0, 9, 3, 9, 3, 9, 1, 1, 1]])
torch.Size([1, 10]) tensor([[1, 2, 2, 6, 2, 6, 2, 6, 2, 6]])
torch.Size([1, 10]) tensor([[1, 3, 0, 7, 3, 8, 1, 3, 8, 0]])
torch.Size([1, 10]) tensor([[1, 9, 7, 9, 7, 2, 8, 0, 0, 0]])


# Model Training

In [36]:
class Batch:
    def __init__(self, src, tgt, padding_idx=0):
        self.src = src
        self.src_mask = (src != padding_idx).unsqueeze(1).to(src.detach())
        self.tgt = tgt[:, :-1].detach().clone()
        self.tgt_y = tgt[:, 1:].detach().clone()
        self.tgt_mask = (self.tgt_y != padding_idx).unsqueeze(-2) & subsequent_mask(self.tgt_y.size(1)).to(src.detach())
        # self.tgt_mask = self.tgt_mask.to(src.detach())
        self.n_tokens = (self.tgt_y != padding_idx).sum().item()

In [40]:
class LabelSmoothing(nn.Module):
    def __init__(self, padding_idx=0, smoothing=0.1):
        super().__init__()
        self.criterion = nn.KLDivLoss(reduction='sum')
        self.pad_idx = padding_idx
        self.smoothing = smoothing

    def forward(self, pred: Tensor, label: Tensor):
        true_dist = pred.detach().clone()
        categroy_size = pred.size(-1)
        true_dist.fill_(self.smoothing / (categroy_size-2))
        # true_dist.index_fill_(1, torch.tensor(self.pad_idx).to(pred), 0)
        true_dist[:, self.pad_idx] = 0
        true_dist.scatter_(1, label.unsqueeze(1), 1 - self.smoothing)
        mask = torch.nonzero(label == self.pad_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        return self.criterion(pred, true_dist)

In [10]:
class SimpleLoss:
    def __init__(self, crit, generator):
        self.crit = crit
        self.generator = generator

    def __call__(self, x, label, norm):
        logit = self.generator(x)
        sloss = self.crit(
            logit.reshape(-1, logit.size(-1)),
            label.reshape(-1)
        ) / norm
        return sloss.detach() * norm, sloss

In [11]:
def data_gen(vocab, batch_size, seq_len, batches):
    for _ in range(batches):
        data = torch.randint(1, vocab, (batch_size, seq_len)).to(torch.int64)
        data[:, 0] = 1
        src = data.clone()
        tgt = data.clone()
        yield Batch(src, tgt, 0)

In [12]:
import time

def run_epoch(
        data_iter,
        model: EncoderDecoder,
        optimizer,
        lr_scheduler,
        loss_compute,
        accum_interval,
        mode='train'
):
    accum_steps = 0
    total_tokens = 0
    total_loss = 0
    start_time = time.time()
    for i, batch in enumerate(data_iter):
        out = model(batch.src, batch.tgt, batch.src_mask, batch.tgt_mask)
        loss, loss_node = loss_compute(out, batch.tgt_y, batch.n_tokens)
        total_loss += loss.item()
        if 'train' in mode:
            total_tokens += batch.n_tokens
            loss_node.backward()
            if i % accum_interval == 0:
                accum_steps += 1
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
            lr_scheduler.step()

        if i % 40 == 0 and 'train' in mode:
            lr = optimizer.param_groups[0]['lr']
            end_time = time.time()
            print(
                "Epoch Step: {:6d} | Accumulation Step: {:3d} | Loss: {:6.2f} | Tokens / Sec: {:7.1f} | Learning Rate: {:6.1e}".format(
                    i, accum_steps, loss_node.detach().item(), total_tokens / (end_time - start_time), lr
                )
            )
            start_time = time.time()
        del loss, loss_node
    return total_loss

In [13]:
def rate(step, warmup, d_model, factor=1.0):
    if step == 0:
        step = 1
    return d_model ** -0.5 * factor * min(step ** -0.5, warmup ** -1.5 * step)

In [14]:
def train_model():
    V = 11
    model = make_model(V, V, 2)
    d_model = 512
    epochs = 20
    batch_size = 80
    batches = 20
    seq_len = 10
    warm_up = 400
    accum_interval = 1
    loss_compute = SimpleLoss(
        LabelSmoothing(0, 0.0),
        model.generator
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9)
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: rate(x, warm_up, d_model, 1))
    for _ in range(epochs):
        model.train()
        run_epoch(
            data_gen(V, batch_size, seq_len, batches),
            model,
            optimizer,
            lr_scheduler,
            loss_compute,
            accum_interval,
            'train'
        )

        model.eval()
        run_epoch(
            data_gen(V, batch_size, seq_len, batches),
            model,
            DummyOptimizer(),
            DummyScheduler(),
            loss_compute,
            accum_interval,
            'eval'
        )

    src = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=torch.int64)
    src_mask = torch.ones(1, 1, 10)
    model.eval()
    greedy_decode(model, src, src_mask, 10, 1)

# train_model()

# Real Example

In [15]:
import os
import spacy
from torchtext import datasets
from torchtext.datasets import multi30k
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader


multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"
multi30k.URL["test"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/mmt16_task1_test.tar.gz"

multi30k.MD5["train"] = "20140d013d05dd9a72dfde46478663ba05737ce983f478f960c1123c6671be5e"
multi30k.MD5["valid"] = "a7aa20e9ebd5ba5adce7909498b94410996040857154dab029851af3a866da8c"
multi30k.MD5["test"] = "6d1ca1dba99e2c5dd54cae1226ff11c2551e6ce63527ebb072a1f70f72a5cd36"

In [16]:
def load_tokenizer():
    try:
        spacy_de, spacy_en = spacy.load('de_core_news_sm'), spacy.load('en_core_web_sm')
    except IOError:
        os.system('python -m spacy download en_core_web_sm')
        os.system('python -m spacy download de_core_news_sm')
        spacy_de, spacy_en = spacy.load('de_core_news_sm'), spacy.load('en_core_web_sm')
    
    return spacy_de, spacy_en

In [30]:
spacy_de, spacy_en = load_tokenizer()

def tokenize(tokenizer, text):
    return [tok.text for tok in tokenizer.tokenizer(text)]

def yield_tokens(dataset, tokenizer, index):
    for item in dataset:
        yield tokenizer(item[index])

def build_vocab():
    def tokenize_de(text):
        return tokenize(spacy_de, text)
    
    def tokenize_en(text):
        return tokenize(spacy_en, text)

    train, val, test = datasets.Multi30k(language_pair=("de", "en"))

    vocab_de = build_vocab_from_iterator(
        iterator=yield_tokens(train+val+test, tokenize_de, 0),
        min_freq=1,
        specials=['<s>', '</s>', '<blank>', '<unk>']
    )

    vocab_en = build_vocab_from_iterator(
        iterator=yield_tokens(train+val+test, tokenize_en, 1),
        min_freq=1,
        specials=['<s>', '</s>', '<blank>', '<unk>']
    )
    return vocab_de, vocab_en

def load_vocab():
    if os.path.exists('vocab.pt'):
        vocab_de, vocab_en = torch.load('vocab.pt')
    else:
        vocab_de, vocab_en = build_vocab()
        torch.save((vocab_de, vocab_en), 'vocab.pt')
    return vocab_de, vocab_en

vocab_de, vocab_en = load_vocab()

In [55]:
print(type(spacy_en('hello_world')))
print(type(spacy_en('hello_world')[0]))

<class 'spacy.tokens.doc.Doc'>
<class 'spacy.tokens.token.Token'>


In [32]:
def collate_batch(
        batch,
        pipline_de,
        pipline_en,
        vocab_de,
        vocab_en,
        device,
        max_len,
):
    start_index, end_index = vocab_de['<s>'], vocab_de['</s>']
    padding_idx = vocab_de['<blank>']
    src_list, tgt_list = [], []
    for from_to_tuple in batch:
        src, tgt = from_to_tuple
        src = torch.concat(
            [
                torch.tensor([start_index], device=device, dtype=torch.int64),
                torch.tensor(
                    vocab_de(pipline_de(src)),
                    dtype=torch.int64,
                    device=device
                ),
                torch.tensor([end_index], device=device, dtype=torch.int64)
            ],
            0
        )

        tgt = torch.concat(
            [
                torch.tensor([start_index], device=device, dtype=torch.int64),
                torch.tensor(
                    vocab_en(pipline_en(tgt)),
                    dtype=torch.int64,
                    device=device
                ),
                torch.tensor([end_index], device=device, dtype=torch.int64)
            ],
            0
        )

        src = pad(
            src,
            [0, max_len-src.size(0)],
            value=padding_idx
        )

        tgt = pad(
            tgt,
            [0, max_len-tgt.size(0)],
            value=padding_idx
        )

        src_list.append(src)
        tgt_list.append(tgt)
    return (torch.stack(src_list, dim=0), torch.stack(tgt_list, dim=0))

In [33]:
def create_dataloader(
        config,
        spacy_de,
        spacy_en,
        vocab_de,
        vocab_en,
        device,
):
    def tokenize_de(text):
        return tokenize(spacy_de, text)
    
    def tokenize_en(text):
        return tokenize(spacy_en, text)

    def collate_fn(batch):
        return collate_batch(
            batch,
            tokenize_de,
            tokenize_en,
            vocab_de,
            vocab_en,
            device,
            config['max_seqlen']
        )

    train, val, test = datasets.Multi30k(language_pair=('de', 'en'))

    train_dataloader = DataLoader(
        train,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=collate_fn
    )

    val_dataloader = DataLoader(
        val,
        batch_size=config['batch_size'],
        shuffle=True,
        collate_fn=collate_fn
    )
    return train_dataloader, val_dataloader

In [34]:
def train_worker(
        device,
        config,
        spacy_de,
        spacy_en,
        vocab_de,
        vocab_en
):
    d_model = 512
    padding_idx = vocab_de['<unk>']
    model = make_model(len(vocab_de), len(vocab_en), max_len=5000)
    model.to(device)

    criterion = LabelSmoothing(padding_idx, 0.1)
    criterion.to(device)

    train_dataloader, eval_dataloader = create_dataloader(
        config,
        spacy_de,
        spacy_en,
        vocab_de,
        vocab_en,
        device
    )

    optimizer = torch.optim.Adam(model.parameters(), config['base_lr'], betas=(0.9, 0.98), eps=1e-9)

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda = lambda step: rate(step, config['warmup'], d_model)
    )

    loss_compute = SimpleLoss(criterion, model.generator)

    for i in range(config['epochs']):
        model.train()
        run_epoch(
            (Batch(batch[0], batch[1], padding_idx) for batch in train_dataloader),
            model,
            optimizer,
            lr_scheduler,
            loss_compute,
            accum_interval=config['accum_interval'],
            mode='train'
        )

        torch.save(model.state_dict(), '{}_{:2d}.pt'.format(config['model_prefix'], i))

        model.eval()
        run_epoch(
            (Batch(batch[0], batch[1], padding_idx) for batch in eval_dataloader),
            model,
            DummyOptimizer(),
            DummyScheduler(),
            loss_compute,
            1,
            'eval'
        )
    torch.save(model.state_dict(), '{}_final.pt'.format(config['model_prefix']))

In [None]:
def load_trained_model():
    config = {
        'epochs': 8,
        'model_prefix': 'multi30k',
        'accum_interval': 10,
        'batch_size': 32,
        'warmup': 3000,
        'base_lr': 1.0,
        'max_seqlen': 72,
    }

    if not os.path.exists('{}_final.pt'.format(config['model_prefix'])):
        train_worker(0, config, spacy_de, spacy_en, vocab_de, vocab_en)
    model = make_model(len(vocab_de), len(vocab_en), max_len=5000)
    model.load_state_dict(torch.load('{}_final.pt'.format(config['model_prefix'])))
    return model

load_trained_model()

In [48]:
def check_outputs(
    valid_dataloader,
    model,
    vocab_src,
    vocab_tgt,
    n_examples=15,
    pad_idx=2,
    eos_string="</s>",
):
    results = [()] * n_examples
    for idx in range(n_examples):
        print("\nExample %d ========\n" % idx)
        b = next(iter(valid_dataloader))
        rb = Batch(b[0], b[1], pad_idx)
        greedy_decode(model, rb.src, rb.src_mask, 64, 0)[0]

        src_tokens = [
            vocab_src.get_itos()[x] for x in rb.src[0] if x != pad_idx
        ]
        tgt_tokens = [
            vocab_tgt.get_itos()[x] for x in rb.tgt[0] if x != pad_idx
        ]

        print(
            "Source Text (Input)        : "
            + " ".join(src_tokens).replace("\n", "")
        )
        print(
            "Target Text (Ground Truth) : "
            + " ".join(tgt_tokens).replace("\n", "")
        )
        model_out = greedy_decode(model, rb.src, rb.src_mask, 72, 0)[0]
        model_txt = (
            " ".join(
                [vocab_tgt.get_itos()[x] for x in model_out if x != pad_idx]
            ).split(eos_string, 1)[0]
            + eos_string
        )
        print("Model Output               : " + model_txt.replace("\n", ""))
        results[idx] = (rb, src_tokens, tgt_tokens, model_out, model_txt)
    return results


def run_model_example(n_examples=5):
    global vocab_de, vocab_en, spacy_de, spacy_en
    config = {
        'epochs': 8,
        'model_prefix': 'multi30k',
        'accum_interval': 10,
        'batch_size': 1,
        'warmup': 3000,
        'base_lr': 1.0,
        'max_seqlen': 72,
    }

    print("Preparing Data ...")
    _, valid_dataloader = create_dataloader(
        config,
        spacy_de,
        spacy_en,
        vocab_de,
        vocab_en,
        torch.device("cpu"),
    )

    print("Loading Trained Model ...")

    model = make_model(len(vocab_de), len(vocab_en), N=6, max_len=5000)
    model.load_state_dict(
        torch.load("{}_final.pt".format(config['model_prefix']), map_location=torch.device("cpu"))
    )

    print("Checking Model Outputs:")
    example_data = check_outputs(
        valid_dataloader, model, vocab_de, vocab_en, n_examples=n_examples
    )
    return model, example_data


_ = run_model_example()

Preparing Data ...
Loading Trained Model ...




Checking Model Outputs:






Source Text (Input)        : <s> Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen </s>
Target Text (Ground Truth) : <s> A group of men are loading cotton onto a truck </s>
Model Output               : <s> A group of men are picking up a truck while being pulled up on a truck . </s>


Source Text (Input)        : <s> Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen </s>
Target Text (Ground Truth) : <s> A group of men are loading cotton onto a truck </s>
Model Output               : <s> A group of men on a truck is hanging out of a truck . </s>


Source Text (Input)        : <s> Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen </s>
Target Text (Ground Truth) : <s> A group of men are loading cotton onto a truck </s>
Model Output               : <s> A group of men are spinning out of a truck 's hanging on a truck . </s>


Source Text (Input)        : <s> Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen </s>
Target Text (Ground Truth) : <s> A group of