From 63b64a69c6c958907d0ee46ecf304d95947c4d35 Mon Sep 17 00:00:00 2001 From: okotaku Date: Mon, 6 Nov 2023 08:24:36 +0000 Subject: [PATCH 1/3] Support maybe_raise_or_warn for peft --- src/diffusers/pipelines/pipeline_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8baafbaef115..27e2b2eba50e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -43,6 +43,7 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, + USE_PEFT_BACKEND, WEIGHTS_NAME, BaseOutput, deprecate, @@ -288,9 +289,15 @@ def maybe_raise_or_warn( # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] - model_cls = sub_model.__class__ if is_compiled_module(sub_model): - model_cls = sub_model._orig_mod.__class__ + sub_model = sub_model._orig_mod + + model_cls = sub_model.__class__ + if USE_PEFT_BACKEND: + from peft import PeftModel + + if isinstance(sub_model, PeftModel): + model_cls = sub_model.base_model.model.__class__ if not issubclass(model_cls, expected_class_obj): raise ValueError( From 2bbf67eee8c8303a78d5b3f4991fdcf5cec451a4 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 9 Nov 2023 23:37:38 +0000 Subject: [PATCH 2/3] fix by comment --- src/diffusers/pipelines/pipeline_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 27e2b2eba50e..1958f3a732e1 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -43,13 +43,13 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, SAFETENSORS_WEIGHTS_NAME, - USE_PEFT_BACKEND, WEIGHTS_NAME, BaseOutput, deprecate, get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, + is_peft_available, is_torch_version, is_transformers_available, logging, @@ -293,7 +293,7 @@ def maybe_raise_or_warn( sub_model = sub_model._orig_mod model_cls = sub_model.__class__ - if USE_PEFT_BACKEND: + if is_peft_available(): from peft import PeftModel if isinstance(sub_model, PeftModel): From 4877ea616dbde37b27efe349a60a05b9d078a50f Mon Sep 17 00:00:00 2001 From: okotaku Date: Tue, 14 Nov 2023 00:37:39 +0000 Subject: [PATCH 3/3] unwrap function --- src/diffusers/pipelines/pipeline_utils.py | 25 +++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1958f3a732e1..6d5d4492ae20 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -272,6 +272,20 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, ) +def _unwrap_model(model): + """Unwraps a model.""" + if is_compiled_module(model): + model = model._orig_mod + + if is_peft_available(): + from peft import PeftModel + + if isinstance(model, PeftModel): + model = model.base_model.model + + return model + + def maybe_raise_or_warn( library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module ): @@ -289,15 +303,8 @@ def maybe_raise_or_warn( # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] - if is_compiled_module(sub_model): - sub_model = sub_model._orig_mod - - model_cls = sub_model.__class__ - if is_peft_available(): - from peft import PeftModel - - if isinstance(sub_model, PeftModel): - model_cls = sub_model.base_model.model.__class__ + unwrapped_sub_model = _unwrap_model(sub_model) + model_cls = unwrapped_sub_model.__class__ if not issubclass(model_cls, expected_class_obj): raise ValueError(