Reference: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html

How to use: <br>
initialize with: <br>
input dimension D (input_dim) <br>
output dimension E (embed_dim) <br>
number of attention heads (num_heads) <br>
--> embed_dim % num_heads == 0!

In [13]:
import os
import numpy as np
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [14]:
class MultiheadAttention(nn.Module):

    def __init__(self, N, input_dim, embed_dim, num_heads): # embed_dim = num_heads*D_weights! input_dim=D!
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        self.qkv_proj = nn.Linear(input_dim, 3*embed_dim, bias=False) # Linear layer
        self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        
        self.layer_norm = nn.LayerNorm([N, embed_dim])

        self._reset_parameters()

    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.qkv_proj.weight)
        #self.qkv_proj.bias.data.fill_(0)
        nn.init.xavier_uniform_(self.o_proj.weight)
        #self.o_proj.bias.data.fill_(0)
    
    @staticmethod
    def scaled_dot_product(q, k, v, mask=None):
        d_k = q.size()[-1]
        attn_logits = torch.matmul(q, k.transpose(-2, -1))
        attn_logits = attn_logits / math.sqrt(d_k)
        if mask is not None:
            attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
        attention = F.softmax(attn_logits, dim=-1)
        values = torch.matmul(attention, v)
        return values, attention

        
    def forward(self, x, mask=None, return_attention=False):
        batch_size, seq_length, _ = x.size() # B, N
        qkv = self.qkv_proj(x)

        # Separate Q, K, V from linear output
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        q, k ,v = qkv.chunk(3, dim=-1)

        # Determine value outputs
        values, attention = self.scaled_dot_product(q, k, v, mask=mask)  # Pass mask as a keyword argument
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        values = values.reshape(batch_size, seq_length, self.embed_dim)
        o = self.o_proj(values)
        
        # Layer Normalization
        o = self.layer_norm(o)

        if return_attention:
            return o, attention
        else:
            return o

In [15]:
# Test

N=20
E=9 # is also the output dimension!
D=6
B=4

matrix = torch.rand(B, N, D)

att = MultiheadAttention(N, input_dim=D, embed_dim=E, num_heads=3)

Z = att(matrix)
print(Z.shape)

torch.Size([4, 20, 9])
