In [14]:
# Use below line for demo in external colabs
!pip install -q torchdata==0.3.0 torchtext==0.12 spacy==3.2 GPUtil
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm
!pip install -q git+https://github.com/nikitakapitan/transformers.git

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch

from transformers.data.token import load_tokenizers
from transformers.data.vocab import load_vocab

from transformers.main import make_model
from transformers.output import check_outputs

from torch.utils.data import DataLoader
from transformers.data.token import tokenize
from transformers.data.Batch import collate_batch

from google.colab import drive
drive.mount('/content/drive')

%load_ext autoreload
%autoreload 2

In [None]:
!cp -r drive/MyDrive/multi30k_model_final.pt multi30k_model_final.pt

In [2]:
spacy_de, spacy_en = load_tokenizers()
vocab_src, vocab_tgt = load_vocab(spacy_de=spacy_de, spacy_en=spacy_en)

Finished.
Vocabulary sizes:
len: SRC=8315 TGT=6384


In [3]:
YOUR_GERMAN_SENTENCE = "Drei Hunde in schwarzen Jacken kaufen Milch in der Innenstadt"

data_setup = {
    'max_padding' : 128,
}

architecture = {
        'src_vocab_len' : len(vocab_src),
        'tgt_vocab_len' : len(vocab_tgt),
        'N' : 6, # loop
        'd_model' : 512, # emb
        'd_ff' : 2048,
        'h' : 8,
        'p_dropout' : 0.1
    }

def collate_fn(batch):
        return collate_batch(
            batch=batch,
            src_pipeline=lambda x : tokenize(x, spacy_de),
            tgt_pipeline=lambda x : tokenize(x, spacy_en),
            src_vocab=vocab_src,
            tgt_vocab=vocab_tgt,
            device=None,
            max_padding=data_setup['max_padding'],
            pad_id=vocab_src.get_stoi()["<blank>"],
        )

phrase = DataLoader([(YOUR_GERMAN_SENTENCE, 'None')], collate_fn=collate_fn)

model = make_model(
    src_vocab_len=architecture['src_vocab_len'],
    tgt_vocab_len=architecture['tgt_vocab_len'],
    N=architecture['N'],
    d_model=architecture['d_model'],
    d_ff=architecture['d_ff'],
    h=architecture['h'],
    dropout=architecture['p_dropout'],
    )

model.load_state_dict(
    torch.load("multi30k_model_final.pt", map_location=torch.device("cpu"))
)

example_data = check_outputs(
        phrase, model, vocab_src, vocab_tgt, n_examples=2
    )