In [5]:
# Use below line for demo in external colabs
#  !pip install -q git+https://github.com/nikitakapitan/transformers.git

In [6]:
import torch
import torch.nn as nn
import math
from transformers.main import make_model

from transformers.Embeddings import Embeddings
from transformers.PositionalEncoding import PositionalEncoding
from transformers.MultiHeadedAttention import MultiHeadedAttention
from transformers.helper import following_mask

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
D_MODEL = 32
D_FF = 2048
DROPOUT = 0.1
N = 2 #layers
H = 8 #heads

m = 10 #tokens in input


FRENCH = 11 #SOURCE : number of all words in French  vocab
ENGLISH  = 11 #TARGET : number of all words in English vocab

# INITIALIZE
test_model = make_model(src_vocab=FRENCH, tgt_vocab=ENGLISH, N=N)
test_model.eval() # switch to eval (no DropOut & BatchNorm learning)

# INFERENCE
src = torch.LongTensor([range(1,m+1)])
src_mask = torch.ones(1,1,m)
print(src.shape, src_mask.shape)

memory = test_model.encode(src, src_mask)


torch.Size([1, 10]) torch.Size([1, 1, 10])


In [8]:
ys = torch.zeros(1, 1).type_as(src)

for i in range(9):
    out = test_model.decode(
        memory, src_mask, ys, following_mask(ys.size(1)).type_as(src.data)
    )
    prob = test_model.generator(out[:, -1])
    _, next_word = torch.max(prob, dim=1)
    next_word = next_word.data[0]
    ys = torch.cat(
        [ys, torch.empty(1, 1).type_as(src.data).fill_(next_word)], dim=1
    )

print("Example Untrained Model Prediction:", ys)

Example Untrained Model Prediction: tensor([[0, 9, 0, 9, 0, 5, 9, 9, 9, 9]])


# Down-brake **Model.encode(src, src_mask)**

## Embeddings(vocab, d_model)

In [9]:
x = src
x.shape

torch.Size([1, 10])

In [10]:
emb = Embeddings(vocab=FRENCH, d_model=D_MODEL)
x = emb(x)
x.shape

torch.Size([1, 10, 32])

## PositionalEncoding(d_model, dropout, max_len)

In [11]:
MAX_LEN = 5000
pos_enc = PositionalEncoding(d_model=D_MODEL, dropout=DROPOUT, max_len=MAX_LEN)
x = pos_enc(x)
x.shape

torch.Size([1, 10, 32])

## MultiHeadedAttention

In [16]:
attn = MultiHeadedAttention(h=H, d_model=D_MODEL)

In [17]:
# MultiHeadedAttention.__init__
d_k = D_MODEL // H
h = H
q_fc = nn.Linear(D_MODEL, D_MODEL)
k_fc = nn.Linear(D_MODEL, D_MODEL)
v_fc = nn.Linear(D_MODEL, D_MODEL)
final_fc = nn.Linear(D_MODEL, D_MODEL)
dropout = nn.Dropout(p=DROPOUT)

input = x
mask = src_mask

# MultiHeadedAttention.forward : compute Query, Key, Value
mask = mask.unsqueeze(1); print(mask.shape)
n_batches = input.size(0) # 1

query = q_fc(input)
key = k_fc(input)
value = v_fc(input)
print('query shape=', query.shape)

torch.Size([1, 1, 1, 10])
query shape= torch.Size([1, 10, 32])


split data into H heads

In [18]:
# split data into H heads.
query = query.view(n_batches, m, h, d_k) .transpose(1, 2)
key = key.view(n_batches, m, h, d_k).transpose(1, 2)
value = value.view(n_batches, m, h, d_k).transpose(1, 2)
query.shape

torch.Size([1, 8, 10, 4])

### attention

In [19]:
# def attention
key_t = key.transpose(-2, -1)
print('query.shape=', query.shape)
print('key_t.shape=', key_t.shape)

scores = torch.matmul(query, key_t) / math.sqrt(d_k)
print('scores.shape=', scores.shape)

scores = scores.masked_fill(mask, -1e9)

p_attn = scores.softmax(dim=-1)
print('p_attn.shape=', p_attn.shape)

# if dropout is not None:
#     p_attn = dropout(p_attn)

print('value.shape=', value.shape)
context = torch.matmul(p_attn, value)
print('context.shape=', context.shape)
# return context, p_attn

query.shape= torch.Size([1, 8, 10, 4])
key_t.shape= torch.Size([1, 8, 4, 10])
scores.shape= torch.Size([1, 8, 10, 10])
p_attn.shape= torch.Size([1, 8, 10, 10])
value.shape= torch.Size([1, 8, 10, 4])
context.shape= torch.Size([1, 8, 10, 4])
