In [27]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np

# Compute self attention

In transformers models, attention provides context for each sequence. This helps the model understand how different words relate to each other to create meaningful sentences. According to Wikipedia’s description, “the attention layer can access all previous states and weigh them according to a learned measure of relevance, providing relevant information about far-away tokens.”

According to “Attention Is All You Need”:

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

We call our particular attention “Scaled Dot-Product Attention”. The input consists of queries and keys of dimension d_key, and values of dimension d_value. We compute the dot products of the query with all keys, divide each by √(d_key), and apply a softmax function to obtain the weights on the values.

Let's consider an input sequence with t k-dimensional vectors.

<img src="self_attention.png" alt="drawing" width="600"/>

To compute the attention:
1. we need to compute the dot product of the matrix in order to obtain the weights.
2. Then we apply softmax function to these weights to have the normalized weights
3. Then we multiply these weights by initial vectors and sum them all to have the self attention

In [35]:
t = 10
k = 5
X = torch.rand(t, k)

raw_weights = torch.mm(X, X.transpose(1, 0))
weights = nn.Softmax(dim=1)(raw_weights)

attention = torch.mm(weights, X)

In [40]:
attention

tensor([[0.5358, 0.4512, 0.4647, 0.6081, 0.5569],
        [0.5740, 0.4100, 0.4658, 0.5725, 0.6087],
        [0.5454, 0.3678, 0.4956, 0.5380, 0.6225],
        [0.6130, 0.3393, 0.4469, 0.4723, 0.6416],
        [0.6616, 0.3441, 0.4753, 0.4769, 0.6055],
        [0.6666, 0.3788, 0.4809, 0.4748, 0.5598],
        [0.6471, 0.3382, 0.4975, 0.4762, 0.5990],
        [0.5740, 0.3639, 0.5166, 0.5478, 0.5777],
        [0.6477, 0.4312, 0.4594, 0.5385, 0.6006],
        [0.6579, 0.3818, 0.5083, 0.4755, 0.5660]])

Every input vector 𝐱i is used in three different ways in the self attention operation:  

- It is compared to every other vector to establish the weights for its own output 𝐲i  
- It is compared to every other vector to establish the weights for the output of the j-th vector 𝐲j  
- It is used as part of the weighted sum to compute each output vector once the weights have been established.  

These roles are often called the query, the key and the value

We make its life a little easier by deriving new vectors for each role, by applying a linear transformation to the original input vector. In other words, we add three k×k weight matrices 𝐖q, 𝐖k,𝐖v and compute three linear transformations of each xi, for the three different parts of the self attention:
- 𝐪i=𝐖q𝐱i 
- 𝐤i=𝐖k𝐱i
- 𝐯i=𝐖v𝐱i

w′ij=𝐪iT𝐤j    

wij=softmax(w′ij)   

𝐲i=∑jwij𝐯j  

<img src="key-query-value.png" alt="drawing" width="600"/>

# Mutli head attention

The simplest way to understand multi-head self-attention is to see it as a small number of copies of the self-attention mechanism applied in parallel, each with their own key, value and query transformation.  
Each head receives low-dimensional keys queries and values. If the input vector has k=256 dimensions, and we have h=4 attention heads, we multiply the input vectors by a 256×64 matrix to project them down to a sequence of 64 dimansional vectors. For every head, we do this 3 times: for the keys, the queries and the values.  
We project the initial vectors in lower dimension. The objective is to have multiple representations of the same vector that we will concatenate in the end to have all these representations in one place.

<img src="multihead_attention.png" alt="drawing" width="800"/>

This requires 3h matrices of size k by k/h. In total, this gives us 3hk(k/h)=3k^2 parameters to compute the inputs to the multi-head self-attention: the same as we had for the single-head self-attention.

We can even implement this with just three k×k matrix multiplications as in the single-head self-attention. The only extra operation we need is to slice the resulting sequence of vectors into chunks.

<img src="compute_query_sequentially.png" alt="drawing" width="800"/>

<img src="compute_query_in_once.png" alt="drawing" width="800"/>


How to build a multihead selfAttention Module:
1. Define a head number that is divisible from the input 
2. Define linear transformations to key, queries and values for each head
3. Apply the linear transformation associated to every input to obtain the key, query and value
4. Reshape the matrix of key, query and value to have them in different heads. One dimension for heads: we can access to the input of each head
5. Merge heads and batch because it's the same operation for each head 
6. Compute the raw weights w′ij=𝐪iT𝐤j and normalize them (because the softmax function can be sensitive to very large input values. These kill the gradient, and slow down learning, or cause it to stop altogether)
7. We apply the Softmax function
8. Multiply weights of self attention to the values
9. Reshape in order to concatenatre heads and have b x t x k
10. Apply the unifyheads an return it

In [124]:
class MultHeadsSelfAttentionOpt(nn.Module):
    # 1.Define a head number that is divisible from the input 
    def __init__(self, k, heads=4, mask=False):
        super().__init__()
        # Check if input is divisible by number of heads
        assert k % heads == 0

        self.k = k
        self.heads = heads
            
        # 2. Define linear transformations to key, queries and values for each head
        # biais = False because we want only weights
        self.to_queries = nn.Linear(k, k, bias=False)
        self.to_keys    = nn.Linear(k, k, bias=False) 
        self.to_values  = nn.Linear(k, k, bias=False)

        # This will be applied after the multi-head self-attention operation.
        self.unifyheads = nn.Linear(k, k)

    def forward(self, x):
        b, t, k = x.size() #as the training will be done by batch

        # 3. Apply the linear transformation associated to every input to obtain the key, query and value
        query = self.to_queries(x)
        key = self.to_keys(x)
        value = self.to_values(x)
        
        s = self.k // self.heads # number of elements per head
        h = self.heads

        # 4. Reshape the matrix of key, query and value to have them in different heads. 
        queries = query.view(b, t, h, s)
        keys = key.view(b, t, h, s)
        values = value.view(b, t, h, s)

        # 5. Merge heads and batch because it's the same operation for each head
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        # 6. Compute the raw weights w′ij=𝐪iT𝐤j and normalize them
        weights_raw = torch.bmm(queries, keys.transpose(1, 2))
        weights_raw_normalized = torch.div(weights_raw, torch.sqrt(torch.tensor(k)))

        # 7. We apply the Softmax function to the similarity dimension (batch dim x input dim x sim dim)
        weights = nn.Softmax(dim=2)(weights_raw_normalized)

        # 8. Multiply weights of self attention to the values
        self_attentions = torch.bmm(weights, values).view(b, h, t, s)

        # 9. Reshape in order to concatenatre heads and have b x t x k
        self_attention_formatted = self_attentions.transpose(1, 2).contiguous().view(b, t, s * h)

        # 10. Apply the unifyheads an return it
        return self.unifyheads(self_attention_formatted)


In [146]:
class OneHeadSelfAttention(nn.Module):
    def __init__(self, k, low_dim):
        super().__init__()
        # Check if input is divisible by number of heads
        self.k = k    
        self.low_dim = low_dim 
        # 1. Define linear transformations to reduce dimensionnalité of input
        # biais = False because we want only weights
        self.to_reduce_dim = nn.Linear(k, low_dim, bias=False)
        # 2. Define linear transformations to key, queries and values
        # biais = False because we want only weights
        self.to_queries = nn.Linear(low_dim, low_dim, bias=False)
        self.to_keys    = nn.Linear(low_dim, low_dim, bias=False) 
        self.to_values  = nn.Linear(low_dim, low_dim, bias=False)

    def forward(self, x):
        b, t, k = x.size() #as the training will be done by batch

        # 3. Reduce dimensionnalité of input
        low_dim_x = self.to_reduce_dim(x)
        
        # 4. Apply the linear transformation associated to every input to obtain the key, query and value
        query = self.to_queries(low_dim_x) # b, t, low_dim
        key = self.to_keys(low_dim_x)
        value = self.to_values(low_dim_x)

        # 5. Compute the raw weights w′ij=𝐪iT𝐤j and normalize them
        weights_raw = torch.bmm(query, key.transpose(1, 2))
        weights_raw_normalized = torch.div(weights_raw, torch.sqrt(torch.tensor(self.low_dim)))

        # 6. We apply the Softmax function to the similarity dimension (batch dim x input dim x sim dim)
        weights = nn.Softmax(dim=2)(weights_raw_normalized)

        # 7. Multiply weights of self attention to the values
        return torch.bmm(weights, value)
    


In [162]:
class MultiHeadSelfAttention(nn.Module):
    # 8.Define a head number that is divisible from the input 
    def __init__(self, k, heads=4):
        super().__init__()
        # Check if input is divisible by number of heads
        assert k % heads == 0

        self.k = k
        self.heads = heads  

        # 9. Instantiate OneHeadSelfAttention multiple times to have MultiHeadSelfAttention
        self.list_heads = []
        for head in range(self.heads):
            self.list_heads.append(OneHeadSelfAttention(k, k//heads))

        # This will be applied after the multi-head self-attention operation.
        self.unifyheads = nn.Linear(k, k)
    
    def forward(self, x):
        # 10. Get all heads elements 
        list_to_concat = []
        for one_head in self.list_heads:
            list_to_concat.append((one_head(x),))

        # 11. Concatenate all the heads
        multi_heads = sum(list_to_concat, ())        
        concatenated = torch.cat(multi_heads, dim=2)

        # 12. Linear transformation
        return self.unifyheads(concatenated)
          

In [103]:
class OneHeadSelfAttentionQKV(nn.Module):
    def __init__(self, k, low_dim):
        super().__init__()
        # Check if input is divisible by number of heads
        self.k = k    
        self.low_dim = low_dim 
        # 1. Define linear transformations to reduce dimensionnalité of input
        # biais = False because we want only weights
        self.to_reduce_dim = nn.Linear(k, low_dim, bias=False)
        # 2. Define linear transformations to key, queries and values
        # biais = False because we want only weights
        self.to_queries = nn.Linear(low_dim, low_dim, bias=False)
        self.to_keys    = nn.Linear(low_dim, low_dim, bias=False) 
        self.to_values  = nn.Linear(low_dim, low_dim, bias=False)

    def forward(self, Q, K, V):
        # 3. Reduce dimensionnalité of input
        low_dim_Q = self.to_reduce_dim(Q)
        low_dim_K = self.to_reduce_dim(K)
        low_dim_V = self.to_reduce_dim(V)

        
        # 4. Apply the linear transformation associated to every input to obtain the key, query and value
        query = self.to_queries(low_dim_Q) 
        key = self.to_keys(low_dim_K)
        value = self.to_values(low_dim_V)

        # 5. Compute the raw weights w′ij=𝐪iT𝐤j and normalize them
        weights_raw = torch.bmm(query, key.transpose(1, 2))
        weights_raw_normalized = torch.div(weights_raw, torch.sqrt(torch.tensor(self.low_dim)))

        # 6. We apply the Softmax function to the similarity dimension (batch dim x input dim x sim dim)
        weights = nn.Softmax(dim=2)(weights_raw_normalized)

        # 7. Multiply weights of self attention to the values
        return torch.bmm(weights, value)
    


In [126]:
class MultiHeadSelfAttentionQKV(nn.Module):
    # 8.Define a head number that is divisible from the input 
    def __init__(self, k, heads=4):
        super().__init__()
        # Check if input is divisible by number of heads
        assert k % heads == 0

        self.k = k
        self.heads = heads  

        # 9. Instantiate OneHeadSelfAttention multiple times to have MultiHeadSelfAttention
        self.list_heads = []
        for head in range(self.heads):
            self.list_heads.append(OneHeadSelfAttentionQKV(k, k//heads))

        # This will be applied after the multi-head self-attention operation.
        self.unifyheads = nn.Linear(k, k)
    
    def forward(self, Q, K, V):
        # 10. Get all heads elements 
        list_to_concat = []
        for one_head in self.list_heads:
            list_to_concat.append((one_head(Q, K, V),))

        # 11. Concatenate all the heads
        multi_heads = sum(list_to_concat, ())        
        concatenated = torch.cat(multi_heads, dim=2)

        # 12. Linear transformation
        return self.unifyheads(concatenated)
          

In [130]:
class MultHeadsSelfAttentionOptQKV(nn.Module):
    # 1.Define a head number that is divisible from the input 
    def __init__(self, k, heads=4, mask=False):
        super().__init__()
        # Check if input is divisible by number of heads
        assert k % heads == 0

        self.k = k
        self.heads = heads
            
        # 2. Define linear transformations to key, queries and values for each head
        # biais = False because we want only weights
        self.to_queries = nn.Linear(k, k, bias=False)
        self.to_keys    = nn.Linear(k, k, bias=False) 
        self.to_values  = nn.Linear(k, k, bias=False)

        # This will be applied after the multi-head self-attention operation.
        self.unifyheads = nn.Linear(k, k)

    def forward(self, Q, K, V):
        b, t, k = Q.size() #as the training will be done by batch

        # 3. Apply the linear transformation associated to every input to obtain the key, query and value
        query = self.to_queries(Q)
        key = self.to_keys(K)
        value = self.to_values(V)
        
        s = self.k // self.heads # number of elements per head
        h = self.heads

        # 4. Reshape the matrix of key, query and value to have them in different heads. 
        queries = query.view(b, t, h, s)
        keys = key.view(b, t, h, s)
        values = value.view(b, t, h, s)

        # 5. Merge heads and batch because it's the same operation for each head
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, s)
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, s)
        values = values.transpose(1, 2).contiguous().view(b * h, t, s)

        # 6. Compute the raw weights w′ij=𝐪iT𝐤j and normalize them
        weights_raw = torch.bmm(queries, keys.transpose(1, 2))
        weights_raw_normalized = torch.div(weights_raw, torch.sqrt(torch.tensor(k)))

        # 7. We apply the Softmax function to the similarity dimension (batch dim x input dim x sim dim)
        weights = nn.Softmax(dim=2)(weights_raw_normalized)

        # 8. Multiply weights of self attention to the values
        self_attentions = torch.bmm(weights, values).view(b, h, t, s)

        # 9. Reshape in order to concatenatre heads and have b x t x k
        self_attention_formatted = self_attentions.transpose(1, 2).contiguous().view(b, t, s * h)

        # 10. Apply the unifyheads an return it
        return self.unifyheads(self_attention_formatted)


In [137]:
X = torch.rand(32, 1000, 256)

In [138]:
A = OneHeadSelfAttentionQKV(256, 64)(X, X, X)
A.size()

torch.Size([32, 1000, 64])

In [139]:
%%time
B = MultiHeadSelfAttentionQKV(k=256, heads=4)(X, X, X)
B.size()

CPU times: user 1.95 s, sys: 285 ms, total: 2.23 s
Wall time: 1.58 s


torch.Size([32, 1000, 256])

In [140]:
%%time
MultHeadsSelfAttentionOptQKV(256, 4)(X, X, X).size()

CPU times: user 1.87 s, sys: 586 ms, total: 2.45 s
Wall time: 1.54 s


torch.Size([32, 1000, 256])