-
Notifications
You must be signed in to change notification settings - Fork 13.4k
vulkan: Add State Space Model (SSM) Operations Support #16463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this contribution!
warp_sdata[warp_offset + lane] = val; | ||
barrier(); | ||
|
||
if (lane < 16) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 16]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it's assuming a subgroup size of 32 (also at line 37).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I understand correctly that this doesn't actually rely on a subgroup size of 32, but it's splitting the workgroup into groups of 32 and just reducing those (and it looks like some reduction across groups of 32 has already happened?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I've missed this one. Yeah I don't think it would work with a size != 32. I need to think more through this one.
Do you've any suggestions on what I could do here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this may work because you're not relying on SubgroupInvocationId or SubgroupID, you've just split the workgroup into groups of 32. Maybe we can just test it on AMD (with wave64) and Intel and verify that it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it work on Intel, but I am worried about all these settings that we made configurable. I've not really tried how it behaves with different values of the constants we defined. Or is the assumption that these values should not be tweaked from vulkan-shaders-gen.cpp
without also changing the implementation in the shader?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could change this to a while loop that will handle any power-of-two value of WARP_SIZE. We do want to allow the spec constants to be changeable but it's fine to have limitations like "must be a power of two".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm wave64 AMD and wave8 llvmpipe are failing one test here, possibly due to this. All other tests are passing.
[SSM_SCAN] NMSE = 31335529439335960.000000000 > 0.000000100 SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4): FAIL
I think Intel also has a subgroup size of 32 so it wouldn't be a good test for this.
return warp_sdata[warp_offset]; | ||
} | ||
|
||
void main() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do all threads always load/store in bounds? In the host code there was some rounding up going on, which suggests maybe some threads don't correspond to in-bounds locations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ping on this one. I don't really understand what this shader does and which locations it should be accessing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried to follow what the CUDA shader does. I'll spend more time on it and see if there is anything I can improve about memory access and make sure all the assumptions in the code are checked.
I've addressed the comments and pushed a new version. The results are even better now:
|
|
||
string_to_spv("ssm_scan_f32_d16", "ssm_scan.comp", {{"A_TYPE", "float"}}); | ||
string_to_spv("ssm_scan_f32_d128", "ssm_scan.comp", {{"A_TYPE", "float"}}); | ||
string_to_spv("ssm_scan_f32_d256", "ssm_scan.comp", {{"A_TYPE", "float"}}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three are all identical now, you only need one.
warp_sdata[warp_offset + lane] = val; | ||
barrier(); | ||
|
||
if (lane < 16) warp_sdata[warp_offset + lane] += warp_sdata[warp_offset + lane + 16]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could change this to a while loop that will handle any power-of-two value of WARP_SIZE. We do want to allow the spec constants to be changeable but it's fine to have limitations like "must be a power of two".
I've completely replaced the code to reduce sum with In the last version I've also renamed |
Be aware that not all devices support subgroup commands. If there's a performance advantage to using them, you can do that, but it would still need a fallback to using a shared memory reduction. If there isn't a performance advantage, just use the shared memory reduction for compatibility. |
with the subgroup code I get:
if I revert to the reduction loop, I've:
|
I've reverted to the version with a for loop. We can look at the subgroup optimization later |
c467631
to
6e70718
Compare
c9d79db
to
b5ed953
Compare
could it be approved again? |
I'll do a proper review soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works fine now.
}; | ||
|
||
void main() { | ||
const uint global_thread_id = gl_WorkGroupID.x * gl_WorkGroupSize.x + gl_LocalInvocationID.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In Vulkan you can shorten this to gl_GlobalInvocationID.x
const int stride_dt = int(src2_nb1) / 4; | ||
const int stride_B = int(src4_nb2) / 4; | ||
const int stride_C = int(src5_nb2) / 4; | ||
const int stride_y = int(n_head * d_head); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use int everywhere? It leads to a lot of casting, and the values don't look like they can/should be negative. Indices should be uints.
state[j] = s0[s0_base_idx + j * D_STATE + tid]; | ||
} | ||
|
||
if (tid >= D_STATE) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't D_STATE
the workgroup size as well? If that's the case, this can't be true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added only to make sure reading the code that there can't be any OOB access, but I agree it is superfluous. I'll drop it
float dt_soft_plus = dt[dt_base_idx + i * stride_dt]; | ||
dt_soft_plus = softplus(dt_soft_plus); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float dt_soft_plus = dt[dt_base_idx + i * stride_dt]; | |
dt_soft_plus = softplus(dt_soft_plus); | |
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); |
Not sure how much of a difference this kind of stuff makes, but with the large variety of Vulkan compilers, I prefer to make it as easy as possible for them with const.
|
||
int lane = tid % SUBGROUP_SIZE; | ||
|
||
warp_sdata[tid] = y; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, I don't understand or know the algorithm. Is the switch to the second shared buffer here to save on a barrier? It should be possible to continue the reduction with only stateC
, or am I missing something?
Of course it also makes replacing the step with subgroup operations easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no reason, I've tried different ways before I got it to work. Your version is much better, thanks for the suggestion
thanks for the review, pushed an updated version |
updated results on Intel Arc and NVIDIA L40S: Intel Arc: ggml_vulkan: 0 = Intel(R) Arc(tm) Graphics (MTL) (Intel open-source Mesa driver) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
build: bc07349 (6756) ggml_vulkan: 0 = Intel(R) Arc(tm) Graphics (MTL) (Intel open-source Mesa driver) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
build: a4d94598e (6769) NVIDIA: ggml_vulkan: 0 = NVIDIA L40S (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
build: 554fd57 (6766) ggml_vulkan: 0 = NVIDIA L40S (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
build: a4d94598e (6769) |
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
}; | ||
struct vk_op_ssm_scan_push_constants { | ||
uint32_t src0_nb2, src0_nb3, src1_nb2, src1_nb3; | ||
uint32_t src2_nb1, src2_nb2, src3_nb1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are usually named like src2_nb1
-> nb21
.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
const uint32_t nr = src0->ne[1]; | ||
const uint32_t n_t = dst->ne[1]; | ||
const uint32_t n_s = dst->ne[2]; | ||
elements = { CEIL_DIV(nr, 32), n_t, n_s }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like you could use wg_denoms for this one (make wg_denoms = {32, 1, 1} and not explicitly use CEIL_DIV here).
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
if (ggml_is_quantized(op->src[0]->type) || ggml_is_quantized(op->src[1]->type) || ggml_is_quantized(op->src[2]->type)) { | ||
return false; | ||
} | ||
if (op->src[3] && ggml_is_quantized(op->src[3]->type)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could maybe do these ggml_is_quantized checks in a loop or using any_of.
@jeffbolznv thanks again, fixed the last comments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. All my comments have been addressed. I haven't had a chance to actually run it yet, I can try to do that tomorrow, but I don't want to block the change.
@0cc4m would you mind another look? |
I ran the backend tests locally and they passed and had no validation errors. I noticed one test was unsupported and I thought this was supported in earlier versions of the change. Was it intentional to remove it?
|
Oh, there's a failure in the lavapipe CI:
lavapipe uses an unusual subgroup size, that might be related. |
yeah I've dropped support for Mamba 1 from this PR as I've not got it to work yet. I can add it as a follow up PR |
I did a quick experiment and changing spec constant 1 from device->subgroup_size to 32 fixed the lavapipe failure. |
how do I run this locally? In what CI task is it happening? |
It's the ubuntu-24-cmake-vulkan CI job. I think you'll need either Linux or WSL to run lavapipe (I run it on Windows with WSL). You'll probably need to set the env var GGML_VK_VISIBLE_DEVICES=0 (assuming this is enumerated as the first device) to enable lavapipe. |
nevermind, I see it now. Taking a look |
to simplify the review, I post here the patch I've applied on the top of the previous version: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
index 06aa10bfe..daebb5dc0 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
@@ -96,7 +96,7 @@ void main() {
barrier();
}
- [[unroll]] for (uint j = 0; j < SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
+ [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
@@ -104,13 +104,13 @@ void main() {
uint lane = tid % SUBGROUP_SIZE;
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
- if (lane < offset) {
+ if (idx < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
- if (tid % SUBGROUP_SIZE == 0) {
+ if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = stateC[idx];
} the test passes locally now |
Passes for me, too. |
Wow, that really is a game changer for Vulkan on nvidia cards. Here an example of my 4060 Mobile GPU:
.
. Original:
.
I understand that one needs special SSM/Mamba(?) LLM models in order to get this. Does anyone know how to specifically search for these kind of models on Huggingface? Thanks so much for your work on this, giuseppe, and please get this into upstream... Regards, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works as intended, I just have some comments about the supports_op code.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
const uint32_t MAX_D_STATE = 256; | ||
|
||
size_t stateC_size = SPLIT_H * MAX_D_STATE * sizeof(float); | ||
size_t warp_sdata_size = MAX_D_STATE * sizeof(float); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You forgot to update this calculation.
ggml/src/ggml-vulkan/ggml-vulkan.cpp
Outdated
const uint32_t SPLIT_H = 16; | ||
const uint32_t MAX_D_STATE = 256; | ||
|
||
size_t stateC_size = SPLIT_H * MAX_D_STATE * sizeof(float); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be using d_state instead of MAX_D_STATE, since the smaller shader may fit even if the large one does not?
Add State Space Model scan operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Add State Space Model conv operation to the Vulkan backend. Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
@0cc4m fixed the last comments and also added the new operations to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, looks good now.
@0cc4m thanks! CI is green now |
implement SSM scan and SSM conv for Vulkan.
Intel Arc:
ggml_vulkan: 0 = Intel(R) Arc(tm) Graphics (MTL) (Intel open-source Mesa driver) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
build: bc07349 (6756)
ggml_vulkan: 0 = Intel(R) Arc(tm) Graphics (MTL) (Intel open-source Mesa driver) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 65536 | int dot: 1 | matrix cores: none
build: a4d94598e (6769)
NVIDIA:
ggml_vulkan: 0 = NVIDIA L40S (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
build: 554fd57 (6766)
ggml_vulkan: 0 = NVIDIA L40S (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
build: a4d94598e (6769)