In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import Rope 

In [48]:
class RopeAttention(nn.Module):

    """A placeholder for the RopeAttention module.
    This module currently does not implement any functionality.
    It is intended to be a stub for future development.
    """

    def __init__(self, d_model, n_heads, kv_latent_dim):
        super().__init__()
        self.d_model = d_model # Dimension of embeddings
        self.n_heads = n_heads # Number of heads
        self.dh = d_model // n_heads # dimensions of heads

        self.rope = Rope.Rope(d_model, 20000) # RoPE instance

        self.W_dq = nn.Linear(d_model, kv_latent_dim, bias = False) # Query down projection
        self.W_dkv = nn.Linear(d_model, kv_latent_dim, bias=False) # Down projection
        
        self.W_uk = nn.Linear(kv_latent_dim, d_model, bias = False) # Up projection to Keys
        self.W_uv = nn.Linear(kv_latent_dim, d_model, bias = False) # Up projection to values
        self.W_uq = nn.Linear(kv_latent_dim, d_model, bias = False) # Up projection to queries
        
        self.W_qr = nn.Linear(kv_latent_dim, d_model, bias = False) # Query projection for RoPE
        self.W_kr = nn.Linear(d_model, self.dh, bias = False) # Key projection for RoPE
        
        self.W_o = nn.Linear(d_model, d_model, bias = False) # Output projection
        self.ln = nn.LayerNorm(kv_latent_dim) # Layer norm

    def forward(self, x, kv_cache=None, kr_cache=None, past_length=0):
        B, S, D = x.size() # Batch size, sequence length, and embedding dimension

        # Query down projection and attention scores
        c_q = self.ln(self.W_dq(x)) # (B, S, kv_latent_dim) - down projection

        #### WITHOUT ROPE ####

        # queries up projection, first 
        q_c = self.W_uq(c_q).view(B, S, self.n_heads, self.dh) # (B, S, num_heads, head_dim)
        attn_scores = torch.zeros(B, S, self.n_heads, self.dh, device=x.device) # (B, S, num_heads)

        # Keys and values down projection
        new_c_kv = self.ln(self.W_dkv(x)) # (B, S, kv_latent_dim) - down projection

        # update cache
        if kv_cache is None:
            c_kv = new_c_kv
        else:
            c_kv = torch.cat([kv_cache, new_c_kv], dim=1)
        
        S_full = c_kv.size(1) # number of tokens in total cache

        # keys and values up projection
        k_c = self.W_uk(c_kv).view(B, S_full, self.n_heads, self.dh) # (B, S_full, num_heads, head_dim)
        v_c = self.W_uv(c_kv).view(B, S_full, self.n_heads, self.dh) # (B, S_full, num_heads, head_dim)

        #### WITH ROPE ####

        # queries up projection
        q_r = self.rope(self.W_qr(c_q)).view(B, S, self.n_heads, self.dh) # (B, S, num_heads, head_dim)

        # Keys up projection
        new_kr_cache = self.ln(self.W_kr(x)) # (B, S, dh) - down projection
        if kr_cache is None:
            k_r = new_kr_cache
        else:
            k_r = torch.cat([kr_cache, new_kr_cache], dim=1)

        # Multiple heads for keys
        k_r = torch.stack([k_r] * self.n_heads, dim=2) # (B, S_full, num_heads, head_dim)

        #### JOINING ####
        # Concatenate queries and keys
        q = torch.cat([q_c, q_r], dim=3) # (B, S, num_heads, head_dim)
        k = torch.cat([k_c, k_r], dim=3) # (B, S_full, num_heads, head_dim)

        # Compute attention scores
        for i in range(self.n_heads):
            attn_scores[:, i, :, :] = torch.matmul(q[:, :, i, :], k[:, :, i, :].transpose(2, 3))

        # Mask, softmax, and dropout
        attn_scores = attn_scores / (self.dh ** 0.5)
        mask = torch.tril(torch.ones((S, S_full), device=x.device), diagonal=past_length)
        attn_scores = attn_scores.masked_fill(mask.view(1, 1, S, S_full) == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)

        # Context vector using attention weights and values
        out_heads = []
        for i in range(self.n_heads):
            context_h = torch.matmul(attn_weights[:, :, i, :], v_c[:, :, i, :])
            out_heads.append(context_h)
        
        out_heads = torch.stack(out_heads, dim=2) # (B, S, num_heads, head_dim)
        out = out_heads.view(B, S, self.d_model)

        out = self.W_o(out) # (B, S, d_model) - output projection

        return out, c_kv, k_r  # Return output, key-value cache, and key RoPE cache

In [49]:
test = RopeAttention(d_model=32, n_heads=4, kv_latent_dim=16)


In [52]:
test(torch.randn(2, 16, 32))

RuntimeError: Given normalized_shape=[16], expected input with shape [*, 16], but got input of size[2, 16, 8]