In [1]:
import torch
import torch.nn as nn
from attention import AbsolutePositionalEncoding as PositionalEncoding

class TransformerConfig:
    def __init__(self, d_model=256, num_heads=4, num_layers=6, vocab_size=10000, max_len=5000, dropout=0.1):
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.dropout = dropout

class TransformerEncoder(nn.Module):
    """
    A simple Transformer encoder block.
    """
    def __init__(self, config):
        super().__init__()
        self.embed = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_encoding = PositionalEncoding(config)

class Transformer(nn.Module):
    """
    A simple Transformer model, as in Vaswani et al. (2017).
    """
    def __init__(self, config):
        super().__init__()
        self.encoder = TransformerEncoder(config)

In [11]:
'''
Absolute Positional Encoding
'''

import torch
import math

max_len = 100
d_model = 12
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=float).unsqueeze(1)
print("position.shape", position.shape, position)
every_other_dim = torch.arange(0, d_model, 2, dtype=float)
print("every_other_dim.shape", every_other_dim.shape, every_other_dim)
'''
a = 2i/d_model
-ln(10000^a) = -a ln(10000) 
exp(ln(10000^-a)) = 10000^(-a) 
'''
div_term = torch.exp((math.log(10000.0)) * -every_other_dim / d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
pe.shape, pe[0][1]

# pe[0]



position.shape torch.Size([100, 1]) tensor([[ 0.],
        [ 1.],
        [ 2.],
        [ 3.],
        [ 4.],
        [ 5.],
        [ 6.],
        [ 7.],
        [ 8.],
        [ 9.],
        [10.],
        [11.],
        [12.],
        [13.],
        [14.],
        [15.],
        [16.],
        [17.],
        [18.],
        [19.],
        [20.],
        [21.],
        [22.],
        [23.],
        [24.],
        [25.],
        [26.],
        [27.],
        [28.],
        [29.],
        [30.],
        [31.],
        [32.],
        [33.],
        [34.],
        [35.],
        [36.],
        [37.],
        [38.],
        [39.],
        [40.],
        [41.],
        [42.],
        [43.],
        [44.],
        [45.],
        [46.],
        [47.],
        [48.],
        [49.],
        [50.],
        [51.],
        [52.],
        [53.],
        [54.],
        [55.],
        [56.],
        [57.],
        [58.],
        [59.],
        [60.],
        [61.],
        [62.],
        [63.],
    

(torch.Size([1, 100, 12]),
 tensor([8.4147e-01, 5.4030e-01, 2.1378e-01, 9.7688e-01, 4.6399e-02, 9.9892e-01,
         9.9998e-03, 9.9995e-01, 2.1544e-03, 1.0000e+00, 4.6416e-04, 1.0000e+00]))

In [166]:
import torch
import torch.nn as nn
import math


class RelativeMultiHeadSelfAttention(nn.Module):
    """
    An example implementation of multi-head self-attention with
    relative position representations, following Shaw et al. (2018).
    """
    def __init__(self, d_model=256, num_heads=4, max_len=512):
        super().__init__()

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.max_len = max_len

        # Projections for the usual Q, K, V
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj   = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.out_proj   = nn.Linear(d_model, d_model)

        self.relative_key_embeddings = nn.Embedding(2 * max_len + 1, self.head_dim)
        self.relative_value_embeddings = nn.Embedding(2 * max_len + 1, self.head_dim)
    
    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        q = self.query_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # --- Content-based (standard) attention score ---
        # qk shape: (batch_size, num_heads, seq_len, seq_len)
        qk = torch.matmul(q, k.transpose(-2, -1))

        # --- Incorporate relative position (Key) ---
        # Build a matrix of relative position offsets for each pair (i, j)
        pos_ids = torch.arange(seq_len).unsqueeze(1) - torch.arange(seq_len)

        # If sequence is seq_len, then pos_ids is in range [-(seq_len), (seq_len)]
        # Shifting by max_len puts the range in [0, 2 * max_len + 1]
        # which is the range of the relative key embeddings
        pos_ids = pos_ids + (self.max_len)
        pos_ids = pos_ids.clamp(0, 2*self.max_len + 1)                      # shape (seq_len, seq_len)

        rel_k = self.relative_key_embeddings(pos_ids)                       # shape (seq_len, seq_len, head_dim)
        q_r = q.unsqueeze(3)                                                # shape (batch_size, num_heads, seq_len, 1, head_dim)
        rel_k = rel_k.unsqueeze(0).unsqueeze(0)                             # shape (1, 1, seq_len, seq_len, head_dim)
        qk_r = torch.matmul(q_r, rel_k.transpose(-2, -1)).squeeze(-2)       # shape (batch_size, num_heads, seq_len, seq_len)

        qk = qk + qk_r
        qk_r = qk / math.sqrt(self.head_dim)
        probs = torch.softmax(qk, dim=-1)                                  # shape (batch_size, num_heads, seq_len, seq_len)

        attn = torch.matmul(probs, v) # scores torch.Size([2, 4, 12, 4])          

        # attn = torch.matmul(scores, v) # scores torch.Size([2, 4, 12, 4])          
        # print("attn", attn.shape)                          

        rel_v = self.relative_value_embeddings(pos_ids)                     # shape (seq_len, seq_len, head_dim)
        print("rel_v", rel_v.shape)

        # Typically we can do a matrix multiplication of the attention probabilities and the values
        # For relative embeddings, we have a separate embedding for each possible distance, per head
        # So, instead of a single matrix with size (seq_len, head_dim), we have a matrix with size (seq_len, seq_len, head_dim)
        # α(i,j) * aV(i,j)
        # Step 1 - element-wise multiply each attention weight by the corresponding relative embedding
        probs_r = probs.unsqueeze(-1)                                       # shape (batch_size, num_heads, seq_len, seq_len, 1)
        rel_v = rel_v.unsqueeze(0).unsqueeze(0)                             # shape (1, 1, seq_len, seq_len, dim_head)
        attn_r = probs_r * rel_v                                            # shape (batch_size, num_heads, seq_len, seq_len, dim_head)
        # Step 2 - sum over the dimension j
        attn_r = attn_r.sum(dim=3)                                           # shape (batch_size, num_heads, seq_len, dim_head)
        
        attn = attn + attn_r
        # Reshape to the original dimensions and consolidate all of the heads
        out = attn.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        # And linearly project
        out = self.out_proj(out)
        return out

        

In [168]:
batch_size = 2
seq_len = 12
d_model = 16
num_heads = 4

# Dummy input
x = torch.randn(batch_size, seq_len, d_model)  # (B, L, d_model)

# Create the module
attn = RelativeMultiHeadSelfAttention(d_model, num_heads, max_len=seq_len)

# Forward pass
output = attn(x)
print("output", output.shape, output)

rel_v torch.Size([12, 12, 4])
output torch.Size([2, 12, 16]) tensor([[[ 7.0756e-02,  5.5792e-01, -3.5239e-01, -3.8813e-02,  3.1118e-02,
          -7.2510e-02, -2.5472e-01,  2.1883e-01,  4.3866e-01, -6.3810e-01,
          -4.8520e-01,  2.7380e-01,  1.7937e-02, -1.3297e-01, -2.8839e-01,
          -2.7720e-01],
         [ 2.2258e-01, -2.6789e-01,  1.1867e-01,  2.2818e-01, -6.2169e-01,
          -1.3462e-01,  1.5293e-01,  9.9110e-02,  1.6223e-01, -2.0867e-01,
           2.0280e-01,  8.6187e-01,  6.1669e-02,  2.2743e-01,  3.1709e-01,
           2.9780e-01],
         [-7.3058e-01, -5.7413e-01, -6.7851e-01,  3.4520e-02,  3.1706e-01,
           3.0778e-01,  3.0434e-01, -5.7832e-01,  1.0315e+00,  2.6030e-01,
          -7.0750e-02,  6.3350e-01, -2.5737e-01, -5.5300e-01,  5.7763e-01,
           1.5898e-01],
         [ 9.9492e-02, -1.2708e-01,  3.8734e-01,  6.8186e-01, -8.5880e-01,
          -8.3040e-01, -4.8952e-01, -1.4007e-03,  3.2401e-01, -5.1098e-01,
          -3.6689e-01,  6.2249e-01,  7.215