In [None]:
import torch
from torch.nn import functional as F
from transformer import Transformer, BOS_IDX, EOS_IDX, block_size, device

In [None]:
model = Transformer()
# model.load_state_dict(torch.load('checkpoints-t0/checkpoint-ep05.pt'))
model.load_state_dict(torch.load('checkpoint-ep05.pt'))
model.to(device)
model.eval()

In [None]:
from tokenizers import Tokenizer
tokenizers = {}
for lang in ['en', 'es']:
    tokenizers[lang] = Tokenizer.from_file(f"tokenizer-{lang}.json")

In [None]:
def tensor_transform(token_ids):
    transformed =  torch.cat((
        torch.tensor([BOS_IDX]),
        torch.tensor(token_ids),
        torch.tensor([EOS_IDX]))
    )
    if len(transformed.shape) == 1:
        transformed = transformed[None, :]
    return transformed

In [None]:
def generate(model, idx_enc, greedy=False):
    # idx is (B, T) array of indices in the current context
    B = idx_enc.shape[0]
    device = next(model.parameters()).device
    idx = torch.ones(B,1).fill_(BOS_IDX).type(torch.long).to(device)
    for i in range(block_size):
    # for i in range(10):
        # get the predictions
        logits, _ = model(idx, idx_enc)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)
        if greedy:
            idx_next = torch.argmax(logits, dim=-1)
        else:
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        # make everything comes after 1st EOS an EOS
        idx_next = torch.where(idx[:, -1]==EOS_IDX, EOS_IDX, idx_next.squeeze())
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next[:, None]), dim=1) # (B, T+1)
        if torch.all(idx[:, -1]==EOS_IDX):
            break
    return idx

In [None]:
sents = [
    "What a beautiful day!",
    "This is the first deep learning model that I built from scratch.",
    "You'll have to address any questions to my commanding officer.",
    "The results of the project are clear.",
    "I mean, my dad does business with them, or he raised money for them.",
]

In [None]:
for sent in sents:
    toks = tensor_transform(tokenizers['en'].encode(sent).ids).to(device)
    toks_es = generate(model, toks, greedy=True)
    sent_es = tokenizers['es'].decode(toks_es[0].tolist())
    print(sent)
    print(sent_es)
    print()