Optimize Mamba2 memory usage by replacing broadcast with einsum#41561
Optimize Mamba2 memory usage by replacing broadcast with einsum#41561SohamWalam11 wants to merge 7 commits intohuggingface:mainfrom
Conversation
|
This pull request resolves a critical bug where the Mamba2 model's output was inconsistent between its optimized 'fast' inference path and its standard path. The discrepancy caused multiple test cases to fail. By debugging the low-level implementation, I identified and fixed a numerical inconsistency, ensuring the model's output is now reliable and all validation tests pass |
|
I will always support a beautiful But @SohamWalam11 I'm a bit confused, I ran those tests on |
|
@Rocketknight1 Thank you for reviewing. Summary of Changes Memory usage reduced from gigabytes to megabytes in large-scale runs. Replaced inefficient broadcasting with a compact `torch.eins |
|
cc @ArthurZucker - I'll have to do some testing, but this seems like something we should accept if you're okay with it. The old code does an element-wise broadcast-multiplication and then a sum on two tensors, and that seems like code that should clearly be replaced by a matmul or einsum because it creates a huge intermediate array! We could also use a similar trick to eliminate the computation of |
|
Thanks, @ArthurZucker. I agree — the previous element-wise broadcast-multiplication was unnecessarily creating large intermediate arrays, which led to excessive memory usage without numerical benefits. The updated implementation replaces that pattern with a direct einsum, which performs the same operation in a single, memory-efficient step. I’ll also look into applying a similar optimization to the M_intermediate computation as suggested — that should further reduce memory overhead and improve runtime efficiency. I’ll run additional benchmarks and share the profiling results here once done. |
|
These operations were deliberately kept to torch without einsum. If you really want to have an efficient implementation, then there is no way around the kernels. I have a complete einsum version at https://github.com/vasqu/mamba2-torch/blob/main/tests/ssd_minimal.py for example cc @molbap |
|
If einsum is totally not allowed then I think at least one of these could become a transpose-matmul-transpose, which at least would not create the intermediate array (transpose just alters the stride metadata, so no extra memory is allocated) |
|
Totally open for any improvements. Maybe I sounded to harsh what I wanted to tell is that:
|
|
Yes, the mamba2 algorithms are based on nd tensor contractions and as said above, array-based impl is deliberate, welcome to optimize it of course, we already merged some (non-einsum) optimizations for the array ops in a couple mamba-like models so if an equivalent formulation without einsum is better, I'm all for it! |
|
🔥 ✊ the people demand einsums ✊ 🔥 |
|
https://github.com/vasqu/mamba2-torch/blob/main/tests/ssd_minimal.py has everything with einsums, just needs benchmarking now :P |
SohamWalam11
left a comment
There was a problem hiding this comment.
these are the proposed changes made
ArthurZucker
left a comment
There was a problem hiding this comment.
Hey, as always for performance boost claims we need to see benchmarks.
@Rocketknight1 maybe we should just provide a general script for this?
Anyways we are fully onboard better perfs, even for the "slow" path it would help mps and cpu and other hardwares without kernels. BUT we need good enough motivation 😉
|
@ArthurZucker I tried quickly but I can't even get a performance benchmark at a normal seq_len because of this line in the original, lol: G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)This is an outer multiplication of two attention tensors, creating a I can keep attempting, maybe I can get a benchmark on a different GPU or with shorter seq lengths, but the memory issue is enough to justify the change imo. |
|
Yeah that makes sense, IDK if you were training or not, it would probably have been worse. |
|
yeah I remember this line @Rocketknight1, took me some time to figure out how to reproduce this computation without tensor contraction and couldn't without an intermediate 😅 (maybe there's a way?) however I was able to run it/measure memory usage on a standard machine, surprised it OOMs immediately |
|
@molbap I think you'd have to transpose the inputs to |
|
Also looking at this more, this whole block (not just |
|
@Rocketknight1 see https://github.com/vasqu/mamba2-torch/blob/main/tests/ssd_minimal.py, the torch only version was mostly derived from this. It's complete einsum with least intermediate tensors |
|
@ArthurZucker permission to revert to the no-intermediates big einsum approach to save memory + time and add the non-einsum path as comments instead? 👼 |
ArthurZucker
left a comment
There was a problem hiding this comment.
Yes let's merge but revert the test please
|
@ArthurZucker The tests has been passed perfectly should I share the screenshot of that peculiarly |
| # Contraction of C and B to get G (attention-weights like) | ||
| G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) | ||
| G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) | ||
| G = torch.einsum('bclhn,bcshn->bclsh', C, B) |
There was a problem hiding this comment.
maxi-nit, I think einsum discard spaces, could we add some here for readability? feels like
G = torch.einsum('b c l h n, b c s h n -> b c l s h', C, B)is more readable than
G = torch.einsum('bclhn,bcshn->bclsh', C, B)|
Since I got approval, I committed the full einsum approach to speed things up and hugely cut memory usage! |
| decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) | ||
| decay_chunk = decay_chunk.transpose(1, 3) | ||
| new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) | ||
| new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) |
There was a problem hiding this comment.
same nit as above, maybe a smol space or two around there to make it more readable?
|
[For maintainers] Suggested jobs to run (before merge) run-slow: mamba2 |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
vasqu
left a comment
There was a problem hiding this comment.
Im good with this, would be nice if we could have small benches just to confirm that it is indeed worth it
Also, @molbap meant the letters to have a space in between I think. Might get too much given the amount of letters. Don't have a strong opinion there.
|
That might add linebreaks and I'm not sure it'd end up clearer! |
|
I stayed vague because just "some space" is good haha. many letters indeed. Just wanted some oxygen in the middle of the einsum 👀 |
|
@SohamWalam11 are you still willing to run some benchmarks to check on the latest version? It's just four big einsums now, so hopefully it should be a lot more performant |
|
Yes @Rocketknight1 |
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.