Skip to content

[KDA] fix internal output_final_state wrapper issue in SM90#66

Merged
KevinZeng08 merged 5 commits intoinclusionAI:mainfrom
yechenzhi:fix_internal_output_state
May 9, 2026
Merged

[KDA] fix internal output_final_state wrapper issue in SM90#66
KevinZeng08 merged 5 commits intoinclusionAI:mainfrom
yechenzhi:fix_internal_output_state

Conversation

@yechenzhi
Copy link
Copy Markdown
Contributor

@yechenzhi yechenzhi commented May 8, 2026

📌 Description

This PR is a follow-up to #63.

#63 fixed the Python wrapper behavior of Hopper KDA fused prefill so that final_state is returned as None when output_final_state=False. However, that PR only changed the Python-side return value. The underlying C++/CUDA path still allocated an output_state buffer and passed a non-null ptr_output_state to the kernel, so the final state was still written internally even when the caller did not request it.

This PR passes output_final_state from the Python wrapper to the C++ API, and avoids allocating or storing the final state when it is not requested.

Specifically, this PR:

  • updates the C++ API to accept output_final_state;
  • allocates output_state only when output_final_state=True;
  • passes nullptr as ptr_output_state when final state output is not requested;
  • skips the final kv_store() in the SM90 KDA mainloop when ptr_output_state == nullptr;
  • adds a regression test to verify that output_final_state=False returns None and produces the same output tensor as output_final_state=True.

This avoids unnecessary final-state allocation and the final global-memory store while preserving the output tensor.

🔍 Related Issues

mentioned here.

🚀 Pull Request Checklist

Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing.

tested with
pytest tests/test_kda_fused_fwd.py

⚡ Performance

python benchmarks/bench_kda_fused_fwd.py

before:
1:output_final_state = True

[Device] NVIDIA H800 PCIe  compute capability sm90  →  using cula.kda.hopper_fused_fwd.cula_kda_prefill

====================================================================================================
 Fixed-Length Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================

====================================================================================================
 Varlen Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================


==============================================================================================================
                  BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)
                  cuLA sm90 fully-fused vs FLA Triton
                  H=64  D=128  dtype=bf16  safe_gate=True  has_init_state=False
                  Warmup=25  Iters=100
==============================================================================================================

  [Fixed-Length]
  ──────────────────────────────────────────────────────────────────────────────────────────
    B       T  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────
    1     512  │    0.000019    0.007692    0.000008  │     0.7245      0.2834     2.56x
    1    1024  │    0.000022    0.008021    0.000010  │     0.7165      0.3142     2.28x
    1    4096  │    0.000019    0.008523    0.000008  │     1.6661      1.1345     1.47x
    1    8192  │    0.000021    0.010204    0.000009  │     3.2977      2.2818     1.45x
    1   16384  │    0.000019    0.008368    0.000008  │     6.5983      4.4572     1.48x
    2     512  │    0.000022    0.008021    0.000010  │     0.7180      0.3628     1.98x
    2    1024  │    0.000020    0.007353    0.000009  │     0.8813      0.6292     1.40x
    2    4096  │    0.000021    0.010204    0.000009  │     3.3218      2.2665     1.47x
    2    8192  │    0.000019    0.008368    0.000008  │     6.7096      4.4716     1.50x
    2   16384  │    0.000019    0.007353    0.000008  │    13.4159      8.7355     1.54x
  ──────────────────────────────────────────────────────────────────────────────────────────

  [Varlen]
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         Config  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
       uniform 10seqs T=4096 [409..415] avg=409  │    0.000019    0.011364    0.000008  │     1.7035      1.1006     1.55x
        random 10seqs T=4096 [24..1201] avg=409  │    0.000019    0.008523    0.000008  │     1.6985      1.0122     1.68x
       skewed 10seqs T=4096 [227..2053] avg=409  │    0.000019    0.008523    0.000008  │     1.6890      0.9857     1.71x
       uniform 20seqs T=4096 [204..220] avg=204  │    0.000019    0.008523    0.000008  │     1.8191      1.4444     1.26x
          random 20seqs T=4096 [5..787] avg=204  │    0.000019    0.011364    0.000008  │     1.7835      1.0948     1.63x
       skewed 20seqs T=4096 [107..2063] avg=204  │    0.000019    0.011364    0.000008  │     1.7408      1.1191     1.56x
       uniform 10seqs T=8192 [819..821] avg=819  │    0.000021    0.010204    0.000009  │     3.2902      1.8610     1.77x
        random 10seqs T=8192 [48..2401] avg=819  │    0.000021    0.010204    0.000009  │     3.3108      1.8550     1.78x
       skewed 10seqs T=8192 [455..4097] avg=819  │    0.000021    0.010204    0.000009  │     3.3159      1.8308     1.81x
       uniform 20seqs T=8192 [409..421] avg=409  │    0.000020    0.010204    0.000009  │     3.3826      2.2154     1.53x
         random 20seqs T=8192 [9..1574] avg=409  │    0.000020    0.010204    0.000009  │     3.3713      1.9197     1.76x
       skewed 20seqs T=8192 [215..4107] avg=409  │    0.000020    0.010256    0.000009  │     3.3757      1.9554     1.73x
   uniform 10seqs T=16384 [1638..1642] avg=1638  │    0.000019    0.008368    0.000008  │     6.5568      3.4048     1.93x
      random 10seqs T=16384 [95..4802] avg=1638  │    0.000019    0.008368    0.000008  │     6.5580      3.4761     1.89x
     skewed 10seqs T=16384 [910..8194] avg=1638  │    0.000019    0.008368    0.000008  │     6.5515      3.3333     1.97x
      uniform 20seqs T=16384 [819..823] avg=819  │    0.000019    0.008368    0.000008  │     6.6029      3.7340     1.77x
       random 20seqs T=16384 [19..3147] avg=819  │    0.000019    0.008368    0.000008  │     6.5960      3.4868     1.89x
      skewed 20seqs T=16384 [431..8195] avg=819  │    0.000019    0.008403    0.000008  │     6.5702      3.4492     1.90x
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================

2: output_final_state = False

[Device] NVIDIA H800 PCIe  compute capability sm90  →  using cula.kda.hopper_fused_fwd.cula_kda_prefill

====================================================================================================
 Fixed-Length Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================

====================================================================================================
 Varlen Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================


==============================================================================================================
                  BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)
                  cuLA sm90 fully-fused vs FLA Triton
                  H=64  D=128  dtype=bf16  safe_gate=True  has_init_state=False
                  Warmup=25  Iters=100
==============================================================================================================

  [Fixed-Length]
  ──────────────────────────────────────────────────────────────────────────────────────────
    B       T  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────
    1     512  │    0.000019    0.007692    0.000008  │     0.7226      0.2873     2.51x
    1    1024  │    0.000022    0.008021    0.000010  │     0.7123      0.3140     2.27x
    1    4096  │    0.000019    0.008523    0.000008  │     1.6578      1.1286     1.47x
    1    8192  │    0.000021    0.010204    0.000009  │     3.2904      2.2521     1.46x
    1   16384  │    0.000019    0.008368    0.000008  │     6.5959      4.4610     1.48x
    2     512  │    0.000022    0.008021    0.000010  │     0.6983      0.3517     1.99x
    2    1024  │    0.000020    0.007353    0.000009  │     0.8698      0.6290     1.38x
    2    4096  │    0.000021    0.010204    0.000009  │     3.3259      2.2705     1.46x
    2    8192  │    0.000019    0.008368    0.000008  │     6.7053      4.4761     1.50x
    2   16384  │    0.000019    0.007353    0.000008  │    13.4593      8.7126     1.54x
  ──────────────────────────────────────────────────────────────────────────────────────────

  [Varlen]
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         Config  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
       uniform 10seqs T=4096 [409..415] avg=409  │    0.000019    0.011364    0.000008  │     1.6610      1.1007     1.51x
        random 10seqs T=4096 [24..1201] avg=409  │    0.000019    0.008523    0.000008  │     1.6489      1.0130     1.63x
       skewed 10seqs T=4096 [227..2053] avg=409  │    0.000019    0.008523    0.000008  │     1.6413      0.9868     1.66x
       uniform 20seqs T=4096 [204..220] avg=204  │    0.000019    0.008523    0.000008  │     1.7251      1.4417     1.20x
          random 20seqs T=4096 [5..787] avg=204  │    0.000019    0.011364    0.000008  │     1.6955      1.0956     1.55x
       skewed 20seqs T=4096 [107..2063] avg=204  │    0.000019    0.011364    0.000008  │     1.6552      1.1234     1.47x
       uniform 10seqs T=8192 [819..821] avg=819  │    0.000021    0.010204    0.000009  │     3.2432      1.8628     1.74x
        random 10seqs T=8192 [48..2401] avg=819  │    0.000021    0.010204    0.000009  │     3.2664      1.8515     1.76x
       skewed 10seqs T=8192 [455..4097] avg=819  │    0.000021    0.010204    0.000009  │     3.2711      1.8314     1.79x
       uniform 20seqs T=8192 [409..421] avg=409  │    0.000020    0.010204    0.000009  │     3.2875      2.2031     1.49x
         random 20seqs T=8192 [9..1574] avg=409  │    0.000020    0.010204    0.000009  │     3.2906      1.9254     1.71x
       skewed 20seqs T=8192 [215..4107] avg=409  │    0.000020    0.010256    0.000009  │     3.2773      1.9557     1.68x
   uniform 10seqs T=16384 [1638..1642] avg=1638  │    0.000019    0.008368    0.000008  │     6.5063      3.4007     1.91x
      random 10seqs T=16384 [95..4802] avg=1638  │    0.000019    0.008368    0.000008  │     6.4758      3.4772     1.86x
     skewed 10seqs T=16384 [910..8194] avg=1638  │    0.000019    0.008368    0.000008  │     6.5240      3.3306     1.96x
      uniform 20seqs T=16384 [819..823] avg=819  │    0.000019    0.008368    0.000008  │     6.4837      3.7402     1.73x
       random 20seqs T=16384 [19..3147] avg=819  │    0.000019    0.008368    0.000008  │     6.5310      3.5019     1.86x
      skewed 20seqs T=16384 [431..8195] avg=819  │    0.000019    0.008403    0.000008  │     6.4862      3.4524     1.88x
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
==============================================================================================================

after:

1: output_final_state = True

[Device] NVIDIA H800 PCIe  compute capability sm90  →  using cula.kda.hopper_fused_fwd.cula_kda_prefill

====================================================================================================
 Fixed-Length Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================

====================================================================================================
 Varlen Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================


==============================================================================================================
                  BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)
                  cuLA sm90 fully-fused vs FLA Triton
                  H=64  D=128  dtype=bf16  safe_gate=True  has_init_state=False
                  Warmup=25  Iters=100
==============================================================================================================

  [Fixed-Length]
  ──────────────────────────────────────────────────────────────────────────────────────────
    B       T  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────
    1     512  │    0.000019    0.007692    0.000008  │     0.7338      0.2845     2.58x
    1    1024  │    0.000022    0.008021    0.000010  │     0.7264      0.3161     2.30x
    1    4096  │    0.000019    0.008523    0.000008  │     1.6662      1.1320     1.47x
    1    8192  │    0.000021    0.010204    0.000009  │     3.2996      2.2640     1.46x
    1   16384  │    0.000019    0.008368    0.000008  │     6.6189      4.4580     1.48x
    2     512  │    0.000022    0.008021    0.000010  │     0.7168      0.3510     2.04x
    2    1024  │    0.000020    0.007353    0.000009  │     0.8810      0.6302     1.40x
    2    4096  │    0.000021    0.010204    0.000009  │     3.3290      2.2683     1.47x
    2    8192  │    0.000019    0.008368    0.000008  │     6.7120      4.4719     1.50x
    2   16384  │    0.000019    0.007353    0.000008  │    13.4336      8.7589     1.53x
  ──────────────────────────────────────────────────────────────────────────────────────────

  [Varlen]
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         Config  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
       uniform 10seqs T=4096 [409..415] avg=409  │    0.000019    0.011364    0.000008  │     1.7064      1.1039     1.55x
        random 10seqs T=4096 [24..1201] avg=409  │    0.000019    0.008523    0.000008  │     1.7010      1.0118     1.68x
       skewed 10seqs T=4096 [227..2053] avg=409  │    0.000019    0.008523    0.000008  │     1.6935      0.9855     1.72x
       uniform 20seqs T=4096 [204..220] avg=204  │    0.000019    0.008523    0.000008  │     1.8211      1.4417     1.26x
          random 20seqs T=4096 [5..787] avg=204  │    0.000019    0.011364    0.000008  │     1.7834      1.0969     1.63x
       skewed 20seqs T=4096 [107..2063] avg=204  │    0.000019    0.011364    0.000008  │     1.7422      1.1298     1.54x
       uniform 10seqs T=8192 [819..821] avg=819  │    0.000021    0.010204    0.000009  │     3.2886      1.8640     1.76x
        random 10seqs T=8192 [48..2401] avg=819  │    0.000021    0.010204    0.000009  │     3.3087      1.8534     1.79x
       skewed 10seqs T=8192 [455..4097] avg=819  │    0.000021    0.010204    0.000009  │     3.3164      1.8312     1.81x
       uniform 20seqs T=8192 [409..421] avg=409  │    0.000020    0.010204    0.000009  │     3.3816      2.2141     1.53x
         random 20seqs T=8192 [9..1574] avg=409  │    0.000020    0.010204    0.000009  │     3.3760      1.9246     1.75x
       skewed 20seqs T=8192 [215..4107] avg=409  │    0.000020    0.010256    0.000009  │     3.3764      1.9501     1.73x
   uniform 10seqs T=16384 [1638..1642] avg=1638  │    0.000019    0.008368    0.000008  │     6.5511      3.4041     1.92x
      random 10seqs T=16384 [95..4802] avg=1638  │    0.000019    0.008368    0.000008  │     6.5635      3.4809     1.89x
     skewed 10seqs T=16384 [910..8194] avg=1638  │    0.000019    0.008368    0.000008  │     6.5311      3.3224     1.97x
      uniform 20seqs T=16384 [819..823] avg=819  │    0.000019    0.008368    0.000008  │     6.5863      3.7379     1.76x
       random 20seqs T=16384 [19..3147] avg=819  │    0.000019    0.008368    0.000008  │     6.6000      3.4784     1.90x
      skewed 20seqs T=16384 [431..8195] avg=819  │    0.000019    0.008403    0.000008  │     6.5688      3.4395     1.91x
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────

===================================================================================

2: output_final_state = False

[Device] NVIDIA H800 PCIe  compute capability sm90  →  using cula.kda.hopper_fused_fwd.cula_kda_prefill

====================================================================================================
 Fixed-Length Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================

====================================================================================================
 Varlen Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================


==============================================================================================================
                  BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)
                  cuLA sm90 fully-fused vs FLA Triton
                  H=64  D=128  dtype=bf16  safe_gate=True  has_init_state=False
                  Warmup=25  Iters=100
==============================================================================================================

  [Fixed-Length]
  ──────────────────────────────────────────────────────────────────────────────────────────
    B       T  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────
    1     512  │    0.000019    0.007692    0.000008  │     0.7232      0.3457     2.09x
    1    1024  │    0.000022    0.008021    0.000010  │     0.7560      0.3233     2.34x
    1    4096  │    0.000019    0.008523    0.000008  │     1.6876      1.1215     1.50x
    1    8192  │    0.000021    0.010204    0.000009  │     3.2877      2.2564     1.46x
    1   16384  │    0.000019    0.008368    0.000008  │     6.6044      4.4438     1.49x
    2     512  │    0.000022    0.008021    0.000010  │     0.7270      0.3366     2.16x
    2    1024  │    0.000020    0.007353    0.000009  │     0.8686      0.6167     1.41x
    2    4096  │    0.000021    0.010204    0.000009  │     3.3357      2.2605     1.48x
    2    8192  │    0.000019    0.008368    0.000008  │     6.6871      4.4606     1.50x
    2   16384  │    0.000019    0.007353    0.000008  │    13.4401      8.7195     1.54x
  ──────────────────────────────────────────────────────────────────────────────────────────

  [Varlen]
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         Config  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
       uniform 10seqs T=4096 [409..415] avg=409  │    0.000019    0.011364    0.000008  │     1.6597      1.0504     1.58x
        random 10seqs T=4096 [24..1201] avg=409  │    0.000019    0.008523    0.000008  │     1.6547      0.9786     1.69x
       skewed 10seqs T=4096 [227..2053] avg=409  │    0.000019    0.008523    0.000008  │     1.6441      0.9521     1.73x
       uniform 20seqs T=4096 [204..220] avg=204  │    0.000019    0.008523    0.000008  │     1.7286      1.3447     1.29x
          random 20seqs T=4096 [5..787] avg=204  │    0.000019    0.011364    0.000008  │     1.6961      1.0298     1.65x
       skewed 20seqs T=4096 [107..2063] avg=204  │    0.000019    0.011364    0.000008  │     1.6574      1.0063     1.65x
       uniform 10seqs T=8192 [819..821] avg=819  │    0.000021    0.010204    0.000009  │     3.2471      1.8187     1.79x
        random 10seqs T=8192 [48..2401] avg=819  │    0.000021    0.010204    0.000009  │     3.2708      1.8295     1.79x
       skewed 10seqs T=8192 [455..4097] avg=819  │    0.000021    0.010204    0.000009  │     3.2701      1.8005     1.82x
       uniform 20seqs T=8192 [409..421] avg=409  │    0.000020    0.010204    0.000009  │     3.2918      2.1226     1.55x
         random 20seqs T=8192 [9..1574] avg=409  │    0.000020    0.010204    0.000009  │     3.2901      1.8700     1.76x
       skewed 20seqs T=8192 [215..4107] avg=409  │    0.000020    0.010256    0.000009  │     3.2880      1.8955     1.73x
   uniform 10seqs T=16384 [1638..1642] avg=1638  │    0.000019    0.008368    0.000008  │     6.4912      3.3560     1.93x
      random 10seqs T=16384 [95..4802] avg=1638  │    0.000019    0.008368    0.000008  │     6.5376      3.4580     1.89x
     skewed 10seqs T=16384 [910..8194] avg=1638  │    0.000019    0.008368    0.000008  │     6.5152      3.2913     1.98x
      uniform 20seqs T=16384 [819..823] avg=819  │    0.000019    0.008368    0.000008  │     6.4979      3.6403     1.78x
       random 20seqs T=16384 [19..3147] avg=819  │    0.000019    0.008368    0.000008  │     6.5160      3.4194     1.91x
      skewed 20seqs T=16384 [431..8195] avg=819  │    0.000019    0.008403    0.000008  │     6.4658      3.3754     1.92x
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================

Reviewer Notes

@yechenzhi yechenzhi marked this pull request as ready for review May 8, 2026 02:30
@yechenzhi yechenzhi changed the title Fix internal output state fix internal output_final_state wrapper issue May 8, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an "output_final_state" flag to the "kda_fwd_prefill" kernel, making the final state tensor allocation and storage optional across the C++, Python, and CUDA layers. The changes include logic to handle null pointers in the kernel and a new test case to verify the functionality. Feedback was provided regarding the robustness of the C++ API, suggesting that the "output_final_state" flag should be respected even when a buffer is explicitly provided to prevent unintended memory stores.

Comment thread csrc/api/kda_sm90.cu Outdated
@yechenzhi
Copy link
Copy Markdown
Contributor Author

yechenzhi commented May 8, 2026

I reran the benchmark after this change. The previous suspicious short-sequence result seems to have been benchmark noise.

For output_final_state=False, the B=1, T=512 case now looks normal again:

BEFORE: output_final_state=False, B=1, T=512: cuLA 0.2873 ms, speedup 2.51x
AFTER:  output_final_state=False, B=1, T=512: cuLA 0.2768 ms, speedup 2.55x
[Device] NVIDIA H800 PCIe  compute capability sm90  →  using cula.kda.hopper_fused_fwd.cula_kda_prefill

====================================================================================================
 Fixed-Length Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================

====================================================================================================
 Varlen Benchmark: cuLA fully-fused (sm90) vs FLA Triton
====================================================================================================


==============================================================================================================
                  BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)
                  cuLA sm90 fully-fused vs FLA Triton
                  H=64  D=128  dtype=bf16  safe_gate=True  has_init_state=False
                  Warmup=25  Iters=100
==============================================================================================================

  [Fixed-Length]
  ──────────────────────────────────────────────────────────────────────────────────────────
    B       T  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────
    1     512  │    0.000019    0.007692    0.000008  │     0.7052      0.2768     2.55x
    1    1024  │    0.000022    0.008021    0.000010  │     0.7017      0.3067     2.29x
    1    4096  │    0.000019    0.008523    0.000008  │     1.6621      1.1303     1.47x
    1    8192  │    0.000021    0.010204    0.000009  │     3.2954      2.2477     1.47x
    1   16384  │    0.000019    0.008368    0.000008  │     6.6039      4.4833     1.47x
    2     512  │    0.000022    0.008021    0.000010  │     0.7041      0.3353     2.10x
    2    1024  │    0.000020    0.007353    0.000009  │     0.8729      0.6163     1.42x
    2    4096  │    0.000021    0.010204    0.000009  │     3.3384      2.2709     1.47x
    2    8192  │    0.000019    0.008368    0.000008  │     6.7160      4.4961     1.49x
    2   16384  │    0.000019    0.007353    0.000008  │    13.4425      8.9255     1.51x
  ──────────────────────────────────────────────────────────────────────────────────────────

  [Varlen]
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         Config  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────
       uniform 10seqs T=4096 [409..415] avg=409  │    0.000019    0.011364    0.000008  │     1.6596      1.0484     1.58x
        random 10seqs T=4096 [24..1201] avg=409  │    0.000019    0.008523    0.000008  │     1.6573      0.9819     1.69x
       skewed 10seqs T=4096 [227..2053] avg=409  │    0.000019    0.008523    0.000008  │     1.6488      0.9384     1.76x
       uniform 20seqs T=4096 [204..220] avg=204  │    0.000019    0.008523    0.000008  │     1.7308      1.3167     1.31x
          random 20seqs T=4096 [5..787] avg=204  │    0.000019    0.011364    0.000008  │     1.6986      1.0294     1.65x
       skewed 20seqs T=4096 [107..2063] avg=204  │    0.000019    0.011364    0.000008  │     1.6610      1.0183     1.63x
       uniform 10seqs T=8192 [819..821] avg=819  │    0.000021    0.010204    0.000009  │     3.2418      1.8101     1.79x
        random 10seqs T=8192 [48..2401] avg=819  │    0.000021    0.010204    0.000009  │     3.2776      1.8335     1.79x
       skewed 10seqs T=8192 [455..4097] avg=819  │    0.000021    0.010204    0.000009  │     3.2636      1.7947     1.82x
       uniform 20seqs T=8192 [409..421] avg=409  │    0.000020    0.010204    0.000009  │     3.3044      2.1047     1.57x
         random 20seqs T=8192 [9..1574] avg=409  │    0.000020    0.010204    0.000009  │     3.2923      1.8747     1.76x
       skewed 20seqs T=8192 [215..4107] avg=409  │    0.000020    0.010256    0.000009  │     3.2990      1.8807     1.75x
   uniform 10seqs T=16384 [1638..1642] avg=1638  │    0.000019    0.008368    0.000008  │     6.5054      3.3705     1.93x
      random 10seqs T=16384 [95..4802] avg=1638  │    0.000019    0.008368    0.000008  │     6.5291      3.4617     1.89x
     skewed 10seqs T=16384 [910..8194] avg=1638  │    0.000019    0.008368    0.000008  │     6.4762      3.2935     1.97x
      uniform 20seqs T=16384 [819..823] avg=819  │    0.000019    0.008368    0.000008  │     6.5000      3.6399     1.79x
       random 20seqs T=16384 [19..3147] avg=819  │    0.000019    0.008368    0.000008  │     6.5626      3.4420     1.91x
      skewed 20seqs T=16384 [431..8195] avg=819  │    0.000019    0.008403    0.000008  │     6.4959      3.3694     1.93x
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================

@KevinZeng08
Copy link
Copy Markdown
Collaborator

Hi, it seems that #64 has a similar state output logic, are there any differences?

@yechenzhi
Copy link
Copy Markdown
Contributor Author

yechenzhi commented May 9, 2026

Hi, it seems that #64 has a similar state output logic, are there any differences?

From what I can see, #64 only partially handles the output-state path. In particular:

  1. It uses need_output_state_buffer, but I do not see a clear output_final_state propagation into the C++ API in that PR.
  2. It still constructs output_state by default with torch::zeros(...), and the C++ return type still appears to be std::tuple<torch::Tensor, torch::Tensor> instead of std::tuple<torch::Tensor, OptionalTensor>. So it may skip the kernel store, but not the allocation / C++-level optional return semantics.
  3. The output-state change is coupled with the GVA work and does not seem to have a focused regression test for output_final_state=False.

This PR is narrower: it explicitly passes output_final_state into C++, avoids constructing output_state when it is false, passes nullptr to the kernel, and skips kv_store().

@KevinZeng08
Copy link
Copy Markdown
Collaborator

KevinZeng08 commented May 9, 2026

Got it, thanks. I will review this PR.

Copy link
Copy Markdown
Collaborator

@KevinZeng08 KevinZeng08 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@KevinZeng08 KevinZeng08 requested a review from cherhh May 9, 2026 02:28
@KevinZeng08 KevinZeng08 changed the title fix internal output_final_state wrapper issue [KDA] fix internal output_final_state wrapper issue in SM90 May 9, 2026
@KevinZeng08 KevinZeng08 merged commit a6d9572 into inclusionAI:main May 9, 2026
@yechenzhi yechenzhi deleted the fix_internal_output_state branch May 9, 2026 04:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants