diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8f33b506827a..40a7a4a2e98e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -50,6 +50,7 @@ get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, + is_compiled_module, is_safetensors_available, is_torch_version, is_transformers_available, @@ -255,7 +256,14 @@ def maybe_raise_or_warn( if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + sub_model = passed_class_obj[name] + model_cls = sub_model.__class__ + if is_compiled_module(sub_model): + model_cls = sub_model._orig_mod.__class__ + + if not issubclass(model_cls, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" @@ -419,6 +427,10 @@ def register_modules(self, **kwargs): if module is None: register_dict = {name: (None, None)} else: + # register the original module, not the dynamo compiled one + if is_compiled_module(module): + module = module._orig_mod + library = module.__module__.split(".")[0] # check if the module is a pipeline module @@ -484,6 +496,12 @@ def is_saveable_module(name, value): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + if is_compiled_module(sub_model): + sub_model = sub_model._orig_mod + model_cls = sub_model.__class__ + save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d803b053be71..30d559d19244 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -73,7 +73,7 @@ from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION -from .torch_utils import randn_tensor +from .torch_utils import is_compiled_module, randn_tensor if is_torch_available(): @@ -85,6 +85,7 @@ nightly, parse_flag_from_env, print_tensor_test, + require_torch_2, require_torch_gpu, skip_mps, slow, diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7a3b8029f828..b2b2e45f23bd 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -24,6 +24,7 @@ is_onnx_available, is_opencv_available, is_torch_available, + is_torch_version, ) from .logging import get_logger @@ -164,6 +165,15 @@ def require_torch(test_case): return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) +def require_torch_2(test_case): + """ + Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. + """ + return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( + test_case + ) + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 113e64c16bac..b9815cbceede 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union from . import logging -from .import_utils import is_torch_available +from .import_utils import is_torch_available, is_torch_version if is_torch_available(): @@ -68,3 +68,10 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents + + +def is_compiled_module(module): + """Check whether the module was compiled with torch.compile()""" + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e880950a7914..7dc649baa182 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -28,6 +28,7 @@ from diffusers.models.attention_processor import AttnProcessor from diffusers.training_utils import EMAModel from diffusers.utils import torch_device +from diffusers.utils.testing_utils import require_torch_gpu class ModelUtilsTest(unittest.TestCase): @@ -168,6 +169,21 @@ def test_from_save_pretrained_variant(self): max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + @require_torch_gpu + def test_from_save_pretrained_dynamo(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model = torch.compile(model) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + assert new_model.__class__ == self.model_class + def test_from_save_pretrained_dtype(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 9f0c9b1a4e19..9d3baa7f8f3c 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -54,7 +54,16 @@ logging, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device +from diffusers.utils import ( + CONFIG_NAME, + WEIGHTS_NAME, + floats_tensor, + is_flax_available, + nightly, + require_torch_2, + slow, + torch_device, +) from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu @@ -966,9 +975,41 @@ def test_from_save_pretrained(self): down_block_types=("DownBlock2D", "AttnDownBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"), ) - schedular = DDPMScheduler(num_train_timesteps=10) + scheduler = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + new_ddpm.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + generator = torch.Generator(device=torch_device).manual_seed(0) + new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + + @require_torch_2 + def test_from_save_pretrained_dynamo(self): + # 1. Load models + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + model = torch.compile(model) + scheduler = DDPMScheduler(num_train_timesteps=10) - ddpm = DDPMPipeline(model, schedular) + ddpm = DDPMPipeline(model, scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None)