In [None]:
import torch
import numpy as np
from typing import Tuple

In [None]:
# PyTorch implementation of apply_scaling
def apply_scaling(freqs: torch.Tensor):
    # Values obtained from grid search
    scale_factor = 8
    low_freq_factor = 1
    high_freq_factor = 4
    old_context_len = 8192  # original llama3 length

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor
    new_freqs = []
    for freq in freqs:
        wavelen = 2 * np.pi / freq
        if wavelen < high_freq_wavelen:
            new_freqs.append(freq)
        elif wavelen > low_freq_wavelen:
            new_freqs.append(freq / scale_factor)
        else:
            assert low_freq_wavelen != high_freq_wavelen
            smooth = (old_context_len / wavelen - low_freq_factor) / (
                high_freq_factor - low_freq_factor
            )
            new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
    return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
    dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False
):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    if use_scaled:
        freqs = apply_scaling(freqs)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

## Verify `apply_scaling`

In [None]:
# # Input tensor
# # freqs_np = np.array([0.1, 1.0, 10.0, 100.0, 1000.0], dtype=np.float32)
# freqs_np = np.array([1.0000, 0.9822, 0.9647, 0.9475, 0.9306, 0.9140, 0.8977, 0.8817, 0.8660,
#         0.8505], dtype=np.float32)

# freqs_torch = torch.tensor(freqs_np)
# scaled_freqs_torch = apply_scaling(freqs_torch).numpy()

# print("PyTorch output:", scaled_freqs_torch)


## Verify `precompute_freqs_cis`

In [None]:
# # Input parameters for testing
# dim = 16
# end = 10
# theta = 10000.0
# use_scaled = True

# # PyTorch version
# freqs_cis_torch = precompute_freqs_cis(dim, end, theta, use_scaled).numpy()

# # Compare outputs
# # print("PyTorch output:", freqs_cis_torch)
# print(freqs_cis_torch.shape )


## Verify `reshape_for_broadcast` 

In [None]:
# # Input tensors for testing
# np.random.seed(42)
# x_shape = (2, 4, 8)
# x_np = np.random.randn(*x_shape).astype(np.float32)
# freqs_cis_np = np.random.randn(x_shape[1], x_shape[-1]).astype(np.complex64)
# print(x_np)
# print(freqs_cis_np)

# x_torch = torch.tensor(x_np)
# freqs_cis_torch = torch.tensor(freqs_cis_np)

# reshaped_torch = reshape_for_broadcast(freqs_cis_torch, x_torch).numpy()

# print("PyTorch output:", reshaped_torch)

## Verify `apply_rotary_emb` 

In [None]:
# np.random.seed(42)

# # Input tensors for testing
# xq_shape = (1, 10, 4, 16)  # (batch_size, seq_len, num_heads, dim)
# xk_shape = (1, 10, 4, 16)

# xq_np = np.random.randn(*xq_shape).astype(np.float32)
# xk_np = np.random.randn(*xk_shape).astype(np.float32)
# freqs_cis_np = np.random.randn(xq_shape[1], xq_shape[-1]//2).astype(np.complex64)

# xq_torch = torch.tensor(xq_np)
# xk_torch = torch.tensor(xk_np)
# freqs_cis_torch = torch.tensor(freqs_cis_np)

# # PyTorch version
# xq_out_torch, xk_out_torch = apply_rotary_emb(xq_torch, xk_torch, freqs_cis_torch)

# # Compare outputs
# print("PyTorch xq output:", xq_out_torch)
# print("PyTorch xk output:", xk_out_torch)


In [None]:
ones = torch.ones((10, 10))
ones_tril = torch.tril(ones).unsqueeze(0).unsqueeze(0)
print(ones.shape)
print(ones_tril.shape)
ones_tril

In [None]:
indices = torch.tensor(list(range(3)))
ones_tril.index_select(2, indices)