In [1]:
import torch
import triton
import triton.language as tl




## 复制

In [65]:
@triton.jit
def _copy(INPUT, OUTPUT, 
          stride0, stride1, stride2, stride3,
          Z, H, M, D: tl.constexpr, 
          BLOCK_M: tl.constexpr):
    start_m = tl.program_id(0)
    # tl.static_print(start_m)
    off_hz = tl.program_id(1)
    off_h = off_hz % H
    off_z = off_hz // H

    input_offset = stride0 * off_z + stride1 * off_h
    INPUT_ptrs = tl.make_block_ptr(
        base=INPUT + input_offset,
        shape=(M,D),
        offsets=(start_m*BLOCK_M,0),
        strides=(stride2, stride3),
        block_shape=(BLOCK_M, D),
        order=(1,0)
    )

    OUTPUT_ptrs = tl.make_block_ptr(
        base=OUTPUT + input_offset,
        shape=(M,D),
        offsets=(start_m*BLOCK_M, 0),
        strides=(stride2, stride3),
        block_shape=(BLOCK_M, D),
        order=(1,0)
    )
    

    inp = tl.load(INPUT_ptrs, boundary_check=(1,0), padding_option='zero')
    # tl.cat(inp, inp, can_reorder=True)
    tl.store(OUTPUT_ptrs, inp, boundary_check=(1,0))
    

def copy(tensor):
    o = torch.empty_like(tensor)
    Z,H,M,D = tensor.shape
    BLOCK_M=4
    grid = lambda meta: (triton.cdiv(M,BLOCK_M), Z*H)
    _copy[grid](tensor, o,
            tensor.stride(0),tensor.stride(1),tensor.stride(2),tensor.stride(3),
                Z,H,M,D,
                BLOCK_M=BLOCK_M
    )

    return o

In [66]:
device = 'cuda:0'
dtype = torch.float16
z,h,m,d = 8,8,128,4
a = torch.randn(z,h,m,d, device=device, dtype=dtype)
b = copy(a)
torch.allclose(a,b)

True

# rotate_half

def

In [9]:
@triton.jit
def _rotate_half(X, Y,
                stride_b, stride_h, stride_n, stride_d,
                B,H,N,D:  tl.constexpr,
                BLOCK_N:  tl.constexpr,
                ):
    pid_bh = tl.program_id(0)
    pid_n = tl.program_id(1)
    pid_b = pid_bh // H
    pid_h = pid_bh % H

    offset = stride_b * pid_b + stride_h * pid_h
    X += offset
    Y += offset

    x_ptrs = tl.make_block_ptr(
        base=X,
        shape=(N,D),
        offsets=(BLOCK_N * pid_n, 0),
        strides=(stride_n, stride_d),
        block_shape=(BLOCK_N, D),
        order=(1,0)
    )
    y_ptrs = tl.make_block_ptr(
        base=Y,
        shape=(N,D),
        offsets=(BLOCK_N * pid_n, 0),
        strides=(stride_n, stride_d),
        block_shape=(BLOCK_N, D),
        order=(1,0)
    )

    x = tl.load(x_ptrs, boundary_check=(0,))
    y = tl.flip(x, dim=1)
    y = tl.view(y, (BLOCK_N, 2, D//2))
    y = tl.flip(y, dim=2)
    y = tl.view(y, (BLOCK_N, D))
    y = tl.where(D//2 <= tl.arange(0, D), y, -y)
    tl.store(y_ptrs, y, boundary_check=(0,))



def triton_rotate_half(x):
    B, H, N, D = x.shape
    assert D % 32 == 0
    y = torch.empty_like(x)
    BLOCK_N = min(triton.next_power_of_2(N), 64)
    grid = lambda meta: (B*H, triton.cdiv(N, BLOCK_N))
    _rotate_half[grid](x,y,
                       *x.stride(),
                       B,H,N,D,
                       BLOCK_N,
                       num_warps=8, num_stages=1

    )
    return y

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def rotate_half2(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((x2, -x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [62]:
x = torch.randn(4,64 ,128, 64).cuda().transpose(1,2)
x.shape

torch.Size([4, 128, 64, 64])

In [23]:
y1 = triton_rotate_half(x)
y2 = rotate_half(x)
torch.allclose(y1[0][0][0], y2[0][0][0])

NameError: name 'x' is not defined

In [24]:
q = torch.randn(4, 8, 128, 64)
q.requires_grad_(True)
k = torch.randn_like(q)
k.requires_grad_(True)
cos = torch.randn(4, 128, 64)
sin = torch.randn_like(cos)
y1, y2 = apply_rotary_pos_emb(q,k,cos, sin)
(y1 + y2).sum().backward()
q.grad[0][0][0]

tensor([-0.2572,  0.1183, -1.4267, -0.2122,  0.9832,  0.6093, -0.2586, -0.2049,
         3.0891, -0.2464,  0.9524,  0.7513, -2.7283, -1.1546, -1.1360, -0.9803,
        -0.2092,  0.6473,  3.5096, -0.4503, -2.0130,  0.5758,  0.7775,  0.9436,
         0.4559,  2.3994, -1.8098,  2.1824, -1.1935,  1.4013, -0.3436, -0.5028,
         1.5051, -1.1040,  0.3131,  1.1901,  1.3044, -0.9019, -1.2243, -1.1132,
         1.8181, -0.5418, -0.3229, -0.4776, -1.5151,  1.3182, -0.2117,  0.0931,
        -1.2636, -2.1016,  1.1561,  0.4012,  0.1271,  1.4285,  0.2166, -1.5709,
         0.2182, -1.3152,  1.0969, -3.0529,  1.8627,  2.0502,  0.7951,  1.5089])

In [32]:
q.grad.is_contiguous()

True

In [28]:
grad_q = (cos + rotate_half2(sin)).reshape(4, 128, 2, -1).unsqueeze(1)
grad_q[0]


tensor([[[[-2.5717e-01,  1.1831e-01, -1.4267e+00,  ...,  1.4013e+00,
           -3.4363e-01, -5.0278e-01],
          [ 1.5051e+00, -1.1040e+00,  3.1310e-01,  ...,  2.0502e+00,
            7.9514e-01,  1.5089e+00]],

         [[-1.1497e+00, -2.4273e-01, -4.7487e-01,  ...,  2.4568e+00,
           -2.3628e-02, -3.5988e+00],
          [ 6.0496e-01, -4.8472e-01,  2.6671e-01,  ..., -2.5075e+00,
           -1.0042e+00,  8.9769e-01]],

         [[ 6.8940e-01, -1.1194e+00, -1.7104e+00,  ..., -3.4532e-01,
           -3.2366e+00,  1.1030e+00],
          [ 1.2432e+00, -9.3682e-01, -1.2761e+00,  ...,  3.7805e-02,
            4.8026e-02,  7.3264e-01]],

         ...,

         [[ 2.6969e+00,  9.8474e-01,  7.4240e-01,  ..., -5.0544e-01,
            1.3220e-04,  1.3742e-01],
          [-1.5116e-01, -1.0517e-01, -7.2030e-02,  ..., -1.8224e+00,
            4.0523e-01,  5.1464e-01]],

         [[ 1.2590e+00, -1.1228e+00,  1.2757e+00,  ..., -1.1118e+00,
            5.6620e-01,  2.2423e-01],
          [-1.

In [4]:
x = torch.randn(4,5,6,8)
y = torch.randn(4,5,6,8)
x.transpose(1,2).transpose(1,2).is_contiguous()

True

In [8]:
torch.cat([x,y], axis=2).is_contiguous()

True

In [10]:
z = rotate_half(x)
z.is_contiguous()

True