You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/en/training/distributed_inference.md
+45-51Lines changed: 45 additions & 51 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -12,51 +12,55 @@ specific language governing permissions and limitations under the License.
12
12
13
13
# Distributed inference
14
14
15
-
On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel.
15
+
Distributed inference splits the workload across multiple GPUs. It a useful technique for fitting larger models in memory and can process multiple prompts for higher throughput.
16
16
17
-
This guide will show you how to use 🤗 Accelerate and PyTorch Distributed for distributed inference.
17
+
This guide will show you how to use [Accelerate](https://huggingface.co/docs/accelerate/index) and [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html) for distributed inference.
18
18
19
-
## 🤗 Accelerate
19
+
## Accelerate
20
20
21
-
🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code.
21
+
Accelerate is a library designed to simplify inference and training on multiple accelerators by handling the setup, allowing users to focus on their PyTorch code.
22
22
23
-
To begin, create a Python file and initialize an [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `distributed_state.device` to assign a GPU to each process.
23
+
Install Accelerate with the following command.
24
24
25
-
Now use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
25
+
```bash
26
+
uv pip install accelerate
27
+
```
28
+
29
+
Initialize a [`accelerate.PartialState`] class in a Python file to create a distributed environment. The [`accelerate.PartialState`] class manages process management, device control and distribution, and process coordination.
30
+
31
+
Move the [`DiffusionPipeline`] to [`accelerate.PartialState.device`] to assign a GPU to each process.
Use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes.
37
46
47
+
```py
38
48
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
52
-
53
-
</Tip>
54
-
55
59
## PyTorch Distributed
56
60
57
-
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)which enables data parallelism.
61
+
PyTorch [DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) enables [data parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=data_parallelism), which replicates the same model on each device, to process different batches of data in parallel.
58
62
59
-
To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]:
63
+
Import `torch.distributed` and `torch.multiprocessing`into a Python file to set up the distributed process group and to spawn the processes for inference on each GPU.
60
64
61
65
```py
62
66
import torch
@@ -65,20 +69,20 @@ import torch.multiprocessing as mp
You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the`world_size` is 2.
77
+
Create a function for inference with [init_process_group](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group). This method creates a distributed environment with the backend type, the `rank` of the current process, and the `world_size` or number of processes participating (for example, 2 GPUs would be`world_size=2`).
74
78
75
-
Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt:
79
+
Move the pipeline to `rank` and use `get_rank` to assign a GPU to each process. Each process handles a different prompt.
To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`:
96
+
Use [mp.spawn](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to create the number of processes defined in `world_size`.
93
97
94
98
```py
95
99
defmain():
@@ -101,31 +105,26 @@ if __name__ == "__main__":
101
105
main()
102
106
```
103
107
104
-
Once you've completed the inference script, use the `--nproc_per_node` argument to specify the number of GPUs to use and call `torchrun` to run the script:
108
+
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
105
109
106
110
```bash
107
111
torchrun run_distributed.py --nproc_per_node=2
108
112
```
109
113
110
-
> [!TIP]
111
-
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
112
-
113
-
## Model sharding
114
+
## device_map
114
115
115
-
Modern diffusion systems such as [Flux](../api/pipelines/flux) are very large and have multiple models. For example, [Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) is made up of two text encoders - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) and [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - a [diffusion transformer](../api/models/flux_transformer), and a [VAE](../api/models/autoencoderkl). With a model this size, it can be challenging to run inference on consumer GPUs.
116
+
The `device_map` argument enables distributed inference by automatically placing model components on separate GPUs. This is especially useful when a model doesn't fit on a single GPU. You can use `device_map` to selectively load and unload the required model components at a given stage as shown in the example below (assumes two GPUs are available).
116
117
117
-
Model sharding is a technique that distributes models across GPUs when the models don't fit on a single GPU. The example below assumes two 16GB GPUs are available for inference.
118
-
119
-
Start by computing the text embeddings with the text encoders. Keep the text encoders on two GPUs by setting `device_map="balanced"`. The `balanced` strategy evenly distributes the model on all available GPUs. Use the `max_memory` parameter to allocate the maximum amount of memory for each text encoder on each GPU.
120
-
121
-
> [!TIP]
122
-
> **Only** load the text encoders for this step! The diffusion transformer and VAE are loaded in a later step to preserve memory.
118
+
Set `device_map="balanced"` to evenly distributes the text encoders on all available GPUs. You can use the `max_memory` argument to allocate a maximum amount of memory for each text encoder. Don't load any other pipeline components to avoid memory usage.
123
119
124
120
```py
125
121
from diffusers import FluxPipeline
126
122
import torch
127
123
128
-
prompt ="a photo of a dog with cat-like look"
124
+
prompt ="""
125
+
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
126
+
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
127
+
"""
129
128
130
129
pipeline = FluxPipeline.from_pretrained(
131
130
"black-forest-labs/FLUX.1-dev",
@@ -142,7 +141,7 @@ with torch.no_grad():
142
141
)
143
142
```
144
143
145
-
Once the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
144
+
After the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
146
145
147
146
```py
148
147
import gc
@@ -162,7 +161,7 @@ del pipeline
162
161
flush()
163
162
```
164
163
165
-
Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto`strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
164
+
Set `device_map="auto"` to automatically distribute the model on the two GPUs. This strategy places a model on the fastest device first before placing a model on a slower device like a CPU or hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.
179
+
> Run `pipeline.hf_device_map` to see how the various models are distributed across devices. This is useful for tracking model device placement. You can also call `hf_device_map` on the transformer model to see how it is distributed.
181
180
182
-
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
181
+
Add the transformer model to the pipeline and set the `output_type="latent"` to generate the latents.
183
182
184
183
```py
185
184
pipeline = FluxPipeline.from_pretrained(
@@ -206,21 +205,12 @@ latents = pipeline(
206
205
).images
207
206
```
208
207
209
-
Remove the pipeline and transformer from memory as they're no longer needed.
210
-
211
-
```py
212
-
del pipeline.transformer
213
-
del pipeline
214
-
215
-
flush()
216
-
```
217
-
218
-
Finally, decode the latents with the VAE into an image. The VAE is typically small enough to be loaded on a single GPU.
208
+
Remove the pipeline and transformer from memory and load a VAE to decode the latents. The VAE is typically small enough to be loaded on a single device.
219
209
220
210
```py
211
+
import torch
221
212
from diffusers import AutoencoderKL
222
213
from diffusers.image_processor import VaeImageProcessor
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
229
+
## Resources
230
+
231
+
- Take a look at this [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for a minimal example of distributed inference with Accelerate.
232
+
- For more details, check out Accelerate's [Distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
233
+
- The `device_map` argument assign models or an entire pipeline to devices. Refer to the [device placement](../using-diffusers/loading#device-placement) docs for more information.
0 commit comments