diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8baafbaef115..6d5d4492ae20 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -49,6 +49,7 @@ get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, + is_peft_available, is_torch_version, is_transformers_available, logging, @@ -271,6 +272,20 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, ) +def _unwrap_model(model): + """Unwraps a model.""" + if is_compiled_module(model): + model = model._orig_mod + + if is_peft_available(): + from peft import PeftModel + + if isinstance(model, PeftModel): + model = model.base_model.model + + return model + + def maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ): @@ -288,9 +303,8 @@ def maybe_raise_or_warn( # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] - model_cls = sub_model.__class__ - if is_compiled_module(sub_model): - model_cls = sub_model._orig_mod.__class__ + unwrapped_sub_model = _unwrap_model(sub_model) + model_cls = unwrapped_sub_model.__class__ if not issubclass(model_cls, expected_class_obj): raise ValueError(