Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 51 additions & 105 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.nn as nn
from huggingface_hub import ModelCard, delete_repo
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer

import diffusers
Expand All @@ -32,7 +33,6 @@
UNet2DConditionModel,
apply_faster_cache,
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
Expand Down Expand Up @@ -2244,80 +2244,6 @@ def test_layerwise_casting_inference(self):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)[0]

@require_torch_accelerator
def test_group_offloading_inference(self):
if not self.test_group_offloading:
return

def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe

def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
for component_name in [
"text_encoder",
"text_encoder_2",
"text_encoder_3",
"transformer",
"unet",
"controlnet",
]:
if not hasattr(pipe, component_name):
continue
component = getattr(pipe, component_name)
if not getattr(component, "_supports_group_offloading", True):
continue
if hasattr(component, "enable_group_offload"):
# For diffusers ModelMixin implementations
component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs)
else:
# For other models not part of diffusers
apply_group_offloading(
component, onload_device=torch.device(torch_device), **group_offloading_kwargs
)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
)
for component_name in ["vae", "vqvae", "image_encoder"]:
component = getattr(pipe, component_name, None)
if isinstance(component, torch.nn.Module):
component.to(torch_device)

def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]

pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)

pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
output_with_group_offloading1 = run_forward(pipe)

pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
output_with_group_offloading2 = run_forward(pipe)

if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()

self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))

def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
Expand Down Expand Up @@ -2364,7 +2290,7 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4
self.assertLess(max_diff, expected_max_difference)

@require_torch_accelerator
def test_pipeline_level_group_offloading_sanity_checks(self):
def test_group_offloading_sanity_checks(self):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)

Expand Down Expand Up @@ -2394,41 +2320,61 @@ def test_pipeline_level_group_offloading_sanity_checks(self):
component_device = next(component.parameters())[0].device
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)

@parameterized.expand([("block_level"), ("leaf_level")])
@require_torch_accelerator
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)
def test_group_offloading_inference(self, offload_type: str = "block_level"):
if not self.test_group_offloading:
pytest.skip("`test_group_offloading` is disabled hence skipping.")

for name, component in pipe.components.items():
if hasattr(component, "_supports_group_offloading"):
if not component._supports_group_offloading:
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe

# Regular inference.
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = torch.manual_seed(0)
out = pipe(**inputs)[0]
def enable_group_offload(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
exclude_modules = {"vae", "vqvae", "image_encoder"}
exclude_modules = list(exclude_modules & set(pipe.components.keys()))
pipe.enable_group_offload(
exclude_modules=exclude_modules, onload_device=torch_device, **group_offloading_kwargs
)
for component_name, component in pipe.components.items():
if component_name in exclude_modules:
continue
elif isinstance(component, torch.nn.Module):
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)

pipe.to("cpu")
del pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]

# Inference with offloading
pipe: DiffusionPipeline = self.pipeline_class(**components)
offload_device = "cpu"
pipe.enable_group_offload(
onload_device=torch_device,
offload_device=offload_device,
offload_type="leaf_level",
)
pipe.set_progress_bar_config(disable=None)
inputs["generator"] = torch.manual_seed(0)
out_offload = pipe(**inputs)[0]
pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)

max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
self.assertLess(max_diff, expected_max_difference)
pipe = create_pipe()
if offload_type == "block_level":
offloading_kwargs = {"offload_type": "block_level", "num_blocks_per_group": 1}
else:
offloading_kwargs = {"offload_type": "leaf_level"}
enable_group_offload(pipe, offloading_kwargs)

output_with_group_offloading = run_forward(pipe)

if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading = output_with_group_offloading.detach().cpu().numpy()

assert np.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4)


@is_staging_test
Expand Down
Loading