In [1]:
import torch
import math

def create_rope_rotate_matrix(dim, max_seq_len, device=None):
    """
    Create the full RoPE rotation matrix.
    
    Args:
    dim (int): Embedding dimension (must be even)
    max_seq_len (int): Maximum sequence length
    device: torch device

    Returns:
    torch.Tensor: Rotation matrix of shape (max_seq_len, dim, dim)
    """
    assert dim % 2 == 0, "Dimension must be even"
    
    # Create the theta values
    theta = 10000 ** (-2 * torch.arange(1, dim//2 + 1, device=device) / dim)
    
    # Create position indices
    m = torch.arange(max_seq_len, device=device).unsqueeze(1)
    
    # Compute all mθ values
    m_theta = m * theta.unsqueeze(0)
    
    # Initialize the rotation matrix
    rotate_matrix = torch.zeros(max_seq_len, dim, dim, device=device)
    
    for i in range(dim // 2):
        cos_theta = torch.cos(m_theta[:, i])
        sin_theta = torch.sin(m_theta[:, i])
        
        # Fill in the 2x2 rotation submatrices
        rotate_matrix[:, 2*i, 2*i] = cos_theta
        rotate_matrix[:, 2*i, 2*i+1] = -sin_theta
        rotate_matrix[:, 2*i+1, 2*i] = sin_theta
        rotate_matrix[:, 2*i+1, 2*i+1] = cos_theta
    
    return rotate_matrix

In [6]:
create_rope_rotate_matrix(4, 1)

tensor([[[1., -0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., -0.],
         [0., 0., 0., 1.]]])