perf: optimize grad_weight accumulation with addmm#1239
Conversation
| if ce_weight.stride(-1) != 1: | ||
| ce_weight = ce_weight.contiguous() | ||
|
|
||
| IS_TORCH2P12 = Version(torch.__version__.split("+")[0]) >= Version("2.12.0") |
There was a problem hiding this comment.
This could probably be located globally, so it doesn't need to run every forward pass.
There was a problem hiding this comment.
Thanks for pointing this out. I’ll move the PyTorch version check to the module-level constants so it is only evaluated once instead of on every forward pass.
| grad_weight, | ||
| grad_logits_chunk.t(), | ||
| _input_chunk.to(dtype=grad_logits_chunk.t().dtype), | ||
| out_dtype=torch.float32, |
There was a problem hiding this comment.
I think technically out_dtype is available earlier than 2.12, I just didnt do the work to track down which version it was introduced in. I think I remember it existing in 2.10 as well.
There was a problem hiding this comment.
Thanks for pointing this out. The exact version wasn’t carefully verified here. After checking the PyTorch docs and source tags, I found that out_dtype was added to torch.addmm in PyTorch 2.8.0 for fp16/bf16 CUDA inputs with fp32 output accumulation. I’ll lower the version guard from 2.12.0 to 2.8.0.
| torch.addmm( | ||
| grad_weight, | ||
| grad_logits_chunk.t(), | ||
| _input_chunk.to(dtype=grad_logits_chunk.t().dtype), |
There was a problem hiding this comment.
Why is this cast now necessary?
There was a problem hiding this comment.
Thanks for asking. The cast is needed because out_dtype only controls the output/accumulation dtype; torch.addmm still requires mat1 and mat2 to have the same input dtype.
I tested this on PyTorch 2.12.0 + CUDA. With out_dtype=torch.float32, addmm still fails for fp16 x fp32 and bf16 x fp32 inputs:
Half and Float -> RuntimeError: mat1 and mat2 must have the same dtype
BFloat16 and Float -> RuntimeError: mat1 and mat2 must have the same dtype
After casting mat2 to match mat1, both fp16 and bf16 paths succeed and write into a fp32 output buffer. In the AMP path here, grad_logits_chunk is the low-precision operand while _input_chunk can remain fp32, so this cast aligns _input_chunk with grad_logits_chunk and keeps the multiply in fp16/bf16 while accumulating into fp32.
| grad_logits_chunk.t().to(grad_weight.dtype), | ||
| _input_chunk.to(grad_weight.dtype), |
There was a problem hiding this comment.
So the input, chunk, and weight all need the same dtype on this path? The desired behavior is typically to multiply in bf16 and then to accumulate in fp32. Doing the multiply in fp32 as well would be pretty slow, so I dont think this is advisable.
There was a problem hiding this comment.
Thanks for pointing this out. I tested the fp32-operand fallback and found that your concern was correct: although it reduces peak memory compared with the old mm(...).float() path, it is slower because the matmul itself runs in fp32.
old: mm(lowp, lowp).float() 23.5 ms, peak 5688 MiB
old fallback: addmm(fp32, fp32, out=fp32) 27.4 ms, peak 3640 MiB
fast: addmm(lowp, lowp, out_dtype=fp32) 13.4 ms, peak 2632 MiB
I updated the logic so this path no longer promotes both operands to fp32. The addmm(..., out_dtype=torch.float32, out=grad_weight) path is now only used when out_dtype is supported, grad_weight is fp32, and grad_logits_chunk is fp16/bf16. Since addmm(..., out=...) does not autocast the operands, and out_dtype only controls the output dtype, _input_chunk is explicitly cast to grad_logits_chunk’s dtype. This keeps the matmul in fp16/bf16 while writing directly into the fp32 accumulation buffer. For unsupported cases, the code now falls back to the original mm(...).float() behavior.
Summary
This PR optimizes
grad_weightaccumulation in fused linear cross entropy by replacing:with an in-place
torch.addmm(..., out=grad_weight)-based accumulation.For PyTorch >= 2.12 on CUDA, when accumulating fp16/bf16 operands into a fp32
grad_weight, this uses:This avoids materializing the full
[V, H]intermediate fromtorch.mm(...).float().For older PyTorch versions where
out_dtypeis unavailable, this falls back to explicitly aligning the operand dtypes withgrad_weight.dtypebefore calling:The fallback is not zero-temp because the fp32 operand casts are materialized, but the temporary memory scales with the operand shapes, roughly
chunk_size * (V + H), instead of the fullV * Houtput intermediate.Fixes #1232.
Memory Benchmark
I benchmarked a 128k-vocab case with
V=131072,H=4096,chunk_size=2048, bf16 inputs, and fp32grad_weighton an NVIDIA GeForce RTX 4090.So on PyTorch 2.12+, the
out_dtypepath removes the large[V, H]peak allocation in this configuration. On PyTorch 2.1.2, the compatibility fallback reduces the extra peak memory from about 3.0 GiB to about 1.03 GiB.Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergenceAdditional targeted testing:
Passed. This covers the fused linear cross entropy paths affected by this change.