In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.insert(0, "..")

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from transformers import AutoTokenizer

from model import UniversalTransformer

%matplotlib inline

In [4]:
# GPT-2 uses BPE
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [9]:
model = UniversalTransformer(
    source_vocab_size=tokenizer.vocab_size,
    target_vocab_size=tokenizer.vocab_size,
    d_model=512,
    n_heads=8,
    d_feedforward=2048,
    max_seq_len=100,
    max_time_step=10,
    halting_thresh=0.8
)

# load checkpoint
checkpoint_path = "../checkpoints/latest.pt"
cp = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(cp["model_state_dict"])

<All keys matched successfully>

In [5]:
samples = [
    { "de": "Im Parlament besteht der Wunsch nach einer Aussprache im Verlauf dieser Sitzungsperiode in den nächsten Tagen.", "en": "You have requested a debate on this subject in the course of the next few days, during this part-session." },
    { "de": "Heute möchte ich Sie bitten - das ist auch der Wunsch einiger Kolleginnen und Kollegen -, allen Opfern der Stürme, insbesondere in den verschiedenen Ländern der Europäischen Union, in einer Schweigeminute zu gedenken.", "en": "In the meantime, I should like to observe a minute' s silence, as a number of Members have requested, on behalf of all the victims concerned, particularly those of the terrible storms, in the various countries of the European Union." },
    { "de": "Ich bitte Sie, sich zu einer Schweigeminute zu erheben.", "en": "Please rise, then, for this minute' s silence." }
]

In [16]:
source = samples[0]["en"]
source_ids = tokenizer(
    [source + tokenizer.eos_token],
    max_length=100,
    truncation=True,
    return_tensors="pt",
).input_ids
source, source_ids

('You have requested a debate on this subject in the course of the next few days, during this part-session.',
 tensor([[ 1639,   423,  9167,   257,  4384,   319,   428,  2426,   287,   262,
           1781,   286,   262,  1306,  1178,  1528,    11,  1141,   428,   636,
             12, 29891,    13, 50256]]))