Skip to content

Commit

Permalink
feat baichuan2 rmsnorm whose hidden size equals to 5120 (#5611)
Browse files Browse the repository at this point in the history
  • Loading branch information
SunflowerAries authored Apr 19, 2024
1 parent e37ee2f commit ccf7279
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/inference/benchmark_ops/benchmark_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
ylabel="ms",
plot_name=f"RMSNorm benchmarking results",
args={"HIDDEN_SIZE": 1024},
args={"HIDDEN_SIZE": 5120},
)
]

Expand Down
6 changes: 6 additions & 0 deletions extensions/csrc/cuda/rms_layernorm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ void rms_layernorm(
case 2:
RMSNORM_LAUNCHER(2, block);
break;
case 3:
RMSNORM_LAUNCHER(3, block);
break;
case 4:
RMSNORM_LAUNCHER(4, block);
break;
Expand Down Expand Up @@ -321,6 +324,9 @@ void fused_add_rms_layernorm(
case 2:
FUSED_ADD_RMSNORM_LAUNCHER(2, block);
break;
case 3:
FUSED_ADD_RMSNORM_LAUNCHER(3, block);
break;
case 4:
FUSED_ADD_RMSNORM_LAUNCHER(4, block);
break;
Expand Down
4 changes: 2 additions & 2 deletions tests/test_infer/test_ops/cuda/test_rms_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@pytest.mark.parametrize("M", [2, 4, 8, 16])
@pytest.mark.parametrize("N", [64, 128, 512])
@pytest.mark.parametrize("N", [64, 128, 512, 5120])
def test_rms_layernorm(M: int, N: int):
torch.manual_seed(123)
torch.cuda.empty_cache()
Expand Down Expand Up @@ -48,4 +48,4 @@ def test_rms_layernorm(M: int, N: int):


if __name__ == "__main__":
test_rms_layernorm(16, 512)
test_rms_layernorm(16, 5120)

0 comments on commit ccf7279

Please sign in to comment.