Skip to content

[NPU] Add NPU optimized Attention Residual Kernel#1210

Merged
Tcc0403 merged 2 commits into
linkedin:mainfrom
lowdy1:attn_res
May 9, 2026
Merged

[NPU] Add NPU optimized Attention Residual Kernel#1210
Tcc0403 merged 2 commits into
linkedin:mainfrom
lowdy1:attn_res

Conversation

@lowdy1
Copy link
Copy Markdown
Contributor

@lowdy1 lowdy1 commented Apr 30, 2026

Summary

This PR introduces both tiled and non-tiled implementations of the attention residual forward and backward Triton kernels, enabling efficient execution when the hidden dimension (D) is large and would otherwise cause UB overflow.

It also adds get_optimal_block_d(...), which leverages compute_default_tiling_strategy to automatically select a safe BLOCK_D and prevent on-chip memory overflow.

Test with:
python -m pytest ./test/transformers/test_attn_res.py -v
python ./benchmark/scripts/benchmark_attn_res.py

Hardware Type: Atlas 800I A2

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@lowdy1
Copy link
Copy Markdown
Contributor Author

lowdy1 commented Apr 30, 2026

Benchmark forward:
attn_res_speed_forward_token_length
Benchmark backward:
attn_res_speed_backward_token_length
Benchmark full:
attn_res_speed_full_token_length

@lowdy1
Copy link
Copy Markdown
Contributor Author

lowdy1 commented Apr 30, 2026

**************************************
     BENCHMARKING SPEED for ATTN_RES
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "attn_res",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      2.8819398880004883,
      5.667540073394775,
      11.246179580688477,
      22.42807960510254
    ],
    "y_values_20": [
      2.8778719902038574,
      5.667540073394775,
      11.246179580688477,
      22.42807960510254
    ],
    "y_values_80": [
      2.8860080242156982,
      5.667540073394775,
      11.246179580688477,
      22.42807960510254
    ],
    "timestamp": "2026-04-30 08:12:59",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "pytorch",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      14.994500160217285,
      29.991540908813477,
      60.286678314208984,
      126.10389709472656
    ],
    "y_values_20": [
      14.994500160217285,
      29.991540908813477,
      60.286678314208984,
      126.10389709472656
    ],
    "y_values_80": [
      14.994500160217285,
      29.991540908813477,
      60.286678314208984,
      126.10389709472656
    ],
    "timestamp": "2026-04-30 08:13:01",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      0.9125200510025024,
      1.7604800462722778,
      3.519040107727051,
      6.974160194396973
    ],
    "y_values_20": [
      0.9114120006561279,
      1.7598319053649902,
      3.5165679454803467,
      6.974160194396973
    ],
    "y_values_80": [
      0.9148439764976501,
      1.7624120712280273,
      3.521512031555176,
      6.974160194396973
    ],
    "timestamp": "2026-04-30 08:13:01",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "pytorch",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      4.38323974609375,
      8.889080047607422,
      17.790620803833008,
      41.15327835083008
    ],
    "y_values_20": [
      4.382195949554443,
      8.889080047607422,
      17.790620803833008,
      41.15327835083008
    ],
    "y_values_80": [
      4.384284019470215,
      8.889080047607422,
      17.790620803833008,
      41.15327835083008
    ],
    "timestamp": "2026-04-30 08:13:02",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "liger",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      1.5567200183868408,
      3.0423898696899414,
      6.044000148773193,
      11.948840141296387
    ],
    "y_values_20": [
      1.5510119199752808,
      3.037775993347168,
      6.044000148773193,
      11.948840141296387
    ],
    "y_values_80": [
      1.5624759197235107,
      3.047003984451294,
      6.044000148773193,
      11.948840141296387
    ],
    "timestamp": "2026-04-30 08:13:02",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "pytorch",
    "metric_name": "speed",
    "metric_unit": "ms",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      10.344120025634766,
      20.295299530029297,
      40.95214080810547,
      81.65383911132812
    ],
    "y_values_20": [
      10.344120025634766,
      20.295299530029297,
      40.95214080810547,
      81.65383911132812
    ],
    "y_values_80": [
      10.344120025634766,
      20.295299530029297,
      40.95214080810547,
      81.65383911132812
    ],
    "timestamp": "2026-04-30 08:13:04",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  }
]
**************************************
     BENCHMARKING MEMORY for ATTN_RES
**************************************
********** Benchmark Data **********
[
  {
    "kernel_name": "attn_res",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      208.13134765625,
      416.19384765625,
      832.31884765625,
      1664.56884765625
    ],
    "y_values_20": [
      208.13134765625,
      416.19384765625,
      832.31884765625,
      1664.56884765625
    ],
    "y_values_80": [
      208.13134765625,
      416.19384765625,
      832.31884765625,
      1664.56884765625
    ],
    "timestamp": "2026-04-30 08:13:04",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "pytorch",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      1104.10205078125,
      2208.16455078125,
      4416.28955078125,
      8832.5390625
    ],
    "y_values_20": [
      1104.10205078125,
      2208.16455078125,
      4416.28955078125,
      8832.5390625
    ],
    "y_values_80": [
      1104.10205078125,
      2208.16455078125,
      4416.28955078125,
      8832.5390625
    ],
    "timestamp": "2026-04-30 08:13:04",
    "kernel_operation_mode": "full",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      72.0810546875,
      144.1435546875,
      288.2685546875,
      576.5185546875
    ],
    "y_values_20": [
      72.0810546875,
      144.1435546875,
      288.2685546875,
      576.5185546875
    ],
    "y_values_80": [
      72.0810546875,
      144.1435546875,
      288.2685546875,
      576.5185546875
    ],
    "timestamp": "2026-04-30 08:13:04",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "pytorch",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      528.16455078125,
      1056.28955078125,
      2112.53955078125,
      4225.03955078125
    ],
    "y_values_20": [
      528.16455078125,
      1056.28955078125,
      2112.53955078125,
      4225.03955078125
    ],
    "y_values_80": [
      528.16455078125,
      1056.28955078125,
      2112.53955078125,
      4225.03955078125
    ],
    "timestamp": "2026-04-30 08:13:04",
    "kernel_operation_mode": "forward",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "liger",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      208.13134765625,
      416.19384765625,
      832.31884765625,
      1664.56884765625
    ],
    "y_values_20": [
      208.13134765625,
      416.19384765625,
      832.31884765625,
      1664.56884765625
    ],
    "y_values_80": [
      208.13134765625,
      416.19384765625,
      832.31884765625,
      1664.56884765625
    ],
    "timestamp": "2026-04-30 08:13:05",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  },
  {
    "kernel_name": "attn_res",
    "kernel_provider": "pytorch",
    "metric_name": "memory",
    "metric_unit": "MB",
    "gpu_name": "Ascend910B4",
    "x_name": "T",
    "x_label": "sequence length",
    "x_values": [
      1024,
      2048,
      4096,
      8192
    ],
    "y_values_50": [
      1104.10205078125,
      2208.16455078125,
      4416.28955078125,
      8832.5390625
    ],
    "y_values_20": [
      1104.10205078125,
      2208.16455078125,
      4416.28955078125,
      8832.5390625
    ],
    "y_values_80": [
      1104.10205078125,
      2208.16455078125,
      4416.28955078125,
      8832.5390625
    ],
    "timestamp": "2026-04-30 08:13:05",
    "kernel_operation_mode": "backward",
    "extra_benchmark_config_str": "{\"N\": 8, \"bsz\": 1, \"hidden_size\": 4096, \"dtype\": \"torch.bfloat16\", \"eps\": 1e-06}",
    "liger_version": "0.7.0"
  }
]

from test.utils import supports_bfloat16

from liger_kernel.ops.attn_res import LigerAttnResFunction
from liger_kernel.ops import LigerAttnResFunction
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking in — is this PR ready for review? We're down to just this one remaining test case that needs fixing.

Image

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please help review this? @Tcc0403

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Tcc0403 Tcc0403 added this pull request to the merge queue May 9, 2026
Merged via the queue into linkedin:main with commit 5e5a48a May 9, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants