In [1]:
from torchtext.vocab import GloVe
embedding_glove = GloVe(name='6B', dim=50)

# Basic Transformer Block
### Contents:
A self attention layer, layer normalization, a feed forward layer (a single MLP applied independently to each vector), and another layer normalization. Residual connections are added around both, before the normalization.

![image](http://peterbloem.nl/files/transformers/transformer-block.svg)


## Input

In [2]:
import torch
import torch.nn.functional as F
X = torch.stack((embedding_glove['the'], embedding_glove['cat'], embedding_glove['walks'], embedding_glove['on'], embedding_glove['the'], embedding_glove['street']))
print(X.shape)
X = X.reshape(1, X.shape[0], X.shape[1])
print(X.shape)

torch.Size([6, 50])
torch.Size([1, 6, 50])


## Self-Attention

In [3]:
from torch import nn
num_heads = 8
num_words = X.shape[1]
num_dim = X.shape[2]
queries = nn.Linear(num_dim, num_heads*num_dim, bias=False)
keys = nn.Linear(num_dim, num_heads*num_dim, bias=False)
values = nn.Linear(num_dim, num_heads*num_dim, bias=False)
unify_heads = nn.Linear(num_heads*num_dim, num_dim, bias=False)

In [4]:
queries = queries(X).view(1, num_words, num_heads, num_dim)
queries = queries.transpose(1, 2).contiguous().view(1*num_heads, num_words, num_dim)
keys = keys(X).view(1, num_words, num_heads, num_dim)
keys = keys.transpose(1, 2).contiguous().view(1*num_heads, num_words, num_dim)
values = values(X).view(1, num_words, num_heads, num_dim)
values = values.transpose(1, 2).contiguous().view(1*num_heads, num_words, num_dim)

In [5]:
queries = queries/(num_dim**(1/4))
keys = keys/(num_dim**(1/4))

raw_weights = torch.bmm(queries, keys.transpose(1, 2))

weights = F.softmax(raw_weights, dim=2)

In [6]:
out = torch.bmm(weights, values).view(1, num_heads, num_words, num_dim)
print(out.shape)
out = out.transpose(1, 2).contiguous().view(1, num_words, num_heads*num_dim)
Y = unify_heads(out)
print(Y.shape)

torch.Size([1, 8, 6, 50])
torch.Size([1, 6, 50])


In [7]:
#Initializations
norm1 = nn.LayerNorm(num_dim)
layer1 = nn.Linear(num_dim, 4*num_dim)
relu = nn.ReLU()
layer2 = nn.Linear(4*num_dim, num_dim)
norm2 = nn.LayerNorm(num_dim)

In [8]:
out1 = norm1(Y+X)
out2 = layer1(out1)
out3 = relu(out2)
out4 = layer2(out3)
final = norm2(out1+out4)

In [9]:
print(final.shape)

torch.Size([1, 6, 50])
