-
Notifications
You must be signed in to change notification settings - Fork 581
[DSV3] Optimized Router Gemm #2019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DSV3] Optimized Router Gemm #2019
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds a BF16→FP32 DSv3 router GEMM: device kernel header, CUDA host launcher with templated/unrolled token dispatch and TVM FFI export, Python loader/validation and JIT spec, tests, and test-script bytecode-cleanup steps. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant PyAPI as routergemm_dsv3.py
participant Validate as _mm_M1_16_K7168_N256_shape_checks
participant ModuleCache as get_dsv3_router_gemm_module
participant JIT as gen_dsv3_router_gemm_module
participant FFI as TVM FFI: dsv3_router_gemm_op
participant CUDA as router_gemm_kernel
User->>PyAPI: mm_M1_16_K7168_N256(mat_a,mat_b,out,launch_with_pdl)
PyAPI->>Validate: check shapes, dtypes, strides, device
alt validation fails
Validate-->>PyAPI: raise ValueError
PyAPI-->>User: error
else
PyAPI->>ModuleCache: obtain compiled module
alt first load
ModuleCache->>JIT: build JitSpec (dsv3_router_gemm.cu)
JIT-->>ModuleCache: compiled module
ModuleCache->>FFI: register TVM FFI op
end
PyAPI->>FFI: call dsv3_router_gemm_op(mat_a,mat_b,out,launch_with_pdl)
FFI->>CUDA: select templated invoke via LoopUnroller (tokens 1..16)
CUDA->>CUDA: vector loads, bf16→float, FMA, warp/shared reductions
CUDA-->>FFI: write out (fp32)
FFI-->>PyAPI: return (out mutated)
PyAPI-->>User: done
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @nvmbreughe, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a highly optimized router GEMM operation, specifically tailored for Deep Seek-V3 model architectures. By leveraging a kernel ported from TRTLLM, it significantly boosts performance for common token counts, offering substantial speedups. The changes include a new CUDA kernel, its Python binding, and thorough testing to ensure reliability and efficiency. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an optimized router GEMM kernel for Deep Seek-V3, ported from TRT-LLM, which shows significant performance improvements. The changes include the CUDA kernel implementation, Python bindings, and corresponding tests. My review focuses on improving code clarity, fixing a bug related to a kernel launch parameter, and removing dead or redundant code. Overall, the changes are well-structured and the addition is valuable.
| struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> { | ||
| static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, | ||
| __nv_bfloat16 const* weights, cudaStream_t stream) { | ||
| if (num_tokens == kEnd) { | ||
| invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, | ||
| stream); | ||
| } else { | ||
| throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); | ||
| } | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This specialization of LoopUnroller::unroll also needs to be updated to accept and forward the launch_with_pdl parameter to invokeRouterGemm.
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input,
__nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) {
if (num_tokens == kEnd) {
invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights,
stream, launch_with_pdl);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};
| constexpr int kNumExperts = 256; | ||
| constexpr int kHiddenDim = 7168; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The constexpr variables kNumExperts and kHiddenDim shadow the function-scope variables with similar names initialized from the input tensors. This can be confusing. Consider renaming them to clarify their purpose as compile-time constants for the specialized kernel. You'll also need to update their usage on lines 136 and 143.
constexpr int kExpectedNumExperts = 256;
constexpr int kExpectedHiddenDim = 7168;
flashinfer/gemm/routergemm_dsv3.py
Outdated
| def _dvs3_router_gemm_shape_checks(mat_a, mat_b, out, launch_with_pdl, bias): | ||
| # Dimension checks | ||
| if mat_a.dim() != 2: | ||
| raise ValueError("mat_a must be a 2D tensor") | ||
| if mat_b.dim() != 2: | ||
| raise ValueError("mat_b must be a 2D tensor") | ||
| if out.dim() != 2: | ||
| raise ValueError("out must be a 2D tensor") | ||
| if bias is not None: | ||
| raise ValueError("bias is not supported yet") | ||
|
|
||
| # Stride checks (check these before dimension checks to give better error messages) | ||
| if mat_a.stride(1) != 1: | ||
| raise ValueError("mat_a must be row-major") | ||
| if out.stride(1) != 1: | ||
| raise ValueError("out must be row-major") | ||
| if mat_b.stride(0) != 1: | ||
| raise ValueError("mat_b must be column-major") | ||
|
|
||
| if mat_a.shape[1] != mat_b.shape[0]: | ||
| raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]") | ||
| if out.shape[0] != mat_a.shape[0]: | ||
| raise ValueError("out.shape[0] must be equal to mat_a.shape[0]") | ||
| if out.shape[1] != mat_b.shape[1]: | ||
| raise ValueError("out.shape[1] must be equal to mat_b.shape[1]") | ||
|
|
||
| # Problem size checks | ||
| expected_hidden_dim = 7168 | ||
| expected_num_experts = 256 | ||
| min_tokens = 1 | ||
| max_tokens = 16 | ||
| if mat_a.shape[0] < min_tokens or mat_a.shape[0] > max_tokens: | ||
| raise ValueError( | ||
| f"mat_a.shape[0] (num_tokens) must be between {min_tokens} and {max_tokens}" | ||
| ) | ||
| if mat_a.shape[1] != expected_hidden_dim: | ||
| raise ValueError( | ||
| f"mat_a.shape[1] (hidden_dim) must be equal to {expected_hidden_dim}" | ||
| ) | ||
| if mat_b.shape[1] != expected_num_experts: | ||
| raise ValueError( | ||
| f"mat_b.shape[1] (num_experts) must be equal to {expected_num_experts}" | ||
| ) | ||
|
|
||
| # Data type checks | ||
| if mat_a.dtype != torch.bfloat16: | ||
| raise ValueError("mat_a must be a bfloat16 tensor") | ||
| if mat_b.dtype != torch.bfloat16: | ||
| raise ValueError("mat_b must be a bfloat16 tensor") | ||
| if out.dtype != torch.float32: | ||
| raise ValueError("out must be a float32 tensor") | ||
|
|
||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This helper function can be improved for clarity and correctness:
- There is a typo in the function name:
_dvs3_router_gemm_shape_checksshould be_dsv3_router_gemm_shape_checks. - The
launch_with_pdlparameter is unused and can be removed from the signature. - The function returns
True, but this value is never used. The function should not have a return statement. - The comment on line 21 is misleading as stride checks are performed after dimension checks.
Please also update the call to this function on line 97.
def _dsv3_router_gemm_shape_checks(mat_a, mat_b, out, bias):
# Dimension checks
if mat_a.dim() != 2:
raise ValueError("mat_a must be a 2D tensor")
if mat_b.dim() != 2:
raise ValueError("mat_b must be a 2D tensor")
if out.dim() != 2:
raise ValueError("out must be a 2D tensor")
if bias is not None:
raise ValueError("bias is not supported yet")
# Stride checks
if mat_a.stride(1) != 1:
raise ValueError("mat_a must be row-major")
if out.stride(1) != 1:
raise ValueError("out must be row-major")
if mat_b.stride(0) != 1:
raise ValueError("mat_b must be column-major")
if mat_a.shape[1] != mat_b.shape[0]:
raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]")
if out.shape[0] != mat_a.shape[0]:
raise ValueError("out.shape[0] must be equal to mat_a.shape[0]")
if out.shape[1] != mat_b.shape[1]:
raise ValueError("out.shape[1] must be equal to mat_b.shape[1]")
# Problem size checks
expected_hidden_dim = 7168
expected_num_experts = 256
min_tokens = 1
max_tokens = 16
if not (min_tokens <= mat_a.shape[0] <= max_tokens):
raise ValueError(
f"mat_a.shape[0] (num_tokens) must be between {min_tokens} and {max_tokens}"
)
if mat_a.shape[1] != expected_hidden_dim:
raise ValueError(
f"mat_a.shape[1] (hidden_dim) must be equal to {expected_hidden_dim}"
)
if mat_b.shape[1] != expected_num_experts:
raise ValueError(
f"mat_b.shape[1] (num_experts) must be equal to {expected_num_experts}"
)
# Data type checks
if mat_a.dtype != torch.bfloat16:
raise ValueError("mat_a must be a bfloat16 tensor")
if mat_b.dtype != torch.bfloat16:
raise ValueError("mat_b must be a bfloat16 tensor")
if out.dtype != torch.float32:
raise ValueError("out must be a float32 tensor")| __device__ __forceinline__ void fma(float2& d, float2 const& a, float2 const& b, float2 const& c) { | ||
| asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" | ||
| : "=l"(reinterpret_cast<uint64_t&>(d)) | ||
| : "l"(reinterpret_cast<uint64_t const&>(a)), | ||
| "l"(reinterpret_cast<uint64_t const&>(b)), | ||
| "l"(reinterpret_cast<uint64_t const&>(c))); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| int const warpSize = 32; | ||
| int const warpId = tid / warpSize; | ||
| int const laneId = tid % warpSize; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
csrc/dsv3_router_gemm.cu(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/jit/dsv3_optimizations.py(1 hunks)flashinfer/model_optimizations/dsv3/routergemm.py(1 hunks)include/flashinfer/gemm/dsv3_router_gemm.cuh(1 hunks)tests/model_optimizations/test_dsv3_router_gemm.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
flashinfer/jit/__init__.py (1)
flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_router_gemm_module(5-11)
tests/model_optimizations/test_dsv3_router_gemm.py (2)
csrc/dsv3_router_gemm.cu (6)
num_tokens(94-103)num_tokens(94-95)num_tokens(108-116)num_tokens(108-109)dsv3_router_gemm_op(119-150)dsv3_router_gemm_op(119-120)flashinfer/model_optimizations/dsv3/routergemm.py (2)
dsv3_router_gemm_op(73-80)dsv3_router_gemm_op(90-100)
flashinfer/jit/dsv3_optimizations.py (1)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)
csrc/dsv3_router_gemm.cu (2)
csrc/tvm_ffi_utils.h (2)
get_stream(272-274)encode_dlpack_dtype(29-31)flashinfer/model_optimizations/dsv3/routergemm.py (2)
dsv3_router_gemm_op(73-80)dsv3_router_gemm_op(90-100)
flashinfer/model_optimizations/dsv3/routergemm.py (3)
flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_router_gemm_module(5-11)flashinfer/jit/core.py (1)
build_and_load(300-312)csrc/dsv3_router_gemm.cu (2)
dsv3_router_gemm_op(119-150)dsv3_router_gemm_op(119-120)
🪛 Ruff (0.14.2)
flashinfer/model_optimizations/dsv3/routergemm.py
10-10: Unused function argument: launch_with_pdl
(ARG001)
13-13: Avoid specifying long messages outside the exception class
(TRY003)
15-15: Avoid specifying long messages outside the exception class
(TRY003)
17-17: Avoid specifying long messages outside the exception class
(TRY003)
19-19: Avoid specifying long messages outside the exception class
(TRY003)
23-23: Avoid specifying long messages outside the exception class
(TRY003)
25-25: Avoid specifying long messages outside the exception class
(TRY003)
27-27: Avoid specifying long messages outside the exception class
(TRY003)
30-30: Avoid specifying long messages outside the exception class
(TRY003)
32-32: Avoid specifying long messages outside the exception class
(TRY003)
34-34: Avoid specifying long messages outside the exception class
(TRY003)
42-44: Avoid specifying long messages outside the exception class
(TRY003)
46-48: Avoid specifying long messages outside the exception class
(TRY003)
50-52: Avoid specifying long messages outside the exception class
(TRY003)
56-56: Avoid specifying long messages outside the exception class
(TRY003)
58-58: Avoid specifying long messages outside the exception class
(TRY003)
60-60: Avoid specifying long messages outside the exception class
(TRY003)
550a0f6 to
75fcd9a
Compare
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
csrc/dsv3_router_gemm.cu(1 hunks)flashinfer/dsv3_ops/__init__.py(1 hunks)flashinfer/gemm/gemm_base.py(3 hunks)flashinfer/gemm/routergemm_dsv3.py(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/jit/dsv3_optimizations.py(1 hunks)include/flashinfer/gemm/dsv3_router_gemm.cuh(1 hunks)tests/model_optimizations/test_dsv3_router_gemm.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/dsv3_router_gemm.cu
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer/jit/__init__.py (1)
flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_router_gemm_module(5-11)
flashinfer/dsv3_ops/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
routergemm_dsv3_hidden_7168_experts_256_tokens_lt16(73-80)routergemm_dsv3_hidden_7168_experts_256_tokens_lt16(90-100)
tests/model_optimizations/test_dsv3_router_gemm.py (2)
flashinfer/gemm/routergemm_dsv3.py (2)
routergemm_dsv3_hidden_7168_experts_256_tokens_lt16(73-80)routergemm_dsv3_hidden_7168_experts_256_tokens_lt16(90-100)csrc/dsv3_router_gemm.cu (4)
num_tokens(94-103)num_tokens(94-95)num_tokens(108-116)num_tokens(108-109)
flashinfer/gemm/routergemm_dsv3.py (2)
flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_router_gemm_module(5-11)flashinfer/jit/core.py (1)
build_and_load(300-312)
flashinfer/gemm/gemm_base.py (4)
flashinfer/autotuner.py (6)
AutoTuner(335-784)ConstraintSpec(86-97)DynamicTensorSpec(41-82)OptimizationProfile(168-183)TunableRunner(194-247)TuningConfig(101-141)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)flashinfer/jit/gemm/core.py (9)
gen_gemm_sm90_module(468-507)gen_gemm_module(37-46)gen_gemm_sm100_module(193-269)gen_gemm_sm120_module(272-354)gen_gemm_sm120_module_cutlass_fp4(97-140)gen_gemm_sm100_module_cutlass_fp4(49-94)gen_gemm_sm100_module_cutlass_fp8(143-190)gen_trtllm_gen_gemm_module(357-390)gen_tgv_gemm_sm10x_module(393-465)flashinfer/jit/cubin_loader.py (1)
setup_cubin_loader(222-242)
flashinfer/jit/dsv3_optimizations.py (1)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)
🪛 Ruff (0.14.3)
flashinfer/gemm/routergemm_dsv3.py
10-10: Unused function argument: launch_with_pdl
(ARG001)
13-13: Avoid specifying long messages outside the exception class
(TRY003)
15-15: Avoid specifying long messages outside the exception class
(TRY003)
17-17: Avoid specifying long messages outside the exception class
(TRY003)
19-19: Avoid specifying long messages outside the exception class
(TRY003)
23-23: Avoid specifying long messages outside the exception class
(TRY003)
25-25: Avoid specifying long messages outside the exception class
(TRY003)
27-27: Avoid specifying long messages outside the exception class
(TRY003)
30-30: Avoid specifying long messages outside the exception class
(TRY003)
32-32: Avoid specifying long messages outside the exception class
(TRY003)
34-34: Avoid specifying long messages outside the exception class
(TRY003)
42-44: Avoid specifying long messages outside the exception class
(TRY003)
46-48: Avoid specifying long messages outside the exception class
(TRY003)
50-52: Avoid specifying long messages outside the exception class
(TRY003)
56-56: Avoid specifying long messages outside the exception class
(TRY003)
58-58: Avoid specifying long messages outside the exception class
(TRY003)
60-60: Avoid specifying long messages outside the exception class
(TRY003)
| # Stride checks (check these before dimension checks to give better error messages) | ||
| if mat_a.stride(1) != 1: | ||
| raise ValueError("mat_a must be row-major") | ||
| if out.stride(1) != 1: | ||
| raise ValueError("out must be row-major") | ||
| if mat_b.stride(0) != 1: | ||
| raise ValueError("mat_b must be column-major") | ||
|
|
||
| if mat_a.shape[1] != mat_b.shape[0]: | ||
| raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]") | ||
| if out.shape[0] != mat_a.shape[0]: | ||
| raise ValueError("out.shape[0] must be equal to mat_a.shape[0]") | ||
| if out.shape[1] != mat_b.shape[1]: | ||
| raise ValueError("out.shape[1] must be equal to mat_b.shape[1]") | ||
|
|
||
| # Problem size checks | ||
| expected_hidden_dim = 7168 | ||
| expected_num_experts = 256 | ||
| min_tokens = 1 | ||
| max_tokens = 16 | ||
| if mat_a.shape[0] < min_tokens or mat_a.shape[0] > max_tokens: | ||
| raise ValueError( | ||
| f"mat_a.shape[0] (num_tokens) must be between {min_tokens} and {max_tokens}" | ||
| ) | ||
| if mat_a.shape[1] != expected_hidden_dim: | ||
| raise ValueError( | ||
| f"mat_a.shape[1] (hidden_dim) must be equal to {expected_hidden_dim}" | ||
| ) | ||
| if mat_b.shape[1] != expected_num_experts: | ||
| raise ValueError( | ||
| f"mat_b.shape[1] (num_experts) must be equal to {expected_num_experts}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against strided views that break the kernel layout
The CUDA kernel indexes A/B/out with hard-coded strides (m_idx * kHiddenDim, n_idx * kHiddenDim, m * kNumExperts). If a caller passes a legal 2D bf16 view that shares storage—e.g. mat_a = base[:, :7168] or out = torch.empty(num_tokens, num_experts + 8)[:, :num_experts]—all current checks still pass (stride(1) == 1 / stride(0) == 1). Once inside the kernel we then march with the wrong stride, reading or writing across rows/columns and returning garbage. Please block these cases explicitly.
Apply a stride check for the remaining axes before launching:
@@
- if mat_a.shape[1] != mat_b.shape[0]:
+ if mat_a.shape[1] != mat_b.shape[0]:
raise ValueError("mat_a.shape[1] must be equal to mat_b.shape[0]")
@@
if out.shape[1] != mat_b.shape[1]:
raise ValueError("out.shape[1] must be equal to mat_b.shape[1]")
+ if mat_a.stride(0) != expected_hidden_dim:
+ raise ValueError("mat_a must be contiguous row-major with stride(0) == hidden_dim")
+ if mat_b.stride(1) != expected_hidden_dim:
+ raise ValueError("mat_b must be column-major with stride(1) == hidden_dim")
+ if out.stride(0) != expected_num_experts:
+ raise ValueError("out must be contiguous row-major with stride(0) == num_experts")Without these guards we silently corrupt data on perfectly legal PyTorch views, which is a correctness bug.
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.3)
23-23: Avoid specifying long messages outside the exception class
(TRY003)
25-25: Avoid specifying long messages outside the exception class
(TRY003)
27-27: Avoid specifying long messages outside the exception class
(TRY003)
30-30: Avoid specifying long messages outside the exception class
(TRY003)
32-32: Avoid specifying long messages outside the exception class
(TRY003)
34-34: Avoid specifying long messages outside the exception class
(TRY003)
42-44: Avoid specifying long messages outside the exception class
(TRY003)
46-48: Avoid specifying long messages outside the exception class
(TRY003)
50-52: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer/gemm/routergemm_dsv3.py around lines 21 to 52, the kernel assumes
fixed hard-coded row/column strides (kHiddenDim and kNumExperts) but the current
checks only verify unit strides and allow strided views that break layout; add
explicit stride equality checks: require mat_a.stride(0) == expected_hidden_dim,
mat_b.stride(1) == expected_hidden_dim, and out.stride(0) ==
expected_num_experts (in addition to the existing unit-stride checks) and raise
ValueError with a clear message if any of these mismatch so views that share
storage but have different layout are rejected before launch.
|
[FAILED] Pipeline #37879756: 1/17 passed |
|
/bot run |
|
[CANCELING] Pipeline #37883369: canceled |
6ea619b to
f1704b9
Compare
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (5)
include/flashinfer/gemm/dsv3_router_gemm.cuh (4)
23-29: Remove unused function or explain its purpose.This function was previously flagged as unused. If it's intended for future use or external callers, please add a comment explaining its purpose. Otherwise, remove it to avoid dead code.
51-52: Protect against dropping tail K tiles.This issue was previously flagged. Integer division silently drops any remaining elements when
kHiddenDimis not divisible byVPT * kBlockSize, causing incorrect GEMM results.
64-64: Remove commented-out duplicate declaration.Line 64 is a duplicate of line 65 and should be removed as previously noted.
107-109: Use existingkWarpSizeconstant consistently.The
warpSizevariable is redundant sincekWarpSizeis already defined at line 48. Using the constant consistently improves readability, as previously noted.flashinfer/gemm/routergemm_dsv3.py (1)
16-66: Address previously flagged issues in shape validation.Multiple issues were previously identified in this function:
- Typo in function name:
_dvs3should be_dsv3(note the order of 'v' and 's')- Unused parameter:
launch_with_pdlis never used in the function body- Unnecessary return: The function returns
Truebut this value is never checked- Misleading comment: Line 25 says "check these before dimension checks" but dimension checks (lines 18-23) actually come first
- Missing stride checks: The kernel assumes fixed strides but only one stride per tensor is validated, allowing strided views that would corrupt results
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
csrc/dsv3_router_gemm.cu(1 hunks)flashinfer/dsv3_ops/__init__.py(1 hunks)flashinfer/gemm/gemm_base.py(3 hunks)flashinfer/gemm/routergemm_dsv3.py(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/jit/dsv3_optimizations.py(1 hunks)include/flashinfer/gemm/dsv3_router_gemm.cuh(1 hunks)tests/model_optimizations/test_dsv3_router_gemm.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/gemm/gemm_base.py
🔇 Additional comments (1)
flashinfer/gemm/routergemm_dsv3.py (1)
92-107: LGTM with noted dependencies.The public API function correctly validates compute capability and delegates to the backend module. Note that effectiveness depends on the shape validation function addressing the previously flagged stride check issues.
| if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts && | ||
| hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code && | ||
| encode_dlpack_dtype(out_dtype_) == float32_code) { | ||
| use_custom_kernel = true; | ||
| } | ||
|
|
||
| if (use_custom_kernel) { | ||
| LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll( | ||
| num_tokens, reinterpret_cast<float*>(out.data_ptr()), | ||
| reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), | ||
| reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl); | ||
| } else { | ||
| TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input tensor size"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add missing mat_b validation before launching the kernel
dsv3_router_gemm_op currently accepts any dtype/leading dimension for mat_b, yet it unconditionally reinterprets the storage as __nv_bfloat16 and assumes exactly hidden_dim entries per expert. Supplying a float32 weight matrix or a column-major tensor with fewer than 7168 rows will silently corrupt results and can read past the allocated buffer. Please reject these configurations up front.
@@
- auto const out_dtype_ = out.dtype();
- auto const data_type = mat_a.dtype();
+ auto const out_dtype_ = out.dtype();
+ auto const data_type = mat_a.dtype();
+ auto const mat_b_dtype = mat_b.dtype();
@@
- TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+ TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+ TVM_FFI_ICHECK(mat_b.sizes()[0] == hidden_dim)
+ << "mat_b first dimension must equal hidden_dim";
+ TVM_FFI_ICHECK(encode_dlpack_dtype(mat_b_dtype) == bfloat16_code)
+ << "mat_b must be bfloat16";
@@
- hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
- encode_dlpack_dtype(out_dtype_) == float32_code) {
+ hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
+ encode_dlpack_dtype(out_dtype_) == float32_code) {| template <int VPT> | ||
| __device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) { | ||
| __nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec)); | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < VPT; i++) { | ||
| dst[i] = __bfloat162float(bf16_ptr[i]); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add compile-time check for VPT parameter.
The function assumes VPT == 8 since a uint4 contains exactly 8 bfloat16 values (16 bytes ÷ 2 bytes/bf16). If called with a different VPT, the loop will either under-read or overflow the buffer.
Apply this diff to enforce the invariant:
template <int VPT>
__device__ __forceinline__ void bf16_uint4_to_float8(uint4 const& vec, float* dst) {
+ static_assert(VPT == 8, "bf16_uint4_to_float8 requires VPT == 8 (uint4 holds exactly 8 bf16 values)");
__nv_bfloat16* bf16_ptr = reinterpret_cast<__nv_bfloat16*>(const_cast<uint4*>(&vec));🤖 Prompt for AI Agents
In include/flashinfer/gemm/dsv3_router_gemm.cuh around lines 32 to 40, the
template assumes VPT equals 8 (uint4 holds eight bfloat16s) but doesn't enforce
it; add a compile-time check such as a static_assert that VPT == 8 with a clear
message so any misuse fails to compile, ensuring the loop neither under-reads
nor overflows the uint4 buffer.
|
[FAILED] Pipeline #37897792: 1/17 passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
scripts/task_test_nightly_build.sh (1)
8-20: Optional refactor: Consolidate duplicated cache cleanup logic.The identical 3-line cache cleanup sequence is now replicated across all 5 test scripts (task_test_blackwell_kernels.sh, task_test_single_node_comm_kernels.sh, task_test_jit_cache_package_build_import.sh, task_test_multi_node_comm_kernels.sh, task_test_nightly_build.sh). Consider extracting to a common shell function or separate cleanup utility to reduce duplication and simplify future maintenance.
# Example: scripts/lib/cleanup_python_cache.sh cleanup_python_cache() { echo "Cleaning Python bytecode cache..." find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true find . -type f -name '*.pyc' -delete 2>/dev/null || true echo "Cache cleaned." echo "" }Then source and invoke from each test script:
source scripts/lib/cleanup_python_cache.sh cleanup_python_cacheflashinfer/gemm/__init__.py (1)
21-35: Consider sorting__all__for consistency.Static analysis suggests applying isort-style sorting to the
__all__list for better maintainability and consistency with project conventions.tests/gemm/test_tgv_gemm.py (1)
9-9: Consider exposing_match_sm_versionas a public utility or refactoring test.The import path update is correct. However, tests importing a private symbol (leading underscore) suggests it may deserve public status. If tests consistently need to check SM version compatibility, consider either making this a public utility function or refactoring tests to rely on exception handling from public APIs instead.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
flashinfer/gemm/__init__.py(1 hunks)scripts/task_test_blackwell_kernels.sh(1 hunks)scripts/task_test_jit_cache_package_build_import.sh(1 hunks)scripts/task_test_multi_node_comm_kernels.sh(1 hunks)scripts/task_test_nightly_build.sh(1 hunks)scripts/task_test_single_node_comm_kernels.sh(1 hunks)tests/gemm/test_mm_fp4.py(1 hunks)tests/gemm/test_tgv_gemm.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (9)
SegmentGEMMWrapper(845-1098)mm_fp4(1862-2013)mm_fp8(1586-1693)tgv_gemm_sm100(583-654)batch_deepgemm_fp8_nt_groupwise(2905-3037)group_deepgemm_fp8_nt_groupwise(2775-2902)gemm_fp8_nt_blockscaled(2418-2443)gemm_fp8_nt_groupwise(2110-2273)group_gemm_fp8_nt_groupwise(2446-2605)
tests/gemm/test_tgv_gemm.py (1)
flashinfer/gemm/gemm_base.py (1)
_match_sm_version(90-93)
🪛 Ruff (0.14.3)
flashinfer/gemm/__init__.py
21-35: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
scripts/task_test_blackwell_kernels.sh (1)
9-14: Cache cleanup placement is appropriate for pre-test import freshness.The bytecode cache cleanup is correctly positioned before pytest execution and includes proper error suppression. This addresses stale imports from the module refactoring (gemm.py → gemm/gemm_base.py) mentioned in the PR description.
scripts/task_test_single_node_comm_kernels.sh (1)
8-13: Pre-install cleanup ensures stale bytecode doesn't interfere with fresh imports.Cache cleanup positioned correctly before
pip install -e ., preventing stale .pyc files from old module structure from shadowing new imports.scripts/task_test_jit_cache_package_build_import.sh (1)
31-35: Early cleanup ensures detection and installation workflows start with clean Python imports.Bytecode cache removal precedes Python detection logic (line 39), ensuring CUDA architecture detection and subsequent installations operate on fresh imports without stale module references.
scripts/task_test_multi_node_comm_kernels.sh (1)
8-13: Pre-install cleanup maintains consistency with other test scripts.Cache cleanup positioned before
pip install -e ., aligned with the standardized approach across the test suite for handling stale imports from module refactoring.scripts/task_test_nightly_build.sh (1)
15-20: Cache cleanup placed early in nightly workflow for comprehensive environment reset.Bytecode cleanup positioned at script entry before GPU diagnostics and multi-phase installations, ensuring no stale imports affect the full nightly build and test cycle.
tests/gemm/test_mm_fp4.py (1)
12-12: Import path change is correct but compensates for missing re-export.The import path was updated to bypass the package-level
__init__.pybecauseCUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERRORisn't re-exported there (despite being in__all__). Once the missing import is added toflashinfer/gemm/__init__.py, this could revert to the cleaner package-level import:from flashinfer.gemm import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR.
20421aa to
a08e0a8
Compare
|
/bot run |
3a88a7d to
837d76f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
csrc/dsv3_router_gemm.cu (1)
123-140: Reject non-bfloat16mat_bbefore launch.Line 141 reinterprets
mat_b.data_ptr()as__nv_bfloat16 const*, but we never verify thatmat_bactually carries bf16 elements. A caller can legally pass a float32 or fp8 column-major tensor (the stride checks still pass), and we would then reinterpret 4‑byte scalars as bf16 pairs, silently corrupting the GEMM inputs. Please gate the custom‑kernel path onmat_b.dtype()and emit an explicit check before the cast.- auto const out_dtype_ = out.dtype(); - auto const data_type = mat_a.dtype(); + auto const out_dtype_ = out.dtype(); + auto const data_type = mat_a.dtype(); + auto const mat_b_dtype = mat_b.dtype(); @@ - TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major"; + TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major"; + TVM_FFI_ICHECK(encode_dlpack_dtype(mat_b_dtype) == bfloat16_code) + << "mat_b must be bfloat16"; @@ - if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts && - hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code && - encode_dlpack_dtype(out_dtype_) == float32_code) { + if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts && + hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code && + encode_dlpack_dtype(mat_b_dtype) == bfloat16_code && + encode_dlpack_dtype(out_dtype_) == float32_code) {flashinfer/gemm/routergemm_dsv3.py (1)
24-63: Reject strided views that break the kernel’s layout assumptions.The CUDA kernel uses fixed pitches (
hidden_dim/num_experts) when it walks A, B, and out. Today we only check the unit inner stride, so a caller can passmat_a = base[:, :7168],mat_b = base[:, :256], orout = base[:, :256]—all valid PyTorch views with stride(0) > expected. Those configurations sail through this validator and then the kernel reads/writes with the wrong pitch. Please enforce the leading strides explicitly.if mat_b.stride(0) != 1: raise ValueError("mat_b must be column-major") @@ if out.shape[1] != mat_b.shape[1]: raise ValueError("out.shape[1] must be equal to mat_b.shape[1]") + if mat_a.stride(0) != expected_hidden_dim: + raise ValueError( + "mat_a must be contiguous row-major with stride(0) == hidden_dim" + ) + if mat_b.stride(1) != expected_hidden_dim: + raise ValueError( + "mat_b must be column-major with stride(1) == hidden_dim" + ) + if out.stride(0) != expected_num_experts: + raise ValueError( + "out must be contiguous row-major with stride(0) == num_experts" + )
🧹 Nitpick comments (2)
flashinfer/gemm/__init__.py (1)
1-34: LGTM! Public API correctly established for GEMM module.All symbols in
__all__are properly imported and re-exported. The newmm_M1_16_K7168_N256is correctly integrated from the DSv3 router GEMM implementation.The static analysis tool suggests sorting
__all__alphabetically for consistency. While this is purely stylistic, it can improve maintainability:__all__ = [ - "SegmentGEMMWrapper", - "bmm_fp8", - "mm_fp4", - "mm_fp8", - "tgv_gemm_sm100", - "group_gemm_mxfp4_nt_groupwise", "batch_deepgemm_fp8_nt_groupwise", - "group_deepgemm_fp8_nt_groupwise", + "bmm_fp8", "gemm_fp8_nt_blockscaled", "gemm_fp8_nt_groupwise", + "group_deepgemm_fp8_nt_groupwise", + "group_gemm_mxfp4_nt_groupwise", "group_gemm_fp8_nt_groupwise", "mm_M1_16_K7168_N256", + "mm_fp4", + "mm_fp8", + "SegmentGEMMWrapper", + "tgv_gemm_sm100", ]tests/model_optimizations/test_dsv3_router_gemm.py (1)
8-28: LGTM! Comprehensive positive test coverage.The test properly validates the DSv3 router GEMM operation across various token counts (1-16) and PDL modes, with appropriate SM100-only gating and numerical accuracy checks (cosine similarity > 0.99).
Line 23 uses
torch.randnfor the output tensor, which pre-fills it with random values. While the operation should overwrite these, usingtorch.emptywould be more conventional and slightly more efficient:- out = torch.randn(num_tokens, num_experts, device="cuda", dtype=torch.float32) + out = torch.empty(num_tokens, num_experts, device="cuda", dtype=torch.float32)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (17)
csrc/dsv3_router_gemm.cu(1 hunks)flashinfer/dsv3_ops/__init__.py(1 hunks)flashinfer/gemm/__init__.py(1 hunks)flashinfer/gemm/gemm_base.py(6 hunks)flashinfer/gemm/routergemm_dsv3.py(1 hunks)flashinfer/jit/__init__.py(1 hunks)flashinfer/jit/dsv3_optimizations.py(1 hunks)include/flashinfer/gemm/dsv3_router_gemm.cuh(1 hunks)scripts/task_test_blackwell_kernels.sh(1 hunks)scripts/task_test_jit_cache_package_build_import.sh(1 hunks)scripts/task_test_multi_node_comm_kernels.sh(1 hunks)scripts/task_test_nightly_build.sh(1 hunks)scripts/task_test_single_node_comm_kernels.sh(1 hunks)tests/gemm/test_group_gemm.py(2 hunks)tests/gemm/test_mm_fp4.py(1 hunks)tests/gemm/test_tgv_gemm.py(1 hunks)tests/model_optimizations/test_dsv3_router_gemm.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (6)
- scripts/task_test_blackwell_kernels.sh
- flashinfer/jit/init.py
- flashinfer/gemm/gemm_base.py
- flashinfer/jit/dsv3_optimizations.py
- scripts/task_test_multi_node_comm_kernels.sh
- tests/gemm/test_group_gemm.py
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer/dsv3_ops/__init__.py (1)
flashinfer/gemm/routergemm_dsv3.py (2)
mm_M1_16_K7168_N256(75-81)mm_M1_16_K7168_N256(89-134)
csrc/dsv3_router_gemm.cu (1)
csrc/tvm_ffi_utils.h (2)
get_stream(272-274)encode_dlpack_dtype(29-31)
tests/model_optimizations/test_dsv3_router_gemm.py (3)
flashinfer/gemm/routergemm_dsv3.py (2)
mm_M1_16_K7168_N256(75-81)mm_M1_16_K7168_N256(89-134)flashinfer/utils.py (1)
get_compute_capability(252-255)csrc/dsv3_router_gemm.cu (4)
num_tokens(93-102)num_tokens(93-94)num_tokens(107-115)num_tokens(107-108)
flashinfer/gemm/routergemm_dsv3.py (4)
flashinfer/jit/dsv3_optimizations.py (1)
gen_dsv3_router_gemm_module(5-11)flashinfer/utils.py (2)
supported_compute_capability(773-853)backend_requirement(856-1131)flashinfer/jit/core.py (1)
build_and_load(300-312)csrc/dsv3_router_gemm.cu (2)
dsv3_router_gemm_op(118-147)dsv3_router_gemm_op(118-118)
tests/gemm/test_tgv_gemm.py (1)
flashinfer/gemm/gemm_base.py (1)
_match_sm_version(90-93)
flashinfer/gemm/__init__.py (2)
flashinfer/gemm/gemm_base.py (9)
SegmentGEMMWrapper(829-1082)mm_fp4(1846-1997)mm_fp8(1570-1677)tgv_gemm_sm100(567-638)batch_deepgemm_fp8_nt_groupwise(2981-3113)group_deepgemm_fp8_nt_groupwise(2851-2978)gemm_fp8_nt_blockscaled(2494-2519)gemm_fp8_nt_groupwise(2186-2349)group_gemm_fp8_nt_groupwise(2522-2681)flashinfer/gemm/routergemm_dsv3.py (2)
mm_M1_16_K7168_N256(75-81)mm_M1_16_K7168_N256(89-134)
🪛 Ruff (0.14.3)
flashinfer/gemm/routergemm_dsv3.py
14-14: Unused function argument: launch_with_pdl
(ARG001)
17-17: Avoid specifying long messages outside the exception class
(TRY003)
19-19: Avoid specifying long messages outside the exception class
(TRY003)
21-21: Avoid specifying long messages outside the exception class
(TRY003)
25-25: Avoid specifying long messages outside the exception class
(TRY003)
27-27: Avoid specifying long messages outside the exception class
(TRY003)
29-29: Avoid specifying long messages outside the exception class
(TRY003)
32-32: Avoid specifying long messages outside the exception class
(TRY003)
34-34: Avoid specifying long messages outside the exception class
(TRY003)
36-36: Avoid specifying long messages outside the exception class
(TRY003)
44-46: Avoid specifying long messages outside the exception class
(TRY003)
48-50: Avoid specifying long messages outside the exception class
(TRY003)
52-54: Avoid specifying long messages outside the exception class
(TRY003)
58-58: Avoid specifying long messages outside the exception class
(TRY003)
60-60: Avoid specifying long messages outside the exception class
(TRY003)
62-62: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/gemm/__init__.py
21-34: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (7)
scripts/task_test_single_node_comm_kernels.sh (1)
8-14: LGTM! Defensive cleanup to prevent stale imports.The bytecode cache cleanup is a sensible precaution before package installation, especially given the module refactoring in this PR (e.g.,
flashinfer/gemm.py→flashinfer/gemm/gemm_base.py). The error suppression with|| trueprevents spurious failures when caches don't exist.scripts/task_test_jit_cache_package_build_import.sh (1)
31-36: LGTM! Consistent cleanup pattern.Same defensive cleanup as in other test scripts, appropriately placed before CUDA architecture detection and package installation.
scripts/task_test_nightly_build.sh (1)
15-21: LGTM! Consistent cleanup pattern.The bytecode cache cleanup follows the same pattern established in other test scripts and is appropriately placed before package installation steps.
tests/gemm/test_mm_fp4.py (1)
12-12: LGTM! Import path updated for module restructuring.The import now uses the more specific
flashinfer.gemm.gemm_basemodule, which is consistent with the PR's refactoring of the GEMM module structure.flashinfer/dsv3_ops/__init__.py (1)
1-5: LGTM! Clean public API for DSv3 operations.The module correctly establishes a focused API surface for DSv3-specific operations by re-exporting
mm_M1_16_K7168_N256from the underlying GEMM implementation.tests/gemm/test_tgv_gemm.py (1)
9-9: LGTM! Import path updated for module restructuring.The import path update to
flashinfer.gemm.gemm_basealigns with the GEMM module refactoring in this PR.tests/model_optimizations/test_dsv3_router_gemm.py (1)
31-137: LGTM! Thorough negative test coverage.The negative tests comprehensively validate input constraints including:
- Boundary conditions for num_tokens [1-16], num_experts [256], and hidden_dim [7168]
- Data type requirements (bfloat16 inputs, float32 output)
- Layout requirements (column-major mat_b)
All error cases correctly use
pytest.raises(ValueError, match=...)to validate both exception type and message content.
| TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1) | ||
| << "mat_a and out must be row-major"; | ||
| TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major"; | ||
| auto stream = get_stream(mat_a.device()); | ||
| bool use_custom_kernel = false; | ||
| if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts && | ||
| hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code && | ||
| encode_dlpack_dtype(out_dtype_) == float32_code) { | ||
| use_custom_kernel = true; | ||
| } | ||
|
|
||
| if (use_custom_kernel) { | ||
| LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll( | ||
| num_tokens, reinterpret_cast<float*>(out.data_ptr()), | ||
| reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), | ||
| reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl); | ||
| } else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enforce leading strides to match the kernel layout.
The kernel indexes mat_a, mat_b, and out with hard-coded strides (m_idx * kHiddenDim, n_idx * kHiddenDim, m * kNumExperts). With the current guard, a perfectly legal PyTorch view such as out = base[:, :256] or mat_a = base[:, :7168] passes (unit inner stride) but has a larger leading stride. In those cases we’ll march with the wrong pitch and either read garbage or write into the wrong rows. Please reject such views up front.
TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1)
<< "mat_a and out must be row-major";
TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+ TVM_FFI_ICHECK(mat_a.strides()[0] == hidden_dim)
+ << "mat_a must have stride(0) == hidden_dim";
+ TVM_FFI_ICHECK(mat_b.strides()[1] == hidden_dim)
+ << "mat_b must have stride(1) == hidden_dim";
+ TVM_FFI_ICHECK(out.strides()[0] == num_experts)
+ << "out must have stride(0) == num_experts";🤖 Prompt for AI Agents
In csrc/dsv3_router_gemm.cu around lines 128 to 144, the current guards only
enforce inner-unit strides but do not ensure the leading (row/column) strides
match the kernel's hard-coded pitches; add explicit checks that
mat_a.strides()[0] == kHiddenDim (row-major pitch), mat_b.strides()[1] ==
kHiddenDim (column-major pitch), and out.strides()[0] == kNumExperts (output row
pitch) and fail early with TVM_FFI_ICHECK if any of these don't hold so views
with larger leading strides are rejected before launching the custom kernel.
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
bkryu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
cyx-6
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
📌 Description
This PR:
From my measurements (num tokens = [1,2,4,8,16]), speedups were observed between 1.36 and 1.82x on B200.
Both positive and negative tests were added to test the behavior.
Breaking Change: Refactored gemm module structure
ACTION REQUIRED: Delete stale
flashinfer/gemm.pyfileThe
gemm.pyfile has been refactored into a package:flashinfer/gemm.py→flashinfer/gemm/gemm_base.pyAfter pulling this change, run:
git clean -fd flashinfer/ # OR manually: rm flashinfer/flashinfer/gemm.pyThis is backward compatible - no import changes needed.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Integration
JIT / Packaging
Tests
Chores