Skip to content

Commit

Permalink
[Inference] Adapt to Fused rotary (#5348)
Browse files Browse the repository at this point in the history
* revise rotary embedding

* remove useless print

* adapt

* fix

* add

* fix

* modeling

* fix

* fix

* fix
  • Loading branch information
CjhHa1 committed Feb 7, 2024
1 parent 35382a7 commit 9f4ab2e
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 22 deletions.
5 changes: 2 additions & 3 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,10 @@ def forward(
torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0)
)

rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])

block_size = k_cache.size(-2)

if is_prompts:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
attn_output = context_attention_unpadded(
q=query_states,
k=key_states,
Expand All @@ -301,7 +300,7 @@ def forward(
sm_scale=sm_scale,
)
else:
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], k_cache, block_tables, sequence_lengths)
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
attn_output = flash_decoding_attention(
q=query_states,
Expand Down
1 change: 0 additions & 1 deletion colossalai/kernel/triton/kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def copy_kv_to_blocked_cache(
block_size = k_cache.size(-2)

num_warps = 8 if head_dim > 128 else 4

grid = (bsz, num_kv_heads)
_copy_to_kvcache_seqlen1_kernel[grid](
k,
Expand Down
136 changes: 128 additions & 8 deletions colossalai/kernel/triton/no_pad_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ def fused_rotary_embedding_kernel(
out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]
out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim

past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1
past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1

last_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride)
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens))
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride

kv_range0 = (
Expand Down Expand Up @@ -274,6 +274,122 @@ def fused_rotary_embedding_kernel(
)


@triton.jit
def fused_rotary_embedding_kernel_v2(
q,
k,
cos,
sin,
kv_cache,
BLOCK_TABLES,
context_lengths,
q_token_stride,
q_head_stride,
k_token_stride,
k_head_stride,
head_dim_stride,
cos_token_stride,
cos_stride,
cacheb_stride,
cacheh_stride,
cachebs_stride,
cached_stride,
bts_stride,
btb_stride,
block_size,
q_total_tokens,
Q_HEAD_NUM: tl.constexpr,
K_HEAD_NUM: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
block_head_index = tl.program_id(0)
if block_head_index >= Q_HEAD_NUM:
return
block_token_index = tl.program_id(1)

dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)

off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride
off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride
off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride
off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride

loaded_q0 = tl.load(
q + off_q0,
)
loaded_q1 = tl.load(
q + off_q1,
)

loaded_k0 = tl.load(
k + off_k0,
)

loaded_k1 = tl.load(
k + off_k1,
)

off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride

loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)
loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)

out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin
out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos

out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin
out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim

past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1

last_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride
block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens))
offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride

kv_range0 = (
block_ids * cacheb_stride
+ block_head_index * cacheh_stride
+ offsets_in_last_block
+ dim_range0 * cached_stride
)
kv_range1 = (
block_ids * cacheb_stride
+ block_head_index * cacheh_stride
+ offsets_in_last_block
+ dim_range1 * cached_stride
)

tl.store(
kv_cache + kv_range0,
out_k0,
)
tl.store(
kv_cache + kv_range1,
out_k1,
)

# concat
tl.store(
q + off_q0,
out_q0,
)
tl.store(
q + off_q1,
out_q1,
)
tl.store(
k + off_k0,
out_k0,
)
tl.store(
k + off_k1,
out_k1,
)


@torch.no_grad()
def rotary_embedding(
q: torch.Tensor,
k: torch.Tensor,
Expand All @@ -297,12 +413,13 @@ def rotary_embedding(
assert q.size(0) == k.size(0)
BLOCK_HEAD = 4
BLOCK_TOKENS = 4
grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]))

if head_dim >= 256:
if head_dim >= 1024:
num_warps = 32
elif head_dim >= 128:
elif head_dim >= 512:
num_warps = 16
elif head_dim >= 256:
num_warps = 8
else:
num_warps = 4

Expand All @@ -318,6 +435,10 @@ def rotary_embedding(
cos_token_stride = cos.stride(0)
cos_stride = cos.stride(1)
if k_cache == None:
grid = lambda META: (
triton.cdiv(q_head_num, META["BLOCK_HEAD"]),
triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]),
)
rotary_embedding_kernel[grid](
q,
k,
Expand All @@ -339,7 +460,8 @@ def rotary_embedding(
num_warps=num_warps,
)
else:
fused_rotary_embedding_kernel[grid](
grid = (triton.next_power_of_2(q_head_num), q_total_tokens)
fused_rotary_embedding_kernel_v2[grid](
q,
k,
cos,
Expand All @@ -365,8 +487,6 @@ def rotary_embedding(
Q_HEAD_NUM=q_head_num,
K_HEAD_NUM=k_head_num,
HEAD_DIM=head_dim,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_TOKENS=BLOCK_TOKENS,
num_warps=num_warps,
)
return
1 change: 1 addition & 0 deletions examples/inference/run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
ROOT=$(realpath $(dirname $0))
echo $ROOT
PY_SCRIPT=${ROOT}/benchmark_llama.py
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
mode=$1
Expand Down
40 changes: 30 additions & 10 deletions tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb

from colossalai.kernel.triton import rotary_embedding
from colossalai.kernel.triton import copy_kv_to_blocked_cache, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2

try:
Expand Down Expand Up @@ -94,8 +94,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
x_names=["num_tokens"],
x_vals=[2**i for i in range(4, 11)],
line_arg="provider",
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
styles=[("red", "-"), ("blue", "-")],
ylabel="ms",
plot_name=f"rotary_emb-batch-{BATCH}",
Expand All @@ -110,23 +110,43 @@ def benchmark_rotary_emb(
num_tokens: int,
num_kv_heads: int,
):
BATCH_SIZE = 4
SEQ_LEN = num_tokens // BATCH_SIZE
max_num_blocks_per_seq = 8
block_size = 64
warmup = 10
rep = 100

head_dim = 128
head_dim = 256
dtype = torch.float16

q_shape = (num_tokens, num_kv_heads, head_dim)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (num_tokens, num_kv_heads, head_dim)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (num_tokens, head_dim // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
v_cache = torch.zeros_like(k_cache)
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")

if provider == "torch_rotary_emb_func":
fn = lambda: torch_rotary_emb(q, cos, sin)
elif provider == "triton_rotary_emb_func":
fn = lambda: rotary_embedding(q, k, cos, sin)
if provider == "no_fused_rotary_emb_func":
fn = lambda: [
rotary_embedding(new_q, new_k, cos, sin),
copy_kv_to_blocked_cache(new_k, k_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables),
]
elif provider == "fused_triton_rotary_emb_func":
fn = lambda: rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths)
else:
raise ValueError("Undefined provider")

Expand All @@ -135,5 +155,5 @@ def benchmark_rotary_emb(


if __name__ == "__main__":
test_rotary_emb(4, 64, 32, 64, torch.float32)
# benchmark_rotary_emb.run(save_path=".",print_data=True)
# test_rotary_emb(4, 64, 32, 64, torch.float32)
benchmark_rotary_emb.run(save_path=".", print_data=True)

0 comments on commit 9f4ab2e

Please sign in to comment.