Skip to content

Commit bfe94a3

Browse files
authored
[Enhacne] Support maybe_raise_or_warn for peft (#5653)
* Support maybe_raise_or_warn for peft * fix by comment * unwrap function
1 parent c9c5436 commit bfe94a3

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
get_class_from_dynamic_module,
5050
is_accelerate_available,
5151
is_accelerate_version,
52+
is_peft_available,
5253
is_torch_version,
5354
is_transformers_available,
5455
logging,
@@ -270,6 +271,20 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
270271
)
271272

272273

274+
def _unwrap_model(model):
275+
"""Unwraps a model."""
276+
if is_compiled_module(model):
277+
model = model._orig_mod
278+
279+
if is_peft_available():
280+
from peft import PeftModel
281+
282+
if isinstance(model, PeftModel):
283+
model = model.base_model.model
284+
285+
return model
286+
287+
273288
def maybe_raise_or_warn(
274289
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
275290
):
@@ -287,9 +302,8 @@ def maybe_raise_or_warn(
287302
# Dynamo wraps the original model in a private class.
288303
# I didn't find a public API to get the original class.
289304
sub_model = passed_class_obj[name]
290-
model_cls = sub_model.__class__
291-
if is_compiled_module(sub_model):
292-
model_cls = sub_model._orig_mod.__class__
305+
unwrapped_sub_model = _unwrap_model(sub_model)
306+
model_cls = unwrapped_sub_model.__class__
293307

294308
if not issubclass(model_cls, expected_class_obj):
295309
raise ValueError(

0 commit comments

Comments
 (0)