Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
get_class_from_dynamic_module,
is_accelerate_available,
is_accelerate_version,
is_peft_available,
is_torch_version,
is_transformers_available,
logging,
Expand Down Expand Up @@ -271,6 +272,20 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
)


def _unwrap_model(model):
"""Unwraps a model."""
if is_compiled_module(model):
model = model._orig_mod

if is_peft_available():
from peft import PeftModel

if isinstance(model, PeftModel):
model = model.base_model.model

return model


def maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
):
Expand All @@ -288,9 +303,8 @@ def maybe_raise_or_warn(
# Dynamo wraps the original model in a private class.
# I didn't find a public API to get the original class.
sub_model = passed_class_obj[name]
model_cls = sub_model.__class__
if is_compiled_module(sub_model):
model_cls = sub_model._orig_mod.__class__
unwrapped_sub_model = _unwrap_model(sub_model)
model_cls = unwrapped_sub_model.__class__

if not issubclass(model_cls, expected_class_obj):
raise ValueError(
Expand Down