In [1]:
import sys
sys.path.append('..')
from model.modeling_buddygpt import RotaryEmbedding

In [None]:
from model.modeling_buddygpt import apply_rotary_pos_emb
import torch

rope = RotaryEmbedding(1024)
q, k = torch.randn(1, 16, 16, 1024), torch.randn(1, 16, 16, 1024) # batch_size, seq_len, head, head_dim
cos, sin = rope(q, 16)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q.shape, k.shape




In [2]:
import torch
import torch.nn.functional as F

x = torch.randn(1, 16, 16, 128)
x1, x2 = x.chunk(2, dim=-1)
print(x1.shape, x2.shape)

(x1 * F.silu(x2)).shape


torch.Size([1, 16, 16, 64]) torch.Size([1, 16, 16, 64])


torch.Size([1, 16, 16, 64])

In [13]:
import torch.nn as nn


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, seq_len=2048, theta=100000):
        super().__init__()
        self.dim = dim
        t = torch.arange(seq_len, dtype=torch.float32)
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
        inv_freq = torch.einsum('i,j->ij', t, inv_freq)
        freqs = torch.cat((inv_freq, inv_freq), dim=-1)
        self.register_buffer('inv_freq', freqs)

    def forward(self, x):
        seq_len = x.shape[1]
        cos = torch.cos(self.inv_freq[:seq_len])
        sin = torch.sin(self.inv_freq[:seq_len])
        cos = cos.to(x.device)
        sin = sin.to(x.device)
        return cos, sin
    

def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin):
    q = q * cos + rotate_half(q) * sin
    k = k * cos + rotate_half(k) * sin
    return q, k



q, k = torch.randn(1, 16, 16, 32), torch.randn(1, 16, 16, 32)
cos, sin = RotaryEmbedding(32)(q)
q0, k0 = apply_rotary_pos_emb(q, k, cos, sin)
print(q0)
print(k0)

tensor([[[[-1.1041, -2.6843, -0.8212,  ..., -1.3711, -1.2855, -0.5729],
          [-0.7209,  1.4998,  0.1335,  ..., -1.8381, -0.1215,  1.0133],
          [-1.4280, -1.6072,  2.2555,  ...,  0.1596, -0.4409, -1.5473],
          ...,
          [ 0.4730,  0.7564, -0.1985,  ..., -0.2968,  0.6742, -0.4883],
          [ 0.3063, -0.0946, -1.4471,  ...,  1.7337,  1.9658,  0.1674],
          [-0.6822, -0.9864, -0.5844,  ..., -2.5665,  0.1781, -1.8472]],

         [[ 1.5176,  0.1779,  1.2010,  ..., -0.6780,  0.2506, -1.4933],
          [-0.3756,  1.2085,  0.4143,  ...,  0.3214, -1.1774, -0.0921],
          [ 0.4268, -0.1386,  0.3934,  ...,  0.2723,  0.1576,  0.0272],
          ...,
          [ 0.9377,  0.4411,  0.9757,  ..., -0.0926, -0.7825, -0.5336],
          [ 0.8833,  1.3063, -0.0897,  ...,  1.1448,  0.8922, -1.3845],
          [ 1.2689, -0.3078, -0.4864,  ..., -0.4094,  0.8023, -0.1769]],

         [[-0.1308, -0.7072, -0.0487,  ...,  1.6820,  0.7393, -0.4771],
          [-0.8641, -0.5326,  

In [14]:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=100000, device=None):
        """ 旋转位置编码
            - dim (int): 旋转嵌入的维度大小。
            - max_position_embeddings (int): 预计算的最大位置嵌入数，默认为2048。
            - base (int): 用于计算逆频率的基本频率，默认为10000。
        """
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        # 计算逆频率值，并将其注册为模型的缓冲区
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # 为了支持`torch.jit.trace`功能，立即计算预存储的余弦和正弦缓存
        self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)

    def _set_cos_sin_cache(self, seq_len, device):
        """ 预计算的余弦和正弦缓存
        """
        self.max_seq_len_cached = seq_len
        # 创建一个从0到最大序列长度-1的整数张量，与 inv_freq 具有相同的设备和数据类型
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        # 计算每个位置与每个维度的频率，形成频谱矩阵
        # freqs = torch.outer(t, self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        
        # 不同于论文中的实现，这里采用了不同的排列方式以获得相同的计算结果
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, x, seq_len):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )



# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """ 在 qk 应用旋转位置编码

    Args:
        q (`torch.Tensor`): q
        k (`torch.Tensor`): k
        cos (`torch.Tensor`): 旋转位置嵌入的余弦部分
        sin (`torch.Tensor`): 旋转位置嵌入的正弦部分
        position_ids (`torch.Tensor`): 与q和k对应位置的标记索引。例如，在处理KV缓存时，可以使用偏移过的位置ID。
        unsqueeze_dim (`int`, *optional*, defaults to 1): 'unsqueeze_dim' 参数指定了沿哪个维度对 cos[position_ids] 
            和 sin[position_ids] 进行扩展，以便它们能够适当地广播到 q 和 k 的维度上。
            例如，注意 cos[position_ids] 和 sin[position_ids] 具有形状 [batch_size, seq_len, head_dim]。
            那么，如果 q 和 k 的形状分别为 [batch_size, heads, seq_len, head_dim]，
            则设置 unsqueeze_dim=1 可使 cos[position_ids] 和 sin[position_ids] 可以广播到 q 和 k 的形状上。
            同样地，如果 q 和 k 的形状为 [batch_size, seq_len, heads, head_dim]，则应将 unsqueeze_dim 设置为 2
    Returns:
        包含使用旋转位置嵌入变换后的q和k张量的 `tuple(torch.Tensor)`。
    """
    def rotate_half(x):
        """ 旋转输入一半的 hidden dim
        """
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    # print("ori cos: ", cos.shape)
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    # print("q: ", q.shape)
    # print("cos: ", cos.shape)
    # print("sin: ", sin.shape)
    # print("rotate_half: ", rotate_half(q).shape)
    # e^(i*theta) = cos(theta) + i*sin(theta)
    # 二维变量 x1+i*x2 旋转 theta 角度 (x1+i*x2)*(cos(theta)+i*sin(theta))
    # 旋转后 x1' = x1*cos(theta) - x2*sin(theta)
    # 旋转后 x2' = x1*sin(theta) + x2*cos(theta)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

# q, k = torch.randn(1, 16, 16, 32), torch.randn(1, 16, 16, 32)
cos, sin = RotaryEmbedding(32)(q, 16)
q1, k1 = apply_rotary_pos_emb(q, k, cos, sin)
q1, k1

(tensor([[[[-1.1041, -2.6843, -0.8212,  ..., -1.3711, -1.2855, -0.5729],
           [-0.7209,  1.4998,  0.1335,  ..., -1.8381, -0.1215,  1.0133],
           [-1.4280, -1.6072,  2.2555,  ...,  0.1596, -0.4409, -1.5473],
           ...,
           [ 0.4730,  0.7564, -0.1985,  ..., -0.2968,  0.6742, -0.4883],
           [ 0.3063, -0.0946, -1.4471,  ...,  1.7337,  1.9658,  0.1674],
           [-0.6822, -0.9864, -0.5844,  ..., -2.5665,  0.1781, -1.8472]],
 
          [[ 1.5176,  0.1779,  1.2010,  ..., -0.6780,  0.2506, -1.4933],
           [-0.3756,  1.2085,  0.4143,  ...,  0.3214, -1.1774, -0.0921],
           [ 0.4268, -0.1386,  0.3934,  ...,  0.2723,  0.1576,  0.0272],
           ...,
           [ 0.9377,  0.4411,  0.9757,  ..., -0.0926, -0.7825, -0.5336],
           [ 0.8833,  1.3063, -0.0897,  ...,  1.1448,  0.8922, -1.3845],
           [ 1.2689, -0.3078, -0.4864,  ..., -0.4094,  0.8023, -0.1769]],
 
          [[-0.1308, -0.7072, -0.0487,  ...,  1.6820,  0.7393, -0.4771],
           [-

In [15]:
import torch
import torch.nn.functional as F

logits = torch.tensor([[2.0, 0.5, 0.3]])  # 模型输出（未归一化）
target = torch.tensor([0])               # 真实类别 index

loss = F.cross_entropy(logits, target)
print(loss)

tensor(0.3406)
