In [33]:
import torch
from llmspeech.model import GPT
from einops import rearrange, repeat
from torch.cuda.amp import autocast
from torch import Tensor
# from rotary_embedding_torch import apply_rotary_emb

In [2]:
device = "cuda"

In [3]:
name = "gpt-small-055000.pt"
step = name.split("-")[-1].split(".")[0]
model = GPT.from_huggingface("gpt-small-055000.pt").eval().to(device)

In [4]:
B, n_heads, max_seqlen, d_head = (
    1,
    model.config.n_heads,
    4096,
    model.config.d_model // model.config.n_heads,
)
q = torch.randn((B, n_heads, max_seqlen, d_head), device=device)
q_lucidrains = model.decoder.blocks[0].attn.rotary_emb.rotate_queries_or_keys(q)

In [5]:
q_lucidrains.shape

torch.Size([1, 12, 4096, 64])

In [6]:
q_lucidrains

tensor([[[[ 1.4490, -0.8268, -1.0788,  ...,  0.2822, -0.0291,  0.7546],
          [ 0.3623,  2.1878,  1.0621,  ...,  0.0234, -0.4768,  0.5508],
          [-0.5561,  1.6592, -0.5241,  ..., -1.2569,  1.9518, -2.2922],
          ...,
          [ 0.5861,  2.5088, -0.9223,  ...,  0.5935,  0.1566, -1.1205],
          [-0.3012, -1.2545, -0.5004,  ..., -0.6181, -2.0841,  0.3057],
          [ 0.1696,  0.4944,  0.5403,  ...,  0.5116,  0.5030, -0.4345]],

         [[-0.2307,  1.1640,  0.3676,  ...,  1.4544,  0.8257,  1.1746],
          [-1.1515, -0.5102,  0.0189,  ..., -0.5825,  1.0733,  0.3655],
          [-0.1857, -0.8186, -0.2061,  ...,  0.4051,  0.6016,  1.1461],
          ...,
          [ 0.0291, -0.4775, -0.1134,  ..., -1.1242, -1.1331, -0.5373],
          [-1.6571, -0.2224,  0.6021,  ..., -0.6981,  0.1160,  0.6479],
          [-0.7649, -0.9822,  0.2265,  ...,  0.8639,  0.7711, -0.9251]],

         [[ 0.1778,  0.7321, -0.5939,  ..., -0.2232, -0.4939, -0.3253],
          [ 1.1385,  0.2855,  

In [7]:
model.decoder.blocks[0].attn.rotary_emb.cache_if_possible

True

In [8]:
model.decoder.blocks[0].attn.rotary_emb.default_seq_dim

-2

In [9]:
cached_freqs = model.decoder.blocks[0].attn.rotary_emb.cached_freqs

In [10]:
cached_freqs.shape

torch.Size([4096, 32])

In [11]:
cached_freqs

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0000e+00, 5.6234e-01,  ..., 3.1623e-04, 1.7783e-04,
         1.7783e-04],
        [2.0000e+00, 2.0000e+00, 1.1247e+00,  ..., 6.3246e-04, 3.5566e-04,
         3.5566e-04],
        ...,
        [4.0930e+03, 4.0930e+03, 2.3017e+03,  ..., 1.2943e+00, 7.2785e-01,
         7.2785e-01],
        [4.0940e+03, 4.0940e+03, 2.3022e+03,  ..., 1.2946e+00, 7.2803e-01,
         7.2803e-01],
        [4.0950e+03, 4.0950e+03, 2.3028e+03,  ..., 1.2950e+00, 7.2821e-01,
         7.2821e-01]], device='cuda:0')

In [12]:
cached_freqs.dtype

torch.float32

In [35]:
def rotate_half(x):
    x = rearrange(x, "... (d r) -> ... d r", r=2)
    x1, x2 = x.unbind(dim=-1)
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, "... d r -> ... (d r)")


@autocast(enabled=False)
def apply_rotary_emb(freqs: Tensor, t: Tensor, start_index:int=0, scale=1.0, seq_dim=-2):
    dtype = t.dtype
    assert t.size(seq_dim) <= freqs.size(0)

    if t.ndim == 3:
        seq_len = t.shape[seq_dim]
        freqs = freqs[-seq_len:]

    rot_dim = freqs.shape[-1]
    end_index = start_index + rot_dim

    assert (
        rot_dim <= t.shape[-1]
    ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"

    t_left, t, t_right = (
        t[..., :start_index],
        t[..., start_index:end_index],
        t[..., end_index:],
    )
    t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
    out = torch.cat((t_left, t, t_right), dim=-1)

    return out.type(dtype)


def precompute(dim: int, max_seqlen: int, theta=10000, dtype=torch.float32):
    t = torch.arange(max_seqlen, dtype=dtype)
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

    freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
    freqs = repeat(freqs, "... n -> ... (n r)", r=2)

    return freqs


def get_seq_pos(self, seq_len, device, dtype, offset=0):
    return (
        torch.arange(seq_len, device=device, dtype=dtype) + offset
    ) / self.interpolate_factor

In [36]:
lucidrains_freqs = precompute(d_head // 2, max_seqlen, dtype=torch.float32).to(device)

In [37]:
lucidrains_freqs.shape

torch.Size([4096, 32])

In [38]:
q.shape

torch.Size([1, 12, 4096, 64])

In [39]:
q_rotated = apply_rotary_emb(lucidrains_freqs, q)

In [40]:
q_rotated

tensor([[[[ 1.4490, -0.8268, -1.0788,  ...,  0.2822, -0.0291,  0.7546],
          [ 0.3623,  2.1878,  1.0621,  ...,  0.0234, -0.4768,  0.5508],
          [-0.5561,  1.6592, -0.5241,  ..., -1.2569,  1.9518, -2.2922],
          ...,
          [ 0.5861,  2.5088, -0.9223,  ...,  0.5935,  0.1566, -1.1205],
          [-0.3012, -1.2545, -0.5004,  ..., -0.6181, -2.0841,  0.3057],
          [ 0.1696,  0.4944,  0.5403,  ...,  0.5116,  0.5030, -0.4345]],

         [[-0.2307,  1.1640,  0.3676,  ...,  1.4544,  0.8257,  1.1746],
          [-1.1515, -0.5102,  0.0189,  ..., -0.5825,  1.0733,  0.3655],
          [-0.1857, -0.8186, -0.2061,  ...,  0.4051,  0.6016,  1.1461],
          ...,
          [ 0.0291, -0.4775, -0.1134,  ..., -1.1242, -1.1331, -0.5373],
          [-1.6571, -0.2224,  0.6021,  ..., -0.6981,  0.1160,  0.6479],
          [-0.7649, -0.9822,  0.2265,  ...,  0.8639,  0.7711, -0.9251]],

         [[ 0.1778,  0.7321, -0.5939,  ..., -0.2232, -0.4939, -0.3253],
          [ 1.1385,  0.2855,  

In [32]:
torch.testing.assert_close(q_rotated, q_lucidrains)