In [1]:
from torchtext.vocab import GloVe
embedding_glove = GloVe(name='6B', dim=50)

# Single Head Self Attention

In [2]:
import torch
import torch.nn.functional as F
from torch import nn
X = torch.stack((embedding_glove['the'], embedding_glove['cat'], embedding_glove['walks'], embedding_glove['on'], embedding_glove['the'], embedding_glove['street']))
print(X.shape)
#X = X.reshape(1, X.shape[0], X.shape[1])
X = X.view(1, X.shape[0], X.shape[1])
print(X.shape)

torch.Size([6, 50])
torch.Size([1, 6, 50])


Every input vector 𝐱i is used in three different ways in the self attention operation:

It is compared to every other vector to establish the weights for its own output 𝐲i. (Query)
It is compared to every other vector to establish the weights for the output of the j-th vector 𝐲j. (Key)
It is used as part of the weighted sum to compute each output vector once the weights have been established. (Value)

 In other words, we add three k×k weight matrices 𝐖q, 𝐖k,𝐖v and compute three linear transformations of each xi, for the three different parts of the self attention:
$$𝐪_{i}=𝐖_{q}𝐱_{i}\;\;\;𝐤i=𝐖_{k}𝐱_{i}\;\;\;𝐯_{i}=𝐖_{v}𝐱_{i}$$


In [3]:

queries = nn.Linear(X.shape[2], X.shape[2], bias=False)
keys = nn.Linear(X.shape[2], X.shape[2], bias=False)
values = nn.Linear(X.shape[2], X.shape[2], bias=False)

In [4]:
queries = queries(X)
keys = keys(X)
values = values(X)

$$w′_{ij}=𝐪_{i}^T𝐤_{j}$$

The softmax function can be sensitive to very large input values. These kill the gradient, and slow down learning, or cause it to stop altogether. Since the average value of the dot product grows with the embedding dimension k, it helps to scale the dot product back a little to stop the inputs to the softmax function from growing too large:

$$w′_{ij}=\frac{𝐪_{i}^T𝐤_{j}}{√k}$$

$$w_{ij}=softmax(w′_{ij})$$

In [5]:
queries = queries/(queries.shape[2]**(1/4))
keys = keys/(keys.shape[2]**(1/4))

raw_weights = torch.bmm(queries, keys.transpose(1, 2))
print(raw_weights.shape)

weights = F.softmax(raw_weights, dim=2) 

torch.Size([1, 6, 6])


We apply the self attention to the values:
$$𝐲_{i}=\sum_{j}w_{ij}𝐯_{j}$$

In [6]:
out = torch.bmm(weights, values)

'out' is our output vector with self-attention. This approach gives the self-attention layer some controllable parameters, and allows it to modify the incoming vectors accordingly.

## Understanding of .contiguous()

In [7]:
g = weights.transpose(1, 2)
print(g.stride())
print(g.is_contiguous())

(36, 1, 6)
False


In [8]:
g = g.contiguous()
print(g.stride())
print(g.is_contiguous())

(36, 6, 1)
True


# Multi-Haeaded Self Attention

In [9]:
X = torch.stack((embedding_glove['the'], embedding_glove['cat'], embedding_glove['walks'], embedding_glove['on'], embedding_glove['the'], embedding_glove['street']))
print(X.shape)
#X = X.reshape(1, X.shape[0], X.shape[1])
X = X.view(1, X.shape[0], X.shape[1])
print(X.shape)

torch.Size([6, 50])
torch.Size([1, 6, 50])


In [10]:
num_heads = 8
unify_heads = nn.Linear(num_heads*X.shape[2], X.shape[2])

In [11]:
queries = nn.Linear(X.shape[2], num_heads*X.shape[2], bias=False)
keys = nn.Linear(X.shape[2], num_heads*X.shape[2], bias=False)
values = nn.Linear(X.shape[2], num_heads*X.shape[2], bias=False)

In [12]:
queries = queries(X)
keys = keys(X)
values = values(X)

In [13]:
print("Before ",queries.shape)
queries = queries.view(1, X.shape[1], num_heads, X.shape[2])
print("After ", queries.shape)
keys = keys.view(1, X.shape[1], num_heads, X.shape[2])
values = values.view(1, X.shape[1], num_heads, X.shape[2])

Before  torch.Size([1, 6, 400])
After  torch.Size([1, 6, 8, 50])


In [14]:
print("Before ",queries.shape)
queries = queries.transpose(1, 2).contiguous().view(1*num_heads, X.shape[1], X.shape[2])
print("After ", queries.shape)
keys = keys.transpose(1, 2).contiguous().view(1*num_heads, X.shape[1], X.shape[2])
values = values.transpose(1, 2).contiguous().view(1*num_heads, X.shape[1], X.shape[2])

Before  torch.Size([1, 6, 8, 50])
After  torch.Size([8, 6, 50])


In [15]:
queries = queries/(X.shape[2]**(1/4))
keys = keys/(X.shape[2]**(1/4))

raw_weights = torch.bmm(queries, keys.transpose(1, 2))
print(raw_weights.shape)

weights = F.softmax(raw_weights, dim=2) 

torch.Size([8, 6, 6])


In [16]:
out = torch.bmm(weights, values)

In [17]:
out = out.transpose(1, 2).contiguous().view(1, X.shape[1], num_heads*X.shape[2])

In [18]:
y = unify_heads(out)

In [19]:
y.shape

torch.Size([1, 6, 50])