# Part 1 - Understanding attention

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn

In [2]:
print(torch.__version__)

1.10.0


# Build our own attention function, single head

In [3]:
### write our own attention function; single head

"""
input: 
- matrix X, size: num_words_in_input x embedding_dim
- matrix W_Q, size: embedding_dim x embedding_dim 
- matrix W_K, size: embedding_dim x embedding_dim 
- matrix W_V, size: embedding_dim x embedding_dim 
output: 
- matrix Z, size: num_words_in_input x embedding_dim
"""

def my_attention(X, W_Q, W_K, W_V):
    
    # get query, key, value
    Q = X @ W_Q
    K = X @ W_K
    V = X @ W_V
    d_k = torch.tensor(K.size(-1))
    
    # get attention
    Z = nn.functional.softmax((Q @ K.T)/torch.sqrt(d_k), dim=1) @ V
    
    return(Z, Q, K, V)

In [4]:
### do the shapes add up?
num_words_in_input_dim = 8
embedding_dim = 4

X_test = torch.rand((num_words_in_input_dim, embedding_dim))
W_Q_test = torch.rand((embedding_dim, embedding_dim))
W_K_test = torch.rand((embedding_dim, embedding_dim))
W_V_test = torch.rand((embedding_dim, embedding_dim))

# run through
Z_test, Q_test, K_test, V_test = my_attention(X_test, W_Q_test, W_K_test, W_V_test)
print(Z_test.size())

torch.Size([8, 4])


In [5]:
### compare with the PyTorch function for single attention head
Z_test_PT, _ = nn.functional._scaled_dot_product_attention(
                           Q_test.view(1, num_words_in_input_dim, embedding_dim),
                           K_test.view(1, num_words_in_input_dim, embedding_dim),
                           V_test.view(1, num_words_in_input_dim, embedding_dim),
                    attn_mask=None, dropout_p=0.)
print(Z_test_PT.size())
print(Z_test_PT)
print(Z_test)
print(torch.allclose(Z_test_PT, Z_test))

### WIN -> allowed to replace

torch.Size([1, 8, 4])
tensor([[[1.2124, 1.1662, 1.0024, 1.0880],
         [1.0774, 1.0575, 0.8936, 0.9834],
         [1.0694, 1.0527, 0.8875, 0.9782],
         [1.0644, 1.0506, 0.8834, 0.9755],
         [1.0938, 1.0761, 0.9070, 1.0000],
         [1.1007, 1.0772, 0.9124, 1.0024],
         [1.0363, 1.0285, 0.8611, 0.9533],
         [1.0900, 1.0684, 0.9034, 0.9937]]])
tensor([[1.2124, 1.1662, 1.0024, 1.0880],
        [1.0774, 1.0575, 0.8936, 0.9834],
        [1.0694, 1.0527, 0.8875, 0.9782],
        [1.0644, 1.0506, 0.8834, 0.9755],
        [1.0938, 1.0761, 0.9070, 1.0000],
        [1.1007, 1.0772, 0.9124, 1.0024],
        [1.0363, 1.0285, 0.8611, 0.9533],
        [1.0900, 1.0684, 0.9034, 0.9937]])
True


# Build one attention block - single head

by block we mean the attention function with weights + linear layer

In [6]:
### build a NN module for our attention mechanism

class MySingleheadAttention(nn.Module):
    
    def __init__(self, W_linear, b_linear):
        super().__init__()
        
        self.W_linear = W_linear
        self.b_linear = b_linear

    def forward(self, Q, K, V):
        
        Z_test_PT, _ = nn.functional._scaled_dot_product_attention(
                           Q.view(1, num_words_in_input_dim, embedding_dim),
                           K.view(1, num_words_in_input_dim, embedding_dim),
                           V.view(1, num_words_in_input_dim, embedding_dim),
                    attn_mask=None, dropout_p=0.)
        
        # Linear layer
        out = Z_test_PT @ self.W_linear.T + self.b_linear
        return(out)

In [7]:

# this includes the linear layer after the attention head
# and therefore yields different results
multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=1, batch_first=True, bias=False)
out_test_PT, _ = multihead_attn(X_test.view(1, num_words_in_input_dim, embedding_dim),
                           X_test.view(1, num_words_in_input_dim, embedding_dim),
                           X_test.view(1, num_words_in_input_dim, embedding_dim),
                            need_weights=False)

in_proj_weight_attn = multihead_attn.in_proj_weight
print(in_proj_weight_attn.size())
in_proj_weight_Q = in_proj_weight_attn[:embedding_dim,:]
in_proj_weight_K = in_proj_weight_attn[embedding_dim:2*embedding_dim,:]
in_proj_weight_V = in_proj_weight_attn[2*embedding_dim:3*embedding_dim,:]

W_Linear_test = multihead_attn.out_proj.weight
#b_Linear_test = multihead_attn.out_proj.bias
print(W_Linear_test.size())
#print(b_Linear_test.size())

torch.Size([12, 4])
torch.Size([4, 4])


In [8]:
single_head_attention_test = MySingleheadAttention(W_Linear_test, torch.zeros((embedding_dim,)))
#regenerate Q_test, K_test, V_test with the weights used by the pytorch function
Z_test, Q_test, K_test, V_test = my_attention(X_test, in_proj_weight_Q.T, in_proj_weight_K.T, in_proj_weight_V.T)
#Note the transpose of W_Q, W_K and W_V due to the way linear layers work in PyTorch
out_test = single_head_attention_test(Q_test, K_test, V_test)
print(out_test)
print(out_test_PT)
print(torch.allclose(out_test, out_test_PT))
# it results in the same thing -> WIN -> allowed to replace

tensor([[[ 0.1113, -0.0103, -0.3275, -0.2489],
         [ 0.1111, -0.0108, -0.3256, -0.2471],
         [ 0.1117, -0.0110, -0.3273, -0.2485],
         [ 0.1101, -0.0094, -0.3244, -0.2465],
         [ 0.1110, -0.0101, -0.3263, -0.2478],
         [ 0.1117, -0.0111, -0.3273, -0.2484],
         [ 0.1101, -0.0093, -0.3245, -0.2465],
         [ 0.1099, -0.0101, -0.3226, -0.2447]]], grad_fn=<AddBackward0>)
tensor([[[ 0.1113, -0.0103, -0.3275, -0.2489],
         [ 0.1111, -0.0108, -0.3256, -0.2471],
         [ 0.1117, -0.0110, -0.3273, -0.2485],
         [ 0.1101, -0.0094, -0.3244, -0.2465],
         [ 0.1110, -0.0101, -0.3263, -0.2478],
         [ 0.1117, -0.0111, -0.3273, -0.2484],
         [ 0.1101, -0.0093, -0.3245, -0.2465],
         [ 0.1099, -0.0101, -0.3226, -0.2447]]], grad_fn=<TransposeBackward0>)
True


# Build one attention block - multihead

In [43]:
### build a NN module for our multihead attention mechanism

class MyMultiheadAttention(nn.Module):
    
    def __init__(self, W_linear, b_linear, num_heads=4):
        super().__init__()
        
        self.W_linear = W_linear
        self.b_linear = b_linear
        self.num_heads=num_heads

    def forward(self, Q, K, V):
        # chunk Q, K, V in num_heads along embedding size
        Q_chunks = Q.chunk(self.num_heads, dim=-1)
        K_chunks = K.chunk(self.num_heads, dim=-1)
        V_chunks = V.chunk(self.num_heads, dim=-1)
        
        # apply attention to each chunk
        Z_test_list = []
        for i in range(len(Q_chunks)):
            Q_head = Q_chunks[i]
            K_head = K_chunks[i]
            V_head = V_chunks[i]
            Z_test_PT_, _ = nn.functional._scaled_dot_product_attention(
                           Q_head.view(1, num_words_in_input_dim, embedding_dim//self.num_heads),
                           K_head.view(1, num_words_in_input_dim, embedding_dim//self.num_heads),
                           V_head.view(1, num_words_in_input_dim, embedding_dim//self.num_heads),
                            attn_mask=None, dropout_p=0.)
            Z_test_list.append(Z_test_PT_)
        
        # concatenate all chunks along embedding size
        Z_test_PT = torch.cat(Z_test_list, dim=-1)
        
        # Linear layer
        out = Z_test_PT @ self.W_linear.T + self.b_linear
        return(out)

In [44]:
### compare with the PyTorch function for multiple attention heads

# PyTorch doc:
# num_heads – Number of parallel attention heads. Note that embed_dim will be split
# across num_heads (i.e. each head will have dimension embed_dim // num_heads

# set embed_dim to 12 to not confuse the number of attention heads with embed dim
embedding_dim=12

# this includes the linear layer after the attention head
# and therefore yields different results
X_test = torch.rand((num_words_in_input_dim, embedding_dim))
num_attention_heads = 4
multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                             num_heads=num_attention_heads, batch_first=True, bias=False)
X_test_heads = X_test.view(1, num_words_in_input_dim, embedding_dim)

out_test_PT, _ = multihead_attn(X_test_heads, X_test_heads, X_test_heads,
                            need_weights=False)

in_proj_weight_attn = multihead_attn.in_proj_weight
print(in_proj_weight_attn.size())
in_proj_weight_Q = in_proj_weight_attn[:embedding_dim,:]
in_proj_weight_K = in_proj_weight_attn[embedding_dim:2*embedding_dim,:]
in_proj_weight_V = in_proj_weight_attn[2*embedding_dim:3*embedding_dim,:]

W_Linear_test = multihead_attn.out_proj.weight
#b_Linear_test = multihead_attn.out_proj.bias
print(W_Linear_test.size())
#print(b_Linear_test.size())
print(out_test_PT.size())

torch.Size([36, 12])
torch.Size([12, 12])
torch.Size([1, 8, 12])


In [45]:
multi_head_attention_test = MyMultiheadAttention(W_Linear_test, torch.zeros((embedding_dim,)),
                                                 num_heads=num_attention_heads)
#regenerate Q_test, K_test, V_test with the weights used by the pytorch function
Z_test, Q_test, K_test, V_test = my_attention(X_test, in_proj_weight_Q.T, in_proj_weight_K.T, in_proj_weight_V.T)
#Note the transpose of W_Q, W_K and W_V due to the way linear layers work in PyTorch
out_test = multi_head_attention_test(Q_test, K_test, V_test)
print(out_test)
print(out_test_PT)
print(torch.allclose(out_test, out_test_PT))
# it results in the same thing -> WIN -> we can now use PyTorch's MultiheadAttention

tensor([[[ 0.2658, -0.1326, -0.0226,  0.3819, -0.2050, -0.1240, -0.0373,
           0.2997, -0.4118, -0.3684, -0.2174,  0.3167],
         [ 0.2661, -0.1293, -0.0220,  0.3800, -0.2064, -0.1232, -0.0362,
           0.2978, -0.4111, -0.3658, -0.2163,  0.3180],
         [ 0.2662, -0.1279, -0.0221,  0.3793, -0.2056, -0.1232, -0.0364,
           0.2974, -0.4100, -0.3645, -0.2167,  0.3193],
         [ 0.2659, -0.1314, -0.0231,  0.3828, -0.2068, -0.1230, -0.0362,
           0.2986, -0.4123, -0.3678, -0.2179,  0.3173],
         [ 0.2654, -0.1282, -0.0220,  0.3808, -0.2047, -0.1249, -0.0354,
           0.2993, -0.4129, -0.3650, -0.2172,  0.3199],
         [ 0.2664, -0.1292, -0.0210,  0.3815, -0.2046, -0.1247, -0.0369,
           0.3000, -0.4126, -0.3658, -0.2181,  0.3207],
         [ 0.2657, -0.1299, -0.0229,  0.3779, -0.2068, -0.1225, -0.0375,
           0.2971, -0.4092, -0.3659, -0.2143,  0.3171],
         [ 0.2642, -0.1309, -0.0221,  0.3813, -0.2017, -0.1266, -0.0352,
           0.3014, -0.41

# Build one encoder block

where one encoder block is a self attention block with residual connections and layer normalization. This is followed by a linear layer with another residual connection and another layer normalization.

In [60]:
# define our own encoder Block
class MyEncoderBlock(nn.Module):
    
    def __init__(self, num_heads=4, embedding_dim=12):
        super().__init__()
        
        self.multi_head_attention = nn.MultiheadAttention(embed_dim=embedding_dim,
                                             num_heads=num_heads, batch_first=True, bias=False)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        
        self.linear1 = nn.Linear(embedding_dim, 4*embedding_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(4*embedding_dim, embedding_dim)

    def forward(self, X):
        
        attn, _ = self.multi_head_attention(X, X, X, need_weights=False)
        n1 = self.norm1(X + attn)
        
        x = self.linear2(self.relu(self.linear1(n1)))
        n2 = self.norm2(x + n1)
        return(n2)

In [62]:
### test if our model has any errors
my_encoder_block = MyEncoderBlock(num_heads=4, embedding_dim=12)
X_test_heads = X_test.view(1, num_words_in_input_dim, embedding_dim)
print(my_encoder_block(X_test_heads).size())

torch.Size([1, 8, 12])


In [65]:
### compare to PyTorch implementation

encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_attention_heads, 
                                                 batch_first=True, dropout=0.)
out_test_PT = encoder_layer(X_test_heads)
print(out_test_PT.size())

torch.Size([1, 8, 12])


In [66]:
### comparison by visual inspection of the PyTorch implementation: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
# we conclude it's good enough; actually setting all the weights manually would be too much of a pain