Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions csrc/dsv3_router_gemm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include "flashinfer/gemm/dsv3_router_gemm.cuh"
#include "tvm_ffi_utils.h"

namespace flashinfer::trtllm_dsv3_router_gemm {
template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream,
bool use_pdl = false) {
constexpr int VPT = 16 / sizeof(T);
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = use_pdl;
config.numAttrs = 1;
config.attrs = attrs;
auto status = cudaLaunchKernelEx(
&config, router_gemm_kernel<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>, output,
mat_a, mat_b);
TVM_FFI_ICHECK(status == cudaSuccess)
<< "cudaLaunchKernelEx failed with error code " << cudaGetErrorString(status);
}

template void invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*,
__nv_bfloat16 const*, cudaStream_t,
bool);

template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller {
static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input,
__nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) {
if (num_tokens == kBegin) {
invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights,
stream, launch_with_pdl);
} else {
LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll(
num_tokens, output, input, weights, stream, launch_with_pdl);
}
}
};

template <int kEnd, int kNumExperts, int kHiddenDim>
struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> {
static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input,
__nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) {
if (num_tokens == kEnd) {
invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream,
launch_with_pdl);
} else {
throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16");
}
}
};

void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) {
int const num_tokens = mat_a.sizes()[0];
int const num_experts = mat_b.sizes()[1];
int const hidden_dim = mat_a.sizes()[1];
auto const out_dtype_ = out.dtype();
auto const data_type = mat_a.dtype();
constexpr int kNumExperts = 256;
constexpr int kHiddenDim = 7168;
Comment on lines +124 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The constexpr variables kNumExperts and kHiddenDim shadow the function-scope variables with similar names initialized from the input tensors. This can be confusing. Consider renaming them to clarify their purpose as compile-time constants for the specialized kernel. You'll also need to update their usage on lines 136 and 143.

  constexpr int kExpectedNumExperts = 256;
  constexpr int kExpectedHiddenDim = 7168;

std::vector<int64_t> output_size = {mat_a.sizes()[0], mat_b.sizes()[1]};
TVM_FFI_ICHECK(mat_a.dim() == 2 && mat_b.dim() == 2) << "mat_a and mat_b must be 2D tensors";
TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1)
<< "mat_a and out must be row-major";
TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
auto stream = get_stream(mat_a.device());
bool use_custom_kernel = false;
if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts &&
hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
encode_dlpack_dtype(out_dtype_) == float32_code) {
use_custom_kernel = true;
}

if (use_custom_kernel) {
LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll(
num_tokens, reinterpret_cast<float*>(out.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()),
reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl);
} else {
Comment on lines +128 to +144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Enforce leading strides to match the kernel layout.

The kernel indexes mat_a, mat_b, and out with hard-coded strides (m_idx * kHiddenDim, n_idx * kHiddenDim, m * kNumExperts). With the current guard, a perfectly legal PyTorch view such as out = base[:, :256] or mat_a = base[:, :7168] passes (unit inner stride) but has a larger leading stride. In those cases we’ll march with the wrong pitch and either read garbage or write into the wrong rows. Please reject such views up front.

   TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1)
       << "mat_a and out must be row-major";
   TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+  TVM_FFI_ICHECK(mat_a.strides()[0] == hidden_dim)
+      << "mat_a must have stride(0) == hidden_dim";
+  TVM_FFI_ICHECK(mat_b.strides()[1] == hidden_dim)
+      << "mat_b must have stride(1) == hidden_dim";
+  TVM_FFI_ICHECK(out.strides()[0] == num_experts)
+      << "out must have stride(0) == num_experts";
πŸ€– Prompt for AI Agents
In csrc/dsv3_router_gemm.cu around lines 128 to 144, the current guards only
enforce inner-unit strides but do not ensure the leading (row/column) strides
match the kernel's hard-coded pitches; add explicit checks that
mat_a.strides()[0] == kHiddenDim (row-major pitch), mat_b.strides()[1] ==
kHiddenDim (column-major pitch), and out.strides()[0] == kNumExperts (output row
pitch) and fail early with TVM_FFI_ICHECK if any of these don't hold so views
with larger leading strides are rejected before launching the custom kernel.

TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input tensor size";
Comment on lines +133 to +145
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Add missing mat_b validation before launching the kernel

dsv3_router_gemm_op currently accepts any dtype/leading dimension for mat_b, yet it unconditionally reinterprets the storage as __nv_bfloat16 and assumes exactly hidden_dim entries per expert. Supplying a float32 weight matrix or a column-major tensor with fewer than 7168 rows will silently corrupt results and can read past the allocated buffer. Please reject these configurations up front.

@@
-  auto const out_dtype_ = out.dtype();
-  auto const data_type = mat_a.dtype();
+  auto const out_dtype_ = out.dtype();
+  auto const data_type = mat_a.dtype();
+  auto const mat_b_dtype = mat_b.dtype();
@@
-  TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+  TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+  TVM_FFI_ICHECK(mat_b.sizes()[0] == hidden_dim)
+      << "mat_b first dimension must equal hidden_dim";
+  TVM_FFI_ICHECK(encode_dlpack_dtype(mat_b_dtype) == bfloat16_code)
+      << "mat_b must be bfloat16";
@@
-      hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
-      encode_dlpack_dtype(out_dtype_) == float32_code) {
+      hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
+      encode_dlpack_dtype(out_dtype_) == float32_code) {

}
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(dsv3_router_gemm_op,
flashinfer::trtllm_dsv3_router_gemm::dsv3_router_gemm_op);

} // namespace flashinfer::trtllm_dsv3_router_gemm
5 changes: 5 additions & 0 deletions flashinfer/dsv3_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flashinfer.gemm import mm_M1_16_K7168_N256

__all__ = [
"mm_M1_16_K7168_N256",
]
34 changes: 34 additions & 0 deletions flashinfer/gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm_base import bmm_fp8 as bmm_fp8
from .gemm_base import mm_fp4 as mm_fp4
from .gemm_base import mm_fp8 as mm_fp8
from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100
from .gemm_base import group_gemm_mxfp4_nt_groupwise as group_gemm_mxfp4_nt_groupwise
from .gemm_base import (
batch_deepgemm_fp8_nt_groupwise as batch_deepgemm_fp8_nt_groupwise,
)
from .gemm_base import (
group_deepgemm_fp8_nt_groupwise as group_deepgemm_fp8_nt_groupwise,
)
from .gemm_base import gemm_fp8_nt_blockscaled as gemm_fp8_nt_blockscaled
from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise
from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise

from .routergemm_dsv3 import (
mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256,
)

__all__ = [
"SegmentGEMMWrapper",
"bmm_fp8",
"mm_fp4",
"mm_fp8",
"tgv_gemm_sm100",
"group_gemm_mxfp4_nt_groupwise",
"batch_deepgemm_fp8_nt_groupwise",
"group_deepgemm_fp8_nt_groupwise",
"gemm_fp8_nt_blockscaled",
"gemm_fp8_nt_groupwise",
"group_gemm_fp8_nt_groupwise",
"mm_M1_16_K7168_N256",
]
36 changes: 18 additions & 18 deletions flashinfer/gemm.py β†’ flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@
from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm
import torch

from .autotuner import (
from ..autotuner import (
AutoTuner,
ConstraintSpec,
DynamicTensorSpec,
OptimizationProfile,
TunableRunner,
TuningConfig,
)
from .fused_moe.utils import (
from ..fused_moe.utils import (
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
)
from .utils import (
from ..utils import (
get_native_fp4_dtype,
is_sm100a_supported,
is_sm100f_supported,
Expand All @@ -44,16 +44,16 @@
backend_requirement,
supported_compute_capability,
)
from .jit.gemm import gen_gemm_sm90_module
from .jit.gemm import gen_gemm_module
from .jit.gemm import gen_gemm_sm100_module
from .jit.gemm import gen_gemm_sm120_module
from .jit.gemm import gen_gemm_sm120_module_cutlass_fp4
from .jit.gemm import gen_gemm_sm100_module_cutlass_fp4
from .jit.gemm import gen_gemm_sm100_module_cutlass_fp8
from .jit.gemm import gen_trtllm_gen_gemm_module
from .jit.gemm import gen_tgv_gemm_sm10x_module
from .jit.gemm import gen_deepgemm_sm100_module
from ..jit.gemm import gen_gemm_sm90_module
from ..jit.gemm import gen_gemm_module
from ..jit.gemm import gen_gemm_sm100_module
from ..jit.gemm import gen_gemm_sm120_module
from ..jit.gemm import gen_gemm_sm120_module_cutlass_fp4
from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp4
from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8
from ..jit.gemm import gen_trtllm_gen_gemm_module
from ..jit.gemm import gen_tgv_gemm_sm10x_module
from ..jit.gemm import gen_deepgemm_sm100_module


CUDNN_AVAILABLE = False
Expand All @@ -70,8 +70,8 @@
raise


from .jit.cubin_loader import setup_cubin_loader
from .utils import (
from ..jit.cubin_loader import setup_cubin_loader
from ..utils import (
_get_cache_buf,
determine_gemm_backend,
get_indptr,
Expand Down Expand Up @@ -733,7 +733,7 @@ def launch_compute_sm80_group_gemm_args(
w_stride_data = torch.empty(batch_size, dtype=ld_type, device=device)
y_stride_data = torch.empty(batch_size, dtype=ld_type, device=device)

from .triton.gemm import compute_sm80_group_gemm_args
from ..triton.gemm import compute_sm80_group_gemm_args

compute_sm80_group_gemm_args[(batch_size,)](
all_problems,
Expand Down Expand Up @@ -795,7 +795,7 @@ def launch_compute_sm90_group_gemm_args(
w_stride_data = torch.empty(batch_size, dtype=stride_type, device=device)
y_stride_data = torch.empty(batch_size, dtype=stride_type, device=device)

from .triton.gemm import compute_sm90_group_gemm_args
from ..triton.gemm import compute_sm90_group_gemm_args

compute_sm90_group_gemm_args[(batch_size,)](
all_problems,
Expand Down Expand Up @@ -2822,7 +2822,7 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise(
def pad_indptr_to_multiple_of_4(
m_indptr: torch.Tensor,
):
from .triton.gemm import compute_padding_mapping
from ..triton.gemm import compute_padding_mapping

batch_size = m_indptr.shape[0] - 1
m = m_indptr[1:] - m_indptr[:-1]
Expand Down
Loading