Reference: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

How to use: <br>
initialize with: <br>
input dimension D (input_dim) <br>
output dimension E (embed_dim) <br>
number of attention heads (num_heads) <br>
--> embed_dim % num_heads == 0!

In [2]:
import os
import numpy as np
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads): # embed_dim = num_heads*D_weights! input_dim=D!
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim, bias=False) # Linear layer
        self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        #self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        #self.o_proj.bias.data.fill_(0)
    
    @staticmethod
    def scaled_dot_product(q, k, v, mask=None):
        d_k = q.size()[-1]
        attn_logits = torch.matmul(q, k.transpose(-2, -1))
        attn_logits = attn_logits / math.sqrt(d_k)
        if mask is not None:
            attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
        attention = F.softmax(attn_logits, dim=-1)
        values = torch.matmul(attention, v)
        return values, attention

        
    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size() # B, N
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k, v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = self.scaled_dot_product(q, k, v, mask=mask)  # Pass mask as a keyword argument
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)
        
        # Layer Normalization
        B, N, D = o.shape
        layer_norm = nn.LayerNorm([N, D])
        o = layer_norm(o)

        if return_attention:
            return o, attention
        else:
            return o

In [7]:
# Test
torch.manual_seed(0)

N=20
E=9 # is also the output dimension!
D=5
B=4

input = torch.rand(B, N, D)

custom_att = MultiheadAttention(input_dim=D, embed_dim=9, num_heads=3)

Z_custom = custom_att(input)

print("Custom implementation:")
print(Z_custom.shape)
print(Z_custom)


Custom implementation:
torch.Size([4, 20, 9])
tensor([[[-0.4028,  1.3057,  0.9940,  0.5412, -0.8579, -0.9394,  0.5938,
          -1.8596,  0.6290],
         [-0.4086,  1.3101,  1.0071,  0.5427, -0.8552, -0.9456,  0.5935,
          -1.8722,  0.6328],
         [-0.4031,  1.3023,  0.9912,  0.5414, -0.8571, -0.9367,  0.5914,
          -1.8563,  0.6271],
         [-0.4067,  1.3093,  1.0045,  0.5399, -0.8552, -0.9441,  0.5932,
          -1.8728,  0.6328],
         [-0.4021,  1.3021,  0.9905,  0.5391, -0.8575, -0.9374,  0.5905,
          -1.8614,  0.6293],
         [-0.4023,  1.3066,  0.9937,  0.5415, -0.8586, -0.9399,  0.5947,
          -1.8591,  0.6293],
         [-0.4017,  1.3030,  0.9905,  0.5386, -0.8573, -0.9369,  0.5911,
          -1.8603,  0.6290],
         [-0.4095,  1.3045,  1.0025,  0.5425, -0.8535, -0.9408,  0.5881,
          -1.8681,  0.6311],
         [-0.4060,  1.3057,  1.0005,  0.5398, -0.8550, -0.9412,  0.5919,
          -1.8681,  0.6300],
         [-0.4074,  1.3094,  1.0044,