In [2]:
import math
import torch
from einops import rearrange
from torch import einsum, nn

def find_correction_factor(num_rotations, dim, base=10000, max_position_embeddings=2048):
    # Inverse dim formula to find number of rotations
    return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))


def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
    low = math.floor(find_correction_factor(
        low_rot, dim, base, max_position_embeddings))
    high = math.ceil(find_correction_factor(
        high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim-1)  # Clamp values just in case


def linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func


def find_newbase_ntk(dim, base=10000, scale=1):
    return base * scale ** (dim / (dim-2))


class DynamicPartNTKScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=5, original_max_position_embeddings=5, base=10000, ntk_factor=1, extrapolation_factor=1, finetuned=False, device=None):
        super().__init__()
        self.dim = dim
        self.base = base
        self.ntk_factor = ntk_factor
        self.extrapolation_factor = extrapolation_factor
        self.max_position_embeddings = max_position_embeddings
        if finetuned:
            self.ntk(self.max_position_embeddings / original_max_position_embeddings, device)
        else:
            inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
            self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached,
                         device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len

            self.ntk(seq_len / self.max_position_embeddings, x.device)

            t = torch.arange(self.max_seq_len_cached,
                             device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.register_buffer("cos_cached", emb.cos()[
                                 None, None, :, :].to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", emb.sin()[
                                 None, None, :, :].to(x.dtype), persistent=False)
        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
    
    def ntk(self, scale, device):

        # Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
        # Do not change unless there is a good reason for doing so!
        beta_0 = 1.25
        beta_1 = 0.75
        gamma_0 = 16
        gamma_1 = 2

        # Three RoPE extrapolation/interpolation methods
        inv_freq_base = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        inv_freq_linear = 1.0 / (scale * (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)))
        inv_freq_ntk = 1.0 / (find_newbase_ntk(self.dim, self.base, scale)
                              ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))

        current_dtype = inv_freq_ntk.dtype
        current_device = inv_freq_ntk.device

        # Combine NTK and Linear
        low, high = find_correction_range(
            beta_0, beta_1, self.dim, self.base, self.max_position_embeddings)
        inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).type(current_dtype).to(current_device)) * self.ntk_factor
        inv_freq = inv_freq_linear * (1 - inv_freq_mask) + inv_freq_ntk * inv_freq_mask

        # Combine Extrapolation and NTK and Linear
        low, high = find_correction_range(
            gamma_0, gamma_1, self.dim, self.base, self.max_position_embeddings)
        inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).type(current_dtype).to(current_device)) * self.extrapolation_factor
        inv_freq = inv_freq * (1 - inv_freq_mask) + inv_freq_base * inv_freq_mask

        self.register_buffer("inv_freq", inv_freq)

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed



# (bs, head, length, dim)
q = torch.randn((2, 12, 10, 32))  # q=[q0, q1, .., qd-1]
k = torch.randn((2, 12, 10, 32))
v = torch.randn((2, 12, 10, 32))
position_ids = torch.tensor(list(range(0, 10))).unsqueeze(0).repeat(2, 1)
print('q:', q[0][0][0])
print('k:', k[0][0][0])
rotary_emb = DynamicPartNTKScaledRotaryEmbedding(dim=32,max_position_embeddings=10)
cos, sin = rotary_emb(v, seq_len=10)
q_new, k_new = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
print('q_new: ', q_new[0][0][0])
print('k_new: ', k_new[0][0][0])

q: tensor([-3.1032, -0.2769,  2.2583,  1.0565,  0.7202,  0.1029, -1.9237, -0.5518,
        -1.7343, -0.4385, -0.4145, -2.4887,  0.3950, -0.1805,  1.4258, -1.2152,
         0.0367, -1.6281,  0.4144,  0.0994, -0.9635,  1.0568,  0.2574, -0.1623,
        -1.3715, -0.1412,  0.5866,  0.1228, -0.2036, -0.3358,  0.5903, -0.4167])
k: tensor([ 1.6162e+00,  6.8729e-01, -1.1949e+00,  1.8411e+00, -4.9815e-01,
         6.7064e-01,  8.6041e-01,  4.7778e-01,  9.4828e-01, -1.2916e+00,
         1.6296e-01,  1.0296e+00,  7.9215e-01,  4.7766e-01,  6.1852e-01,
        -2.1944e-01, -3.2050e-01, -5.8963e-01, -1.5785e-03,  3.2309e-01,
        -1.0185e+00,  7.7100e-01,  1.6684e+00, -1.5183e-01, -8.4329e-01,
         7.7285e-01,  2.8210e-01, -5.9170e-02,  9.7809e-01, -5.1536e-02,
         4.2023e-01, -2.3006e-01])
q_new:  tensor([-3.1032, -0.2769,  2.2583,  1.0565,  0.7202,  0.1029, -1.9237, -0.5518,
        -1.7343, -0.4385, -0.4145, -2.4887,  0.3950, -0.1805,  1.4258, -1.2152,
         0.0367, -1.6281,  0.414

In [5]:
def get_mscale(scale=1):
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0


class YaRNScaledRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=5, base=10000, scale=1, original_max_position_embeddings=2048, extrapolation_factor=1, attn_factor=1, beta_fast=32, beta_slow=1, finetuned=False, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.scale = scale
        self.original_max_position_embeddings = original_max_position_embeddings
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow

        self.yarn(device)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        dtype = torch.get_default_dtype()

        self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False)
        self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len

            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

            self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(x.dtype), persistent=False)
            self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(x.dtype), persistent=False)
        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

    def yarn(self, device):
        pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (self.scale * pos_freqs)

        low, high = find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
        inv_freq_mask = (1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
        inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

        self.register_buffer("inv_freq", inv_freq)
        self.mscale = float(get_mscale(self.scale) * self.attn_factor) # Get n-d magnitude scaling corrected for interpolation


# (bs, head, length, dim)
q = torch.randn((2, 12, 10, 32))  # q=[q0, q1, .., qd-1]
k = torch.randn((2, 12, 10, 32))
v = torch.randn((2, 12, 10, 32))
position_ids = torch.tensor(list(range(0, 10))).unsqueeze(0).repeat(2, 1)
print('q:', q[0][0][0])
print('k:', k[0][0][0])
rotary_emb = YaRNScaledRotaryEmbedding(dim=32,max_position_embeddings=10)
cos, sin = rotary_emb(v, seq_len=10)
q_new, k_new = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
print('q_new: ', q_new[0][0][0])
print('k_new: ', k_new[0][0][0])

q: tensor([ 0.4122, -0.9974, -0.4258,  0.8649,  0.2575,  0.9717, -1.8156, -0.1367,
         1.6512, -1.3428,  0.2563,  1.1431,  0.3470,  0.1834,  0.7185, -0.8874,
         0.5366,  0.4353, -0.9806, -0.9068,  1.0905,  0.1720,  0.8175,  1.0371,
        -0.1976,  1.3093, -0.9622, -0.7057, -0.8805, -1.0351, -0.5688, -0.2100])
k: tensor([-0.1960, -0.1691,  0.5125,  0.7086, -0.2739,  0.5460,  0.1985, -0.6967,
        -0.6941, -0.2064,  1.7953, -0.2109,  2.0027,  0.8383,  1.4107, -1.8471,
         1.2018, -0.2565,  1.5268, -1.7817,  0.9530,  0.5201,  0.3590,  0.7296,
        -0.0674,  0.4554,  0.4298,  0.7931,  0.5609, -0.2021,  0.0628,  0.2346])
q_new:  tensor([ 0.4122, -0.9974, -0.4258,  0.8649,  0.2575,  0.9717, -1.8156, -0.1367,
         1.6512, -1.3428,  0.2563,  1.1431,  0.3470,  0.1834,  0.7185, -0.8874,
         0.5366,  0.4353, -0.9806, -0.9068,  1.0905,  0.1720,  0.8175,  1.0371,
        -0.1976,  1.3093, -0.9622, -0.7057, -0.8805, -1.0351, -0.5688, -0.2100])
k_new:  tensor([-0.1960