In [14]:
# Use below line for demo in external colabs
!pip install -q torchdata torchtext spacy==3.2 portalocker altair GPUtil
!python -m spacy download de_core_news_sm
!python -m spacy download en_core_web_sm
!pip install -q git+https://github.com/nikitakapitan/transflate.git

In [2]:
import warnings
warnings.filterwarnings('ignore')

import torch
import transflate

from transflate.data.token import load_tokenizers
from transflate.data.vocab import load_vocab

from transflate.data.dataloader import create_dataloaders

from transflate.main import make_model
from transflate.output import check_outputs

from torch.utils.data import DataLoader


%load_ext autoreload
%autoreload 2

In [3]:
spacy_de, spacy_en = load_tokenizers()
vocab_src, vocab_tgt = load_vocab(spacy_de=spacy_de, spacy_en=spacy_en)

Finished.
Vocabulary sizes:
len: SRC=8315 TGT=6384


In [4]:
data_setup = {
    'max_padding' : 128,
}

architecture = {
        'src_vocab_len' : len(vocab_src),
        'tgt_vocab_len' : len(vocab_tgt),
        'N' : 6, # loop
        'd_model' : 512, # emb
        'd_ff' : 2048,
        'h' : 8,
        'p_dropout' : 0.1
    }

model = make_model(
    src_vocab_len=architecture['src_vocab_len'],
    tgt_vocab_len=architecture['tgt_vocab_len'],
    N=architecture['N'],
    d_model=architecture['d_model'],
    d_ff=architecture['d_ff'],
    h=architecture['h'],
    dropout=architecture['p_dropout'],
    )

model.load_state_dict(
    torch.load("../../multi30k_model_final.pt", map_location=torch.device("cpu"))
)



<All keys matched successfully>

In [5]:
# input text
text = "Vier Jungen spielen mit einem großen Hund im Hof"
print('Step.0 Raw text: ', text)
text = [(text, "")]

tokenize_de = lambda text : [tok.text for tok in spacy_de.tokenizer(text)]
tokenize_en = lambda text : [tok.text for tok in spacy_en.tokenizer(text)]

collate_fn = lambda x:  transflate.data.Batch.collate_batch(
            batch=x,
            src_pipeline=tokenize_de,
            tgt_pipeline=tokenize_en,
            src_vocab=vocab_src,
            tgt_vocab=vocab_tgt,
            device=torch.device("cpu"),
            max_padding=data_setup['max_padding'],
            pad_id=vocab_src.get_stoi()["<blank>"],
        )

text_dataloader = DataLoader(text, collate_fn = collate_fn)
print('Step.1 Processed text: \n', list(text_dataloader)[0][0]) 

Step.0 Raw text:  Vier Jungen spielen mit einem großen Hund im Hof
Step.1 Processed text: 
 tensor([[  0, 128,  92,  58,  10,   6,  80,  33,  22, 433,   1,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
           2,   2]])


In [6]:
# check outputs
pad_idx = 2
eos_string="</s>"

b = next(iter(text_dataloader))
rb = transflate.data.Batch.Batch(src=b[0], tgt=b[1], pad=pad_idx)

model_out = transflate.output.greedy_decode(model, rb.src, rb.src_mask, max_len=72, start_symbol=0)[0]
model_txt = (" ".join([vocab_tgt.get_itos()[x] for x in model_out if x!= pad_idx]).split(eos_string, 1)[0] + eos_string)

print('Model output: ', model_txt) # '<s> Four boys are playing with a large dog in the yard . </s>'


Model output:  <s> Four boys are playing with a large dog in a yard . </s>


### Break-Down : Input Preprocessing

In [None]:
text = "Vier Jungen spielen mit einem großen Hund im Hof"
print(f"Step.0 {text=}")
print('*-*-* Start DATA part *-*-*')
text = [tok.text for tok in spacy_de.tokenizer(text)]
print(f"Step.1 {text=}")
text = vocab_src(text)
print(f"Step.2 {text=}")

bs_id = torch.tensor([0])  #0 index for <s>  
eos_id = torch.tensor([1]) #1 index for </s>
text = torch.cat([bs_id, torch.tensor(text), eos_id])
print(f"Step.3 {text=}")

text = [torch.nn.functional.pad(
    input=text, 
    pad=(0, data_setup['max_padding'] - len(text) ),
    value=vocab_src.get_stoi()["<blank>"])]
print(f"Step.4 {text[0][:42]=}")


text = torch.stack(text) # add batch dimension
print(text.shape)
dl = DataLoader(text)
print('*-*-* End DATA part *-*-*')

### Break-Down : Encoder

In [7]:
src = next(iter(dl))
src_mask = torch.ones_like(text).unsqueeze(0)
print(f"Step.0 Input Tensor {src.shape=}")

src_emb = torch.nn.Embedding(architecture['src_vocab_len'], architecture['d_model'])
src = src_emb(src)
print(f"Step.1 Embedding {src.shape=}")

src_pos_enc = transflate.PositionalEncoding.PositionalEncoding(d_model=architecture['d_model'], dropout=0.1)
src = src_pos_enc(src)
print(f"Step.2 Positional Encoder {src.shape=}")

residual_src1 = src.clone()
print(f"Step.3 Layer Norm {residual_src1.shape=}")

### +-+-+-+ Multi-Headed Attention +-+-+-+

# Query-Key-Value
n_heads = 8
d_head = architecture['d_model'] // n_heads

q_fc = torch.nn.Linear(architecture['d_model'], architecture['d_model'])
k_fc = torch.nn.Linear(architecture['d_model'], architecture['d_model'])
v_fc = torch.nn.Linear(architecture['d_model'], architecture['d_model'])
final_fc = torch.nn.Linear(architecture['d_model'], architecture['d_model'])
attn_dropuot = torch.nn.Dropout(p=0.1)

attn_from = src
attn_to = src # self-attn
value = src
mask = src_mask

query = q_fc(attn_from)
key = k_fc(attn_to)
value = v_fc(value)
print(f"Step.4.1 Query-Key-Value {query.shape=} {key.shape=} {value.shape=}")

# split to n_heads 
n_batches = src.size(0)
n_tokens = src.size(1)

query = query.view(n_batches, n_tokens, n_heads, d_head) .transpose(1, 2)
key = key.view(n_batches, n_tokens, n_heads, d_head).transpose(1, 2)
value = value.view(n_batches, n_tokens, n_heads, d_head).transpose(1, 2)
print(f"Step.4.2 Split to {n_heads} heads\n \t{query.shape=}\n \t{key.shape=}\n \t{value.shape=}")

# Attention
key_transpose = key.transpose(-2, -1)
scores = torch.matmul(query, key_transpose) / (d_head**0.5)
scores = scores.masked_fill(mask, -1e9)
p_attn = scores.softmax(dim=-1)
# p_attn = attn_dropuot(p_attn)

headed_context = torch.matmul(p_attn, value)

context = headed_context.transpose(1,2).contiguous().view(n_batches, n_tokens, n_heads * d_head)

src = final_fc(context)

print(f"Step.4.3 Attention \n \t{key_transpose.shape=}\n \t{scores.shape=}\n \t{p_attn.shape=}\n \
    \t{headed_context.shape=}\n \t{context.shape=}\n \t{src.shape=}")


norm1 = transflate.LayerNorm.LayerNorm(architecture['d_model'])
src = norm1(src)
print(f"Step.5 Layer Norm {src.shape=}")

src = residual_src1 + src    # end residual 1
residual_src2 = src.clone()  # start residual 2
print(f"Step.6 Layer Norm {residual_src2.shape=}")


# Feed Forward
w_1 = torch.nn.Linear(architecture['d_model'], architecture['d_ff'])
w_2 = torch.nn.Linear(architecture['d_ff'], architecture['d_model'])
fc_dropuot = torch.nn.Dropout(p=0.1)

fc1 = w_1(src).relu() 
# src = fc_dropuot(src)
src = w_2(fc1)
print(f"Step.7 Layer Norm {src.shape=}")

norm2 = transflate.LayerNorm.LayerNorm(architecture['d_model'])
src = norm2(src)
print(f"Step.8 Layer Norm {src.shape=}")

src = residual_src2 + src    # end residual 1
print(f"Step.9 Final Shape: {src.shape=}")



Step.0 text='Vier Jungen spielen mit einem großen Hund im Hof'
*-*-* Start DATA part *-*-*
Step.1 text=['Vier', 'Jungen', 'spielen', 'mit', 'einem', 'großen', 'Hund', 'im', 'Hof']
Step.2 text=[128, 92, 58, 10, 6, 80, 33, 22, 433]
Step.3 text=tensor([  0, 128,  92,  58,  10,   6,  80,  33,  22, 433,   1])
Step.4 text[0][:42]=tensor([  0, 128,  92,  58,  10,   6,  80,  33,  22, 433,   1,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
          2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2])
torch.Size([1, 128])
*-*-* End DATA part *-*-*
=-=-=-= Start MODEL part =-=-=-=
Step.0 Input Tensor src.shape=torch.Size([1, 128])
Step.1 Embedding src.shape=torch.Size([1, 128, 512])
Step.2 Positional Encoder src.shape=torch.Size([1, 128, 512])
Step.3 Layer Norm residual_src1.shape=torch.Size([1, 128, 512])
Step.4.1 Query-Key-Value query.shape=torch.Size([1, 128, 512]) key.shape=torch.Size([1, 128, 512]) value.shape=torch.Size([1, 128

### Break-Down : Decoder