-
Notifications
You must be signed in to change notification settings - Fork 13.3k
CUDA: add fp kernel for larger batch size MoE #16512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
089888f
to
3183a8e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without having looked at the kernel in a lot of detail, I am getting the impression that you did not target specific occupancies. If you look at the CUDA documentation you'll find that the amount of SRAM per SM is not the same across generations. For high compute utilization my recommendation would be to either target a single CUDA block with 256 threads or 2 CUDA blocks with 128 threads. The reasoning behind this is that with maximum register use you can run a total of 256 threads in parallel. With a single CUDA block you can then also use larger tile sizes and get higher arithmetic intensity but at the cost of underutilization whenever you call __syncthreads
.
I did try 4 warps instead of 8 when ne01 is larger, it improves the performance by about 20-30%, whereas cublas is faster by a larger amount. I only tested on 3090 and 4090, on 3090 the effect of this change is much more pronounced, so much so that it reaches parity but doesn't beat cublas. Hence to keep things simple for the first PR of this kernel I didn't branch on the architecture. However, if you think there's something worth exploring more regarding number of warps and architecture I'd be happy to do so. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The biggest problem that this kernel has is that it does not improve upon the arithmetic intensity of the preexisting kernel: for optimal compute efficiency you want to do as many floating-point operations as possible per data loaded. For the memory pattern I originally implemented this was not a priority: each warp loads only its own data so that there is no need to synchronize them until the end. As a consequence the arithmetic intensity and maximum achievable compute utilization is low but the batch size provides a hard limit for that anyways. But for large batch sizes the memory pattern should be different. Warps should cooperate to load data into SRAM, synchronize, load the data with multiple warps loading the same data from SRAM multiple times, do the matrix multiplication, synchronize again, and then load the next data.
How do you want to proceed with this PR? Do you want to merge it with the current memory pattern or implement a different one for large batch sizes? Even without a different memory pattern compacting the expert ids is obviously beneficial.
void ggml_cuda_launch_mmq_ids_helper( | ||
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, | ||
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this kernel should be moved to a separate file like mmid.cu
.
For the current memory pattern the number of warps is essentially negligible if the matrices are sufficiently large. Without having looked at NSight Compute, I suspect that the difference between 4 and 8 warps comes from tail effects where some performance is lost at the end when the GPU runs out of work. I would say not to focus on that until the kernel can be pushed to output tiles that are at least of size 64x64, then it would make sense to look into e.g. a stream-k decomposition to reduce tail effects (this is particularly relevant for datacenter GPUs with lots of SMs). |
To be clear: when I was previously commenting about occupancy that would specifically apply to a kernel with large output tiles like at least 64x64, that is when you'll have to start thinking about SRAM vs. register limits. |
Can you explain what exactly you mean by "double-buffering"? I would have thought you mean the use of asynchronous data copies, like is for example being done in For synchronous copies you can also look into Overall, please tell me whether you intend to make further changes to the kernel in this PR prior to merging. |
y-tiles have two ping-pong buffers where the global load can be issued in one buffer while computing the previous buffer without stalling the warp. This is due to instruction level parallelism. It's not as efficient as
Let me try loading data into shared memory for a 64x64 tile to improve the arithmetic intensity. If that is not helpful then we can merge as is, as it's still useful for a large variety of models |
This PR adds a new kernel in mmf for larger batch sizes for MoE, it leverages
mmq_ids_helper
and adds double-buffering for the gather of src1 cols based onids_src_sorted
It is currently faster than the cuBLAS fallback till n_tokens=512 for when ne01 <= 1024, beyond that it is only faster for n_tokens<= 128. It would require a bigger rewrite for larger ne01 where a CTA loads multiple tiles and operates on it. cuBLAS seems to run a 128x128 tile after ne01 >= 1024 which keeps tensor cores better utilized. Other things that I tried which didn't work were - increasing rows per block and cols per block, double buffering the x tiles, and operating on multiple tiles in one kernel.
Models with ne01 <= 1024 seem to be all latest Qwen models, so this should be helpful for them provided someone wants to run them at original precision and default ubatch size.
On a A100 for qwen3:
For a smaller MoE model like
lfm2moe
where ne01 = 1792 (on a 3090)Others like granite-4 also benefit