In [None]:
from functools import partial
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def compute_mixed_cis(
    freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int
):
    N = t_x.shape[0]
    depth = freqs.shape[1]
    # No float 16 for this range
    with torch.cuda.amp.autocast(enabled=False):
        freqs_x = (
            (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2))
            .view(depth, N, num_heads, -1)
            .permute(0, 2, 1, 3)  # (depth, num_heads, N, dim)
        )
        freqs_y = (
            (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2))
            .view(depth, N, num_heads, -1)
            .permute(0, 2, 1, 3)  # (depth, num_heads, N, dim)
        )
        freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)

    return freqs_cis


def reshape_for_broadcast_old(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]

    return freqs_cis.view(*shape)


def apply_rotary_emb_old(
    xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast_old(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)


def reshape_for_broadcast_new(
    freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int
):
    """
    Reshape frequency tensor for broadcasting it with another tensor.

    This function reshapes the frequency tensor to have the same shape as the
    target tensor 'x' for the purpose of broadcasting the frequency tensor
    during element-wise operations.

    Args:
        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
        x (torch.Tensor): Target tensor for broadcasting compatibility.
        seq_dim (int): Sequence dimension index.

    Returns:
        torch.Tensor: Reshaped frequency tensor.
    """
    ndim = x.ndim
    assert 0 <= seq_dim < ndim
    assert freqs_cis.shape == (
        x.shape[seq_dim],
        x.shape[-3],
        2,
        2,
    ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
    shape = [
        d if i == seq_dim or i == ndim - 3 else 1
        for i, d in enumerate(x.shape[:-2])
    ] + [2, 2]
    return freqs_cis.view(*shape)


def apply_rotary_emb_new(
    xq: torch.Tensor,
    xk: torch.Tensor,
    seq_dim: int,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)  # B S H D -> B S H D/2 1 2
    xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)  # B S H D -> B S H D/2 1 2
    freqs_cis = reshape_for_broadcast_new(
        freqs_cis, xq_, seq_dim
    ).float()  # S D/2 2 2 -> 1 S 1 D/2 2 2
    xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
    xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def init_t_xy(end_x: int, end_y: int):
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode="floor").float()
    return t_x, t_y

In [None]:
t_x, t_y = init_t_xy(3, 3)

In [None]:
theta = 100
dim = 64
mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
angles = torch.zeros(1)
fx = torch.cat(
    [mag * torch.cos(angles), mag * torch.cos(torch.pi / 2 + angles)], dim=-1
)
(t_x.unsqueeze(-1) @ fx.unsqueeze(-2)).view(3, 3, -1).permute(1, 0, 2)

In [None]:
mag

In [None]:
fx[: len(fx) // 2]

In [None]:
fx[len(fx) // 2 :]

In [None]:
mag

In [None]:
fx.shape

In [None]:
t_x

In [None]:
fx

In [None]:
def init_random_2d_freqs_new(dim: int, end: int, theta: float = 100.0):
    assert dim % 4 == 0, "dim must be divisible by 4"
    # the // 2 is because we will be decomposing each pair of numbers as real
    # and imaginary parts, and use the same magnitude for both
    # We use // 4 and 100 theta because we are reducing the amount of positions
    # by sqrt.
    mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    indices = torch.arange(end, device=mag.device)
    freqs = torch.outer(indices, mag).float()

    # angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)

    cos, sin = torch.cos(freqs), torch.sin(freqs)
    freqs = torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.shape, 2, 2)

    return freqs


def init_random_2d_freqs_old(
    dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True
):
    freqs_x = []
    freqs_y = []
    mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    for i in range(num_heads):
        angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
        fx = torch.cat(
            [mag * torch.cos(angles), mag * torch.cos(torch.pi / 2 + angles)],
            dim=-1,
        )
        fy = torch.cat(
            [mag * torch.sin(angles), mag * torch.sin(torch.pi / 2 + angles)],
            dim=-1,
        )
        freqs_x.append(fx)
        freqs_y.append(fy)
    freqs_x = torch.stack(freqs_x, dim=0)
    freqs_y = torch.stack(freqs_y, dim=0)
    freqs = torch.stack([freqs_x, freqs_y], dim=0)
    return freqs

In [None]:
new_freq = init_random_2d_freqs_new(10, 10, 10000)

In [None]:
new_freq.shape

In [None]:
t_x

In [None]:
t_x.shape

In [None]:
torch.stack(
    (new_freq[t_x.long(), :, 1, :], new_freq[t_y.long(), :, 0, :]), dim=2
).shape

In [None]:
freqs = init_random_2d_freqs_old(16, 3)

In [None]:
freqs.shape

In [None]:
freqs = init_random_2d_freqs_old(16, 3)
freqs.shape

In [None]:
freqs.shape

In [None]:
freqs = init_random_2d_freqs_new(16, 10)
freqs.shape

In [None]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given
    dimensions.

    This function calculates a frequency tensor with complex exponentials using
    the given dimension 'dim' and the end index 'end'. The 'theta' parameter
    scales the frequencies. The returned tensor contains complex values in
    complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation.
            Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.
    """
    freqs = 1.0 / (
        theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
    )
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()

    cos, sin = freqs.cos(), freqs.sin()

    return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)

In [None]:
freqs = []
for i, _ in enumerate([1]):
    freqs.append(init_random_2d_freqs_old(dim=16, num_heads=3, theta=10000))
freqs = torch.stack(freqs, dim=1).view(2, len([1]), -1)

# freqs = init_random_2d_freqs(16, 3, 10000, False)
t_x, t_y = init_t_xy(5, 2)
freqs_cis = compute_mixed_cis(freqs, t_x, t_y, num_heads=3)

freqs_cis = freqs_cis.squeeze(0)  # assuing one head

In [None]:
x = torch.

In [None]:
freqs.shape

In [None]:
freqs_cis

In [None]:
freqs

In [None]:
freqs.shape

In [None]:
freqs_cis.shape

In [None]:
freqs_cis

In [None]:
freqs_2 = precompute_freqs_cis(16, 10, 10000)

In [None]:
freqs_2.shape

In [None]:
freqs_2

In [None]:
freqs

In [None]:
theta = 10000.0
dim = 4
end = 3

freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()

cos, sin = freqs.cos(), freqs.sin()

freq_new = torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)

In [None]:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()

freq_ = torch.polar(torch.ones_like(freqs), freqs)

In [None]:
x = torch.randn(1, 3, dim)

In [None]:
x

In [None]:
xq_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))

In [None]:
xq_

In [None]:
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim - 3 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

In [None]:
freq_

In [None]:
freq_ready = reshape_for_broadcast(freq_, xq_)

In [None]:
freq_ready

In [None]:
freqs_cis

In [None]:
torch.view_as_real(freq_ready * xq_).flatten(3)

In [None]:
x_temp = x.unsqueeze(2)
x_temp.shape

In [None]:
def reshape_for_broadcast(
    freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int
):
    ndim = x.ndim
    assert 0 <= seq_dim < ndim
    assert freqs_cis.shape == (
        x.shape[seq_dim],
        x.shape[-3],
        2,
        2,
    ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
    shape = [
        d if i == seq_dim or i == ndim - 3 else 1
        for i, d in enumerate(x.shape[:-2])
    ] + [2, 2]
    return freqs_cis.view(*shape)


xq_ = x_temp.reshape(*x_temp.shape[:-1], -1, 1, 2)  # B S H D -> B S H D/2 1 2
freqs_cis = reshape_for_broadcast(
    freq_new, xq_, 1
).float()  # S D/2 2 2 -> 1 S 1 D/2 2 2
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)