From 57a4f3448a31500c5a319e3b55a75508e0020499 Mon Sep 17 00:00:00 2001 From: lcf Date: Mon, 10 Jun 2024 15:03:24 +0800 Subject: [PATCH 1/4] fix: ValueError when using FromOriginalModelMixin in subclasses #8440 (cherry picked from commit 92859978436acf844760fc0e992165b489d0180a) --- src/diffusers/loaders/single_file_model.py | 28 ++++++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index f06f4832740c..9c77fa682d01 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -67,6 +67,23 @@ } +def _is_subclass(src_cls, dst_cls_str): + """ + Find if src_cls is a subclass of dst_cls whose name is dst_cls_str + """ + for cls in src_cls.__mro__: + if cls.__name__ == dst_cls_str: + return True + return False + + +def _get_single_file_loadable_mapping_class(cls): + for dst_cls_str in SINGLE_FILE_LOADABLE_CLASSES: + if _is_subclass(cls, dst_cls_str): + return dst_cls_str + return None + + def _get_mapping_function_kwargs(mapping_fn, **kwargs): parameters = inspect.signature(mapping_fn).parameters @@ -144,8 +161,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ``` """ - class_name = cls.__name__ - if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + mapping_class_name = _get_single_file_loadable_mapping_class(cls) + # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + if mapping_class_name is None: raise ValueError( f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" ) @@ -190,7 +208,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision=revision, ) - mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name] + mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] if original_config: @@ -202,7 +220,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = if config_mapping_fn is None: raise ValueError( ( - f"`original_config` has been provided for {class_name} but no mapping function" + f"`original_config` has been provided for {mapping_class_name} but no mapping function" "was found to convert the original config to a Diffusers config in" "`diffusers.loaders.single_file_utils`" ) @@ -262,7 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ) if not diffusers_format_checkpoint: raise SingleFileComponentError( - f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." + f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) ctx = init_empty_weights if is_accelerate_available() else nullcontext From 520f8418876f3cfea4fd26fa549f02d511b24710 Mon Sep 17 00:00:00 2001 From: Luo Chaofan <79003314+fkcptlst@users.noreply.github.com> Date: Tue, 11 Jun 2024 20:46:06 +0800 Subject: [PATCH 2/4] Update src/diffusers/loaders/single_file_model.py Co-authored-by: Dhruv Nair --- src/diffusers/loaders/single_file_model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 9c77fa682d01..d18e7dc2e24a 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -78,6 +78,14 @@ def _is_subclass(src_cls, dst_cls_str): def _get_single_file_loadable_mapping_class(cls): + diffusers_module = importlib.import_module(__name__.split(".")[0]) + for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: + loadable_class = getattr(diffusers_module, loadable_class_str) + + if issubclass(cls, loadable_class): + return loadable_class_str + + return None for dst_cls_str in SINGLE_FILE_LOADABLE_CLASSES: if _is_subclass(cls, dst_cls_str): return dst_cls_str From a962600b1b443814939ec0690b5796c5e7dfbab3 Mon Sep 17 00:00:00 2001 From: Luo Chaofan <79003314+fkcptlst@users.noreply.github.com> Date: Tue, 11 Jun 2024 20:54:58 +0800 Subject: [PATCH 3/4] Update single_file_model.py --- src/diffusers/loaders/single_file_model.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index d18e7dc2e24a..cf4c41488b67 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -67,16 +67,6 @@ } -def _is_subclass(src_cls, dst_cls_str): - """ - Find if src_cls is a subclass of dst_cls whose name is dst_cls_str - """ - for cls in src_cls.__mro__: - if cls.__name__ == dst_cls_str: - return True - return False - - def _get_single_file_loadable_mapping_class(cls): diffusers_module = importlib.import_module(__name__.split(".")[0]) for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: @@ -86,10 +76,6 @@ def _get_single_file_loadable_mapping_class(cls): return loadable_class_str return None - for dst_cls_str in SINGLE_FILE_LOADABLE_CLASSES: - if _is_subclass(cls, dst_cls_str): - return dst_cls_str - return None def _get_mapping_function_kwargs(mapping_fn, **kwargs): From 37c312fa630ed81b84b8e64f54cc5f3f0feb061d Mon Sep 17 00:00:00 2001 From: Luo Chaofan <79003314+fkcptlst@users.noreply.github.com> Date: Tue, 11 Jun 2024 21:47:42 +0800 Subject: [PATCH 4/4] Update single_file_model.py --- src/diffusers/loaders/single_file_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index cf4c41488b67..67feaed83d19 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect import re from contextlib import nullcontext