In [None]:
"""
# RoPE 시각화 노트북

이 노트북은 Hugging Face `transformers`가 제공하는 RoPE 초기화 유틸을 사용해서
`position`(가로축)과 `hidden dim`(세로축) 기준의 heatmap을 그립니다.

리팩토링 목표

- 기본 RoPE(`default`)만 바로 실행 가능하게 제공
- 다른 RoPE 변형으로 바꾸기 쉽도록 구조를 정리
- 함수 내부에 함수를 정의하지 않음
- 유사 기능을 묶어서 재사용 가능하게 구성
"""

import math
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple, List

import matplotlib.pyplot as plt
import torch
from transformers import ROPE_INIT_FUNCTIONS


@dataclass
class RopeConfig:
    rope_theta: float = 10000.0
    hidden_size: int = 4096
    num_attention_heads: int = 32
    head_dim: Optional[int] = None
    partial_rotary_factor: float = 1.0
    max_position_embeddings: int = 8192

    # transformers의 RoPE 스케일링 설정과 동일한 형태의 dict를 넣습니다.
    # 기본은 default 입니다.
    rope_scaling: Dict[str, Any] = field(default_factory=lambda: {"rope_type": "default"})


@dataclass
class RopeMethod:
    name: str
    rope_type: str
    cfg: RopeConfig
    # 일부 rope_type은 seq_len을 요구합니다. 기본(default)은 보통 None이어도 됩니다.
    seq_len: Optional[int] = None


## theta, cos, sin
def build_position_emb(
    inv_freq: torch.Tensor,  #
    L: int,
    attention_factor: float,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    # inv_freq = 1 / freq (base)
    # freq = 1 / period
    # inv_freq == period (if period larger, required time for one wave (period) becomes larger...)
    # however large period means that angular speed become smaller (paradox)
    inv_freq_expanded = inv_freq[:, None].float()  # (D/2, 1)
    position_ids = torch.arange(L, device=device, dtype=dtype)[None, :].float()  # (1, L)

    phase = (inv_freq_expanded @ position_ids).transpose(0, 1)  # (D/2, L) -> (L, D/2)
    phase = torch.cat([phase, phase], dim=-1)  #  (L, D)
    return phase, phase.cos() * attention_factor, phase.sin() * attention_factor


def postprocess_theta(theta: torch.Tensor, mode: str) -> torch.Tensor:
    # mode:
    # - "raw": 그대로
    # - "phase_0_2pi": 0 이상 2파이 미만으로 mod
    # - "phase_negpi_pi": -파이 초과 파이 이하 범위로 매핑
    if mode == "raw":
        return theta

    two_pi = 2.0 * math.pi
    phase = torch.remainder(theta, two_pi)

    if mode == "phase_0_2pi":
        return phase

    if mode == "phase_negpi_pi":
        # [0, 2pi) -> (-pi, pi]
        return (phase + math.pi) % two_pi - math.pi

    raise ValueError(f"알 수 없는 theta mode: {mode}")


def restore_to_full_dim(
    rotary_mat: torch.Tensor,
    *,
    head_dim: int,
    rotary_dim: int,
    kind: str,
    device: torch.device,
    dtype: torch.dtype,
) -> torch.Tensor:
    """
    partial 외의 부분 0도 강제 입력 (시각화를 위해)
    - cos: 1
    - sin: 0
    - theta: 0
    """

    L = int(rotary_mat.shape[0])

    if rotary_dim == head_dim:
        return rotary_mat

    if kind == "cos":
        full = torch.ones((L, head_dim), device=device, dtype=dtype)
    elif kind == "sin":
        full = torch.zeros((L, head_dim), device=device, dtype=dtype)
    elif kind == "theta":
        full = torch.zeros((L, head_dim), device=device, dtype=dtype)
    else:
        raise ValueError(f"알 수 없는 kind: {kind}")

    full[:, :rotary_dim] = rotary_mat[:, :rotary_dim]
    return full


## 4. heatmap 그리기
"""
matrix: (L, D) >> visualize: (D, L)

"""


def compute_color_limits(
    mats: Dict[str, torch.Tensor],
    *,
    symmetric: bool,
) -> Tuple[float, float]:
    vmin = min(float(m.min().item()) for m in mats.values())
    vmax = max(float(m.max().item()) for m in mats.values())
    if symmetric:
        v = max(abs(vmin), abs(vmax))
        return -v, v
    return vmin, vmax


def plot_heatmap_grid(
    mats: Dict[str, torch.Tensor],
    *,
    title: str,
    vmin: float,
    vmax: float,
    cmap: str = "viridis",
    grid_cols: int = 3,
    origin: str = "lower",
) -> None:
    names = list(mats.keys())
    n = len(names)
    cols = max(1, int(grid_cols))
    rows = (n + cols - 1) // cols

    fig, axes = plt.subplots(
        rows,
        cols,
        figsize=(cols * 5.0, rows * 4.0),
        constrained_layout=True,
    )

    # axes를 2차원 리스트 형태로 통일
    if rows == 1 and cols == 1:
        axes = [[axes]]
    elif rows == 1:
        axes = [list(axes)]
    elif cols == 1:
        axes = [[ax] for ax in axes]
    else:
        axes = [list(r) for r in axes]

    im = None
    for i, name in enumerate(names):
        r = i // cols
        c = i % cols
        ax = axes[r][c]

        mat = mats[name].T  # (D, L)
        im = ax.imshow(
            mat.numpy(),
            aspect="auto",
            origin=origin,
            vmin=vmin,
            vmax=vmax,
            cmap=cmap,
        )
        ax.set_title(name, fontsize=10)
        ax.set_xlabel("position")
        ax.set_ylabel("hidden dim")

    # 남는 칸 숨기기
    for j in range(n, rows * cols):
        r = j // cols
        c = j % cols
        axes[r][c].axis("off")

    fig.suptitle(title, fontsize=12)

    # 공용 컬러바
    if im is not None:
        fig.colorbar(im, ax=[ax for row in axes for ax in row if ax.has_data()], shrink=0.85)

    plt.show()


## 5. RoPE 방법별 행렬 만들기
"""
현재는 `default`만 준비합니다.

다른 방법으로 바꾸려면 `methods` 리스트에 `RopeMethod`를 추가하면 됩니다.
추가 시에도 HF `ROPE_INIT_FUNCTIONS`를 그대로 사용합니다.
"""


def build_rotary_matrices(
    methods: List[RopeMethod],
    *,
    kind: str,
    L: int,
    head_dim_show: int,
    rotary_dim_show: int,
    theta_mode: str,
    device: torch.device,
    dtype: torch.dtype,
) -> Dict[str, torch.Tensor]:
    matrices: Dict[str, torch.Tensor] = {}

    # Actually, it is not "inversed" frequency in physics.
    # inversed frequency in physics equals period, but it is contradictory with its purpose (high value means fast angular speed)
    # Considering its use, "angular_speed" is required.

    for m in methods:
        # get angular speed
        inv_freq, attn_factor = ROPE_INIT_FUNCTIONS[m.rope_type](cfg, device, seq_len=m.seq_len)
        inv_freq = inv_freq.detach().to("cpu")
        attn_factor = float(attn_factor)

        theta, cos, sin = build_position_emb(inv_freq, L=L, attention_factor=attn_factor, device=device, dtype=dtype)

        if kind == "theta":
            theta = postprocess_theta(theta, mode=theta_mode)
            rotary_mat = theta
        elif kind == "cos":
            rotary_mat = cos
        elif kind == "sin":
            rotary_mat = sin
        else:
            raise ValueError(f"알 수 없는 kind: {kind}")

        full = restore_to_full_dim(
            rotary_mat,
            head_dim=head_dim_show // 2,
            rotary_dim=rotary_dim_show,
            kind=kind,
            device=device,
            dtype=dtype,
        )
        matrices[m.name] = full.detach().cpu()

    return matrices


def plot_kind(
    methods: List[RopeMethod],
    *,
    kind: str,
    L: int,
    head_dim: int,
    rotary_dim: int,
    rope_theta: float,
    partial_rotary_factor: float,
    theta_mode: str,
    global_scale: bool = True,
    symmetric_scale: bool = True,
    cmap: str = "viridis",
    origin: str = "lower",
    grid_cols: int = 3,
    device: torch.device,
    dtype: torch.dtype,
) -> None:
    matrices = build_rotary_matrices(
        methods,
        kind=kind,
        L=L,
        head_dim_show=head_dim,
        rotary_dim_show=rotary_dim,
        theta_mode=theta_mode,
        device=device,
        dtype=dtype,
    )

    # 현재 구현은 비교 목적이므로 global_scale이 꺼져도 동일한 계산을 하되,
    # 옵션은 유지해서 확장 시 자연스럽게 바꿀 수 있도록 둡니다.
    symmetric = bool(symmetric_scale) if kind in ("cos", "sin") else False
    vmin, vmax = compute_color_limits(matrices, symmetric=symmetric)

    info = (
        f"head_dim={head_dim}, "
        f"rotary_dim={rotary_dim}, "
        f"L={L}, rope_theta={rope_theta}, partial_rotary_factor={partial_rotary_factor}"
    )
    title = f"{kind} | {info}"

    plot_heatmap_grid(
        mats=matrices,
        title=title,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
        grid_cols=grid_cols,
        origin=origin,
    )


## 6. 실행 파라미터
"""
아래 셀만 수정하면 됩니다.
지금은 기본 RoPE(default)만 실행합니다.
"""
# 실행 파라미터
device = torch.device("cpu")  # "cuda"도 가능
dtype = torch.float32

# 모델 형상
rope_theta = 10000.0
hidden_size = 4096
num_attention_heads = 32
head_dim = None  # None이면 hidden_size / num_heads
partial_rotary_factor = 0.9  # 0 <= f <= 1

# 시각화 범위
L = 8192 * 4

# theta 표시 방식
theta_mode = "phase_negpi_pi"  # "raw", "phase_0_2pi", "phase_negpi_pi"

# plot 옵션
cmap = "viridis"
origin = "lower"
grid_cols = 2


assert hidden_size % num_attention_heads == 0, ValueError(
    f"hidden_size({hidden_size})가 num_attention_heads({num_attention_heads})로 나누어 떨어지지 않습니다."
)
assert 0 <= partial_rotary_factor <= 1

# head에 따라 자동 계산
resolved_head_dim = head_dim if head_dim is not None else int(hidden_size // num_attention_heads)
rotary_dim = int(resolved_head_dim * float(partial_rotary_factor))
resolved_rotary_dim = 1 + (rotary_dim // 2) if rotary_dim % 2 != 0 else rotary_dim // 2


# 기본 config
cfg = RopeConfig(
    rope_theta=rope_theta,
    hidden_size=hidden_size,
    num_attention_heads=num_attention_heads,
    head_dim=resolved_head_dim,
    partial_rotary_factor=partial_rotary_factor,
    max_position_embeddings=L,  # 기본은 크게 중요하지 않지만, 형태를 맞춥니다.
    rope_scaling={"rope_type": "default"},
)


methods: List[RopeMethod] = [
    RopeMethod(
        name="default",
        rope_type="default",
        cfg=cfg,
        seq_len=None,
    )
]

print("사용 가능한 rope_type:", ", ".join(sorted(ROPE_INIT_FUNCTIONS.keys())))
print(f"head_dim={resolved_head_dim}, rotary_dim={resolved_rotary_dim}")

# cos, sin, theta를 각각 heatmap으로 확인합니다.
plot_kind(
    methods,
    kind="cos",
    L=L,
    head_dim=resolved_head_dim,
    rotary_dim=resolved_rotary_dim,
    rope_theta=rope_theta,
    partial_rotary_factor=partial_rotary_factor,
    theta_mode=theta_mode,
    global_scale=True,
    symmetric_scale=True,
    cmap=cmap,
    origin=origin,
    grid_cols=grid_cols,
    device=device,
    dtype=dtype,
)

plot_kind(
    methods,
    kind="sin",
    L=L,
    head_dim=resolved_head_dim,
    rotary_dim=resolved_rotary_dim,
    rope_theta=rope_theta,
    partial_rotary_factor=partial_rotary_factor,
    theta_mode=theta_mode,
    global_scale=True,
    symmetric_scale=True,
    cmap=cmap,
    origin=origin,
    grid_cols=grid_cols,
    device=device,
    dtype=dtype,
)

# plot_kind(
#     methods,
#     kind="theta",
#     L=L,
#     head_dim=resolved_head_dim,
#     rotary_dim=resolved_rotary_dim,
#     rope_theta=rope_theta,
#     partial_rotary_factor=partial_rotary_factor,
#     theta_mode=theta_mode,
#     global_scale=True,
#     symmetric_scale=False,
#     cmap=cmap,
#     origin=origin,
#     grid_cols=grid_cols,
#     device=device,
#     dtype=dtype,
# )


## 8. 다른 RoPE 방식으로 확장하는 방법
"""
아래 예시는 실행하지 않는 예시 코드입니다.
`methods`에 `RopeMethod`를 추가하는 방식으로 확장할 수 있습니다.

주의

- transformers 버전에 따라 지원 rope_type이 다를 수 있습니다.
- 어떤 rope_type은 `seq_len` 또는 `cfg.max_position_embeddings` 같은 값이 중요합니다.
"""
# 예시 (실행하지 마세요)
#
# cfg_linear = RopeConfig(
#     rope_theta=rope_theta,
#     hidden_size=hidden_size,
#     num_attention_heads=num_heads,
#     head_dim=head_dim,
#     partial_rotary_factor=partial_rotary_factor,
#     max_position_embeddings=L,
#     rope_scaling={"rope_type": "linear", "factor": 4.0},
# )
#
# methods = [
#     RopeMethod(name="default", rope_type="default", cfg=cfg),
#     RopeMethod(name="linear(f=4.0)", rope_type="linear", cfg=cfg_linear),
# ]
#
# plot_kind(... kind="cos" ...)
