Skip to content

Commit

Permalink
draft docs on gpu performance tuning
Browse files Browse the repository at this point in the history
Co-authored-by: Tao Wang <wangtao@google.com>
  • Loading branch information
mattjj and wang12tao committed Oct 31, 2023
1 parent 20255dc commit 664e834
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 1 deletion.
109 changes: 109 additions & 0 deletions docs/gpu_performance_tips.md
@@ -0,0 +1,109 @@
# GPU peformance tips

This document focuses on performance tips for neural network workloads

## Matmul precision

On recent GPU generations, such as the Nvidia A100 generation or later, it can
be a good idea to perform most computations in `bfloat16` precision. For
example, if using [Flax](https://github.com/google/flax), instantiate `Dense`
layers using `flax.linen.Dense(..., dtype=jax.numpy.bfloat16)`. Here are some
code examples:
* In the [Flax LM1B
example](https://github.com/google/flax/tree/main/examples/lm1b), `Dense`
modules are [instantiated with a configurable
dtype](https://github.com/google/flax/blob/fd8fd76a4af5307a61f85bac98feab9b26d60db8/examples/lm1b/models.py#L188)
which [defaults](https://github.com/google/flax/blob/fd8fd76a4af5307a61f85bac98feab9b26d60db8/examples/lm1b/configs/default.py#L112) to
[bfloat16](https://github.com/google/flax/blob/c0087535d7f2e5bfcbf2a7be6825b9f5055a54c6/examples/lm1b/train.py#L431).
* In [MaxText](https://github.com/google/maxtext), `DenseGeneral` modules are
also [instantiated with a configurable
dtype](https://github.com/google/maxtext/blob/07dc6ce27ced1246407d0de311d4a0d6a9fd46d8/MaxText/layers.py#L592)
that [defaults to
bfloat16](https://github.com/google/maxtext/blob/07dc6ce27ced1246407d0de311d4a0d6a9fd46d8/MaxText/configs/base.yml#L41).

## XLA performance flags

The existence and exact behavior of XLA flags may be `jaxlib`-version dependent.

As of `jaxlib==0.4.18` (released [Oct 6
2023](https://pypi.org/project/jaxlib/#history)), setting these XLA flags can
improve performance. Some are related to communication between GPUs, and so are
only relevant when running computations on multiple devices, while others are
related to code generation on each device.

Some of these may be set by default in future releases.

These flags can be set via the `XLA_FLAGS` shell environment variable. For
example, we can add this to the top of a Python file:
```python
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)
```

For more examples, see also [XLA Flags recommended for Pax
training on Nvidia GPUs](https://github.com/NVIDIA/JAX-Toolbox/blob/main/rosetta/rosetta/projects/pax/README.md#xla-flags).


### Code generation flags

* **--xla_gpu_enable_triton_softmax_fusion** This flag enables an automatic
softmax fusion, based on pattern-matching backed by Triton code generation.
The default value is False.
* **--xla_gpu_triton_gemm_any** Use the Triton-based GEMM (matmul) emitter for
any GEMM that it supports. The default value is False.

### Communication flags

* **--xla_gpu_enable_async_collectives** This flag enables the collective ops
such as `AllReduce`, `AllGather`, `ReduceScatter` and `CollectivePermute` to
be asynchronous. Asynchronous communication can overlap cross-core
communication with computation. The default value is False.
* **--xla_gpu_enable_latency_hiding_scheduler** This flag enables latency hiding
schedulers to overlap asynchronous communication with computation efficiently.
The default value is False.
* **--xla_gpu_enable_pipelined_collectives** When using pipeline parallelism,
this flag enables overlapping the (i+1)-th layer weight `AllGather` with the
i-th layer computation. It also enables enable overlapping (i+1)-th layer
weight `Reduce`/`ReduceScatter` with i-th layer's computation. The default
value is False. **There are some bugs when this flag is turned on.**
* **--xla_gpu_collective_permute_decomposer_threshold** This flag is useful when
performing [GSPMD pipelining](https://arxiv.org/abs/2105.04663). Setting a
nonzero threshold decomposes `CollectivePermute`s into
`CollectivePermuteReceiveDone` and `CollectivePermuteSendDone` pairs, so that
computation can be performed between each corresponding
`ReceiveDone`/`SendDone` pair and hence achieve more overlap. By default the
threshold is 0 and there is no decomposition. Setting it to threshold > 0 such
as `--xla_gpu_collective_permute_decomposer_threshold=1024` can enable this
feature.
* **--xla_gpu_all_gather_combine_threshold_bytes**
**--xla_gpu_reduce_scatter_combine_threshold_bytes**
**--xla_gpu_all_reduce_combine_threshold_bytes**
These flags tune when to combine multiple small
`AllGather`/`ReduceScatter`/`AllReduce` into one big
`AllGather`/`ReduceScatter`/`AllReduce` to reduce time spent on cross-device
communication. For example, for the the `AllGather`/`ReduceScatter` thresholds
on a Transformer-based workload, consider tuning them high enough so as to
combine at least a Transformer Layer's weight `AllGather`/`ReduceScatter`. By
default, the `combine_threshold_bytes` is set to 256.

## NCCL flags

These Nvidia NCCL flag values may be useful for single-host multi-device
computations on Nvidia GPUs:

```python
os.environ.update({
"NCCL_LL128_BUFFSIZE": "-2",
"NCCL_LL_BUFFSIZE": "-2",
"NCCL_PROTO": "SIMPLE,LL,LL128",
})
```

These NCCL flags could improve single-host communication speed. These flags
don't seem useful for multi-host communication yet.
3 changes: 2 additions & 1 deletion docs/user_guides.rst
Expand Up @@ -14,6 +14,7 @@ or deployed codebases.
profiling
device_memory_profiling
debugging/index
gpu_performance_tips


.. toctree::
Expand All @@ -33,4 +34,4 @@ or deployed codebases.
errors
transfer_guard

pallas/index
pallas/index

0 comments on commit 664e834

Please sign in to comment.