In [23]:
# Okay, it actually shouldn't be that hard to test whether this is getting the same results.  I don't know if I have it in my right now, though.
import torch
class RoPE:
    _cache = {}
    @classmethod
    def _populate_cache(cls, dim, seq_len, device, cache_key):
        dim, period = cache_key
        inv_freq = 1.0 / (period ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
        t = torch.arange(seq_len, device = device, dtype = torch.int64).type_as(inv_freq)
        with torch.autocast(device_type = device.type, enabled=False):
            freqs = torch.outer(t, inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos().to(torch.get_default_dtype())
            sin = emb.sin().to(torch.get_default_dtype())
        cls._cache[cache_key] = (seq_len, cos, sin)
        
    @classmethod
    def _get_cached_sin_con(cls, dim, seq_len, device, period = 10_000):
        cache_key = (dim, period)
        if cache_key not in cls._cache or seq_len > cls._cache[cache_key][0]:
            cls._populate_cache(dim, seq_len, device, cache_key)
        _, cos, sin = cls._cache[cache_key]
        return cos[:seq_len].to(device), sin[:seq_len].to(device)

    @classmethod
    def embed(cls, x, period = 10_000, head_size = None):
        seq_len = x.size(-2)
        device = x.device
        dim = head_size if head_size is not None else x.size(-1)
        cos, sin = cls._get_cached_sin_con(dim, seq_len, device, period = period)
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        rotated = torch.cat((-x2, x1), dim=-1)
        embedded = (x * cos) + (rotated * sin)
        return embedded

torch.Size([7, 256]) torch.Size([7, 256])
torch.Size([3, 4, 7, 128]) torch.Size([3, 4, 7, 128])


torch.Size([3, 4, 7, 256])

In [29]:
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.mistral.modeling_mistral import MistralRotaryEmbedding

head_dim = 64
num_heads = 4
llama_rope = LlamaRotaryEmbedding(head_dim)
mistral_rope = MistralRotaryEmbedding(head_dim)

In [39]:
seq_len = 7
x = torch.randn(3, num_heads, seq_len, head_dim)
cos_llama, sin_llama = llama_rope(x, seq_len = seq_len)
cos_mistral, sin_mistral = mistral_rope(x, seq_len = seq_len)
cos_me, sin_me = RoPE._get_cached_sin_con(head_dim, seq_len, x.device)
## All seem to be equal; so far so good.

In [44]:
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    print(cos.shape)
    cos = cos.unsqueeze(unsqueeze_dim)
    print(cos.shape)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

q_llama, k_llama = apply_rotary_pos_emb(x, x, cos_llama, sin_llama)
#q_me, k_me = RoPE.embed(x), RoPE.embed(x)


torch.Size([7, 64]) torch.Size([7, 64])
torch.Size([3, 4, 7, 32]) torch.Size([3, 4, 7, 32])
torch.Size([7, 64]) torch.Size([7, 64])
torch.Size([3, 4, 7, 32]) torch.Size([3, 4, 7, 32])
