https://arxiv.org/pdf/1803.02155.pdf
https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a

In [1]:
import math
import torch
from torch import nn

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, max_len=1024, dropout=0.1):
        super().__init__()
        d_head, remainder = divmod(d_model, num_heads)
        if remainder:
            raise ValueError(
                "incompatible `d_model` and `num_heads`"
            )
        self.d_model = d_model
        self.num_heads = num_heads
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.query = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.Er = nn.Parameter(
            torch.randn((num_heads, max_len, d_head))
        )

    
    def forward(self, x):
        # x.shape == (batch_size, seq_len, d_model)
        batch_size, seq_len, _ = x.shape
        
        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
        # k_t.shape = (batch_size, num_heads, d_head, seq_len)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        # shape = (batch_size, num_heads, seq_len, d_head)
        
        attn = torch.matmul(q, k_t) / math.sqrt(q.size(-1))
        # attn.shape = (batch_size, num_heads, seq_len, seq_len)
        

$$
z_i = \sum_{j = 1}^n \alpha_{ij} (x_j W^V + a_{ij}^V)
$$

$$
e_{ij} = \frac{x_i W^Q (x_j W^K + a_{ij}^K)^\top}{\sqrt{d_z}}
$$

Otherwise, the rest of the equation remains unchanged.

$$
\alpha_{ij} = \frac{\text{exp} \space e_{ij}}{\sum_{k = 1}^n \text{exp} \space e_{ik}}
$$