Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check failed in collective_pipeliner when using gradient accumulation with non-unrolled loop #22210

Closed
qGentry opened this issue Jul 1, 2024 · 3 comments
Assignees
Labels
better_errors Improve the error reporting needs info More information is required to diagnose & prioritize the issue.

Comments

@qGentry
Copy link

qGentry commented Jul 1, 2024

Description

Hi, I have following setup:

  • Transformer model with N layers scanned over input
  • fully sharded data parallel sharding
  • asynchronous communications (latency-hiding scheduler, pipelined all-gather,all-reduce,reduce-scatter)

I'm using following flags:

--xla_gpu_graph_level=0 
--xla_gpu_enable_triton_gemm=false 
--xla_gpu_enable_command_buffer= 
--xla_gpu_enable_latency_hiding_scheduler=true 
--xla_gpu_enable_all_gather_combine_by_dim=false 
--xla_gpu_enable_reduce_scatter_combine_by_dim=false 
--xla_gpu_enable_pipelined_all_gather=true 
--xla_gpu_enable_pipelined_reduce_scatter=true 
--xla_gpu_enable_pipelined_all_reduce=true 
--xla_gpu_enable_pipelined_collectives=false 
--xla_gpu_enable_while_loop_double_buffering=true 
--xla_gpu_enable_highest_priority_async_stream=true 
--xla_gpu_disable_async_collectives=collectivebroadcast,alltoall,collectivepermute

This works correctly and indeed hide layers' weights all-gather and gradient reduce-scatter behind computations.

Problems are starting to arise when I try to use gradient accumulation in this setup. It is implemented like this:

    grads_sum = jax.tree_map(jnp.zeros_like, train_state.params)
    train_state, grads_sum = jax.lax.fori_loop(
        lower=0,
        upper=num_minibatches_in_batch,
        body_fun=_loop_body,
        init_val=(train_state, grads_sum),
        unroll=False,
    )

    mean_grads = jax.tree_map(lambda x: x / num_minibatches_in_batch, grads_sum)

When I set gradient accumulation factor (num_minibatches_in_batch in this snippet) to value greater than 1, I'm getting following error during compilation:

2024-07-01 12:57:35.488299: F external/xla/xla/service/collective_pipeliner.cc:675] Check failed: last_cloned != nullptr (0 vs. nullptr)

Here is --xla_dump_to result:
xla_dump.tgz

One important fact here is that if I set unroll in jax.lax.fori_loop to True, then there is no compilation error and everything works. But this obviously leads to additional memory usage proportional to gradient accumulation factor so this hack doesn't seem to be viable.

My hypothesis is that when using --xla_gpu_enable_while_loop_double_buffering=true with pipelined collectives and latency hiding scheduler, XLA compiler tries to double buffer this fori_loop which is actually undesired behavior.

Basically, there are two problems:

  • Bug in compiler that leads to hard-to-parse source of error in JAX code
  • If my hypothesis is correct, I would like to have mechanism to disable while_loop_double_buffering for specific loops (like gradient accumulation loop) or enable only for specific loops (like layers loop)

I've tested this on JAX 0.4.29 and 0.4.30.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.24.3
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ffisin-dev-8gpu', release='5.4.0-155-generic', version='#172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023', machine='x86_64')


$ nvidia-smi
Mon Jul  1 13:21:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   68C    P0             141W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:91:00.0 Off |                    0 |
| N/A   48C    P0             121W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:95:00.0 Off |                    0 |
| N/A   69C    P0             137W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:99:00.0 Off |                    0 |
| N/A   50C    P0             126W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:AB:00.0 Off |                    0 |
| N/A   68C    P0             142W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:AF:00.0 Off |                    0 |
| N/A   49C    P0             124W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   68C    P0             143W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:B7:00.0 Off |                    0 |
| N/A   48C    P0             121W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
@qGentry
Copy link
Author

qGentry commented Jul 1, 2024

related XLA issue:
openxla/xla#14332

@mattjj
Copy link
Collaborator

mattjj commented Jul 9, 2024

Thanks for raising this.

I think it's an XLA:GPU issue, and we don't have any way to fix it from JAX.

That said, the hard-to-parsae error may be something we can get traction on from JAX... can you say a bit more about what would've helped in the error message? We attach Python source information to the HLO program, but it's up to XLA to raise errors that reference it... from JAX we could've at least told you which jitted function raised the compiler error, but I'm not sure if we have other information to provide...

@mattjj mattjj self-assigned this Jul 9, 2024
@mattjj mattjj added better_errors Improve the error reporting needs info More information is required to diagnose & prioritize the issue. and removed bug Something isn't working labels Jul 20, 2024
@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2024

I think we should close this in favor of the XLA issue. Looks like it just got assigned yesterday!

@mattjj mattjj closed this as completed Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting needs info More information is required to diagnose & prioritize the issue.
Projects
None yet
Development

No branches or pull requests

3 participants