Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,17 +485,19 @@ def register_modules(self, **kwargs):
if module is None:
register_dict = {name: (None, None)}
else:
# register the original module, not the dynamo compiled one
# register the config from the original module, not the dynamo compiled one
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

if is_compiled_module(module):
module = module._orig_mod
not_compiled_module = module._orig_mod
else:
not_compiled_module = module

library = module.__module__.split(".")[0]
library = not_compiled_module.__module__.split(".")[0]

# check if the module is a pipeline module
module_path_items = module.__module__.split(".")
module_path_items = not_compiled_module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None

path = module.__module__.split(".")
path = not_compiled_module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)

# if library is not in LOADABLE_CLASSES, then it is a custom module.
Expand All @@ -504,10 +506,10 @@ def register_modules(self, **kwargs):
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = module.__module__
library = not_compiled_module.__module__

# retrieve class_name
class_name = module.__class__.__name__
class_name = not_compiled_module.__class__.__name__

register_dict = {name: (library, class_name)}

Expand Down
6 changes: 6 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
CONFIG_NAME,
WEIGHTS_NAME,
floats_tensor,
is_compiled_module,
nightly,
require_torch_2,
slow,
Expand Down Expand Up @@ -99,6 +100,11 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
scheduler = DDPMScheduler(num_train_timesteps=10)

ddpm = DDPMPipeline(model, scheduler)

# previous diffusers versions stripped compilation off
# compiled modules
assert is_compiled_module(ddpm.unet)

Comment on lines +103 to +107
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten just added to existing pipeline dynamo test. Confirmed fails on main and passes on branch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)

Expand Down