[NPU] Add group norm support on NPU#1144
Conversation
|
Hi @Tcc0403 , please take a look. Thanks! |
| else: | ||
| dW_block = tl.where(mask, DY_block * x_hat, 0.0) | ||
| dB_block = tl.where(mask, DY_block, 0.0) | ||
| tl.atomic_add(DW_scratch_base + global_channel, dW_block, mask=mask) | ||
| tl.atomic_add(DB_scratch_base + global_channel, dB_block, mask=mask) |
There was a problem hiding this comment.
L377 says
# Placeholder buffers (unused in kernel when COMPUTE_PARAM_GRAD=False)
, which contradicts what this block does. Shouldn't it be no-op here?
There was a problem hiding this comment.
Removed.
I originally kept it to preserve the kernel structure for potential future experiments, but it’s not needed in the current implementation.
| # Placeholder buffers (unused in kernel when COMPUTE_PARAM_GRAD=False) | ||
| DW_scratch = torch.empty((1, 1), dtype=torch.float32, device=W.device) | ||
| DB_scratch = torch.empty((1, 1), dtype=torch.float32, device=W.device) |
There was a problem hiding this comment.
Can placeholder buffers set to None in triton-ascend? to avoid accidently access in device code.
|
Hi @Tcc0403, changes applied. Appreciate another review when you have time, thanks! |
Tcc0403
left a comment
There was a problem hiding this comment.
LGTM! I left a comment about a potential improvement, but it can be done in another PR!
| if COMPUTE_PARAM_GRAD: | ||
| if SINGLE_CHANNEL_TILE: | ||
| dW_partial = tl.sum(tl.where(mask, DY_block * x_hat, 0.0), axis=1) | ||
| dB_partial = tl.sum(tl.where(mask, DY_block, 0.0), axis=1) | ||
| tl.atomic_add(DW_scratch_base + global_channel, dW_partial, mask=row_mask) | ||
| tl.atomic_add(DB_scratch_base + global_channel, dB_partial, mask=row_mask) |
There was a problem hiding this comment.
I wonder if we can accumulate dw and db over grid loop and store it after, similar to
Liger-Kernel/src/liger_kernel/ops/rms_norm.py
Line 390 in db14ea8
With this approach, we can avoid using atomic_add and potentially handle the scenario where num_col_blocks>1.
The solution is not trivial and not gauranteed to achieve better performance, leaving the comment here as a future works direction.
There was a problem hiding this comment.
Thanks for the suggestion! I’ll look into this.
Summary
This PR introduces a functional GroupNorm operator for Ascend NPU.
Key improvements:
grid should be less than 65536!andub overflowthat occurs when the original GPU-oriented liger-kernel GroupNorm implementation is executed on NPU.While the current implementation is still slower than the HuggingFace implementation in end-to-end benchmarks, it provides a stable and functional GroupNorm path for Ascend NPU.
This PR mainly focuses on correctness and NPU compatibility. Further kernel-level optimizations will be explored in follow-up work.
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence