Skip to content

perf: optimize grad_weight accumulation with addmm#1239

Open
maskyuanzh wants to merge 5 commits into
linkedin:mainfrom
maskyuanzh:fix-fused-linear-ce-addmm
Open

perf: optimize grad_weight accumulation with addmm#1239
maskyuanzh wants to merge 5 commits into
linkedin:mainfrom
maskyuanzh:fix-fused-linear-ce-addmm

Conversation

@maskyuanzh
Copy link
Copy Markdown

@maskyuanzh maskyuanzh commented May 26, 2026

Summary

This PR optimizes grad_weight accumulation in fused linear cross entropy by replacing:

grad_weight += torch.mm(...).float()

with an in-place torch.addmm(..., out=grad_weight)-based accumulation.

For PyTorch >= 2.12 on CUDA, when accumulating fp16/bf16 operands into a fp32 grad_weight, this uses:

torch.addmm(..., out_dtype=torch.float32, out=grad_weight)

This avoids materializing the full [V, H] intermediate from torch.mm(...).float().

For older PyTorch versions where out_dtype is unavailable, this falls back to explicitly aligning the operand dtypes with grad_weight.dtype before calling:

torch.addmm(..., out=grad_weight)

The fallback is not zero-temp because the fp32 operand casts are materialized, but the temporary memory scales with the operand shapes, roughly chunk_size * (V + H), instead of the full V * H output intermediate.

Fixes #1232.

Memory Benchmark

I benchmarked a 128k-vocab case with V=131072, H=4096, chunk_size=2048, bf16 inputs, and fp32 grad_weight on an NVIDIA GeForce RTX 4090.

PyTorch 2.1.2+cu121:
  old mm(...).float():                     extra peak 3072 MiB
  fallback fp32 operands + addmm(out=...): extra peak 1056 MiB
  proposed auto path:                      extra peak 1056 MiB

PyTorch 2.12.0+cu126:
  old mm(...).float():                     extra peak 3072 MiB
  addmm(out_dtype=torch.float32, out=...): extra peak 0 MiB
  fallback fp32 operands + addmm(out=...): extra peak 1056 MiB
  proposed auto path:                      extra peak 0 MiB

So on PyTorch 2.12+, the out_dtype path removes the large [V, H] peak allocation in this configuration. On PyTorch 2.1.2, the compatibility fallback reduces the extra peak memory from about 3.0 GiB to about 1.03 GiB.

Testing Done

  • Hardware Type: NVIDIA GeForce RTX 4090
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Additional targeted testing:

pytest -q test/transformers/test_fused_linear_cross_entropy.py

Passed. This covers the fused linear cross entropy paths affected by this change.

@maskyuanzh maskyuanzh changed the title Optimize grad_weight accumulation with addmm perf: optimize grad_weight accumulation with addmm May 26, 2026
if ce_weight.stride(-1) != 1:
ce_weight = ce_weight.contiguous()

IS_TORCH2P12 = Version(torch.__version__.split("+")[0]) >= Version("2.12.0")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably be located globally, so it doesn't need to run every forward pass.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I’ll move the PyTorch version check to the module-level constants so it is only evaluated once instead of on every forward pass.

grad_weight,
grad_logits_chunk.t(),
_input_chunk.to(dtype=grad_logits_chunk.t().dtype),
out_dtype=torch.float32,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think technically out_dtype is available earlier than 2.12, I just didnt do the work to track down which version it was introduced in. I think I remember it existing in 2.10 as well.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. The exact version wasn’t carefully verified here. After checking the PyTorch docs and source tags, I found that out_dtype was added to torch.addmm in PyTorch 2.8.0 for fp16/bf16 CUDA inputs with fp32 output accumulation. I’ll lower the version guard from 2.12.0 to 2.8.0.

torch.addmm(
grad_weight,
grad_logits_chunk.t(),
_input_chunk.to(dtype=grad_logits_chunk.t().dtype),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this cast now necessary?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking. The cast is needed because out_dtype only controls the output/accumulation dtype; torch.addmm still requires mat1 and mat2 to have the same input dtype.

I tested this on PyTorch 2.12.0 + CUDA. With out_dtype=torch.float32, addmm still fails for fp16 x fp32 and bf16 x fp32 inputs:

Half and Float     -> RuntimeError: mat1 and mat2 must have the same dtype
BFloat16 and Float -> RuntimeError: mat1 and mat2 must have the same dtype

After casting mat2 to match mat1, both fp16 and bf16 paths succeed and write into a fp32 output buffer. In the AMP path here, grad_logits_chunk is the low-precision operand while _input_chunk can remain fp32, so this cast aligns _input_chunk with grad_logits_chunk and keeps the multiply in fp16/bf16 while accumulating into fp32.

Comment on lines +230 to +231
grad_logits_chunk.t().to(grad_weight.dtype),
_input_chunk.to(grad_weight.dtype),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the input, chunk, and weight all need the same dtype on this path? The desired behavior is typically to multiply in bf16 and then to accumulate in fp32. Doing the multiply in fp32 as well would be pretty slow, so I dont think this is advisable.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I tested the fp32-operand fallback and found that your concern was correct: although it reduces peak memory compared with the old mm(...).float() path, it is slower because the matmul itself runs in fp32.

old: mm(lowp, lowp).float()               23.5 ms, peak 5688 MiB
old fallback: addmm(fp32, fp32, out=fp32) 27.4 ms, peak 3640 MiB
fast: addmm(lowp, lowp, out_dtype=fp32)   13.4 ms, peak 2632 MiB

I updated the logic so this path no longer promotes both operands to fp32. The addmm(..., out_dtype=torch.float32, out=grad_weight) path is now only used when out_dtype is supported, grad_weight is fp32, and grad_logits_chunk is fp16/bf16. Since addmm(..., out=...) does not autocast the operands, and out_dtype only controls the output dtype, _input_chunk is explicitly cast to grad_logits_chunk’s dtype. This keeps the matmul in fp16/bf16 while writing directly into the fp32 accumulation buffer. For unsupported cases, the code now falls back to the original mm(...).float() behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mem and compute inefficiency in fused_linear_cross_entropy_foward

2 participants