-
Notifications
You must be signed in to change notification settings - Fork 583
[DSV3] Optimized Router Gemm #2019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
da99870
9773aea
e2302d2
7b6ae9f
4d308a1
6686135
1866926
af00b1b
b1f00e9
c2885de
d80a50c
56d7a3f
ade20b8
2a7ba01
57fe1a7
dd1ee6d
2cc1771
bedc580
0de835e
837d76f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| #include "flashinfer/gemm/dsv3_router_gemm.cuh" | ||
| #include "tvm_ffi_utils.h" | ||
|
|
||
| namespace flashinfer::trtllm_dsv3_router_gemm { | ||
| template <typename T, int kNumTokens, int kNumExperts, int kHiddenDim> | ||
| void invokeRouterGemm(float* output, T const* mat_a, T const* mat_b, cudaStream_t stream, | ||
| bool use_pdl = false) { | ||
| constexpr int VPT = 16 / sizeof(T); | ||
| constexpr int kBlockSize = 128; | ||
| cudaLaunchConfig_t config; | ||
| config.gridDim = kNumExperts; | ||
| config.blockDim = kBlockSize; | ||
| config.dynamicSmemBytes = 0; | ||
| config.stream = stream; | ||
| cudaLaunchAttribute attrs[1]; | ||
| attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||
| attrs[0].val.programmaticStreamSerializationAllowed = use_pdl; | ||
| config.numAttrs = 1; | ||
| config.attrs = attrs; | ||
| auto status = cudaLaunchKernelEx( | ||
| &config, router_gemm_kernel<T, kBlockSize, VPT, kNumTokens, kNumExperts, kHiddenDim>, output, | ||
| mat_a, mat_b); | ||
| TVM_FFI_ICHECK(status == cudaSuccess) | ||
| << "cudaLaunchKernelEx failed with error code " << cudaGetErrorString(status); | ||
| } | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 1, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 2, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 3, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 4, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 5, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 6, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 7, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 8, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 9, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 10, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 11, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 12, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 13, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 14, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 15, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template void invokeRouterGemm<__nv_bfloat16, 16, 256, 7168>(float*, __nv_bfloat16 const*, | ||
| __nv_bfloat16 const*, cudaStream_t, | ||
| bool); | ||
|
|
||
| template <int kBegin, int kEnd, int kNumExperts, int kHiddenDim> | ||
| struct LoopUnroller { | ||
| static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, | ||
| __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) { | ||
| if (num_tokens == kBegin) { | ||
| invokeRouterGemm<__nv_bfloat16, kBegin, kNumExperts, kHiddenDim>(output, input, weights, | ||
| stream, launch_with_pdl); | ||
| } else { | ||
| LoopUnroller<kBegin + 1, kEnd, kNumExperts, kHiddenDim>::unroll( | ||
| num_tokens, output, input, weights, stream, launch_with_pdl); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| template <int kEnd, int kNumExperts, int kHiddenDim> | ||
| struct LoopUnroller<kEnd, kEnd, kNumExperts, kHiddenDim> { | ||
| static void unroll(int num_tokens, float* output, __nv_bfloat16 const* input, | ||
| __nv_bfloat16 const* weights, cudaStream_t stream, bool launch_with_pdl) { | ||
| if (num_tokens == kEnd) { | ||
| invokeRouterGemm<__nv_bfloat16, kEnd, kNumExperts, kHiddenDim>(output, input, weights, stream, | ||
| launch_with_pdl); | ||
| } else { | ||
| throw std::invalid_argument("Invalid num_tokens, only supports 1 to 16"); | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| void dsv3_router_gemm_op(TensorView mat_a, TensorView mat_b, TensorView out, bool launch_with_pdl) { | ||
| int const num_tokens = mat_a.sizes()[0]; | ||
| int const num_experts = mat_b.sizes()[1]; | ||
| int const hidden_dim = mat_a.sizes()[1]; | ||
| auto const out_dtype_ = out.dtype(); | ||
| auto const data_type = mat_a.dtype(); | ||
| constexpr int kNumExperts = 256; | ||
| constexpr int kHiddenDim = 7168; | ||
| std::vector<int64_t> output_size = {mat_a.sizes()[0], mat_b.sizes()[1]}; | ||
nvmbreughe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| TVM_FFI_ICHECK(mat_a.dim() == 2 && mat_b.dim() == 2) << "mat_a and mat_b must be 2D tensors"; | ||
| TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1) | ||
| << "mat_a and out must be row-major"; | ||
| TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major"; | ||
| auto stream = get_stream(mat_a.device()); | ||
| bool use_custom_kernel = false; | ||
| if (num_tokens >= 1 && num_tokens <= 16 && num_experts == kNumExperts && | ||
| hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code && | ||
| encode_dlpack_dtype(out_dtype_) == float32_code) { | ||
| use_custom_kernel = true; | ||
| } | ||
|
|
||
| if (use_custom_kernel) { | ||
| LoopUnroller<1, 16, kNumExperts, kHiddenDim>::unroll( | ||
| num_tokens, reinterpret_cast<float*>(out.data_ptr()), | ||
| reinterpret_cast<__nv_bfloat16 const*>(mat_a.data_ptr()), | ||
| reinterpret_cast<__nv_bfloat16 const*>(mat_b.data_ptr()), stream, launch_with_pdl); | ||
| } else { | ||
|
Comment on lines
+128
to
+144
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Enforce leading strides to match the kernel layout. The kernel indexes TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1)
<< "mat_a and out must be row-major";
TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+ TVM_FFI_ICHECK(mat_a.strides()[0] == hidden_dim)
+ << "mat_a must have stride(0) == hidden_dim";
+ TVM_FFI_ICHECK(mat_b.strides()[1] == hidden_dim)
+ << "mat_b must have stride(1) == hidden_dim";
+ TVM_FFI_ICHECK(out.strides()[0] == num_experts)
+ << "out must have stride(0) == num_experts";π€ Prompt for AI Agents |
||
| TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input tensor size"; | ||
|
Comment on lines
+133
to
+145
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add missing mat_b validation before launching the kernel
@@
- auto const out_dtype_ = out.dtype();
- auto const data_type = mat_a.dtype();
+ auto const out_dtype_ = out.dtype();
+ auto const data_type = mat_a.dtype();
+ auto const mat_b_dtype = mat_b.dtype();
@@
- TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+ TVM_FFI_ICHECK(mat_b.strides()[0] == 1) << "mat_b must be column-major";
+ TVM_FFI_ICHECK(mat_b.sizes()[0] == hidden_dim)
+ << "mat_b first dimension must equal hidden_dim";
+ TVM_FFI_ICHECK(encode_dlpack_dtype(mat_b_dtype) == bfloat16_code)
+ << "mat_b must be bfloat16";
@@
- hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
- encode_dlpack_dtype(out_dtype_) == float32_code) {
+ hidden_dim == kHiddenDim && encode_dlpack_dtype(data_type) == bfloat16_code &&
+ encode_dlpack_dtype(out_dtype_) == float32_code) { |
||
| } | ||
| } | ||
|
|
||
| TVM_FFI_DLL_EXPORT_TYPED_FUNC(dsv3_router_gemm_op, | ||
| flashinfer::trtllm_dsv3_router_gemm::dsv3_router_gemm_op); | ||
|
|
||
| } // namespace flashinfer::trtllm_dsv3_router_gemm | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from flashinfer.gemm import mm_M1_16_K7168_N256 | ||
|
|
||
| __all__ = [ | ||
| "mm_M1_16_K7168_N256", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper | ||
| from .gemm_base import bmm_fp8 as bmm_fp8 | ||
| from .gemm_base import mm_fp4 as mm_fp4 | ||
| from .gemm_base import mm_fp8 as mm_fp8 | ||
| from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100 | ||
| from .gemm_base import group_gemm_mxfp4_nt_groupwise as group_gemm_mxfp4_nt_groupwise | ||
| from .gemm_base import ( | ||
| batch_deepgemm_fp8_nt_groupwise as batch_deepgemm_fp8_nt_groupwise, | ||
| ) | ||
| from .gemm_base import ( | ||
| group_deepgemm_fp8_nt_groupwise as group_deepgemm_fp8_nt_groupwise, | ||
| ) | ||
| from .gemm_base import gemm_fp8_nt_blockscaled as gemm_fp8_nt_blockscaled | ||
| from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise | ||
| from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise | ||
|
|
||
| from .routergemm_dsv3 import ( | ||
| mm_M1_16_K7168_N256 as mm_M1_16_K7168_N256, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "SegmentGEMMWrapper", | ||
| "bmm_fp8", | ||
| "mm_fp4", | ||
| "mm_fp8", | ||
| "tgv_gemm_sm100", | ||
| "group_gemm_mxfp4_nt_groupwise", | ||
| "batch_deepgemm_fp8_nt_groupwise", | ||
| "group_deepgemm_fp8_nt_groupwise", | ||
| "gemm_fp8_nt_blockscaled", | ||
| "gemm_fp8_nt_groupwise", | ||
| "group_gemm_fp8_nt_groupwise", | ||
| "mm_M1_16_K7168_N256", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
constexprvariableskNumExpertsandkHiddenDimshadow the function-scope variables with similar names initialized from the input tensors. This can be confusing. Consider renaming them to clarify their purpose as compile-time constants for the specialized kernel. You'll also need to update their usage on lines 136 and 143.