Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] add int8 rotary embedding kernel for smoothquant #4843

Merged
merged 62 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
c7d6975
[shardformer] fix GPT2DoubleHeadsModel (#4703)
flybird11111 Sep 13, 2023
e2c0e7f
[hotfix] Fix import error: colossal.kernel without triton installed (…
yuanheng-zhao Sep 14, 2023
20190b4
[shardformer] to fix whisper test failed due to significant accuracy …
flybird11111 Sep 14, 2023
ce97790
[doc] fix llama2 code link (#4726)
binmakeswell Sep 14, 2023
f911d5b
[doc] Add user document for Shardformer (#4702)
Fridge003 Sep 15, 2023
8c2dda7
[format] applied code formatting on changed files in pull request 472…
github-actions[bot] Sep 15, 2023
50e5602
[doc] add shardformer support matrix/update tensor parallel documents…
Fridge003 Sep 15, 2023
e4fc57c
Optimized some syntax errors in the documentation and code under appl…
digger-yu Sep 15, 2023
4616263
[shardformer] update pipeline parallel document (#4725)
flybird11111 Sep 15, 2023
cd4e61d
[legacy] remove deterministic data loader test
ppt0011 Sep 15, 2023
6a03c93
[shardformer] update seq parallel document (#4730)
FoolPlayer Sep 15, 2023
608cffa
[example] add gpt2 HybridParallelPlugin example (#4653)
FoolPlayer Sep 15, 2023
73eb3e8
Merge pull request #4738 from ppt0011/main
ppt0011 Sep 15, 2023
451c346
[doc] polish shardformer doc (#4735)
Fridge003 Sep 15, 2023
ac27979
[shardformer] add custom policy in hybrid parallel plugin (#4718)
oahzxl Sep 15, 2023
4c4482f
[example] llama2 add fine-tune example (#4673)
flybird11111 Sep 15, 2023
d151dca
[doc] explaination of loading large pretrained models (#4741)
Fridge003 Sep 15, 2023
32e7f99
[kernel] update triton init #4740 (#4740)
oahzxl Sep 18, 2023
b5f9e37
[legacy] clean up legacy code (#4743)
ver217 Sep 18, 2023
3c6b831
[format] applied code formatting on changed files in pull request 474…
github-actions[bot] Sep 18, 2023
079bf3c
[misc] update pre-commit and run all files (#4752)
ver217 Sep 19, 2023
10513f2
[doc] explain suitable use case for each plugin
ppt0011 Sep 19, 2023
a04337b
[doc] put individual plugin explanation in front
ppt0011 Sep 19, 2023
e10d9f0
[doc] add model examples for each plugin
ppt0011 Sep 19, 2023
4d7537b
[doc] put native colossalai plugins first in description section
ppt0011 Sep 20, 2023
07c2e3d
Merge pull request #4757 from ppt0011/main
ppt0011 Sep 20, 2023
7b9b864
[chat]: update rm, add wandb and fix bugs (#4471)
CWHer Sep 20, 2023
c0a0337
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapp…
Fridge003 Sep 20, 2023
df66741
[bug] fix get_default_parser in examples (#4764)
Fridge003 Sep 21, 2023
66f3926
[doc] clean up outdated docs (#4765)
ver217 Sep 21, 2023
493a5ef
[doc] add shardformer doc to sidebar (#4768)
Fridge003 Sep 21, 2023
901ab1e
[chat]: add lora merge weights config (#4766)
CWHer Sep 21, 2023
3e05c07
[lazy] support torch 2.0 (#4763)
ver217 Sep 21, 2023
1e0e080
[bug] Fix the version check bug in colossalai run when generating the…
littsk Sep 22, 2023
946ab56
[feature] add gptq for inference (#4754)
Xu-Kai Sep 22, 2023
ce7ade3
[inference] chatglm2 infer demo (#4724)
CjhHa1 Sep 22, 2023
4146f1c
[release] update version (#4775)
ver217 Sep 22, 2023
74aa7d9
initial commit: add colossal llama 2 (#4784)
TongLi3701 Sep 24, 2023
ce77785
[feature] ColossalEval: Evaluation Pipeline for LLMs (#4786)
chengeharrison Sep 24, 2023
d512a4d
[doc] add llama2 domain-specific solution news (#4789)
binmakeswell Sep 25, 2023
26cd6d8
[fix] fix weekly runing example (#4787)
flybird11111 Sep 25, 2023
a2db755
[doc] polish shardformer doc (#4779)
Fridge003 Sep 26, 2023
64a08b2
[checkpointio] support unsharded checkpointIO for hybrid parallel (#4…
Fridge003 Sep 26, 2023
bd01467
update readme
TongLi3701 Sep 26, 2023
4965c0d
[lazy] support from_pretrained (#4801)
ver217 Sep 26, 2023
8cbce61
update
TongLi3701 Sep 26, 2023
62b6af1
Merge pull request #4805 from TongLi3701/docs/fix
Desperado-Jia Sep 26, 2023
b6cf0ac
[hotfix] change llama2 Colossal-LLaMA-2 script filename (#4800)
Chandler-Bing Sep 26, 2023
a227063
[misc] add last_epoch in CosineAnnealingWarmupLR (#4778)
hova88 Sep 26, 2023
da15fdb
[doc] add lazy init docs (#4808)
ver217 Sep 27, 2023
54b3ad8
[hotfix] fix norm type error in zero optimizer (#4795)
littsk Sep 27, 2023
11f1e42
[hotfix] Correct several erroneous code comments (#4794)
littsk Sep 27, 2023
fb46d05
[format] applied code formatting on changed files in pull request 459…
github-actions[bot] Sep 27, 2023
bbbcac2
fix format (#4815)
TongLi3701 Sep 27, 2023
be400a0
[chat] fix gemini strategy (#4698)
flybird11111 Sep 27, 2023
1fa8c5e
Update Qwen-7B results (#4821)
chengeharrison Sep 27, 2023
822051d
[doc] update slack link (#4823)
binmakeswell Sep 27, 2023
c3bef20
add autotune (#4822)
Xu-Kai Sep 28, 2023
ed06731
update Colossal (#4832)
TongLi3701 Sep 28, 2023
83f85c8
add int8 rotary embedding kernel
Xu-Kai Sep 29, 2023
b4b59d4
remove useless code
Xu-Kai Sep 29, 2023
7d20460
Merge branch 'feature/smoothquant' into feature/smoothquant
Xu-Kai Sep 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .gptq_triton import gptq_fused_linear_triton
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax
Expand All @@ -22,6 +23,7 @@
"rotary_embedding_fwd",
"token_attention_fwd",
"gptq_fused_linear_triton",
"int8_rotary_embedding_fwd",
]

except ImportError:
Expand Down
119 changes: 119 additions & 0 deletions colossalai/kernel/triton/int8_rotary_embedding_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import torch
import triton
import triton.language as tl


@triton.jit
def _rotary_kernel(
q,
input_scale,
output_scale,
Cos,
Sin,
q_bs_stride,
q_h_stride,
q_d_stride,
cos_bs_stride,
cos_d_stride,
total_len,
HEAD_NUM: tl.constexpr,
BLOCK_HEAD: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
current_head_index = tl.program_id(0)
current_seq_index = tl.program_id(1)

dim_range0 = tl.arange(0, HEAD_DIM // 2)
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)

current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)

off_q0 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range0[None, None, :] * q_d_stride
)
off_q1 = (
current_seq_range[:, None, None] * q_bs_stride
+ current_head_range[None, :, None] * q_h_stride
+ dim_range1[None, None, :] * q_d_stride
)

off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride

q0 = tl.load(
q + off_q0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)
q1 = tl.load(
q + off_q1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
other=0.0,
)

cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
in_scale = tl.load(input_scale)
o_scale = tl.load(output_scale)

q0 = q0.to(tl.float32) * in_scale
q1 = q1.to(tl.float32) * in_scale

out0 = (q0 * cos - q1 * sin) / o_scale
out1 = (q0 * sin + q1 * cos) / o_scale

# out0 = out0.to(tl.int8)
# out1 = out1.to(tl.int8)
Xu-Kai marked this conversation as resolved.
Show resolved Hide resolved

tl.store(
q + off_q0,
out0,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)
tl.store(
q + off_q1,
out1,
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
)

return


@torch.no_grad()
def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):
total_len = q.shape[0]
head_num = q.shape[1]
head_dim = q.shape[2]
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
BLOCK_HEAD = 4
BLOCK_SEQ = 32
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
if head_dim >= 128:
num_warps = 8
else:
num_warps = 4

_rotary_kernel[grid](
q,
input_scale,
output_scale,
cos,
sin,
q.stride(0),
q.stride(1),
q.stride(2),
cos.stride(0),
cos.stride(1),
total_len,
HEAD_NUM=head_num,
BLOCK_HEAD=BLOCK_HEAD,
BLOCK_SEQ=BLOCK_SEQ,
HEAD_DIM=head_dim,
num_warps=num_warps,
num_stages=1,
)
return
59 changes: 59 additions & 0 deletions tests/test_smoothquant/test_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm


import pytest
import torch
from packaging import version

try:
from colossalai.kernel.triton import int8_rotary_embedding_fwd

HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")

TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")


def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0 : dim // 2]
x1 = x[:, :, dim // 2 : dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)


@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_rotary_emb():
SEQ_LEN = 1
HEAD_NUM = 32
HEAD_DIM = 128
dtype = torch.float
# create data
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
# forward pass
y_torch = torch_rotary_emb(x, cos, sin)

input_scale = torch.max(torch.abs(x)) / 127
output_scale = torch.max(torch.abs(y_torch)) / 127

x = x / input_scale
x = x.to(torch.int8)

int8_rotary_embedding_fwd(x, cos, sin, input_scale, output_scale)
y_triton = x.to(torch.float) * output_scale
assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True)
Xu-Kai marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
test_rotary_emb()
Loading