Skip to content

Optimize Mamba2 memory usage by replacing broadcast with einsum#41561

Open
SohamWalam11 wants to merge 7 commits intohuggingface:mainfrom
SohamWalam11:fix-mamba2-output-discrepancy
Open

Optimize Mamba2 memory usage by replacing broadcast with einsum#41561
SohamWalam11 wants to merge 7 commits intohuggingface:mainfrom
SohamWalam11:fix-mamba2-output-discrepancy

Conversation

@SohamWalam11
Copy link
Copy Markdown

@SohamWalam11 SohamWalam11 commented Oct 14, 2025

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@SohamWalam11
Copy link
Copy Markdown
Author

This pull request resolves a critical bug where the Mamba2 model's output was inconsistent between its optimized 'fast' inference path and its standard path. The discrepancy caused multiple test cases to fail. By debugging the low-level implementation, I identified and fixed a numerical inconsistency, ensuring the model's output is now reliable and all validation tests pass

@Rocketknight1
Copy link
Copy Markdown
Member

I will always support a beautiful einsum instead of a wasteful and ugly intermediate array 👼

But @SohamWalam11 I'm a bit confused, I ran those tests on main and the output difference between train and eval was 0. Can you clarify if this is a PR to optimize memory usage or to fix numerical discrepancies?

@SohamWalam11
Copy link
Copy Markdown
Author

@Rocketknight1 Thank you for reviewing.
This PR focuses on memory optimization rather than correcting numerical discrepancies. The goal is to eliminate unnecessary intermediate tensors and reduce the overall memory footprint while maintaining functional equivalence.

Summary of Changes

Memory usage reduced from gigabytes to megabytes in large-scale runs.

Replaced inefficient broadcasting with a compact `torch.eins

@Rocketknight1
Copy link
Copy Markdown
Member

cc @ArthurZucker - I'll have to do some testing, but this seems like something we should accept if you're okay with it. The old code does an element-wise broadcast-multiplication and then a sum on two tensors, and that seems like code that should clearly be replaced by a matmul or einsum because it creates a huge intermediate array!

We could also use a similar trick to eliminate the computation of M_intermediate below. Those intermediate arrays are massive and I don't think the compiler compiles them out.

@SohamWalam11
Copy link
Copy Markdown
Author

Thanks, @ArthurZucker. I agree — the previous element-wise broadcast-multiplication was unnecessarily creating large intermediate arrays, which led to excessive memory usage without numerical benefits.

The updated implementation replaces that pattern with a direct einsum, which performs the same operation in a single, memory-efficient step. I’ll also look into applying a similar optimization to the M_intermediate computation as suggested — that should further reduce memory overhead and improve runtime efficiency.

I’ll run additional benchmarks and share the profiling results here once done.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 14, 2025

These operations were deliberately kept to torch without einsum. If you really want to have an efficient implementation, then there is no way around the kernels. I have a complete einsum version at https://github.com/vasqu/mamba2-torch/blob/main/tests/ssd_minimal.py for example

cc @molbap

@Rocketknight1
Copy link
Copy Markdown
Member

If einsum is totally not allowed then I think at least one of these could become a transpose-matmul-transpose, which at least would not create the intermediate array (transpose just alters the stride metadata, so no extra memory is allocated)

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 14, 2025

Totally open for any improvements. Maybe I sounded to harsh what I wanted to tell is that:

  • we need convincing numbers on einsum vs torch (or other changes)
  • everyone onboard with this especially @ArthurZucker :D

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Oct 15, 2025

Yes, the mamba2 algorithms are based on nd tensor contractions and as said above, array-based impl is deliberate, welcome to optimize it of course, we already merged some (non-einsum) optimizations for the array ops in a couple mamba-like models so if an equivalent formulation without einsum is better, I'm all for it!

@Rocketknight1
Copy link
Copy Markdown
Member

🔥 ✊ the people demand einsums ✊ 🔥

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 15, 2025

https://github.com/vasqu/mamba2-torch/blob/main/tests/ssd_minimal.py has everything with einsums, just needs benchmarking now :P

Copy link
Copy Markdown
Author

@SohamWalam11 SohamWalam11 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are the proposed changes made

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, as always for performance boost claims we need to see benchmarks.
@Rocketknight1 maybe we should just provide a general script for this?

Anyways we are fully onboard better perfs, even for the "slow" path it would help mps and cpu and other hardwares without kernels. BUT we need good enough motivation 😉

@Rocketknight1
Copy link
Copy Markdown
Member

Rocketknight1 commented Oct 16, 2025

@ArthurZucker I tried quickly but I can't even get a performance benchmark at a normal seq_len because of this line in the original, lol:

G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]  # shape: (b, c, l, s, h, n)

This is an outer multiplication of two attention tensors, creating a (b, c, l, s, h, n) intermediate when the final output we want is only (b, c, l, s, h). G_intermediate is massive! It's n times larger than the actual tensor we want so my GPU goes OOM on this line! einsum computes the output tensor directly without creating G_intermediate and works fine.

I can keep attempting, maybe I can get a benchmark on a different GPU or with shorter seq lengths, but the memory issue is enough to justify the change imo.

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Yeah that makes sense, IDK if you were training or not, it would probably have been worse.
Fine with a supeer well commentedd einsum that explains what we go from and to !

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Oct 17, 2025

yeah I remember this line @Rocketknight1, took me some time to figure out how to reproduce this computation without tensor contraction and couldn't without an intermediate 😅 (maybe there's a way?) however I was able to run it/measure memory usage on a standard machine, surprised it OOMs immediately

@Rocketknight1
Copy link
Copy Markdown
Member

Rocketknight1 commented Oct 17, 2025

@molbap I think you'd have to transpose the inputs to (b, c, h, l, n) and (b, c, h, n, s), then do a matmul to get (b, c, h, l, s) and finally transpose that to get the (b, c, l, s, h) you want, lol. It would be extremely confusing and probably not very performant!

@Rocketknight1
Copy link
Copy Markdown
Member

Rocketknight1 commented Oct 17, 2025

Also looking at this more, this whole block (not just G) feels like one or two enormous tensor contractions - I'm going to see if I can replace it all with a single big einsum to completely wipe out all of those intermediates, and then leave the original path in as comments. That should make it possible to follow for people who don't like einsums, but should give us some nice memory + performance savings.

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Oct 17, 2025

@Rocketknight1 see https://github.com/vasqu/mamba2-torch/blob/main/tests/ssd_minimal.py, the torch only version was mostly derived from this. It's complete einsum with least intermediate tensors

@Rocketknight1
Copy link
Copy Markdown
Member

Rocketknight1 commented Oct 20, 2025

@ArthurZucker permission to revert to the no-intermediates big einsum approach to save memory + time and add the non-einsum path as comments instead? 👼

Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let's merge but revert the test please

@SohamWalam11
Copy link
Copy Markdown
Author

@ArthurZucker The tests has been passed perfectly should I share the screenshot of that peculiarly

# Contraction of C and B to get G (attention-weights like)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
G = torch.einsum('bclhn,bcshn->bclsh', C, B)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxi-nit, I think einsum discard spaces, could we add some here for readability? feels like

G = torch.einsum('b c l h n, b c s h n -> b c l s h', C, B)

is more readable than

G = torch.einsum('bclhn,bcshn->bclsh', C, B)

Comment thread tests/models/mamba2/test_mamba2_consistency.py Outdated
@Rocketknight1
Copy link
Copy Markdown
Member

Since I got approval, I committed the full einsum approach to speed things up and hugely cut memory usage!

decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
decay_chunk = decay_chunk.transpose(1, 3)
new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit as above, maybe a smol space or two around there to make it more readable?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, added!

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: mamba2

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im good with this, would be nice if we could have small benches just to confirm that it is indeed worth it

Also, @molbap meant the letters to have a space in between I think. Might get too much given the amount of letters. Don't have a strong opinion there.

@Rocketknight1
Copy link
Copy Markdown
Member

That might add linebreaks and I'm not sure it'd end up clearer!

@molbap
Copy link
Copy Markdown
Contributor

molbap commented Nov 19, 2025

I stayed vague because just "some space" is good haha. many letters indeed. Just wanted some oxygen in the middle of the einsum 👀
though +1 for benches

@Rocketknight1
Copy link
Copy Markdown
Member

Rocketknight1 commented Nov 19, 2025

@SohamWalam11 are you still willing to run some benchmarks to check on the latest version? It's just four big einsums now, so hopefully it should be a lot more performant

@SohamWalam11
Copy link
Copy Markdown
Author

Yes @Rocketknight1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants