Skip to content

Conversation

@nvmbreughe
Copy link
Contributor

@nvmbreughe nvmbreughe commented Oct 31, 2025

📌 Description

This PR:

  • adds an optimized router gemm for problem sizes such as Deep Seek-V3. It is ported over from TRTLLM.
  • serves as an example on API naming for specialized ops on narrow support surfaces

From my measurements (num tokens = [1,2,4,8,16]), speedups were observed between 1.36 and 1.82x on B200.

Both positive and negative tests were added to test the behavior.

Breaking Change: Refactored gemm module structure

ACTION REQUIRED: Delete stale flashinfer/gemm.py file

The gemm.py file has been refactored into a package:

  • flashinfer/gemm.pyflashinfer/gemm/gemm_base.py

After pulling this change, run:

git clean -fd flashinfer/
# OR manually:
rm flashinfer/flashinfer/gemm.py

This is backward compatible - no import changes needed.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • High-performance DSv3 router GEMM (bf16 → float32) targeting tokens 1–16, 256 experts, 7168 hidden dim with optional serialized kernel launch.
  • Integration

    • Python-facing op with runtime shape/dtype/stride validation, backend loader, and registered custom-op entrypoint; public mm_M1_16_K7168_N256 exposed.
  • JIT / Packaging

    • Adds a JIT module generator and re-exports it for easy import.
  • Tests

    • Unit tests for positive paths and extensive validation/error cases.
  • Chores

    • Import-path adjustments and pre-test bytecode cache cleanup.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 31, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a BF16→FP32 DSv3 router GEMM: device kernel header, CUDA host launcher with templated/unrolled token dispatch and TVM FFI export, Python loader/validation and JIT spec, tests, and test-script bytecode-cleanup steps.

Changes

Cohort / File(s) Summary
CUDA host binding & launcher
csrc/dsv3_router_gemm.cu
New TVM FFI-exported dsv3_router_gemm_op; templated invokeRouterGemm, LoopUnroller unrolling tokens 1–16, dtype/layout/stride validation, optional programmatic stream serialization (PDL), explicit bf16 instantiations, and TVM FFI error/logging.
CUDA device header / kernel
include/flashinfer/gemm/dsv3_router_gemm.cuh
New device helpers and templated router_gemm_kernel: PTX FMA helper, bf16 uint4→float conversion, vectorized loads, per-token FMA accumulation, warp/shared reductions, and output writes (templated on block/VPT/tokens/experts/hidden).
Python integration & op
flashinfer/gemm/routergemm_dsv3.py
New module with shape/dtype/stride validation _mm_M1_16_K7168_N256_shape_checks, cached backend loader get_dsv3_router_gemm_module() that builds/loads/registers the FFI op, and public mm_M1_16_K7168_N256 wrapper delegating to the FFI op (with backend requirement and registration).
JIT spec & exports
flashinfer/jit/dsv3_optimizations.py, flashinfer/jit/__init__.py
New gen_dsv3_router_gemm_module() returning a JitSpec for dsv3_router_gemm.cu; re-exported from package init.
Public API exposure
flashinfer/dsv3_ops/__init__.py, flashinfer/gemm/__init__.py
Re-exported mm_M1_16_K7168_N256 into package API and added to __all__.
Tests (new)
tests/model_optimizations/test_dsv3_router_gemm.py
Positive numeric tests for tokens [1,2,3,5,8,13,16] (experts=256, hidden=7168) with optional launch_with_pdl; negative tests asserting ValueError for invalid shapes/dtypes/strides; runtime gated by SM100.
Import adjustments
flashinfer/gemm/gemm_base.py, tests/gemm/*
Adjusted several imports to parent-package (..) paths; updated tests to import some symbols from flashinfer.gemm.gemm_base.
Test/script cleanup
scripts/task_test_*.sh
Added pre-test cleanup steps to remove __pycache__ directories and .pyc files across multiple test scripts to avoid stale bytecode.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant User
    participant PyAPI as routergemm_dsv3.py
    participant Validate as _mm_M1_16_K7168_N256_shape_checks
    participant ModuleCache as get_dsv3_router_gemm_module
    participant JIT as gen_dsv3_router_gemm_module
    participant FFI as TVM FFI: dsv3_router_gemm_op
    participant CUDA as router_gemm_kernel

    User->>PyAPI: mm_M1_16_K7168_N256(mat_a,mat_b,out,launch_with_pdl)
    PyAPI->>Validate: check shapes, dtypes, strides, device
    alt validation fails
        Validate-->>PyAPI: raise ValueError
        PyAPI-->>User: error
    else
        PyAPI->>ModuleCache: obtain compiled module
        alt first load
            ModuleCache->>JIT: build JitSpec (dsv3_router_gemm.cu)
            JIT-->>ModuleCache: compiled module
            ModuleCache->>FFI: register TVM FFI op
        end
        PyAPI->>FFI: call dsv3_router_gemm_op(mat_a,mat_b,out,launch_with_pdl)
        FFI->>CUDA: select templated invoke via LoopUnroller (tokens 1..16)
        CUDA->>CUDA: vector loads, bf16→float, FMA, warp/shared reductions
        CUDA-->>FFI: write out (fp32)
        FFI-->>PyAPI: return (out mutated)
        PyAPI-->>User: done
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Areas needing extra attention:
    • include/flashinfer/gemm/dsv3_router_gemm.cuh — PTX FMA, bf16→float conversion, vectorized loads, reduction correctness and edge cases.
    • csrc/dsv3_router_gemm.cu — template dispatch, LoopUnroller correctness, explicit instantiations, PDL/stream serialization and TVM FFI error/logging.
    • flashinfer/gemm/routergemm_dsv3.py & tests — validation messages, decorator registration, hardware gating and test SM100 guard.
    • JIT spec paths and module registration/registration names.

Suggested reviewers

  • djmmoss
  • cyx-6
  • yongwww
  • aleozlx
  • bkryu

Poem

🐰
I hopped through kernels, swift and spry,
bf16 whispers turned to float on high.
Tokens marched in threads aligned,
Warps combined the sums I mined.
A crunchy carrot for every GPU sky!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 6.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[DSV3] Optimized Router Gemm' clearly and specifically describes the main change—adding an optimized router GEMM implementation for DSV3 problem sizes.
Description check ✅ Passed The PR description includes a detailed description of changes and reasoning. It covers what the PR does, mentions positive and negative tests, reports performance improvements, and clearly documents a breaking change requiring stale file deletion.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a highly optimized router GEMM operation, specifically tailored for Deep Seek-V3 model architectures. By leveraging a kernel ported from TRTLLM, it significantly boosts performance for common token counts, offering substantial speedups. The changes include a new CUDA kernel, its Python binding, and thorough testing to ensure reliability and efficiency.

Highlights

  • Performance Optimization: Introduces an optimized router GEMM kernel specifically for Deep Seek-V3 (DSV3) problem sizes, achieving 1.86x to 2.5x speedups for 1-16 tokens.
  • TRTLLM Port: The optimization is ported from TRTLLM, leveraging existing high-performance implementations for efficiency.
  • New CUDA Kernel: Adds a new CUDA kernel (dsv3_router_gemm.cuh and dsv3_router_gemm.cu) for the specialized GEMM operation, including custom FMA and bfloat16 conversion.
  • Python Integration: Provides a Python interface (routergemm.py) with comprehensive shape, stride, and data type checks for the new operation, integrating it with the JIT compilation system.
  • Comprehensive Testing: Includes both positive and negative tests to ensure the numerical correctness and robust input validation of the new dsv3_router_gemm_op.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized router GEMM kernel for Deep Seek-V3, ported from TRT-LLM, which shows significant performance improvements. The changes include the CUDA kernel implementation, Python bindings, and corresponding tests. My review focuses on improving code clarity, fixing a bug related to a kernel launch parameter, and removing dead or redundant code. Overall, the changes are well-structured and the addition is valuable.

Comment on lines 107 to 116
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input,
__nv_bfloat16 const* weights, cudaStream_t stream) {
if (num_tokens == kEnd) {
invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights,
stream);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This specialization of LoopUnroller::unroll also needs to be updated to accept and forward the launch_with_pdl parameter to invokeRouterGemm.

struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
  static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input,
                     __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) {
    if (num_tokens == kEnd) {
      invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights,
                                                                     stream, launch_with_pdl);
    } else {
      throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
    }
  }
};

Comment on lines +126 to +125
constexpr int kNumExperts = 256;
constexpr int kHiddenDim = 7168;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The constexpr variables kNumExperts and kHiddenDim shadow the function-scope variables with similar names initialized from the input tensors. This can be confusing. Consider renaming them to clarify their purpose as compile-time constants for the specialized kernel. You'll also need to update their usage on lines 136 and 143.

  constexpr int kExpectedNumExperts = 256;
  constexpr int kExpectedHiddenDim = 7168;

Comment on lines 10 to 62
def _dvs3_router_gemm_shape_checks(mat_a, mat_b, out, launch_with_pdl, bias):
# Dimension checks
if mat_a.dim() != 2:
raise ValueError("mat_a must be a 2D tensor")
if mat_b.dim() != 2:
raise ValueError("mat_b must be a 2D tensor")
if out.dim() != 2:
raise ValueError("out must be a 2D tensor")
if bias is not None:
raise ValueError("bias is not supported yet")

# Stride checks (check these before dimension checks to give better error messages)
if mat_a.stride(1) != 1:
raise ValueError("mat_a must be row-major")
if out.stride(1) != 1:
raise ValueError("out must be row-major")
if mat_b.stride(0) != 1:
raise ValueError("mat_b must be column-major")

if mat_a.shape[1] != mat_b.shape[0]:
raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]")
if out.shape[0] != mat_a.shape[0]:
raise ValueError("out.shape[0] must be equal to mat_a.shape[0]")
if out.shape[1] != mat_b.shape[1]:
raise ValueError("out.shape[1] must be equal to mat_b.shape[1]")

# Problem size checks
expected_hidden_dim = 7168
expected_num_experts = 256
min_tokens = 1
max_tokens = 16
if mat_a.shape[0] < min_tokens or mat_a.shape[0] > max_tokens:
raise ValueError(
f"mat_a.shape[0] (num_tokens) must be between {min_tokens} and {max_tokens}"
)
if mat_a.shape[1] != expected_hidden_dim:
raise ValueError(
f"mat_a.shape[1] (hidden_dim) must be equal to {expected_hidden_dim}"
)
if mat_b.shape[1] != expected_num_experts:
raise ValueError(
f"mat_b.shape[1] (num_experts) must be equal to {expected_num_experts}"
)

# Data type checks
if mat_a.dtype != torch.bfloat16:
raise ValueError("mat_a must be a bfloat16 tensor")
if mat_b.dtype != torch.bfloat16:
raise ValueError("mat_b must be a bfloat16 tensor")
if out.dtype != torch.float32:
raise ValueError("out must be a float32 tensor")

return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This helper function can be improved for clarity and correctness:

  • There is a typo in the function name: _dvs3_router_gemm_shape_checks should be _dsv3_router_gemm_shape_checks.
  • The launch_with_pdl parameter is unused and can be removed from the signature.
  • The function returns True, but this value is never used. The function should not have a return statement.
  • The comment on line 21 is misleading as stride checks are performed after dimension checks.

Please also update the call to this function on line 97.

def _dsv3_router_gemm_shape_checks(mat_a, mat_b, out, bias):
    # Dimension checks
    if mat_a.dim() != 2:
        raise ValueError("mat_a must be a 2D tensor")
    if mat_b.dim() != 2:
        raise ValueError("mat_b must be a 2D tensor")
    if out.dim() != 2:
        raise ValueError("out must be a 2D tensor")
    if bias is not None:
        raise ValueError("bias is not supported yet")

    # Stride checks
    if mat_a.stride(1) != 1:
        raise ValueError("mat_a must be row-major")
    if out.stride(1) != 1:
        raise ValueError("out must be row-major")
    if mat_b.stride(0) != 1:
        raise ValueError("mat_b must be column-major")

    if mat_a.shape[1] != mat_b.shape[0]:
        raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]")
    if out.shape[0] != mat_a.shape[0]:
        raise ValueError("out.shape[0] must be equal to mat_a.shape[0]")
    if out.shape[1] != mat_b.shape[1]:
        raise ValueError("out.shape[1] must be equal to mat_b.shape[1]")

    # Problem size checks
    expected_hidden_dim = 7168
    expected_num_experts = 256
    min_tokens = 1
    max_tokens = 16
    if not (min_tokens <= mat_a.shape[0] <= max_tokens):
        raise ValueError(
            f"mat_a.shape[0] (num_tokens) must be between {min_tokens} and {max_tokens}"
        )
    if mat_a.shape[1] != expected_hidden_dim:
        raise ValueError(
            f"mat_a.shape[1] (hidden_dim) must be equal to {expected_hidden_dim}"
        )
    if mat_b.shape[1] != expected_num_experts:
        raise ValueError(
            f"mat_b.shape[1] (num_experts) must be equal to {expected_num_experts}"
        )

    # Data type checks
    if mat_a.dtype != torch.bfloat16:
        raise ValueError("mat_a must be a bfloat16 tensor")
    if mat_b.dtype != torch.bfloat16:
        raise ValueError("mat_b must be a bfloat16 tensor")
    if out.dtype != torch.float32:
        raise ValueError("out must be a float32 tensor")

Comment on lines +23 to +29
__device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) {
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n"
: "=l"(reinterpret_cast<uint64_t&>(d))
: "l"(reinterpret_cast<uint64_t const&>(a)),
"l"(reinterpret_cast<uint64_t const&>(b)),
"l"(reinterpret_cast<uint64_t const&>(c)));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This fma function is defined but appears to be unused in this file. It should be removed to avoid dead code.

Comment on lines +107 to +109
int const warpSize = 32;
int const warpId = tid / warpSize;
int const laneId = tid % warpSize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The warpSize constant is redundant as kWarpSize is already defined on line 48. Using kWarpSize consistently improves readability.

  int const warpId = tid / kWarpSize;
  int const laneId = tid % kWarpSize;

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f9cd034 and 550a0f6.

📒 Files selected for processing (6)
  • csrc/dsv3_router_gemm.cu (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/dsv3_optimizations.py (1 hunks)
  • flashinfer/model_optimizations/dsv3/routergemm.py (1 hunks)
  • include/flashinfer/gemm/dsv3_router_gemm.cuh (1 hunks)
  • tests/model_optimizations/test_dsv3_router_gemm.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
flashinfer/jit/__init__.py (1)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_router_gemm_module (5-11)
tests/model_optimizations/test_dsv3_router_gemm.py (2)
csrc/dsv3_router_gemm.cu (6)
  • num_tokens (94-103)
  • num_tokens (94-95)
  • num_tokens (108-116)
  • num_tokens (108-109)
  • dsv3_router_gemm_op (119-150)
  • dsv3_router_gemm_op (119-120)
flashinfer/model_optimizations/dsv3/routergemm.py (2)
  • dsv3_router_gemm_op (73-80)
  • dsv3_router_gemm_op (90-100)
flashinfer/jit/dsv3_optimizations.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
csrc/dsv3_router_gemm.cu (2)
csrc/tvm_ffi_utils.h (2)
  • get_stream (272-274)
  • encode_dlpack_dtype (29-31)
flashinfer/model_optimizations/dsv3/routergemm.py (2)
  • dsv3_router_gemm_op (73-80)
  • dsv3_router_gemm_op (90-100)
flashinfer/model_optimizations/dsv3/routergemm.py (3)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_router_gemm_module (5-11)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/dsv3_router_gemm.cu (2)
  • dsv3_router_gemm_op (119-150)
  • dsv3_router_gemm_op (119-120)
🪛 Ruff (0.14.2)
flashinfer/model_optimizations/dsv3/routergemm.py

10-10: Unused function argument: launch_with_pdl

(ARG001)


13-13: Avoid specifying long messages outside the exception class

(TRY003)


15-15: Avoid specifying long messages outside the exception class

(TRY003)


17-17: Avoid specifying long messages outside the exception class

(TRY003)


19-19: Avoid specifying long messages outside the exception class

(TRY003)


23-23: Avoid specifying long messages outside the exception class

(TRY003)


25-25: Avoid specifying long messages outside the exception class

(TRY003)


27-27: Avoid specifying long messages outside the exception class

(TRY003)


30-30: Avoid specifying long messages outside the exception class

(TRY003)


32-32: Avoid specifying long messages outside the exception class

(TRY003)


34-34: Avoid specifying long messages outside the exception class

(TRY003)


42-44: Avoid specifying long messages outside the exception class

(TRY003)


46-48: Avoid specifying long messages outside the exception class

(TRY003)


50-52: Avoid specifying long messages outside the exception class

(TRY003)


56-56: Avoid specifying long messages outside the exception class

(TRY003)


58-58: Avoid specifying long messages outside the exception class

(TRY003)


60-60: Avoid specifying long messages outside the exception class

(TRY003)

@nvmbreughe nvmbreughe marked this pull request as draft November 3, 2025 19:25
@nvmbreughe nvmbreughe force-pushed the mbreughe/dsv3_router branch from 550a0f6 to 75fcd9a Compare November 4, 2025 16:40
@nvmbreughe nvmbreughe marked this pull request as ready for review November 4, 2025 16:41
@nvmbreughe
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !109 has been created, and the CI pipeline #37879756 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 550a0f6 and 75fcd9a.

📒 Files selected for processing (8)
  • csrc/dsv3_router_gemm.cu (1 hunks)
  • flashinfer/dsv3_ops/__init__.py (1 hunks)
  • flashinfer/gemm/gemm_base.py (3 hunks)
  • flashinfer/gemm/routergemm_dsv3.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/dsv3_optimizations.py (1 hunks)
  • include/flashinfer/gemm/dsv3_router_gemm.cuh (1 hunks)
  • tests/model_optimizations/test_dsv3_router_gemm.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/dsv3_router_gemm.cu
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer/jit/__init__.py (1)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_router_gemm_module (5-11)
flashinfer/dsv3_ops/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
  • routergemm_dsv3_hidden_7168_experts_256_tokens_lt16 (73-80)
  • routergemm_dsv3_hidden_7168_experts_256_tokens_lt16 (90-100)
tests/model_optimizations/test_dsv3_router_gemm.py (2)
flashinfer/gemm/routergemm_dsv3.py (2)
  • routergemm_dsv3_hidden_7168_experts_256_tokens_lt16 (73-80)
  • routergemm_dsv3_hidden_7168_experts_256_tokens_lt16 (90-100)
csrc/dsv3_router_gemm.cu (4)
  • num_tokens (94-103)
  • num_tokens (94-95)
  • num_tokens (108-116)
  • num_tokens (108-109)
flashinfer/gemm/routergemm_dsv3.py (2)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_router_gemm_module (5-11)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
flashinfer/gemm/gemm_base.py (4)
flashinfer/autotuner.py (6)
  • AutoTuner (335-784)
  • ConstraintSpec (86-97)
  • DynamicTensorSpec (41-82)
  • OptimizationProfile (168-183)
  • TunableRunner (194-247)
  • TuningConfig (101-141)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
flashinfer/jit/gemm/core.py (9)
  • gen_gemm_sm90_module (468-507)
  • gen_gemm_module (37-46)
  • gen_gemm_sm100_module (193-269)
  • gen_gemm_sm120_module (272-354)
  • gen_gemm_sm120_module_cutlass_fp4 (97-140)
  • gen_gemm_sm100_module_cutlass_fp4 (49-94)
  • gen_gemm_sm100_module_cutlass_fp8 (143-190)
  • gen_trtllm_gen_gemm_module (357-390)
  • gen_tgv_gemm_sm10x_module (393-465)
flashinfer/jit/cubin_loader.py (1)
  • setup_cubin_loader (222-242)
flashinfer/jit/dsv3_optimizations.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
🪛 Ruff (0.14.3)
flashinfer/gemm/routergemm_dsv3.py

10-10: Unused function argument: launch_with_pdl

(ARG001)


13-13: Avoid specifying long messages outside the exception class

(TRY003)


15-15: Avoid specifying long messages outside the exception class

(TRY003)


17-17: Avoid specifying long messages outside the exception class

(TRY003)


19-19: Avoid specifying long messages outside the exception class

(TRY003)


23-23: Avoid specifying long messages outside the exception class

(TRY003)


25-25: Avoid specifying long messages outside the exception class

(TRY003)


27-27: Avoid specifying long messages outside the exception class

(TRY003)


30-30: Avoid specifying long messages outside the exception class

(TRY003)


32-32: Avoid specifying long messages outside the exception class

(TRY003)


34-34: Avoid specifying long messages outside the exception class

(TRY003)


42-44: Avoid specifying long messages outside the exception class

(TRY003)


46-48: Avoid specifying long messages outside the exception class

(TRY003)


50-52: Avoid specifying long messages outside the exception class

(TRY003)


56-56: Avoid specifying long messages outside the exception class

(TRY003)


58-58: Avoid specifying long messages outside the exception class

(TRY003)


60-60: Avoid specifying long messages outside the exception class

(TRY003)

Comment on lines +21 to +54
# Stride checks (check these before dimension checks to give better error messages)
if mat_a.stride(1) != 1:
raise ValueError("mat_a must be row-major")
if out.stride(1) != 1:
raise ValueError("out must be row-major")
if mat_b.stride(0) != 1:
raise ValueError("mat_b must be column-major")

if mat_a.shape[1] != mat_b.shape[0]:
raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]")
if out.shape[0] != mat_a.shape[0]:
raise ValueError("out.shape[0] must be equal to mat_a.shape[0]")
if out.shape[1] != mat_b.shape[1]:
raise ValueError("out.shape[1] must be equal to mat_b.shape[1]")

# Problem size checks
expected_hidden_dim = 7168
expected_num_experts = 256
min_tokens = 1
max_tokens = 16
if mat_a.shape[0] < min_tokens or mat_a.shape[0] > max_tokens:
raise ValueError(
f"mat_a.shape[0] (num_tokens) must be between {min_tokens} and {max_tokens}"
)
if mat_a.shape[1] != expected_hidden_dim:
raise ValueError(
f"mat_a.shape[1] (hidden_dim) must be equal to {expected_hidden_dim}"
)
if mat_b.shape[1] != expected_num_experts:
raise ValueError(
f"mat_b.shape[1] (num_experts) must be equal to {expected_num_experts}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against strided views that break the kernel layout

The CUDA kernel indexes A/B/out with hard-coded strides (m_idx * kHiddenDim, n_idx * kHiddenDim, m * kNumExperts). If a caller passes a legal 2D bf16 view that shares storage—e.g. mat_a = base[:, :7168] or out = torch.empty(num_tokens, num_experts + 8)[:, :num_experts]—all current checks still pass (stride(1) == 1 / stride(0) == 1). Once inside the kernel we then march with the wrong stride, reading or writing across rows/columns and returning garbage. Please block these cases explicitly.

Apply a stride check for the remaining axes before launching:

@@
-    if mat_a.shape[1] != mat_b.shape[0]:
+    if mat_a.shape[1] != mat_b.shape[0]:
         raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]")
@@
     if out.shape[1] != mat_b.shape[1]:
         raise ValueError("out.shape[1] must be equal to mat_b.shape[1]")
+    if mat_a.stride(0) != expected_hidden_dim:
+        raise ValueError("mat_a must be contiguous row-major with stride(0) == hidden_dim")
+    if mat_b.stride(1) != expected_hidden_dim:
+        raise ValueError("mat_b must be column-major with stride(1) == hidden_dim")
+    if out.stride(0) != expected_num_experts:
+        raise ValueError("out must be contiguous row-major with stride(0) == num_experts")

Without these guards we silently corrupt data on perfectly legal PyTorch views, which is a correctness bug.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.3)

23-23: Avoid specifying long messages outside the exception class

(TRY003)


25-25: Avoid specifying long messages outside the exception class

(TRY003)


27-27: Avoid specifying long messages outside the exception class

(TRY003)


30-30: Avoid specifying long messages outside the exception class

(TRY003)


32-32: Avoid specifying long messages outside the exception class

(TRY003)


34-34: Avoid specifying long messages outside the exception class

(TRY003)


42-44: Avoid specifying long messages outside the exception class

(TRY003)


46-48: Avoid specifying long messages outside the exception class

(TRY003)


50-52: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer/gemm/routergemm_dsv3.py around lines 21 to 52, the kernel assumes
fixed hard-coded row/column strides (kHiddenDim and kNumExperts) but the current
checks only verify unit strides and allow strided views that break layout; add
explicit stride equality checks: require mat_a.stride(0) == expected_hidden_dim,
mat_b.stride(1) == expected_hidden_dim, and out.stride(0) ==
expected_num_experts (in addition to the existing unit-stride checks) and raise
ValueError with a clear message if any of these mismatch so views that share
storage but have different layout are rejected before launch.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37879756: 1/17 passed

@nvmbreughe
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !109 has been updated with latest changes, and the CI pipeline #37883369 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #37883369: canceled

@nvmbreughe nvmbreughe force-pushed the mbreughe/dsv3_router branch from 6ea619b to f1704b9 Compare November 4, 2025 22:32
@nvmbreughe
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !109 has been updated with latest changes, and the CI pipeline #37897792 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (5)
include/flashinfer/gemm/dsv3_router_gemm.cuh (4)

23-29: Remove unused function or explain its purpose.

This function was previously flagged as unused. If it's intended for future use or external callers, please add a comment explaining its purpose. Otherwise, remove it to avoid dead code.


51-52: Protect against dropping tail K tiles.

This issue was previously flagged. Integer division silently drops any remaining elements when kHiddenDim is not divisible by VPT * kBlockSize, causing incorrect GEMM results.


64-64: Remove commented-out duplicate declaration.

Line 64 is a duplicate of line 65 and should be removed as previously noted.


107-109: Use existing kWarpSize constant consistently.

The warpSize variable is redundant since kWarpSize is already defined at line 48. Using the constant consistently improves readability, as previously noted.

flashinfer/gemm/routergemm_dsv3.py (1)

16-66: Address previously flagged issues in shape validation.

Multiple issues were previously identified in this function:

  1. Typo in function name: _dvs3 should be _dsv3 (note the order of 'v' and 's')
  2. Unused parameter: launch_with_pdl is never used in the function body
  3. Unnecessary return: The function returns True but this value is never checked
  4. Misleading comment: Line 25 says "check these before dimension checks" but dimension checks (lines 18-23) actually come first
  5. Missing stride checks: The kernel assumes fixed strides but only one stride per tensor is validated, allowing strided views that would corrupt results
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6ea619b and f1704b9.

📒 Files selected for processing (8)
  • csrc/dsv3_router_gemm.cu (1 hunks)
  • flashinfer/dsv3_ops/__init__.py (1 hunks)
  • flashinfer/gemm/gemm_base.py (3 hunks)
  • flashinfer/gemm/routergemm_dsv3.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/dsv3_optimizations.py (1 hunks)
  • include/flashinfer/gemm/dsv3_router_gemm.cuh (1 hunks)
  • tests/model_optimizations/test_dsv3_router_gemm.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/gemm/gemm_base.py
🔇 Additional comments (1)
flashinfer/gemm/routergemm_dsv3.py (1)

92-107: LGTM with noted dependencies.

The public API function correctly validates compute capability and delegates to the backend module. Note that effectiveness depends on the shape validation function addressing the previously flagged stride check issues.

Comment on lines +133 to +145
if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts &&
hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
encode_dlpack_dtype(out_dtype_) == float32_code) {
use_custom_kernel = true;
}

if (use_custom_kernel) {
LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll(
num_tokens, reinterpret_cast<float*>(out.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl);
} else {
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input tensor size";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add missing mat_b validation before launching the kernel

dsv3_router_gemm_op currently accepts any dtype/leading dimension for mat_b, yet it unconditionally reinterprets the storage as __nv_bfloat16 and assumes exactly hidden_dim entries per expert. Supplying a float32 weight matrix or a column-major tensor with fewer than 7168 rows will silently corrupt results and can read past the allocated buffer. Please reject these configurations up front.

@@
-  auto const out_dtype_ = out.dtype();
-  auto const data_type = mat_a.dtype();
+  auto const out_dtype_ = out.dtype();
+  auto const data_type = mat_a.dtype();
+  auto const mat_b_dtype = mat_b.dtype();
@@
-  TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+  TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+  TVM_FFI_ICHECK(mat_b.sizes()[0] == hidden_dim)
+      << "mat_b first dimension must equal hidden_dim";
+  TVM_FFI_ICHECK(encode_dlpack_dtype(mat_b_dtype) == bfloat16_code)
+      << "mat_b must be bfloat16";
@@
-      hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
-      encode_dlpack_dtype(out_dtype_) == float32_code) {
+      hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
+      encode_dlpack_dtype(out_dtype_) == float32_code) {

Comment on lines +32 to +40
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) {
__nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));

#pragma unroll
for (int i = 0; i < VPT; i++) {
dst[i] = __bfloat162float(bf16_ptr[i]);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add compile-time check for VPT parameter.

The function assumes VPT == 8 since a uint4 contains exactly 8 bfloat16 values (16 bytes ÷ 2 bytes/bf16). If called with a different VPT, the loop will either under-read or overflow the buffer.

Apply this diff to enforce the invariant:

 template <int VPT>
 __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) {
+  static_assert(VPT == 8, "bf16_uint4_to_float8 requires VPT == 8 (uint4 holds exactly 8 bf16 values)");
   __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));
🤖 Prompt for AI Agents
In include/flashinfer/gemm/dsv3_router_gemm.cuh around lines 32 to 40, the
template assumes VPT equals 8 (uint4 holds eight bfloat16s) but doesn't enforce
it; add a compile-time check such as a static_assert that VPT == 8 with a clear
message so any misuse fails to compile, ensuring the loop neither under-reads
nor overflows the uint4 buffer.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #37897792: 1/17 passed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
scripts/task_test_nightly_build.sh (1)

8-20: Optional refactor: Consolidate duplicated cache cleanup logic.

The identical 3-line cache cleanup sequence is now replicated across all 5 test scripts (task_test_blackwell_kernels.sh, task_test_single_node_comm_kernels.sh, task_test_jit_cache_package_build_import.sh, task_test_multi_node_comm_kernels.sh, task_test_nightly_build.sh). Consider extracting to a common shell function or separate cleanup utility to reduce duplication and simplify future maintenance.

# Example: scripts/lib/cleanup_python_cache.sh
cleanup_python_cache() {
    echo "Cleaning Python bytecode cache..."
    find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
    find . -type f -name '*.pyc' -delete 2>/dev/null || true
    echo "Cache cleaned."
    echo ""
}

Then source and invoke from each test script:

source scripts/lib/cleanup_python_cache.sh
cleanup_python_cache
flashinfer/gemm/__init__.py (1)

21-35: Consider sorting __all__ for consistency.

Static analysis suggests applying isort-style sorting to the __all__ list for better maintainability and consistency with project conventions.

tests/gemm/test_tgv_gemm.py (1)

9-9: Consider exposing _match_sm_version as a public utility or refactoring test.

The import path update is correct. However, tests importing a private symbol (leading underscore) suggests it may deserve public status. If tests consistently need to check SM version compatibility, consider either making this a public utility function or refactoring tests to rely on exception handling from public APIs instead.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f1704b9 and 20421aa.

📒 Files selected for processing (8)
  • flashinfer/gemm/__init__.py (1 hunks)
  • scripts/task_test_blackwell_kernels.sh (1 hunks)
  • scripts/task_test_jit_cache_package_build_import.sh (1 hunks)
  • scripts/task_test_multi_node_comm_kernels.sh (1 hunks)
  • scripts/task_test_nightly_build.sh (1 hunks)
  • scripts/task_test_single_node_comm_kernels.sh (1 hunks)
  • tests/gemm/test_mm_fp4.py (1 hunks)
  • tests/gemm/test_tgv_gemm.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (9)
  • SegmentGEMMWrapper (845-1098)
  • mm_fp4 (1862-2013)
  • mm_fp8 (1586-1693)
  • tgv_gemm_sm100 (583-654)
  • batch_deepgemm_fp8_nt_groupwise (2905-3037)
  • group_deepgemm_fp8_nt_groupwise (2775-2902)
  • gemm_fp8_nt_blockscaled (2418-2443)
  • gemm_fp8_nt_groupwise (2110-2273)
  • group_gemm_fp8_nt_groupwise (2446-2605)
tests/gemm/test_tgv_gemm.py (1)
flashinfer/gemm/gemm_base.py (1)
  • _match_sm_version (90-93)
🪛 Ruff (0.14.3)
flashinfer/gemm/__init__.py

21-35: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
scripts/task_test_blackwell_kernels.sh (1)

9-14: Cache cleanup placement is appropriate for pre-test import freshness.

The bytecode cache cleanup is correctly positioned before pytest execution and includes proper error suppression. This addresses stale imports from the module refactoring (gemm.py → gemm/gemm_base.py) mentioned in the PR description.

scripts/task_test_single_node_comm_kernels.sh (1)

8-13: Pre-install cleanup ensures stale bytecode doesn't interfere with fresh imports.

Cache cleanup positioned correctly before pip install -e ., preventing stale .pyc files from old module structure from shadowing new imports.

scripts/task_test_jit_cache_package_build_import.sh (1)

31-35: Early cleanup ensures detection and installation workflows start with clean Python imports.

Bytecode cache removal precedes Python detection logic (line 39), ensuring CUDA architecture detection and subsequent installations operate on fresh imports without stale module references.

scripts/task_test_multi_node_comm_kernels.sh (1)

8-13: Pre-install cleanup maintains consistency with other test scripts.

Cache cleanup positioned before pip install -e ., aligned with the standardized approach across the test suite for handling stale imports from module refactoring.

scripts/task_test_nightly_build.sh (1)

15-20: Cache cleanup placed early in nightly workflow for comprehensive environment reset.

Bytecode cleanup positioned at script entry before GPU diagnostics and multi-phase installations, ensuring no stale imports affect the full nightly build and test cycle.

tests/gemm/test_mm_fp4.py (1)

12-12: Import path change is correct but compensates for missing re-export.

The import path was updated to bypass the package-level __init__.py because CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR isn't re-exported there (despite being in __all__). Once the missing import is added to flashinfer/gemm/__init__.py, this could revert to the cleaner package-level import: from flashinfer.gemm import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR.

@nvmbreughe nvmbreughe force-pushed the mbreughe/dsv3_router branch from 20421aa to a08e0a8 Compare November 5, 2025 22:24
@nvmbreughe
Copy link
Contributor Author

/bot run

@nvmbreughe nvmbreughe force-pushed the mbreughe/dsv3_router branch from 3a88a7d to 837d76f Compare November 7, 2025 17:42
@nvmbreughe nvmbreughe enabled auto-merge (squash) November 7, 2025 17:43
@nvmbreughe nvmbreughe disabled auto-merge November 7, 2025 17:44
@nvmbreughe nvmbreughe enabled auto-merge (squash) November 7, 2025 17:44
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
csrc/dsv3_router_gemm.cu (1)

123-140: Reject non-bfloat16 mat_b before launch.

Line 141 reinterprets mat_b.data_ptr() as __nv_bfloat16 const*, but we never verify that mat_b actually carries bf16 elements. A caller can legally pass a float32 or fp8 column-major tensor (the stride checks still pass), and we would then reinterpret 4‑byte scalars as bf16 pairs, silently corrupting the GEMM inputs. Please gate the custom‑kernel path on mat_b.dtype() and emit an explicit check before the cast.

-  auto const out_dtype_ = out.dtype();
-  auto const data_type = mat_a.dtype();
+  auto const out_dtype_ = out.dtype();
+  auto const data_type = mat_a.dtype();
+  auto const mat_b_dtype = mat_b.dtype();
@@
-  TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+  TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+  TVM_FFI_ICHECK(encode_dlpack_dtype(mat_b_dtype) == bfloat16_code)
+      << "mat_b must be bfloat16";
@@
-  if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts &&
-      hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
-      encode_dlpack_dtype(out_dtype_) == float32_code) {
+  if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts &&
+      hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
+      encode_dlpack_dtype(mat_b_dtype) == bfloat16_code &&
+      encode_dlpack_dtype(out_dtype_) == float32_code) {
flashinfer/gemm/routergemm_dsv3.py (1)

24-63: Reject strided views that break the kernel’s layout assumptions.

The CUDA kernel uses fixed pitches (hidden_dim / num_experts) when it walks A, B, and out. Today we only check the unit inner stride, so a caller can pass mat_a = base[:, :7168], mat_b = base[:, :256], or out = base[:, :256]—all valid PyTorch views with stride(0) > expected. Those configurations sail through this validator and then the kernel reads/writes with the wrong pitch. Please enforce the leading strides explicitly.

     if mat_b.stride(0) != 1:
         raise ValueError("mat_b must be column-major")
 
@@
     if out.shape[1] != mat_b.shape[1]:
         raise ValueError("out.shape[1] must be equal to mat_b.shape[1]")
+    if mat_a.stride(0) != expected_hidden_dim:
+        raise ValueError(
+            "mat_a must be contiguous row-major with stride(0) == hidden_dim"
+        )
+    if mat_b.stride(1) != expected_hidden_dim:
+        raise ValueError(
+            "mat_b must be column-major with stride(1) == hidden_dim"
+        )
+    if out.stride(0) != expected_num_experts:
+        raise ValueError(
+            "out must be contiguous row-major with stride(0) == num_experts"
+        )
🧹 Nitpick comments (2)
flashinfer/gemm/__init__.py (1)

1-34: LGTM! Public API correctly established for GEMM module.

All symbols in __all__ are properly imported and re-exported. The new mm_M1_16_K7168_N256 is correctly integrated from the DSv3 router GEMM implementation.

The static analysis tool suggests sorting __all__ alphabetically for consistency. While this is purely stylistic, it can improve maintainability:

 __all__ = [
-    "SegmentGEMMWrapper",
-    "bmm_fp8",
-    "mm_fp4",
-    "mm_fp8",
-    "tgv_gemm_sm100",
-    "group_gemm_mxfp4_nt_groupwise",
     "batch_deepgemm_fp8_nt_groupwise",
-    "group_deepgemm_fp8_nt_groupwise",
+    "bmm_fp8",
     "gemm_fp8_nt_blockscaled",
     "gemm_fp8_nt_groupwise",
+    "group_deepgemm_fp8_nt_groupwise",
+    "group_gemm_mxfp4_nt_groupwise",
     "group_gemm_fp8_nt_groupwise",
     "mm_M1_16_K7168_N256",
+    "mm_fp4",
+    "mm_fp8",
+    "SegmentGEMMWrapper",
+    "tgv_gemm_sm100",
 ]
tests/model_optimizations/test_dsv3_router_gemm.py (1)

8-28: LGTM! Comprehensive positive test coverage.

The test properly validates the DSv3 router GEMM operation across various token counts (1-16) and PDL modes, with appropriate SM100-only gating and numerical accuracy checks (cosine similarity > 0.99).

Line 23 uses torch.randn for the output tensor, which pre-fills it with random values. While the operation should overwrite these, using torch.empty would be more conventional and slightly more efficient:

-    out = torch.randn(num_tokens, num_experts, device="cuda", dtype=torch.float32)
+    out = torch.empty(num_tokens, num_experts, device="cuda", dtype=torch.float32)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6975d22 and 837d76f.

📒 Files selected for processing (17)
  • csrc/dsv3_router_gemm.cu (1 hunks)
  • flashinfer/dsv3_ops/__init__.py (1 hunks)
  • flashinfer/gemm/__init__.py (1 hunks)
  • flashinfer/gemm/gemm_base.py (6 hunks)
  • flashinfer/gemm/routergemm_dsv3.py (1 hunks)
  • flashinfer/jit/__init__.py (1 hunks)
  • flashinfer/jit/dsv3_optimizations.py (1 hunks)
  • include/flashinfer/gemm/dsv3_router_gemm.cuh (1 hunks)
  • scripts/task_test_blackwell_kernels.sh (1 hunks)
  • scripts/task_test_jit_cache_package_build_import.sh (1 hunks)
  • scripts/task_test_multi_node_comm_kernels.sh (1 hunks)
  • scripts/task_test_nightly_build.sh (1 hunks)
  • scripts/task_test_single_node_comm_kernels.sh (1 hunks)
  • tests/gemm/test_group_gemm.py (2 hunks)
  • tests/gemm/test_mm_fp4.py (1 hunks)
  • tests/gemm/test_tgv_gemm.py (1 hunks)
  • tests/model_optimizations/test_dsv3_router_gemm.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (6)
  • scripts/task_test_blackwell_kernels.sh
  • flashinfer/jit/init.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/dsv3_optimizations.py
  • scripts/task_test_multi_node_comm_kernels.sh
  • tests/gemm/test_group_gemm.py
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer/dsv3_ops/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
  • mm_M1_16_K7168_N256 (75-81)
  • mm_M1_16_K7168_N256 (89-134)
csrc/dsv3_router_gemm.cu (1)
csrc/tvm_ffi_utils.h (2)
  • get_stream (272-274)
  • encode_dlpack_dtype (29-31)
tests/model_optimizations/test_dsv3_router_gemm.py (3)
flashinfer/gemm/routergemm_dsv3.py (2)
  • mm_M1_16_K7168_N256 (75-81)
  • mm_M1_16_K7168_N256 (89-134)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
csrc/dsv3_router_gemm.cu (4)
  • num_tokens (93-102)
  • num_tokens (93-94)
  • num_tokens (107-115)
  • num_tokens (107-108)
flashinfer/gemm/routergemm_dsv3.py (4)
flashinfer/jit/dsv3_optimizations.py (1)
  • gen_dsv3_router_gemm_module (5-11)
flashinfer/utils.py (2)
  • supported_compute_capability (773-853)
  • backend_requirement (856-1131)
flashinfer/jit/core.py (1)
  • build_and_load (300-312)
csrc/dsv3_router_gemm.cu (2)
  • dsv3_router_gemm_op (118-147)
  • dsv3_router_gemm_op (118-118)
tests/gemm/test_tgv_gemm.py (1)
flashinfer/gemm/gemm_base.py (1)
  • _match_sm_version (90-93)
flashinfer/gemm/__init__.py (2)
flashinfer/gemm/gemm_base.py (9)
  • SegmentGEMMWrapper (829-1082)
  • mm_fp4 (1846-1997)
  • mm_fp8 (1570-1677)
  • tgv_gemm_sm100 (567-638)
  • batch_deepgemm_fp8_nt_groupwise (2981-3113)
  • group_deepgemm_fp8_nt_groupwise (2851-2978)
  • gemm_fp8_nt_blockscaled (2494-2519)
  • gemm_fp8_nt_groupwise (2186-2349)
  • group_gemm_fp8_nt_groupwise (2522-2681)
flashinfer/gemm/routergemm_dsv3.py (2)
  • mm_M1_16_K7168_N256 (75-81)
  • mm_M1_16_K7168_N256 (89-134)
🪛 Ruff (0.14.3)
flashinfer/gemm/routergemm_dsv3.py

14-14: Unused function argument: launch_with_pdl

(ARG001)


17-17: Avoid specifying long messages outside the exception class

(TRY003)


19-19: Avoid specifying long messages outside the exception class

(TRY003)


21-21: Avoid specifying long messages outside the exception class

(TRY003)


25-25: Avoid specifying long messages outside the exception class

(TRY003)


27-27: Avoid specifying long messages outside the exception class

(TRY003)


29-29: Avoid specifying long messages outside the exception class

(TRY003)


32-32: Avoid specifying long messages outside the exception class

(TRY003)


34-34: Avoid specifying long messages outside the exception class

(TRY003)


36-36: Avoid specifying long messages outside the exception class

(TRY003)


44-46: Avoid specifying long messages outside the exception class

(TRY003)


48-50: Avoid specifying long messages outside the exception class

(TRY003)


52-54: Avoid specifying long messages outside the exception class

(TRY003)


58-58: Avoid specifying long messages outside the exception class

(TRY003)


60-60: Avoid specifying long messages outside the exception class

(TRY003)


62-62: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/gemm/__init__.py

21-34: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (7)
scripts/task_test_single_node_comm_kernels.sh (1)

8-14: LGTM! Defensive cleanup to prevent stale imports.

The bytecode cache cleanup is a sensible precaution before package installation, especially given the module refactoring in this PR (e.g., flashinfer/gemm.pyflashinfer/gemm/gemm_base.py). The error suppression with || true prevents spurious failures when caches don't exist.

scripts/task_test_jit_cache_package_build_import.sh (1)

31-36: LGTM! Consistent cleanup pattern.

Same defensive cleanup as in other test scripts, appropriately placed before CUDA architecture detection and package installation.

scripts/task_test_nightly_build.sh (1)

15-21: LGTM! Consistent cleanup pattern.

The bytecode cache cleanup follows the same pattern established in other test scripts and is appropriately placed before package installation steps.

tests/gemm/test_mm_fp4.py (1)

12-12: LGTM! Import path updated for module restructuring.

The import now uses the more specific flashinfer.gemm.gemm_base module, which is consistent with the PR's refactoring of the GEMM module structure.

flashinfer/dsv3_ops/__init__.py (1)

1-5: LGTM! Clean public API for DSv3 operations.

The module correctly establishes a focused API surface for DSv3-specific operations by re-exporting mm_M1_16_K7168_N256 from the underlying GEMM implementation.

tests/gemm/test_tgv_gemm.py (1)

9-9: LGTM! Import path updated for module restructuring.

The import path update to flashinfer.gemm.gemm_base aligns with the GEMM module refactoring in this PR.

tests/model_optimizations/test_dsv3_router_gemm.py (1)

31-137: LGTM! Thorough negative test coverage.

The negative tests comprehensively validate input constraints including:

  • Boundary conditions for num_tokens [1-16], num_experts [256], and hidden_dim [7168]
  • Data type requirements (bfloat16 inputs, float32 output)
  • Layout requirements (column-major mat_b)

All error cases correctly use pytest.raises(ValueError, match=...) to validate both exception type and message content.

Comment on lines +128 to +144
TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1)
<< "mat_a and out must be row-major";
TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
auto stream = get_stream(mat_a.device());
bool use_custom_kernel = false;
if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts &&
hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
encode_dlpack_dtype(out_dtype_) == float32_code) {
use_custom_kernel = true;
}

if (use_custom_kernel) {
LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll(
num_tokens, reinterpret_cast<float*>(out.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Enforce leading strides to match the kernel layout.

The kernel indexes mat_a, mat_b, and out with hard-coded strides (m_idx * kHiddenDim, n_idx * kHiddenDim, m * kNumExperts). With the current guard, a perfectly legal PyTorch view such as out = base[:, :256] or mat_a = base[:, :7168] passes (unit inner stride) but has a larger leading stride. In those cases we’ll march with the wrong pitch and either read garbage or write into the wrong rows. Please reject such views up front.

   TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1)
       << "mat_a and out must be row-major";
   TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+  TVM_FFI_ICHECK(mat_a.strides()[0] == hidden_dim)
+      << "mat_a must have stride(0) == hidden_dim";
+  TVM_FFI_ICHECK(mat_b.strides()[1] == hidden_dim)
+      << "mat_b must have stride(1) == hidden_dim";
+  TVM_FFI_ICHECK(out.strides()[0] == num_experts)
+      << "out must have stride(0) == num_experts";
🤖 Prompt for AI Agents
In csrc/dsv3_router_gemm.cu around lines 128 to 144, the current guards only
enforce inner-unit strides but do not ensure the leading (row/column) strides
match the kernel's hard-coded pitches; add explicit checks that
mat_a.strides()[0] == kHiddenDim (row-major pitch), mat_b.strides()[1] ==
kHiddenDim (column-major pitch), and out.strides()[0] == kNumExperts (output row
pitch) and fail early with TVM_FFI_ICHECK if any of these don't hold so views
with larger leading strides are rejected before launching the custom kernel.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Collaborator

@cyx-6 cyx-6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@nvmbreughe nvmbreughe merged commit c8f2b03 into flashinfer-ai:main Nov 7, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants