Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:GPU] Fine-grained remat policy makes async/pipelined collectives execute in the main stream #14397

Closed
qGentry opened this issue Jul 3, 2024 · 12 comments
Assignees

Comments

@qGentry
Copy link

qGentry commented Jul 3, 2024

Description

Hi, I have following setup:

  • Transformer model with N layers scanned over input
  • fully sharded data parallel sharding
  • asynchronous communications (latency-hiding scheduler, pipelined all-gather,all-reduce,reduce-scatter)

I'm using following flags:

--xla_gpu_graph_level=0 
--xla_gpu_enable_latency_hiding_scheduler=true 
--xla_gpu_enable_all_gather_combine_by_dim=false 
--xla_gpu_enable_reduce_scatter_combine_by_dim=false 
--xla_gpu_enable_pipelined_all_gather=true 
--xla_gpu_enable_pipelined_reduce_scatter=true 
--xla_gpu_enable_pipelined_all_reduce=true 
--xla_gpu_enable_pipelined_collectives=false 
--xla_gpu_enable_while_loop_double_buffering=true 
--xla_gpu_enable_highest_priority_async_stream=true 
--xla_gpu_all_reduce_combine_threshold_bytes=2147483648 
--xla_gpu_all_gather_combine_threshold_bytes=2147483648  
--xla_gpu_reduce_scatter_combine_threshold_bytes=2147483648
--xla_gpu_disable_async_collectives=collectivebroadcast,alltoall,collectivepermute

To speedup backward by fine-grained reduction of activations recomputation, I marked each dense layer's output in transformer block with specific name:

result = jax.lax.dot_general(
    inputs,
    kernel,
    dimension_numbers=((axis, contract_ind), ((), ())),
    precision=self.precision,
    preferred_element_type=self.accumulator_dtype,
)
result = jax.ad_checkpoint.checkpoint_name(result, self.activation_dot_name)

So, for example, in attention layer I have "dot_attention_query", "dot_attention_key", "dot_attention_value", "dot_attention_out".

And then I apply checkpoint policy on scanned function which accepts list of activation names to checkpoint:

def rematted_layer(layer):
    return nn.remat(
        layer,
        policy=jax.checkpoint_policies.save_only_these_names(
            *self.config.save_names_for_bwd
        ),
        prevent_cse=not self.config.scan,
    )

and then scan It over embeddings:

apply_block = rematted_layer(apply_block)
apply_block = nn.scan(
    apply_block,
    length=self.config.num_layers,
    variable_axes={
        "params": 0,
    },
    variable_broadcast=False,
    split_rngs={"params": True},
    metadata_params={nn.PARTITION_NAME: "layers"},
)
block = TransformerBlock(
    name="scan",
    config=self.config.block,
)
embeddings, _ = apply_block(block, embeddings, None)

If I set self.config.save_names_for_bwd to empty list (which is basically equivalent to "nothing_saveable" policy), then communications works correctly - all-gather/reduce-scatters/all-reduces are overlapped with computations, as can be seen on this perfetto trace:
Screenshot 2024-07-03 at 14 33 27
nothing_saveable.tgz

But as soon as I start to specify some names in self.config.save_names_for_bwd, for example,

    save_names_for_bwd:
      - dot_mlp_out
      - dot_attention_value
      - dot_attention_query
      - dot_attention_key

While these activations is indeed not recomputed during backward pass, all communications are executed in main stream without any overlapping with computations:
Screenshot 2024-07-03 at 14 35 06
save_only_these_names_trace.tgz

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.30
jaxlib: 0.4.30
numpy:  1.24.3
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='ffisin-dev-8gpu', release='5.4.0-155-generic', version='#172-Ubuntu SMP Fri Jul 7 16:10:02 UTC 2023', machine='x86_64')


$ nvidia-smi
Mon Jul  1 13:21:57 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:8D:00.0 Off |                    0 |
| N/A   68C    P0             141W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  | 00000000:91:00.0 Off |                    0 |
| N/A   48C    P0             121W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  | 00000000:95:00.0 Off |                    0 |
| N/A   69C    P0             137W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  | 00000000:99:00.0 Off |                    0 |
| N/A   50C    P0             126W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          On  | 00000000:AB:00.0 Off |                    0 |
| N/A   68C    P0             142W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          On  | 00000000:AF:00.0 Off |                    0 |
| N/A   49C    P0             124W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          On  | 00000000:B3:00.0 Off |                    0 |
| N/A   68C    P0             143W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          On  | 00000000:B7:00.0 Off |                    0 |
| N/A   48C    P0             121W / 700W |   1074MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
@qGentry
Copy link
Author

qGentry commented Jul 3, 2024

corresponding JAX issue:
google/jax#22252

@qGentry qGentry changed the title Fine-grained remat policy makes async/pipelined collectives execute in the main stream [XLA:GPU] Fine-grained remat policy makes async/pipelined collectives execute in the main stream Jul 3, 2024
@cheshire
Copy link
Member

cheshire commented Jul 8, 2024

@golechwierowicz could you take a look?

@golechwierowicz
Copy link
Member

In general if the latency hiding scheduler is not able to schedule any computation within the async collective then this collective will be executed on the compute stream. Can you provide an HLO based, single host - 2 gpu repro?

You can do this via --xla_dump_to= flag.

@qGentry
Copy link
Author

qGentry commented Jul 9, 2024

Hi @golechwierowicz

First of all, let me clarify my training setup:
I've implemented not just "vanilla" FSDP that shards model parameters over typical 'data' (or 'batch' axis) but instead some sort of mix of HSDP and Zero-1. Classic HSDP replicates model parameters and optimizer state across nodes with slow interconnect (typically Ethernet or InifiniBand) and shards them within node with fast interconnect (typically NVLink). This is suboptimal as we have to replicate optimizer state while we can actually shard it over replicas groups (similar to zero redundancy optimizer stage 1 configuration). So we actually replicate model parameters across nodes, shard model parameters inside nodes and shard optimizer state across entire cluster (all devices).

Last week I've performed tons of various experiments and noticed that reason behind this behavior might not be fine-grained checkpoint activations but rather some weird XLA behavior.

I've tested following setups:
7B, 30B, 70B LLAMA2 transformer pretraining on 8x32 nodes (256 GPUs).

For 7B, 30B models, as I mentioned in the issue, asynchronous communications are working correctly without activation checkpointing and are executing in compute stream when I try to checkpoint anything.

But for 70B model all communications are actually being executed in the compute stream even without selective activation checkpointing.

On the other hand, for the single-node repro I've tried following setups:
7B model, sharded across all 8 devices - asynchronous communications are working correctly with and without activation checkpointing. So to simulate big setup (multiple replicated param groups) I've tried following setup:
model params are sharded across 4 devices and there are 2 replication groups. Optimizer state are sharded across all devices and. This helped me to achieve following repro:

xla_dump_selective_saveable.tgz
perfetto-trace-selective-saveable.tgz

xla_dump_nothing_saveable.tgz
perfetto-trace-nothing-saveable.tgz

In single-node selective checkpointing repro I've noticed that during forward pass weights' all-gather performed asynchronously while all communication during backward pass are executed in compute stream, which is even weirder.

@golechwierowicz
Copy link
Member

8 GPUs is fine.

I don't see jit__unnamed_wrapped_function_ in the dump which, I guess from the trace, is a module of interest. Are you sure there were no errors related to producing a dump in the logs?

Alternatively I can also work with a small single node JAX-based repro.

@qGentry
Copy link
Author

qGentry commented Jul 10, 2024

Sorry, I see, looks like with JAX compilation cache enabled, cached modules are not being dumped.
Here is XLA dumps with caching disabled.
xla_dump_nothing_saveable.tgz
xla_dump_selective_saveable.tgz

@golechwierowicz
Copy link
Member

Unfortunately I am missing some custom kernels to run it purely at XLA level.

INTERNAL: Failed to capture gpu graph: [...jax/jaxlib/gpu/triton_kernels.cc] operation cuModuleLoadData(&module, module_image_.data()) failed: CUDA_ERROR_INVALID_SOURCE: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

I was however able to look at intermediate passes and have some idea. Can you try the following flags in addition the flags you provide already?

--xla_gpu_copy_insertion_use_region_analysis=false

and then

--xla_gpu_copy_insertion_use_region_analysis=false
--xla_gpu_enable_command_buffer=""

and see if it leads to an improvement?

If not can you get rid of custom kernels so I can debug this further? I want to run it via multihost_hlo_runner tool available here with flags --num_replicas=1 --num_partitions=8 --use_spmd_partitioning and running jit__unnamed_wrapped_function_.before_optimizations.txt.

@qGentry
Copy link
Author

qGentry commented Jul 10, 2024

Thank you, I've tried --xla_gpu_copy_insertion_use_region_analysis=false flag and it is helping with communications not being asynchronous (with and without --xla_gpu_enable_command_buffer="" flag). Although It's weird that I'm not seeing any speedup compared to 'nothing_saveable' strategy but this may be because of small scale, I'll try it on 256GPU setups.

Here is XLA dump with this flag enabled in case you need it.
xla_dump_selective_saveable_no_region_analysis.tgz

Meanwhile, I'll try to reproduce this behavior without custom triton kernels.

@qGentry
Copy link
Author

qGentry commented Jul 10, 2024

Unfortunately I can't reproduce this behavior without FlashAttention triton kernel, communications are performed asynchronously with and without activation checkpointing. Maybe reason behind this is that I had to set batch size 2 time smaller because otherwise it wouldn't fit in memory.

@qGentry
Copy link
Author

qGentry commented Jul 10, 2024

Also, can you please elaborate on how exactly --xla_gpu_copy_insertion_use_region_analysis=false flag fixes this problem?

@golechwierowicz
Copy link
Member

It prevents too aggressive copy elision which happens at a cost of insertion of control predecessors/successors. These additional control flow constructs break the flow of latency hiding scheduler (LHS). So disabling region analysis creates extra copies but this is still beneficial because it unlocks compute-overlap opportunities for LHS in the compiler. Of course, the XLA documentation in this regard is lacking but I have some code pointers.

Copy insertion

// Copy insertion is a legalization HLO pass which inserts copies (kCopy
// instructions) to eliminate several kinds of problems in the HLO module.
//
// (1) Entry parameter or a constant live out of the entry computation. Entry
// computation arguments and constants have different lifetimes than the
// computation result and cannot share the same allocation. Parameters and
// constants live out of non-entry computations do not need copies.
//
// (2) Different values which are simultaneously live and which must be held
// in the same buffer. This can occur in while bodies. Specifically, the
// while loop state (the arguments to the while instruction) is updated
// in-place and the update may clobber the value from the previous
// iteration before the previous value is dead. Computations called from
// kCall instructions do not need such copies because kCall has no update
// in-place semantics.
//
// (3) The buffer set of the root instruction of the entry computation must be
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
// InstructionAliasSet::IsDistinct return true.

Latency Hiding Scheduler

class DefaultSchedulerCore : public SchedulerCore {

@golechwierowicz
Copy link
Member

Also, for GPUs, this behavior should be a default one in the next JAX release.

2df7755

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants