Skip to content

[NPU]Improvement performence for grpo_loss#1174

Merged
Tcc0403 merged 4 commits intolinkedin:mainfrom
UserChen666:improvement
Apr 16, 2026
Merged

[NPU]Improvement performence for grpo_loss#1174
Tcc0403 merged 4 commits intolinkedin:mainfrom
UserChen666:improvement

Conversation

@UserChen666
Copy link
Copy Markdown
Contributor

@UserChen666 UserChen666 commented Mar 31, 2026

Summary

This PR implements Triton kernel-level performance optimizations for the grpo_loss operator.

Details

  1. Block Size (BLOCK_N) & Memory Coefficient Tuning
    Unified default tiling by increasing BLOCK_N from 2048 to 4096 across all kernels, improving per-batch computation granularity and reducing loop counts and kernel launch overhead.
    Adjusted memory multipliers downward to fit the NPU’s 192KB Unified Buffer (UB) capacity and enhance memory utilization:
    Softmax: 6.0 → 3.0
    Forward: 10.0 → 4.0
    Backward: 12.0 → 8.0
  2. Computation Instruction Optimization (Reduce Divisions, Improve Instruction Efficiency)
    Precompute inv_temp = 1.0 / TEMPERATURE to replace multiple in-loop divisions with single multiplications, reducing floating-point latency.
    Simplified backward gradient expression: dlogp = dlogp * dloss * inv_temp instead of the original chained division, lowering the number of floating-point operations.
  3. Loop & Compilation Optimizations
    Changed inner kernel loops from range to tl.static_range to provide explicit loop-unrolling hints to the compiler, optimizing instruction scheduling and pipeline efficiency.
    Explicit index type casting: INPUT_IDS cast to int32 to avoid implicit type conversion overhead on NPU.
  4. Masking & Memory Access Optimization
    Unified use of the cols_mask variable to reuse memory access masks, reducing redundant calculations and improving memory access throughput.
    Simplified gradient calculation logic: (cols_idx - probs) * dlogp instead of tl.where branching, minimizing branch judgment overhead.

Testing Done

image image
  • Hardware Type: Atlas A800I
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@UserChen666 UserChen666 changed the title Improvement performence for grpo_loss 【NPU】Improvement performence for grpo_loss Apr 1, 2026
@UserChen666
Copy link
Copy Markdown
Contributor Author

@Tcc0403 please review, thank you.

@UserChen666
Copy link
Copy Markdown
Contributor Author

@Tcc0403

@UserChen666 UserChen666 changed the title 【NPU】Improvement performence for grpo_loss [NPU]Improvement performence for grpo_loss Apr 14, 2026
Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Tcc0403 Tcc0403 added this pull request to the merge queue Apr 16, 2026
Merged via the queue into linkedin:main with commit fcaae50 Apr 16, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants