In [None]:
!pip install -qU diffusers accelerate transformers huggingface_hub

In [None]:
from huggingface_hub import notebook_login
notebook_login()

# Working with big models

**Stable Diffusion XL (SDXL)** is not just a single model, but a collection of multiple models. SDXL has four different model-level components:
* a variational autoencoder (VAE)
* two text encoders
* a UNet for denoising

Usually, the text encoders and the denoiser are much larger compared to the VAE.

When a text encoder checkpoint has multiple shards, like `T5-xxl` for SD3, it is automatically handled by the `Transformers` library as it is a required dependency of Diffusers when using the `StableDiffusion3Pipeline`. More specifically, `Transformers` will automatically handle the loading of multiple shards within the requested model class and get it ready so that inference can be performed.

The denoiser checkpoint can also have multiple shards and supports inference thanks to the `Accelerate` library.

For example, we can save a sharded checkpoint for the SDXL UNet:

In [None]:
from diffusers import UNet2DConditionModel

unet = UNet2DConditionModel.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    subfolder='unet'
)

unet.save_pretrained(
    'sdxl-unet-sharded',
    max_shard_size='5GB',
)

The size of the `fp32` variant of the SDXL UNet checkpoint is ~10.4GB. Set the `max_shard_size` parameter to 5GB to create 3 shards.

After saving, we can load them in `StableDiffusionXLPipeline`:

In [None]:
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
import torch

unet = UNet2DConditionModel.from_pretrained(
    'sayakpaul/sdxl-unet-sharded',
    torch_dtype=torch.float16,
)

pipeline = StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    unet=unet,
    torch_dtype=torch.float16,
).to('cuda')

In [None]:
image = pipeline(
    'a cute dog running on the grass',
    num_inference_steps=30,
).images[0]
image.save('dog-running.png')

## Device placement

On distributed setups, we can run inference across multiple GPUs with Accelerate.

With Accelerate, we can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where we have more than one GPU.

For example, if we have two 8GB GPUs, then using `enable_model_cpu_offload()` may not work well because:
* it only works on a single GPU
* a single model might not fit on a single GPU.

To make use of both GPUs, we can use the `"balanced"` device placement strategy which splits the models across all available GPUs.

In [None]:
from diffusers import DiffusionPipeline
import torch

pipeline = DiffusionPipeline.from_pretrained(
    'stabilityai/stable-diffusion-v1-5',
    torch_dtype=torch.float16,
    use_safetensors=True,
    device_map='balanced', # device placement strategy
)

We can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:

In [None]:
from diffusers import DiffusionPipeline
import torch

max_memory = {0: '1GB', 1: '1GB'}

pipeline = DiffusionPipeline.from_pretrained(
    'stabilityai/stable-diffusion-v1-5',
    torch_dtype=torch.float16,
    use_safetensors=True,
    device_map='balanced',
    max_memory=max_memory,
)

If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.

By default, `Diffusers` uses the maximum memory of all devices. If the models do not fit on the GPUs, they are offloaded to the CPU. If the CPU does not have enough memory, then we might see an error. In this case, we could defer to using `enable_sequantial_cpu_offload()` and `enable_model_cpu_offload()`.

Call `reset_device_map()` to reset the `device_map` of a pipeline. This is also necessary if we want to use methods like `to()`, `enable_sequential_cpu_offload()`, and `enable_model_cpu_offload()` on a pipeline that was device-mapped.

In [None]:
pipeline.reset_device_map()

Once a pipeline has been deviced-mapped, we can also access its device map via `hf_device_map`:

In [None]:
print(pipeline.hf_device_map)