In [14]:
# Use below line for demo in external colabs
!pip install -q torchdata torchtext spacy==3.2 portalocker altair GPUtil
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm
!pip install -q git+https://github.com/nikitakapitan/transflate.git

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import transflate

from transflate.data.token import load_tokenizers
from transflate.data.vocab import load_vocab

from transflate.data.dataloader import create_dataloaders
from transflate.main import make_model
from transflate.output import check_outputs

from torch.utils.data import DataLoader


%load_ext autoreload
%autoreload 2

In [2]:
spacy_de, spacy_en = load_tokenizers()
vocab_src, vocab_tgt = load_vocab(spacy_de=spacy_de, spacy_en=spacy_en)

Finished.
Vocabulary sizes:
len: SRC=8315 TGT=6384


In [3]:
data_setup = {
    'max_padding' : 128,
}

architecture = {
        'src_vocab_len' : len(vocab_src),
        'tgt_vocab_len' : len(vocab_tgt),
        'N' : 6, # loop
        'd_model' : 512, # emb
        'd_ff' : 2048,
        'h' : 8,
        'p_dropout' : 0.1
    }

model = make_model(
    src_vocab_len=architecture['src_vocab_len'],
    tgt_vocab_len=architecture['tgt_vocab_len'],
    N=architecture['N'],
    d_model=architecture['d_model'],
    d_ff=architecture['d_ff'],
    h=architecture['h'],
    dropout=architecture['p_dropout'],
    )

model.load_state_dict(
    torch.load("../../multi30k_model_final.pt", map_location=torch.device("cpu"))
)



<All keys matched successfully>

In [4]:
# input text
text = "Vier Jungen spielen mit einem großen Hund im Hof"
print('Step.0 Raw text: ', text)
text = [(text, "")]

tokenize_de = lambda text : [tok.text for tok in spacy_de.tokenizer(text)]
tokenize_en = lambda text : [tok.text for tok in spacy_en.tokenizer(text)]

collate_fn = lambda x:  transflate.data.Batch.collate_batch(
            batch=x,
            src_pipeline=tokenize_de,
            tgt_pipeline=tokenize_en,
            src_vocab=vocab_src,
            tgt_vocab=vocab_tgt,
            device=torch.device("cpu"),
            max_padding=data_setup['max_padding'],
            pad_id=vocab_src.get_stoi()["<blank>"],
        )

text_dataloader = DataLoader(text, collate_fn = collate_fn)
print('Step.1 Processed text: \n', list(text_dataloader)[0][0]) 

Step.0 Raw text:  Vier Jungen spielen mit einem großen Hund im Hof
Step.1 Processed text: 
 tensor([[  0, 128,  92,  58,  10,   6,  80,  33,  22, 433,   1,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2]])


In [5]:
# check outputs
pad_idx = 2
eos_string="</s>"

b = next(iter(text_dataloader))
rb = transflate.data.Batch.Batch(src=b[0], tgt=b[1], pad=pad_idx)

model_out = transflate.output.greedy_decode(model, rb.src, rb.src_mask, max_len=72, start_symbol=0)[0]
model_txt = (" ".join([vocab_tgt.get_itos()[x] for x in model_out if x!= pad_idx]).split(eos_string, 1)[0] + eos_string)

print('Model output: ', model_txt) # '<s> Four boys are playing with a large dog in the yard . </s>'


Model output:  <s> Four boys are playing with a large dog in the yard . </s>


# Break-Down : run_model_example

In [27]:
text = "Vier Jungen spielen mit einem großen Hund im Hof"
print(f"Step.0 {text=}")
text = [tok.text for tok in spacy_de.tokenizer(text)]
print(f"Step.1 {text=}")
text = vocab_src(text)
print(f"Step.2 {text=}")

bs_id = torch.tensor([0])  #0 index for <s>  
eos_id = torch.tensor([1]) #1 index for </s>
text = torch.cat([bs_id, torch.tensor(text), eos_id])
print(f"Step.3 {text=}")

text = torch.nn.functional.pad(
    input=text, 
    pad=(0, data_setup['max_padding'] - len(text) ),
    value=vocab_src.get_stoi()["<blank>"])
print(f"Step.4 {text=}")


Step.0 text='Vier Jungen spielen mit einem großen Hund im Hof'
Step.1 text=['Vier', 'Jungen', 'spielen', 'mit', 'einem', 'großen', 'Hund', 'im', 'Hof']
Step.2 text=[128, 92, 58, 10, 6, 80, 33, 22, 433]
Step.3 text=tensor([  0, 128,  92,  58,  10,   6,  80,  33,  22, 433,   1])
Step.4 text=tensor([  0, 128,  92,  58,  10,   6,  80,  33,  22, 433,   1,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
        

In [33]:
torch.Tensor([range(1, 5+1)]).shape

torch.Size([1, 5])

## 3-Substep 1/3 output.greedy_decode

In [23]:
from transformers.helper import following_mask

b
src = rb.src
print('src.shape=', [mapa[e] for e in src.shape])
src_mask = rb.src_mask
print('src_mask.shape=', [mapa[e] for e in src_mask.shape])
max_len = 72
start_symbol = 0

memory = model.encode(src, src_mask)
print('memory.shape=', [mapa[e] for e in memory.shape])

tgt=torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)

i = 0 # in range(72 - 1)
out = model.decode(memory, src_mask, tgt, following_mask(tgt.size(1)).type_as(src.data))
print('out.shape=', [mapa[e] for e in out.shape])

src.shape= [1, 'max_padding']
src_mask.shape= [1, 1, 'max_padding']
memory.shape= [1, 'max_padding', 'd_model']
out.shape= [1, 1, 'd_model']
