Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/en/using-diffusers/loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,30 @@ print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)

If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.

### Parallel loading

Large models are often [sharded](../training/distributed_inference#model-sharding) into smaller files so that they are easier to load. Diffusers supports loading shards in parallel to speed up the loading process.

Set the environment variables below to enable parallel loading.

- Set `HF_ENABLE_PARALLEL_LOADING` to `"YES"` to enable parallel loading of shards.
- Set `HF_PARALLEL_LOADING_WORKERS` to configure the number of parallel threads to use when loading shards. More workers loads a model faster but uses more memory.

The `device_map` argument should be set to `"cuda"` to pre-allocate a large chunk of memory based on the model size. This substantially reduces model load time because warming up the memory allocator now avoids many smaller calls to the allocator later.

```py
import os
import torch
from diffusers import DiffusionPipeline

os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
pipeline = DiffusionPipeline.from_pretrained(
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
torch_dtype=torch.bfloat16,
device_map="cuda"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to talk a bit about device_map? The motivation behind passing device_map essentially comes from this PR: #11904. I have also tried providing a bit of motivation behind adding it at the pipeline-level here: #12122 (comment).

)
```

### Local pipeline

To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.
Expand Down