In [1]:
import torch
import torch.nn as nn

In [2]:
class RoPE(nn.Module):
    """
    Applies RoPE to the last dimension of an input tensor.
    Assumes input tensor shape: (batch, seq_len, d_out)
    """

    def __init__(self, dim, base=10000):
        """
        dim (int): feature dimension (must be even)
        base (int): RoPE base (default 10000)
        """
        super().__init__()
        assert dim % 2 == 0, "RoPE requires an even dimension"

        self.dim = dim
        self.num_pairs = dim // 2

        inv_freq = 1.0 / (base ** (torch.arange(0, self.num_pairs) * 2 / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x):
        """
        x: (batch, seq_len, dim)
        returns: Tensor of same shape with RoPE applied
        """
        b, seq_len, d = x.shape

        positions = torch.arange(
            seq_len, device=x.device, dtype=self.inv_freq.dtype
        )

        angles = torch.einsum("i,j->ij", positions, self.inv_freq)
        sin = angles.sin()[None, :, :]
        cos = angles.cos()[None, :, :]

        x = x.view(b, seq_len, self.num_pairs, 2)
        x1, x2 = x[..., 0], x[..., 1]

        x_rot = torch.stack(
            [
                x1 * cos - x2 * sin,
                x1 * sin + x2 * cos,
            ],
            dim=-1,
        )

        return x_rot.view(b, seq_len, d)

In [4]:
""" 
this can be called in the following way:
rope = RoPE(dim=d_out)
queries = rope(queries)
keys = rope(keys)
"""

' \nthis can be called in the following way:\nrope = RoPE(dim=d_out)\nqueries = rope(queries)\nkeys = rope(keys)\n'

In [9]:
# Alternative way: incorporating RoPE directly in the CausalAttention class itself
class CausalAttention(nn.Module):
    """
    Single-head causal self-attention with Rotary Positional Embeddings (RoPE).
    """

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False, rope_base=10000):
        super().__init__()
        assert d_out % 2 == 0, "RoPE requires d_out to be even"

        self.d_out = d_out
        self.num_pairs = d_out // 2

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key =   nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1),
        )

        # RoPE frequency vector
        inv_freq = 1.0 / (rope_base ** (torch.arange(0, self.num_pairs) * 2 / d_out))
        self.register_buffer("inv_freq", inv_freq)

    def apply_rope(self, x):
        """
        x : (batch, seq_len, d_out)
        returns: rotated x, same shape
        """
        b, seq_len, d = x.shape

        # (seq_len, )
        positions = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)

        # (seq_len, num_pairs)
        angles = torch.einsum("i,j->ij", positions, self.inv_freq)

        sin = angles.sin()[None, :, :] # (1, seq_len, num_pairs)
        cos = angles.cos()[None, :, :] # (1, seq_len, num_pairs)

        # split into 2D pairs
        x = x.view(b, seq_len, self.num_pairs, 2)
        x1, x2 = x[..., 0], x[..., 1]

        # Rotate
        x_rotated = torch.stack([
                x1 * cos - x2 * sin,
                x1 * sin + x2 * cos,
            ],
            dim=-1,
        )


        return x_rotated.view(b, seq_len, d)

    def forward(self, x):
        """
        x: (batch_size, num_tokens, d_in)
        returns: Context vectors, shape (batch_size, num_tokens, d_out)
        """
        b, num_tokens, _ = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        queries = self.apply_rope(queries)
        keys = self.apply_rope(keys)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec