In [None]:
# Use below line for demo in external colabs
#  !pip install -q git+https://github.com/nikitakapitan/transformers.git

In [1]:
import torch
import torch.nn as nn
import math
from transformers.main import make_model

from transformers.Embeddings import Embeddings
from transformers.PositionalEncoding import PositionalEncoding
from transformers.MultiHeadedAttention import MultiHeadedAttention
from transformers.helper import following_mask
from transformers.LayerNorm import LayerNorm

%load_ext autoreload
%autoreload 2

In [2]:
D_MODEL = 32
D_FF = 2048
N = 2 #layers
H = 8 #heads

n = 10 #tokens in input

mapa = {1 : 'b', 2 : 'layers', 8 : 'h', 4 : 'd_head', 
        10 : 'n', 32 : 'emb/d_model', 2048 : 'd_ff' }


FRENCH = 11 #SOURCE : number of all words in French  vocab
ENGLISH  = 11 #TARGET : number of all words in English vocab

# INITIALIZE
test_model = make_model(src_vocab=FRENCH, tgt_vocab=ENGLISH, N=N)
test_model.eval() # switch to eval (no DropOut & BatchNorm learning)

# INFERENCE
src = torch.LongTensor([range(1,n+1)])
src_mask = torch.ones(1,1,n)
print(src.shape, src_mask.shape)

torch.Size([1, 10]) torch.Size([1, 1, 10])


  self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))


# Encode

In [3]:
memory = test_model.encode(src, src_mask)

# Down-brake **Model.encode(src, src_mask)**

In [4]:
x = src
x.shape

torch.Size([1, 10])

## Step 1 Embeddings

In [None]:
emb = Embeddings(vocab=FRENCH, d_model=D_MODEL)
x = emb(x)
x.shape

## Step 2 PositionEncoding(d_model, max_len)

In [None]:
MAX_LEN = 5000
pos_enc = PositionalEncoding(d_model=D_MODEL, dropout=DROPOUT, max_len=MAX_LEN)
x = pos_enc(x)
x.shape

+ Start step 3 ResidualConnection (input 3)
## Step 4 MultiHeadedAttention

In [None]:
attn = MultiHeadedAttention(h=H, d_model=D_MODEL)

### Step 4.1 Query, Key, Value

In [None]:
# MultiHeadedAttention.__init__
d_head = D_MODEL // H
h = H
q_fc = nn.Linear(D_MODEL, D_MODEL)
k_fc = nn.Linear(D_MODEL, D_MODEL)
v_fc = nn.Linear(D_MODEL, D_MODEL)
final_fc = nn.Linear(D_MODEL, D_MODEL)
dropout = nn.Dropout(p=DROPOUT)

input_3 = x
mask = src_mask

# MultiHeadedAttention.forward : compute Query, Key, Value
mask = mask.unsqueeze(1)
print('mask shape=',mask.shape)
n_batches = input_3.size(0) # 1

print('input.shape=', [mapa[e] for e in input_3.shape])
query = q_fc(input_3)
key = k_fc(input_3)
value = v_fc(input_3)

print('query.shape=', [mapa[e] for e in query.shape])

### Step 4.2 Split to H heads

In [None]:
# split data into H heads.
query = query.view(n_batches, n, h, d_head) .transpose(1, 2)
key = key.view(n_batches, n, h, d_head).transpose(1, 2)
value = value.view(n_batches, n, h, d_head).transpose(1, 2)
query.shape

### Step 4.3 Attention

In [None]:
# def attention
key_t = key.transpose(-2, -1)
print('query.shape=', [mapa[e] for e in query.shape])
print('key_t.shape=', [mapa[e] for e in key_t.shape])

scores = torch.matmul(query, key_t) / math.sqrt(d_head)
print('scores.shape=', [mapa[e] for e in scores.shape])

scores = scores.masked_fill(mask, -1e9)

p_attn = scores.softmax(dim=-1)
print('p_attn.shape=', [mapa[e] for e in p_attn.shape])

# if dropout is not None:
#     p_attn = dropout(p_attn)

print('value.shape=', [mapa[e] for e in value.shape])
headed_context = torch.matmul(p_attn, value)
print('headed_context.shape=', [mapa[e] for e in headed_context.shape])

context = headed_context.transpose(1,2).contiguous().view(n_batches, n, h * d_head)
print('context.shape=', [mapa[e] for e in context.shape])

output_4 = final_fc(context)
print('output_4.shape=', [mapa[e] for e in output_4.shape])

## Step 5 LayerNorm

In [None]:
norm = LayerNorm(D_MODEL)
output_5 = norm(output_4)
print('output_5.shape=', [mapa[e] for e in output_5.shape])

## Step 3 Residual Connectnion (output 3)

In [None]:
# end ResidualConnection
output_3 = input_3 + output_5
print('output_3.shape=', [mapa[e] for e in output_3.shape])

+ Start Step 6 Residual Connecction (input 6)

## Step 7 PositionalWiseFeedForward

In [None]:
w_1 = nn.Linear(D_MODEL, D_FF)
w_2 = nn.Linear(D_FF, D_MODEL)

fc1 = w_1(output_3).relu() # + DropOut
print('fc1.shape=', [mapa[e] for e in fc1.shape])
output_7 = w_2(fc1)
print('output_7.shape=', [mapa[e] for e in output_7.shape])

## Step 8 LayerNorm

In [None]:
norm = LayerNorm(D_MODEL)
output_8 = norm(output_7)
print('output_8.shape=', [mapa[e] for e in output_8.shape])

## Step 6 ResidualConnection

In [None]:
output_6 = output_8 + output_3

## Step 9 : Repeat N times loop Step3-Step8

# Decode

In [5]:
tgt = torch.zeros(1, 1).type_as(src)

for i in range(9):
    tgt_mask = following_mask(tgt.size(1)).type_as(src.data)

    out = test_model.decode(memory, src_mask, tgt, tgt_mask)

    prob = test_model.generator(out[:, -1])

    next_word = torch.argmax(prob, dim=1).unsqueeze(0)

    tgt=torch.cat([tgt, next_word],dim=1)

print("Example Untrained Model Prediction:", tgt)

Example Untrained Model Prediction: tensor([[0, 2, 7, 2, 2, 2, 2, 2, 2, 2]])


# Down-brake **Model.decode(memory, src_mask, tgt, tgt_mask)**

# Down-brake **Predict next word**

In [6]:
tgt = torch.zeros(1, 1).type_as(src)
tgt_mask = following_mask(tgt.size(1)).type_as(src.data)

prob = test_model.generator(out[:, -1])
ext_word = torch.argmax(prob, dim=1).unsqueeze(0)

tgt=torch.cat([tgt, next_word],dim=1)
tgt


tensor([[0, 2]])