Skip to content

[Optimize, NPU] Remove tl.where from _rms_norm_forward/backward_kernel_tiled()#1153

Merged
Tcc0403 merged 1 commit intolinkedin:mainfrom
pt-ecosystem:main
Mar 20, 2026
Merged

[Optimize, NPU] Remove tl.where from _rms_norm_forward/backward_kernel_tiled()#1153
Tcc0403 merged 1 commit intolinkedin:mainfrom
pt-ecosystem:main

Conversation

@zheliuyu
Copy link
Copy Markdown
Contributor

@zheliuyu zheliuyu commented Mar 19, 2026

Summary

When the mask has a large shape, tl.where is not NPU-friendly in triton-ascend, leading to low kernel performance. When writing kernels, it's best to use alternative logic instead. This can result in significant performance improvements.

Will these changes affect accuracy? Since the masking operation has already been applied when loading X_block, it will not affect the calculation result during tl.sum.

Testing Done

Accuracy first

The shapes in test_rms_norm.py are too small to trigger the _rms_norm_forward_kernel_tiled kernel, so we need a new configuration.

@pytest.mark.parametrize(
    "bs, sl, hd",
    [
        (2, 2048, 4096),
        (2, 2048, 8192),
        (2, 2048, 16384),
        (2, 2048, 32768),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.float32, 1e-4, 1e-6),
    ],
)
@pytest.mark.parametrize(
    "reference, offset, casting_mode",
    [
        (LlamaRMSNorm, 0.0, "llama"),
        (GemmaRMSNorm, 1.0, "gemma"),
        pytest.param(
            BaseRMSNorm,
            0.0,
            "none",
            marks=pytest.mark.skipif(device == "npu", reason="Ascend NPU does not support this test"),
        ),
    ],
)
@pytest.mark.parametrize(
    "in_place",
    [
        True,
        False,
    ],
)
@pytest.mark.parametrize(
    "elementwise_affine",
    [
        True,
        False,
    ],
)

Env

image

Results after code modification

image

Benchmark test

The test cases in benchmark_rms_norm.py should keep the same shapes as those in test_rms_norm.py.

common_configs = {
    ...
    "x_values": [2**i for i in range(12, 16)],
    ...
}

Before Optimization

forward
rms_norm_speed_forward

backward
rms_norm_speed_backward

full
rms_norm_speed_full

memory
rms_norm_memory_full

all_benchmark_data_raw.csv

After Optimization

forward
rms_norm_speed_forward

backward
rms_norm_speed_backward

full
rms_norm_speed_full

memory
rms_norm_memory_full

all_benchmark_data_optimized.csv

  • Hardware Type: Atlas 900 A2 PoD
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@zheliuyu zheliuyu changed the title [npu] Remove tl.where from _rms_norm_backward_kernel_tiled [Optimize, NPU] Remove tl.where from _rms_norm_forward/backward_kernel_tiled() Mar 20, 2026
@zheliuyu zheliuyu marked this pull request as ready for review March 20, 2026 02:54
@zheliuyu
Copy link
Copy Markdown
Contributor Author

@Tcc0403 @TianHao324 PR is ready for review. Thanks.

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, one should always avoid tl.where at best. Thanks for pointing it out.

@Tcc0403 Tcc0403 added this pull request to the merge queue Mar 20, 2026
Merged via the queue into linkedin:main with commit 781083b Mar 20, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants