In [1]:
# Use below line for demo in external colabs
# !pip install -q torchdata==0.3.0 torchtext==0.12 spacy==3.2 altair GPUtil
# !python -m spacy download de_core_news_sm
# !python -m spacy download en_core_web_sm
# !pip install -q git+https://github.com/nikitakapitan/transformers.git

In [1]:
import torch
import torch.nn as nn
import math
from copy import deepcopy
from transformers.main import make_model

from transformers.Embeddings import Embeddings
from transformers.PositionalEncoding import PositionalEncoding
from transformers.MultiHeadedAttention import MultiHeadedAttention
from transformers.helper import following_mask
from transformers.LayerNorm import LayerNorm

%load_ext autoreload
%autoreload 2

In [3]:
D_MODEL = 32 # should be 2^k and k>= n_heads
D_FF = 2048
N = 2 #layers
H = 8 #heads

n = 10 #tokens in input
_dropout = 0.1
MAX_LEN = 5000

mapa = {1 : '1', 2 : '2', 3 : '3', 8 : 'h', 4 : 'd_head', 
        10 : 'n', 32 : 'emb', 2048 : 'd_ff' , 512: 'unkown'}


FRENCH = 11 #SOURCE : number of all words in French  vocab
ENGLISH  = 11 #TARGET : number of all words in English vocab

# INITIALIZE
test_model = make_model(src_vocab_len=FRENCH, tgt_vocab_len=ENGLISH, N=N, d_model=D_MODEL, d_ff = D_FF, h=H)
test_model.eval() # switch to eval (no DropOut & BatchNorm learning)

# INFERENCE
src = torch.LongTensor([range(1,n+1)])
src_mask = torch.ones(1,1,n)
print(src.shape, src_mask.shape)

torch.Size([1, 10]) torch.Size([1, 1, 10])


# Encode

In [4]:
memory = test_model.encode(src, src_mask)
print('memory.shape=', [mapa[e] for e in memory.shape])

memory.shape= ['1', 'n', 'emb']


# Down-brake **Model.encode(src, src_mask)**

In [5]:
src.shape, src.type

(torch.Size([1, 10]), <function Tensor.type>)

## Step 1 Embeddings

In [5]:
src_emb = Embeddings(vocab_len=FRENCH, d_model=D_MODEL)
src = src_emb(src)
src.shape

torch.Size([1, 10, 32])

## Step 2 PositionEncoding(d_model, max_len)

In [6]:
src_pos_enc = PositionalEncoding(d_model=D_MODEL, dropout=_dropout, max_len=MAX_LEN)
src = src_pos_enc(src)
src.shape #input_3

torch.Size([1, 10, 32])

### >> Start Step 3 : ResidualConnection

In [7]:
residual_src = src

## Step 4 MultiHeadedAttention

attn = MultiHeadedAttention(h=H, d_model=D_MODEL)
### Step 4.1 Query, Key, Value

In [8]:
# MultiHeadedAttention.__init__
d_head = D_MODEL // H
h = H
q_fc = nn.Linear(D_MODEL, D_MODEL)
k_fc = nn.Linear(D_MODEL, D_MODEL)
v_fc = nn.Linear(D_MODEL, D_MODEL)
final_fc = nn.Linear(D_MODEL, D_MODEL)
dropout = nn.Dropout(p=_dropout)

attn_from = src #input_3
attn_to = src
value = src
mask = src_mask

# MultiHeadedAttention.forward : compute Query, Key, Value
mask = mask.unsqueeze(1)
print('mask shape=',mask.shape)
n_batches = src.size(0) # 1

print('src.shape=', [mapa[e] for e in src.shape])
query = q_fc(attn_from)
key = k_fc(attn_to)
value = v_fc(value)

print('query.shape=', [mapa[e] for e in query.shape])
print('key.shape=', [mapa[e] for e in key.shape])
print('value.shape=', [mapa[e] for e in value.shape])

mask shape= torch.Size([1, 1, 1, 10])
src.shape= ['1', 'n', 'emb']
query.shape= ['1', 'n', 'emb']
key.shape= ['1', 'n', 'emb']
value.shape= ['1', 'n', 'emb']


### Step 4.2 Split to H heads

In [9]:
# split data into H heads.
query = query.view(n_batches, n, h, d_head) .transpose(1, 2)
key = key.view(n_batches, n, h, d_head).transpose(1, 2)
value = value.view(n_batches, n, h, d_head).transpose(1, 2)
query.shape

torch.Size([1, 8, 10, 4])

### Step 4.3 Attention

In [10]:
# def attention
key_t = key.transpose(-2, -1)
print('query.shape=', [mapa[e] for e in query.shape])
print('key_t.shape=', [mapa[e] for e in key_t.shape])

scores = torch.matmul(query, key_t) / math.sqrt(d_head)
print('scores.shape=', [mapa[e] for e in scores.shape])

scores = scores.masked_fill(mask, -1e9)

p_attn = scores.softmax(dim=-1)
print('p_attn.shape=', [mapa[e] for e in p_attn.shape])

# if dropout is not None:
#     p_attn = dropout(p_attn)

print('value.shape=', [mapa[e] for e in value.shape])
headed_context = torch.matmul(p_attn, value)
print('headed_context.shape=', [mapa[e] for e in headed_context.shape])

context = headed_context.transpose(1,2).contiguous().view(n_batches, n, h * d_head)
print('context.shape=', [mapa[e] for e in context.shape])

src = final_fc(context) # output_4
print('output_4.shape=', [mapa[e] for e in src.shape])

query.shape= ['1', 'h', 'n', 'd_head']
key_t.shape= ['1', 'h', 'd_head', 'n']
scores.shape= ['1', 'h', 'n', 'n']
p_attn.shape= ['1', 'h', 'n', 'n']
value.shape= ['1', 'h', 'n', 'd_head']
headed_context.shape= ['1', 'h', 'n', 'd_head']
context.shape= ['1', 'n', 'emb']
output_4.shape= ['1', 'n', 'emb']


## Step 5 LayerNorm

In [11]:
norm = LayerNorm(D_MODEL)
src = norm(src) #output_5
print('output_5.shape=', [mapa[e] for e in src.shape])

output_5.shape= ['1', 'n', 'emb']


## Step 3 Residual Connectnion (output 3)

In [12]:
# end ResidualConnection
src = residual_src + src
print('output_3.shape=', [mapa[e] for e in src.shape])

output_3.shape= ['1', 'n', 'emb']


### >> start Step 6 Residual Connecction (input 6)

In [13]:
residual_src = src

## Step 7 PositionalWiseFeedForward

In [14]:
w_1 = nn.Linear(D_MODEL, D_FF)
w_2 = nn.Linear(D_FF, D_MODEL)

fc1 = w_1(src).relu() # + DropOut
print('fc1.shape=', [mapa[e] for e in fc1.shape])
src = w_2(fc1) # output_7
print('output_7.shape=', [mapa[e] for e in src.shape])

fc1.shape= ['1', 'n', 'd_ff']
output_7.shape= ['1', 'n', 'emb']


## Step 8 LayerNorm

In [15]:
norm = LayerNorm(D_MODEL)
src = norm(src) #output_8
print('output_8.shape=', [mapa[e] for e in src.shape])

output_8.shape= ['1', 'n', 'emb']


## Step 6 ResidualConnection

In [16]:
src = residual_src + src
print('output_6.shape=', [mapa[e] for e in src.shape])

output_6.shape= ['1', 'n', 'emb']


## Step 9 : Repeat N times loop Step3-Step8

# Decode

In [17]:
# tgt = torch.zeros(1, 1).type(torch.LongTensor)

# for _ in range(9):
#     tgt_mask = following_mask(tgt.size(1)).type_as(src.data)

#     out = test_model.decode(memory, src_mask, tgt, tgt_mask)

#     prob = test_model.generator(out[:, -1])

#     next_word = torch.argmax(prob, dim=1).unsqueeze(0)

#     tgt=torch.cat([tgt, next_word],dim=1)

# print("Example Untrained Model Prediction:", tgt)

# Down-brake **Model.decode(memory, src_mask, tgt, tgt_mask)**

## Step 10 : initialize tgt

In [18]:
target = torch.zeros(1, 1).type(torch.LongTensor)

tgt_mask = following_mask(target.size(1)).type_as(src.data)
target.shape

torch.Size([1, 1])

## Step 11 : tgt_emb

In [20]:
tgt_emb = Embeddings(vocab_len=ENGLISH, d_model=D_MODEL)
tgt = tgt_emb(target)
tgt.shape

torch.Size([1, 1, 32])

## Step 12 : pos_emb

In [21]:
tgt_pos_enc = PositionalEncoding(d_model=D_MODEL, dropout=_dropout, max_len=MAX_LEN)
tgt = tgt_pos_enc(tgt)
tgt.shape # input_13


torch.Size([1, 1, 32])

### >>Start Step 13

In [22]:
residual_tgt = tgt
print('residual_tgt.shape=', [mapa[e] for e in residual_tgt.shape])

residual_tgt.shape= ['1', '1', 'emb']


## Step 14 : Masked Multi-head Attention

### Step 14.1 Query, Key, Value

In [23]:
# MultiHeadedAttention.__init__
d_head = D_MODEL // H
h = H
q_fc = nn.Linear(D_MODEL, D_MODEL)
k_fc = nn.Linear(D_MODEL, D_MODEL)
v_fc = nn.Linear(D_MODEL, D_MODEL)
final_fc = nn.Linear(D_MODEL, D_MODEL)
dropout = nn.Dropout(p=_dropout)

attn_from = tgt # input_13
attn_to = tgt
value = tgt
mask = tgt_mask

# MultiHeadedAttention.forward : compute Query, Key, Value
mask = mask.unsqueeze(1)
print('tgt_mask shape=',mask.shape)
n_batches = tgt.size(0) # 1

print('tgt.shape=', [mapa[e] for e in tgt.shape])
query = q_fc(attn_from)
key = k_fc(attn_to)
value = v_fc(value)

print('query.shape=', [mapa[e] for e in query.shape])
print('key.shape=', [mapa[e] for e in key.shape])
print('value.shape=', [mapa[e] for e in value.shape])

tgt_mask shape= torch.Size([1, 1, 1, 1])
tgt.shape= ['1', '1', 'emb']
query.shape= ['1', '1', 'emb']
key.shape= ['1', '1', 'emb']
value.shape= ['1', '1', 'emb']


### Step 14.2 Split to H heads

In [24]:
# split data into H heads.
n_tokens = query.size(1)
query = query.view(n_batches, n_tokens, h, d_head) .transpose(1, 2)
key = key.view(n_batches, n_tokens, h, d_head).transpose(1, 2)
value = value.view(n_batches, n_tokens, h, d_head).transpose(1, 2)
query.shape

torch.Size([1, 8, 1, 4])

### Step 14.3 Attention

In [25]:
# def attention
key_t = key.transpose(-2, -1)
print('query.shape=', [mapa[e] for e in query.shape])
print('key_t.shape=', [mapa[e] for e in key_t.shape])

scores = torch.matmul(query, key_t) / math.sqrt(d_head)
print('scores.shape=', [mapa[e] for e in scores.shape])

scores = scores.masked_fill(mask, -1e9)

p_attn = scores.softmax(dim=-1)
print('p_attn.shape=', [mapa[e] for e in p_attn.shape])

# if dropout is not None:
#     p_attn = dropout(p_attn)

print('value.shape=', [mapa[e] for e in value.shape])
headed_context = torch.matmul(p_attn, value)
print('headed_context.shape=', [mapa[e] for e in headed_context.shape])

context = headed_context.transpose(1,2).contiguous().view(n_batches, n_tokens, h * d_head)
print('context.shape=', [mapa[e] for e in context.shape])

tgt = final_fc(context) # output_14
print('output_14.shape=', [mapa[e] for e in tgt.shape])

query.shape= ['1', 'h', '1', 'd_head']
key_t.shape= ['1', 'h', 'd_head', '1']
scores.shape= ['1', 'h', '1', '1']
p_attn.shape= ['1', 'h', '1', '1']
value.shape= ['1', 'h', '1', 'd_head']
headed_context.shape= ['1', 'h', '1', 'd_head']
context.shape= ['1', '1', 'emb']
output_14.shape= ['1', '1', 'emb']


## Step 15 LayerNorm

In [26]:
norm = LayerNorm(D_MODEL)
tgt = norm(tgt) #output_15
print('output_15.shape=', [mapa[e] for e in tgt.shape])

output_15.shape= ['1', '1', 'emb']


## Step 13 ResidualConnection (output 13)

In [27]:
# end ResidualConnection
tgt = residual_tgt + tgt
print('output_13.shape=', [mapa[e] for e in tgt.shape])

output_13.shape= ['1', '1', 'emb']


### start Step 16 : ResidualConnection (input 16)

In [28]:
residual_tgt= tgt

## Step 17 : Multi-headed attention (2x)
### Step 17.1 Query, Key, Value
 First appearance of 'memory' variable

In [29]:
# MultiHeadedAttention.__init__
d_head = D_MODEL // H
h = H
q_fc = nn.Linear(D_MODEL, D_MODEL)
k_fc = nn.Linear(D_MODEL, D_MODEL)
v_fc = nn.Linear(D_MODEL, D_MODEL)
final_fc = nn.Linear(D_MODEL, D_MODEL)
dropout = nn.Dropout(p=_dropout)

attn_from = tgt    #(b, dyn, emb)
attn_to = memory   #(b, n, emb)
value = memory     #(b, n, emb)
mask = src_mask

# MultiHeadedAttention.forward : compute Query, Key, Value
mask = mask.unsqueeze(1)
print('mask shape=',mask.shape)
n_batches = src.size(0) # 1

print('memory.shape=', [mapa[e] for e in memory.shape])
print('tgt.shape=', [mapa[e] for e in tgt.shape])

query = q_fc(attn_from) #(b, dyn, emb)
key = k_fc(attn_to)     #(b, n, emb)
value = v_fc(value)     #(b, n, emb)

print('query.shape=', [mapa[e] for e in query.shape])
print('key.shape=', [mapa[e] for e in key.shape])
print('value.shape=', [mapa[e] for e in value.shape])

mask shape= torch.Size([1, 1, 1, 10])
memory.shape= ['1', 'n', 'emb']
tgt.shape= ['1', '1', 'emb']
query.shape= ['1', '1', 'emb']
key.shape= ['1', 'n', 'emb']
value.shape= ['1', 'n', 'emb']


### Step 17.2 Split to H heads

In [30]:
# split data into H heads.
n_tokens_from = attn_from.size(1)
n_tokens_to = attn_to.size(1)
n_tokens_value = value.size(1)

query = query.view(n_batches, n_tokens_from, h, d_head) .transpose(1, 2)
key = key.view(n_batches, n_tokens_to, h, d_head).transpose(1, 2)
value = value.view(n_batches, n_tokens_value, h, d_head).transpose(1, 2)

print('query.shape=', [mapa[e] for e in query.shape])
print('key.shape=', [mapa[e] for e in key.shape])
print('value.shape=', [mapa[e] for e in value.shape])

query.shape= ['1', 'h', '1', 'd_head']
key.shape= ['1', 'h', 'n', 'd_head']
value.shape= ['1', 'h', 'n', 'd_head']


### Step 17.3 Attention

In [31]:
# def attention
key_t = key.transpose(-2, -1)
print('query.shape=', [mapa[e] for e in query.shape])
print('key_t.shape=', [mapa[e] for e in key_t.shape])

scores = torch.matmul(query, key_t) / math.sqrt(d_head)
print('scores.shape=', [mapa[e] for e in scores.shape])

scores = scores.masked_fill(mask, -1e9)

p_attn = scores.softmax(dim=-1)
print('p_attn.shape=', [mapa[e] for e in p_attn.shape])

# if dropout is not None:
#     p_attn = dropout(p_attn)

print('value.shape=', [mapa[e] for e in value.shape])
headed_context = torch.matmul(p_attn, value)
print('headed_context.shape=', [mapa[e] for e in headed_context.shape])

context = headed_context.transpose(1,2).contiguous().view(n_batches, n_tokens_from, h * d_head)
print('context.shape=', [mapa[e] for e in context.shape])

tgt = final_fc(context) # output_4
print('tgt.shape=', [mapa[e] for e in tgt.shape])

query.shape= ['1', 'h', '1', 'd_head']
key_t.shape= ['1', 'h', 'd_head', 'n']
scores.shape= ['1', 'h', '1', 'n']
p_attn.shape= ['1', 'h', '1', 'n']
value.shape= ['1', 'h', 'n', 'd_head']
headed_context.shape= ['1', 'h', '1', 'd_head']
context.shape= ['1', '1', 'emb']
tgt.shape= ['1', '1', 'emb']


## Step 18 LayerNorm

In [32]:
norm = LayerNorm(D_MODEL)
tgt = norm(tgt) #output_18
print('output_18.shape=', [mapa[e] for e in tgt.shape])

output_18.shape= ['1', '1', 'emb']


## Step 16 : ResidualConnection (output 16)

In [33]:
tgt = residual_tgt + tgt
print('output_16.shape=', [mapa[e] for e in tgt.shape])

output_16.shape= ['1', '1', 'emb']


### >> Start Step 19 : ResidualCOnnection

In [34]:
residual_tgt = tgt

## Step 20 PositionalWiseFeedForward

In [35]:
w_1 = nn.Linear(D_MODEL, D_FF)
w_2 = nn.Linear(D_FF, D_MODEL)

tgt_fc1 = w_1(tgt).relu() # + DropOut
print('tgt_fc1.shape=', [mapa[e] for e in tgt_fc1.shape])
tgt = w_2(tgt_fc1) # output_20
print('output_20.shape=', [mapa[e] for e in tgt.shape])

tgt_fc1.shape= ['1', '1', 'd_ff']
output_20.shape= ['1', '1', 'emb']


## Step 21 LayerNorm

In [36]:
norm = LayerNorm(D_MODEL)
tgt = norm(tgt) #output_8
print('output_21.shape=', [mapa[e] for e in tgt.shape])

output_21.shape= ['1', '1', 'emb']


## Step 19 ResidualConnection (output 19)

In [37]:
tgt = residual_tgt+ tgt
print('output_19.shape=', [mapa[e] for e in tgt.shape])

output_19.shape= ['1', '1', 'emb']


## Step 22 : Generator & Update tgt

In [38]:
prob = test_model.generator(tgt[:, -1])
next_word = torch.argmax(prob, dim=1).unsqueeze(0)
print('tgt.shape=', [mapa[e] for e in tgt.shape])
print('prev target.shape=', [mapa[e] for e in target.shape])
print('next_word.shape=', [mapa[e] for e in next_word.shape])

target=torch.cat([target, next_word],dim=1)
print('curent target.shape=', [mapa[e] for e in target.shape])

tgt.shape= ['1', '1', 'emb']
prev target.shape= ['1', '1']
next_word.shape= ['1', '1']
curent target.shape= ['1', '2']


## Step 23 : Repeat N times loop step 13 -> step 22

In [39]:
tgt_mask = following_mask(target.size(1)).type_as(src.data)

tgt = tgt_emb(target)
tgt = tgt_pos_enc(tgt)

residual_tgt = tgt

attn_from = tgt 
attn_to = tgt
value = tgt
mask = tgt_mask.unsqueeze(1)
n_batches = tgt.size(0) 

query = q_fc(attn_from)
key = k_fc(attn_to)
value = v_fc(value)
n_tokens = query.size(1)

query = query.view(n_batches, n_tokens, h, d_head) .transpose(1, 2)
key = key.view(n_batches, n_tokens, h, d_head).transpose(1, 2)
value = value.view(n_batches, n_tokens, h, d_head).transpose(1, 2)

key_t = key.transpose(-2, -1)
scores = torch.matmul(query, key_t) / math.sqrt(d_head)
scores = scores.masked_fill(mask, -1e9)
p_attn = scores.softmax(dim=-1)
headed_context = torch.matmul(p_attn, value)
context = headed_context.transpose(1,2).contiguous().view(n_batches, n_tokens, h * d_head)
tgt = final_fc(context) 

tgt = norm(tgt) 
tgt = residual_tgt + tgt

residual_tgt= tgt

attn_from = tgt
attn_to = memory
value = memory

mask = src_mask.unsqueeze(1)
n_batches = src.size(0) 

query = q_fc(attn_from)
key = k_fc(attn_to)
value = v_fc(value)

n_tokens_from = attn_from.size(1)
n_tokens_to = attn_to.size(1)
n_tokens_value = value.size(1)

query = query.view(n_batches, n_tokens_from, h, d_head) .transpose(1, 2)
key = key.view(n_batches, n_tokens_to, h, d_head).transpose(1, 2)
value = value.view(n_batches, n_tokens_value, h, d_head).transpose(1, 2)

key_t = key.transpose(-2, -1)
scores = torch.matmul(query, key_t) / math.sqrt(d_head)
scores = scores.masked_fill(mask, -1e9)
p_attn = scores.softmax(dim=-1)
headed_context = torch.matmul(p_attn, value)
context = headed_context.transpose(1,2).contiguous().view(n_batches, n_tokens_from, h * d_head)
tgt = final_fc(context) 

tgt = norm(tgt)
tgt = residual_tgt + tgt

residual_tgt = tgt

tgt_fc1 = w_1(tgt).relu() 
tgt = w_2(tgt_fc1) 

tgt = norm(tgt)
tgt = residual_tgt+ tgt

prob = test_model.generator(tgt[:, -1])
next_word = torch.argmax(prob, dim=1).unsqueeze(0)

target=torch.cat([target, next_word],dim=1)

print('current target.shape=', [mapa[e] for e in target.shape])


current target.shape= ['1', '3']


##