-
Notifications
You must be signed in to change notification settings - Fork 13.9k
Add support for CUMSUM and TRI for CUDA. #17584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
For cumsum we should use https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceScan.html and use this kernel as a fallback |
|
I have a small optimization for the tri kernel (; Benchmark Results1. llama.cpp benchmark (50 runs each)
2. Profiler Statistics rtx 2070 (Nsight)
@@ -1,16 +1,7 @@
#include "tri.cuh"
#include "ggml.h"
-// Triangle type comparison - determines which elements to keep
-__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) {
- switch (type) {
- case GGML_TRI_TYPE_LOWER: return i < r;
- case GGML_TRI_TYPE_LOWER_DIAG: return i <= r;
- case GGML_TRI_TYPE_UPPER: return i > r;
- case GGML_TRI_TYPE_UPPER_DIAG: return i >= r;
- default: return false;
- }
-}
+
template<typename T>
static __global__ void tri_kernel(
@@ -31,10 +22,22 @@ static __global__ void tri_kernel(
const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03);
T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+ // Optimization: Avoid control flow (switch) inside the hot loop.
+ // Map the 4 triangle types to a generic "split point" and "keep direction" logic.
+ // LOWER / UPPER_DIAG: Split at 'r' (i1). LOWER_DIAG / UPPER: Split at 'r + 1'.
+ int add_to_split = 0;
+ if (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) {
+ add_to_split = 1;
+ }
+ int64_t split_point = i1 + add_to_split;
+ bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
+
// Each thread processes elements at stride blockDim.x
for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) {
- dst_row[i0] = tri_compare(i0, i1, ttype)
- ? src_row[i0] : static_cast<T>(0.f);
+ // If prefix_keep is true, keep (i0 < split_point). Else, keep (i0 >= split_point).
+ bool keep = ((i0 < split_point) == prefix_keep);
+ dst_row[i0] = keep ? src_row[i0] : T(0);
}
} |
ggml/src/ggml-cuda/cumsum.cu
Outdated
| // Load value and compute prefix sum within warp | ||
| float val = static_cast<float>(src_row[i0]); | ||
| val = warp_prefix_inclusive_sum(val); | ||
| dst_row[i0] = static_cast<T>(val); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be much preferable to store the temporary results in registers or shared memory rather than global memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't val here already stored in a register though? I'm afraid I'll need some more guidance here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dst_row is in global memory. With this code you are writing data to VRAM on this line, only to later read this data again, add a value to it, and write it back. So you have 3x as much I/O to the comparatively slow VRAM vs. the comparatively faster SRAM or registers where you could be storing it instead until you write the data once at the end of the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I get it, thanks!
|
Regarding the implementation proposed by @wsbagnsv1 . If one were to do something like that the in my opinion correct way to do it would be to calculate start and end points for copying and for zeroing and to then simply do 2 loops over those areas. If at all possible a conditional statement inside the loop should be avoided. But that would potentially make the kernel less flexible if other patterns for |
|
Okay, when adding in @JohannesGaessler's remarks about not calculating the comparison in kernel code, @wsbagnsv1's optimizations just flowed naturally, so I just combined it. EDIT: nvm, had wrong strides |
|
Okay, I implemented the double loop algorithm. I think those cases that are now templated are the only cases that will be supported, so it's probably fine this way. |
|
@gabe-l-hart would be grateful if you could look at the HIP code fixes, I have completely no idea what I'm doing there (and not able to test either aside from the CI). |
Unfortunately, I'm not much use here as I also don't have any background with HIP. I just tried installing it on my GB10 device, but haven't had any luck. |
|
|
||
| static __device__ __forceinline__ unsigned int get_warp_mask() { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| return __ballot(1); // HIP equivalent |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know basically nothing about HIP, but according to this doc, it seems like __activemask(); should be supported? The main difference referenced there is the warp size of 64 vs 32 which I could absolutely imagine being accidentally hard coded somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifically, I see #define WARP_SIZE 32 at the top of this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc/ @IMbackK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the WARP_SIZE is deprecated and the remaining uses should only be used in places affecting performance, but not correctness, the non-deprecated equivalent is ggml_cuda_get_physical_warp_size
__activemask is indeed supported and works, but i will need to check how long - will do that later.
We will need to change the return type of this and the kernel below, @pwilkin you can do so or skip the kernel on hip and i will fix it in a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IMbackK okay, I'll comment it out then and add a TODO, prefer to leave it so someone who knows what they're doing then leave an untested vibe-coded patch :)
|
@pwilkin not sure if you missed my comment, but CUB should be superior for most cases |
Ah, completely forgot about that one! Yeah, will do. |
|
All right, implemented CUB-compatible version per @am17an's request, removed the global memory access per @JohannesGaessler's request (I'd be lying if I said I figured all of that on my own, fortunately, it turns out the new DeepSeek 3.2 Speciale is quite good at both optimizing kernels and explaining it). After all the optimizations expecially the biggest case improved a lot, also, the fallback implementation is performance-wise very similar to the BlockScan implementation. |
|
What I meant was to use the out of the box function https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceScan.html for the prefix sum |
|
@am17an Yeah, ended up using https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockScan.html instead since DeviceScan can't be used inside kernels and is only for single-array cumulative sums. The function (InclusiveSum) is pretty much the same. |
|
@am17an done |
|
Now I need to wait for the HIP CI job to finish so that I know what to comment out :) |
|
Okay, since we're not supporting F16/BF16 in CPU anyway, I'll comment them out since there are some errors on other platforms with the bfloat16 conversions. |
Using |
Extracted and adapted kernels by @gabe-l-hart from #16623