diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d5fa22548a15..4bdae21907da 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -485,17 +485,19 @@ def register_modules(self, **kwargs): if module is None: register_dict = {name: (None, None)} else: - # register the original module, not the dynamo compiled one + # register the config from the original module, not the dynamo compiled one if is_compiled_module(module): - module = module._orig_mod + not_compiled_module = module._orig_mod + else: + not_compiled_module = module - library = module.__module__.split(".")[0] + library = not_compiled_module.__module__.split(".")[0] # check if the module is a pipeline module - module_path_items = module.__module__.split(".") + module_path_items = not_compiled_module.__module__.split(".") pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None - path = module.__module__.split(".") + path = not_compiled_module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. @@ -504,10 +506,10 @@ def register_modules(self, **kwargs): if is_pipeline_module: library = pipeline_dir elif library not in LOADABLE_CLASSES: - library = module.__module__ + library = not_compiled_module.__module__ # retrieve class_name - class_name = module.__class__.__name__ + class_name = not_compiled_module.__class__.__name__ register_dict = {name: (library, class_name)} diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 5af3a6c16b40..cd3700d0ccdf 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -61,6 +61,7 @@ CONFIG_NAME, WEIGHTS_NAME, floats_tensor, + is_compiled_module, nightly, require_torch_2, slow, @@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): scheduler = DDPMScheduler(num_train_timesteps=10) ddpm = DDPMPipeline(model, scheduler) + + # previous diffusers versions stripped compilation off + # compiled modules + assert is_compiled_module(ddpm.unet) + ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None)