In [1]:
import torch
import torch.nn as nn
import numpy as np

## Scaled Dot Product Attention

### Pytorch 2.0 Version

In [2]:
### PyTorch Version
from torch.nn.functional import scaled_dot_product_attention

### Self-inplemented version

In [3]:
import math

class DotProductAttention(nn.Module):
    '''
    Args:
        p_dropout: probability of an element to be zeroed. Default: 0., i.e. no dropout
    '''

    def __init__(self, dropout_p=0.):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, queries, keys, values):
        d = queries.shape[-1]

        scores = (queries @ keys.transpose(-2, -1))/math.sqrt(d)
        self.attn_weights = nn.functional.softmax(scores, dim=-1)

        return self.dropout(self.attn_weights) @ values

In [4]:
### Test Case
queries = torch.normal(0, 1, (2,  1, 2)) ; print('queries.shape:', queries.shape)
keys    = torch.normal(0, 1, (2, 10, 2)) ; print('keys.shape:   ', keys.shape   )
values  = torch.normal(0, 1, (2, 10, 4)) ; print('values.shape: ', values.shape )

queries.shape: torch.Size([2, 1, 2])
keys.shape:    torch.Size([2, 10, 2])
values.shape:  torch.Size([2, 10, 4])


In [5]:
attentionMe = DotProductAttention() ; attentionMe.eval()

DotProductAttention(
  (dropout): Dropout(p=0.0, inplace=False)
)

In [6]:
%timeit attentionMe(queries, keys, values)
%timeit scaled_dot_product_attention(queries, keys, values)

90.2 µs ± 12.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
89.9 µs ± 3.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]:
resultMe = attentionMe(queries, keys, values)
resultMe.shape

torch.Size([2, 1, 4])

In [8]:
# check properties of attention weights
print('shape of attention weights', attentionMe.attn_weights.shape)
print('sum of attention weights of the 1st item in the batch', sum(attentionMe.attn_weights[0,0,:]))
print('sum of attention weights of the 2nd item in the batch', sum(attentionMe.attn_weights[1,0,:]))

shape of attention weights torch.Size([2, 1, 10])
sum of attention weights of the 1st item in the batch tensor(1.0000)
sum of attention weights of the 2nd item in the batch tensor(1.0000)


In [9]:
resultPT = scaled_dot_product_attention(queries, keys, values)
resultPT.shape

torch.Size([2, 1, 4])

In [10]:
are_close = torch.allclose(resultMe, resultPT)
are_close

True

## MultiHeadAttention

### Self-implemented with for loop

In [11]:
class MultiHeadAttention_forloop(nn.Module):

    def __init__(self, embed_dim, num_heads, bias=False):
        super().__init__()
        self.embed_dim = embed_dim  # embedded dimension for each token in a sequence
        self.num_heads = num_heads

        assert embed_dim % num_heads == 0, f"Can't divide dimension {embed_dim} into {num_heads} heads"

        d_head = int(embed_dim / num_heads)

        self.Wq = nn.ModuleList([nn.Linear(d_head, d_head, bias=bias) for _ in range(self.num_heads)])
        self.Wk = nn.ModuleList([nn.Linear(d_head, d_head, bias=bias) for _ in range(self.num_heads)])
        self.Wv = nn.ModuleList([nn.Linear(d_head, d_head, bias=bias) for _ in range(self.num_heads)])
        self.Wo = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, sequences):
        # Sequences has shape (N, seq_length, embed_dim), where embed_dim = token dimension
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.num_heads):
                Wq = self.Wq[head]
                Wk = self.Wk[head]
                Wv = self.Wv[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = Wq(seq), Wk(seq), Wv(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        
        out_concat = torch.cat([torch.unsqueeze(r, dim=0) for r in result])
        
        return self.Wo(out_concat)


### Self-Implemented Vectorized

In [12]:
class MultiHeadAttention_vectorized(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout_p=0., bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout_p)
        
        self.Wq = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.Wk = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.Wv = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.Wo = nn.Linear(embed_dim, embed_dim, bias=bias)
    
    def forward(self, queries, keys, values, need_weights=False, average_attn_weights=False):
        queries = self.transpose_qkv(self.Wq(queries))
        keys    = self.transpose_qkv(self.Wq(keys))
        values  = self.transpose_qkv(self.Wq(values))
        
        output = self.attention(queries, keys, values) # (batch_size*num_heads, n_seq, embed_dim/num_heads)
        output_concat = self.transpose_output(output)  # (batch_size, n_seq, embed_dim)
        
        if need_weights:
            ori_attn_weights = self.attention.attn_weights
            ori_attn_weights = ori_attn_weights.reshape(-1, self.num_heads, ori_attn_weights.shape[1], ori_attn_weights.shape[2])
            if average_attn_weights:
                return self.Wo(output_concat), ori_attn_weights.mean(dim=1)
            else: 
                return self.Wo(output_concat), ori_attn_weights
        else:
            return self.Wo(output_concat)
                        
    
    def transpose_qkv(self, X):
        '''Transposition for parallel computation of multiple attention heads.
            input  X.shape = (batch_size, n_seq, embed_dim) 
            output X.shape = (batch_size*num_heads, n_seq, embed_dim/num_heads)
        '''
        X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
        X = X.permute(0, 2, 1, 3)
        return X.reshape(-1, X.shape[2], X.shape[3])
    
    def transpose_output(self, X):
        '''Reverse the operation of transpose_qkv.'''
        X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
        X = X.permute(0, 2, 1, 3)
        return X.reshape(X.shape[0], X.shape[1], -1)

In [13]:
### Test Case
batch_size, n_seq, embed_dim = 128, 32, 16
num_heads = 2
x = torch.rand(batch_size, n_seq, embed_dim) # (batch_size, n_seq, embed_dim)
x.shape

torch.Size([128, 32, 16])

In [14]:
model_att_forloop = MultiHeadAttention_forloop(embed_dim, num_heads) ; model_att_forloop.eval()
%timeit model_att_forloop(x)

109 ms ± 4.46 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [15]:
model_att_vec = MultiHeadAttention_vectorized(embed_dim, num_heads) ; model_att_vec.eval()
%timeit model_att_vec(x, x, x)

1.79 ms ± 22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [16]:
out_forloop = model_att_forloop(x)
out_forloop.shape

torch.Size([128, 32, 16])

In [17]:
out_me, att_weights_me = model_att_vec(x, x, x, need_weights=True, average_attn_weights=False)
print('out_me.shape:', out_me.shape)
print('att_weights_me.shape:', att_weights_me.shape)

out_me.shape: torch.Size([128, 32, 16])
att_weights_me.shape: torch.Size([128, 2, 32, 32])


In [18]:
out_me, att_weights_me = model_att_vec(x, x, x, need_weights=True, average_attn_weights=True)
print('out_me.shape:', out_me.shape)
print('att_weights_me.shape:', att_weights_me.shape)

out_me.shape: torch.Size([128, 32, 16])
att_weights_me.shape: torch.Size([128, 32, 32])


### Pytorch 2.0 MultiHeadAttention

In [19]:
multihead_attn_PT = nn.MultiheadAttention(embed_dim, num_heads) ; multihead_attn_PT.eval()
%timeit multihead_attn_PT(x, x, x, need_weights=False)

2.23 ms ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
multihead_attn_PT = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
out_PT, att_weights_PT_avg = multihead_attn_PT(x, x, x, need_weights=True, average_attn_weights=True)
print('out_PT.shape:', out_PT.shape)
print('att_weights_PT_avg.shape:', att_weights_PT_avg.shape)

out_PT.shape: torch.Size([128, 32, 16])
att_weights_PT_avg.shape: torch.Size([128, 32, 32])


In [21]:
#multihead_attn_PT = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
out_PT, att_weights_PT = multihead_attn_PT(x, x, x, need_weights=True, average_attn_weights=False)
print('out_PT.shape:', out_PT.shape)
print('att_weights_PT.shape:', att_weights_PT.shape)

out_PT.shape: torch.Size([128, 32, 16])
att_weights_PT.shape: torch.Size([128, 2, 32, 32])


In [22]:
torch.allclose(att_weights_PT.mean(dim=1), att_weights_PT_avg)

True

## ------ END ------------

### References

- https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c