Check failed in collective_pipeliner when using gradient accumulation with non-unrolled loop #22210
Labels
better_errors
Improve the error reporting
needs info
More information is required to diagnose & prioritize the issue.
Description
Hi, I have following setup:
I'm using following flags:
This works correctly and indeed hide layers' weights all-gather and gradient reduce-scatter behind computations.
Problems are starting to arise when I try to use gradient accumulation in this setup. It is implemented like this:
When I set gradient accumulation factor (num_minibatches_in_batch in this snippet) to value greater than 1, I'm getting following error during compilation:
Here is --xla_dump_to result:
xla_dump.tgz
One important fact here is that if I set
unroll
in jax.lax.fori_loop to True, then there is no compilation error and everything works. But this obviously leads to additional memory usage proportional to gradient accumulation factor so this hack doesn't seem to be viable.My hypothesis is that when using
--xla_gpu_enable_while_loop_double_buffering=true
with pipelined collectives and latency hiding scheduler, XLA compiler tries to double buffer this fori_loop which is actually undesired behavior.Basically, there are two problems:
I've tested this on JAX 0.4.29 and 0.4.30.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: