In [14]:
# Use below line for demo in external colabs
!pip install -q torchdata torchtext spacy==3.2 portalocker altair GPUtil
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm
!pip install -q git+https://github.com/nikitakapitan/transflate.git

In [37]:
import warnings
warnings.filterwarnings('ignore')

import torch
import transflate

from transflate.data.token import load_tokenizers
from transflate.data.vocab import load_vocab

from transflate.data.dataloader import create_dataloaders
from transflate.main import make_model
from transflate.output import check_outputs

from torch.utils.data import DataLoader


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
spacy_de, spacy_en = load_tokenizers()
vocab_src, vocab_tgt = load_vocab(spacy_de=spacy_de, spacy_en=spacy_en)

Building German Vocabulary ...
Building English Vocabulary ...
Finished.
Vocabulary sizes:
len: SRC=8315 TGT=6384


In [14]:
data_setup = {
    'max_padding' : 128,
}

architecture = {
        'src_vocab_len' : len(vocab_src),
        'tgt_vocab_len' : len(vocab_tgt),
        'N' : 6, # loop
        'd_model' : 512, # emb
        'd_ff' : 2048,
        'h' : 8,
        'p_dropout' : 0.1
    }

model = make_model(
    src_vocab_len=architecture['src_vocab_len'],
    tgt_vocab_len=architecture['tgt_vocab_len'],
    N=architecture['N'],
    d_model=architecture['d_model'],
    d_ff=architecture['d_ff'],
    h=architecture['h'],
    dropout=architecture['p_dropout'],
    )

model.load_state_dict(
    torch.load("../../multi30k_model_final.pt", map_location=torch.device("cpu"))
)



<All keys matched successfully>

In [35]:
# input text
text = "Vier Jungen spielen mit einem großen Hund im Hof"
print('Step.0 Raw text: ', text)
text = [(text, "")]

tokenize_de = lambda text : [tok.text for tok in spacy_de.tokenizer(text)]
tokenize_en = lambda text : [tok.text for tok in spacy_en.tokenizer(text)]

collate_fn = lambda x:  transflate.collate_batch(
            batch=x,
            src_pipeline=tokenize_de,
            tgt_pipeline=tokenize_en,
            src_vocab=vocab_src,
            tgt_vocab=vocab_tgt,
            device=torch.device("cpu"),
            max_padding=data_setup['max_padding'],
            pad_id=vocab_src.get_stoi()["<blank>"],
        )

text_dataloader = DataLoader(text, collate_fn = collate_fn)
print('Step.1 Processed text: \n', list(text_dataloader)[0][0]) 

Step.0 Raw text:  Vier Jungen spielen mit einem großen Hund im Hof
Step.1 Processed text: 
 tensor([[  0, 128,  92,  58,  10,   6,  80,  33,  22, 433,   1,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2]])


In [41]:
# check outputs
pad_idx = 2
eos_string="</s>"

b = next(iter(text_dataloader))
rb = transflate.data.Batch.Batch(src=b[0], tgt=b[1], pad=pad_idx)

model_out = transflate.output.greedy_decode(model, rb.src, rb.src_mask, max_len=72, start_symbol=0)[0]
model_txt = (" ".join([vocab_tgt.get_itos()[x] for x in model_out if x!= pad_idx]).split(eos_string, 1)[0] + eos_string)

print('Model output: ', model_txt) # '<s> Four boys are playing with a large dog in the yard . </s>'


Model output:  <s> Four boys are playing with a large dog in the yard . </s>


# Break-Down : run_model_example
## Step 1/3 : create dataloader

In [21]:
# print tensor shapes
mapa = { 1 : 1, 128 : 'max_padding', 512 : 'd_model',}

In [12]:
from transflate.data.dataloader import create_dataloaders

_, valid_dataloader = create_dataloaders(
        device=torch.device("cpu"),
        vocab_src=vocab_src,
        vocab_tgt=vocab_tgt,
        spacy_de=spacy_de,
        spacy_en=spacy_en,
        batch_size=1,
        max_padding = data_setup['max_padding'],
        is_distributed=False,
    )

## Step 2/3 : create model and load its model state

In [13]:
# outputs.run_model_example step 2/3 : create and load model state

from transformers.main import make_model


model = make_model(len(vocab_src), len(vocab_tgt), N=6) # d_model=512, d_ff=2048, h=8
model.load_state_dict(
        torch.load("multi30k_model_final.pt", map_location=torch.device("cpu"))
    )

<All keys matched successfully>

## Step 3/3 Break-Down : check_outputs

In [22]:
# outputs.run_model_example step 3/3 : check_outputs

from transformers.data.Batch import Batch

n_examples=5
pad_idx = 2
eos_string = "</s>"

results = [()] * n_examples

idx = 0 # example 0 in range(len(n_examples))
b = next(iter(valid_dataloader))
rb = Batch(src=b[0], tgt=b[1], pad=2)

src_tokens = [vocab_src.get_itos()[x] for x in rb.src[0] if x!=pad_idx]
tgt_tokens = [vocab_tgt.get_itos()[x] for x in rb.tgt[0] if x!=pad_idx]

print(f"Source text (Input) {src_tokens}")
print(f"Target Text (Ground Truth) {tgt_tokens}")

Source text (Input) ['<s>', 'Ein', 'Mann', 'im', 'mittleren', 'Alter', 'legt', 'am', 'Knie', 'eines', 'jüngeren', '<unk>', ',', 'der', 'auf', 'einem', '<unk>', 'sitzt', ',', 'einen', '<unk>', 'an', '.', '</s>']
Target Text (Ground Truth) ['<s>', 'A', 'middle', '-', 'aged', 'man', 'is', 'taping', 'up', 'the', 'knee', 'of', 'a', 'younger', 'football', 'player', 'who', 'is', 'sitting', 'on', 'a', '<unk>', 'table', '.', '</s>']


## 3-Substep 1/3 output.greedy_decode

In [23]:
from transformers.helper import following_mask

b
src = rb.src
print('src.shape=', [mapa[e] for e in src.shape])
src_mask = rb.src_mask
print('src_mask.shape=', [mapa[e] for e in src_mask.shape])
max_len = 72
start_symbol = 0

memory = model.encode(src, src_mask)
print('memory.shape=', [mapa[e] for e in memory.shape])

tgt=torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)

i = 0 # in range(72 - 1)
out = model.decode(memory, src_mask, tgt, following_mask(tgt.size(1)).type_as(src.data))
print('out.shape=', [mapa[e] for e in out.shape])

src.shape= [1, 'max_padding']
src_mask.shape= [1, 1, 'max_padding']
memory.shape= [1, 'max_padding', 'd_model']
out.shape= [1, 1, 'd_model']
