Skip to content
20 changes: 19 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_compiled_module,
is_safetensors_available,
is_torch_version,
is_transformers_available,
Expand Down Expand Up @@ -255,7 +256,14 @@ def maybe_raise_or_warn(
if class_candidate is not None and issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate

if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
model_cls = sub_model.__class__
if is_compiled_module(sub_model):
model_cls = sub_model._orig_mod.__class__

if not issubclass(model_cls, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
Expand Down Expand Up @@ -419,6 +427,10 @@ def register_modules(self, **kwargs):
if module is None:
register_dict = {name: (None, None)}
else:
# register the original module, not the dynamo compiled one
if is_compiled_module(module):
module = module._orig_mod

library = module.__module__.split(".")[0]

# check if the module is a pipeline module
Expand Down Expand Up @@ -484,6 +496,12 @@ def is_saveable_module(name, value):
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__

# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
if is_compiled_module(sub_model):
sub_model = sub_model._orig_mod
model_cls = sub_model.__class__

save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from .logging import get_logger
from .outputs import BaseOutput
from .pil_utils import PIL_INTERPOLATION
from .torch_utils import randn_tensor
from .torch_utils import is_compiled_module, randn_tensor


if is_torch_available():
Expand All @@ -85,6 +85,7 @@
nightly,
parse_flag_from_env,
print_tensor_test,
require_torch_2,
require_torch_gpu,
skip_mps,
slow,
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_onnx_available,
is_opencv_available,
is_torch_available,
is_torch_version,
)
from .logging import get_logger

Expand Down Expand Up @@ -164,6 +165,15 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)


def require_torch_2(test_case):
"""
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
"""
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
test_case
)


def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import List, Optional, Tuple, Union

from . import logging
from .import_utils import is_torch_available
from .import_utils import is_torch_available, is_torch_version


if is_torch_available():
Expand Down Expand Up @@ -68,3 +68,10 @@ def randn_tensor(
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)

return latents


def is_compiled_module(module):
"""Check whether the module was compiled with torch.compile()"""
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
return False
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
16 changes: 16 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from diffusers.models.attention_processor import AttnProcessor
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import require_torch_gpu


class ModelUtilsTest(unittest.TestCase):
Expand Down Expand Up @@ -168,6 +169,21 @@ def test_from_save_pretrained_variant(self):
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")

@require_torch_gpu
def test_from_save_pretrained_dynamo(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict)
model.to(torch_device)
model = torch.compile(model)

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)

assert new_model.__class__ == self.model_class

def test_from_save_pretrained_dtype(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

Expand Down
47 changes: 44 additions & 3 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,16 @@
logging,
)
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device
from diffusers.utils import (
CONFIG_NAME,
WEIGHTS_NAME,
floats_tensor,
is_flax_available,
nightly,
require_torch_2,
slow,
torch_device,
)
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu


Expand Down Expand Up @@ -966,9 +975,41 @@ def test_from_save_pretrained(self):
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
schedular = DDPMScheduler(num_train_timesteps=10)
scheduler = DDPMScheduler(num_train_timesteps=10)

ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)

with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname)
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)

generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images

generator = torch.Generator(device=torch_device).manual_seed(0)
new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images

assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"

@require_torch_2
def test_from_save_pretrained_dynamo(self):
# 1. Load models
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
model = torch.compile(model)
scheduler = DDPMScheduler(num_train_timesteps=10)

ddpm = DDPMPipeline(model, schedular)
ddpm = DDPMPipeline(model, scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)

Expand Down