Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]Add fused rotary kernel and get cos cache kernel #5302

Merged
merged 4 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from .context_attn_unpad import context_attention_unpadded
from .flash_decoding import flash_decoding_attention
from .flash_decoding_utils import FDIntermTensors

from .rms_layernorm import rms_layernorm
from .fused_rotary_embedding import fused_rotary_embedding
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
from .no_pad_rotary_embedding import rotary_embedding
from .rms_layernorm import rms_layernorm
from .rotary_cache_copy import get_xine_cache
from .softmax import softmax

__all__ = [
Expand All @@ -27,4 +28,6 @@
"gptq_fused_linear_triton",
"rotary_embedding",
"FDIntermTensors",
"fused_rotary_embedding",
"get_xine_cache",
]
182 changes: 182 additions & 0 deletions colossalai/kernel/triton/fused_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import torch
import triton
import triton.language as tl


@triton.jit
def fused_rotary_emb(
q,
k,
cos_cache,
sin_cache,
cumsum_lengths,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_dim_stride,
q_total_tokens,
Q_HEAD_NUM: tl.constexpr,
K_HEAD_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
N_ELEMENTS: tl.constexpr,
):
block_head_index = tl.program_id(0)
block_group_index = tl.program_id(1)
group_token_index = tl.program_id(2)
idx = block_group_index * BLOCK_SIZE + group_token_index

# original seq_idx and pos
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
cos = tl.load(
cos_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride
) # [1,HEAD_DIM//2]
sin = tl.load(sin_cache + ori_seq_idx * cos_token_stride + tl.arange(0, HEAD_DIM // 2) * cos_dim_stride)

cur_head_range = block_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)

off_q0 = (
idx * q_token_stride
+ cur_head_range[None, :, None] * q_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_q1 = (
idx * q_token_stride
+ cur_head_range[None, :, None] * q_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)

off_k0 = (
idx * k_token_stride
+ cur_head_range[None, :, None] * k_head_stride
+ dim_range0[None, None, :] * head_dim_stride
)
off_k1 = (
idx * q_token_stride
+ cur_head_range[None, :, None] * k_head_stride
+ dim_range1[None, None, :] * head_dim_stride
)

q_0 = tl.load(
q + off_q0,
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
other=0.0,
)

q_1 = tl.load(
q + off_q1,
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
other=0.0,
)

k_0 = tl.load(
k + off_k0,
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
other=0.0,
)

k_1 = tl.load(
k + off_k1,
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
other=0.0,
)

out_q0 = q_0 * cos - q_1 * sin
out_q1 = k_0 * sin + k_1 * cos

out_k0 = q_0 * cos - q_1 * sin
out_k1 = k_0 * sin + k_1 * cos
# concat
tl.store(
q + off_q0,
out_q0,
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
)
tl.store(
q + off_q1,
out_q1,
mask=((cur_head_range[None, :, None] < Q_HEAD_NUM) & (idx < q_total_tokens)),
)

tl.store(
k + off_k0,
out_k0,
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
)
tl.store(
k + off_k1,
out_k1,
mask=((cur_head_range[None, :, None] < K_HEAD_NUM) & (idx < q_total_tokens)),
)


@torch.no_grad()
def fused_rotary_embedding(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
lengths,
):
"""
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
lengths [num_seqs]
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_SIZE = 16
cumsum_lens = torch.cumsum(lengths, dim=0)

grid = (triton.cdiv(q_head_num, BLOCK_HEAD), triton.cdiv(q_total_tokens, BLOCK_SIZE), BLOCK_SIZE)

if head_dim >= 128:
num_warps = 8
else:
num_warps = 4

q_token_stride = q.stride(0)
q_head_stride = q.stride(1)
head_dim_stride = q.stride(2)

k_token_stride = k.stride(0)
k_head_stride = k.stride(1)

k_head_num = q.shape[1]

cos_token_stride = cos.stride(0)
cos_dim_stride = cos.stride(1)

fused_rotary_emb[grid](
q,
k,
cos,
sin,
cumsum_lens,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_dim_stride,
q_total_tokens,
Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num,
HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SIZE=BLOCK_SIZE,
N_ELEMENTS=triton.next_power_of_2(q_total_tokens),
num_warps=num_warps,
)
7 changes: 4 additions & 3 deletions colossalai/kernel/triton/no_pad_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,12 @@ def rotary_embedding(
Args:
q: query tensor, [total_tokens, head_num, head_dim]
k: key tensor, [total_tokens, head_num, head_dim]
cos: cosine for rotary embedding, [total_tokens, head_dim]
sin: sine for rotary embedding, [total_tokens, head_dim]
cos: cosine for rotary embedding, [max_position_len, head_dim]
sin: sine for rotary embedding, [max_position_len, head_dim]
lengths [num_seqs]
"""
q_total_tokens, q_head_num, head_dim = q.shape
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 8
grid = (triton.cdiv(q_head_num, BLOCK_HEAD), 2 * triton.cdiv(q_total_tokens, BLOCK_TOKENS))
Expand Down
110 changes: 110 additions & 0 deletions colossalai/kernel/triton/rotary_cache_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import triton
import triton.language as tl


@triton.jit
def prefill_cache_kernel(
CaChe,
cumsum_lengths,
output,
cache_stride,
hidden_stride,
total_length,
HIDDEN_DIM: tl.constexpr,
N_ELEMENTS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
idx0 = tl.program_id(axis=0)
idx1 = tl.program_id(axis=1)
idx = idx0 * BLOCK_SIZE + idx1

# original seq_idx and pos
cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))
ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))
_cache = tl.load(CaChe + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride)
tl.store(output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, _cache, mask=idx < total_length)


@triton.jit
def decoding_cache_kernel(
CaChe,
lengths,
output,
cache_stride,
hidden_stride,
HIDDEN_DIM: tl.constexpr,
NUM_SEQS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,]
_cache = tl.load(CaChe + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride)
tl.store(
output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),
_cache,
mask=idx[:, None] < NUM_SEQS,
)


@torch.no_grad()
def get_xine_cache(lengths: torch.Tensor, cache: torch.Tensor, is_prompts: bool = False):
"""
Transform cos/sin cache into no pad sequence, with two different modes.
Args:
lengths: shape(num_seqs,), stores lenghth of each sequence.
cache: shape(max_rotary_position(e.g.2048), head_dim), cos/sin cache constrcuted in model.
is_prompts: bool, mark if in prefill mode.
For prefill mode:
cos/sin cache for each sequence is equal to its length.
For decoding mode:
cos/sin cache is only needed for the last token.
"""

_, hidden_dim = cache.shape
num_seqs = lengths.numel()

BLOCK_SIZE = 16
if hidden_dim >= 128:
num_warps = 8
else:
num_warps = 4

cache_stride = cache.stride(0)
hidden_stride = cache.stride(1)

if is_prompts:
total_length = lengths.sum().item()
cumsum_lens = torch.cumsum(lengths, dim=0)
output = torch.empty((total_length, hidden_dim), dtype=cache.dtype, device=cache.device)
grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE)
prefill_cache_kernel[grid](
cache,
cumsum_lens,
output,
cache_stride,
hidden_stride,
total_length,
HIDDEN_DIM=hidden_dim,
N_ELEMENTS=triton.next_power_of_2(num_seqs),
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
else:
# BUG: get memory access error whe using a deepcopy lengths to replace lengths
nlengths = torch.as_tensor(lengths) - 1
output = torch.empty((num_seqs, hidden_dim), dtype=cache.dtype, device=cache.device)
grid = (triton.cdiv(num_seqs, BLOCK_SIZE),)
decoding_cache_kernel[grid](
cache,
nlengths,
output,
cache_stride,
hidden_stride,
HIDDEN_DIM=hidden_dim,
NUM_SEQS=num_seqs,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)

return output
Loading
Loading