diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index f537a3f44917..dbcf081b1f17 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect import re from contextlib import nullcontext @@ -72,6 +73,17 @@ } +def _get_single_file_loadable_mapping_class(cls): + diffusers_module = importlib.import_module(__name__.split(".")[0]) + for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: + loadable_class = getattr(diffusers_module, loadable_class_str) + + if issubclass(cls, loadable_class): + return loadable_class_str + + return None + + def _get_mapping_function_kwargs(mapping_fn, **kwargs): parameters = inspect.signature(mapping_fn).parameters @@ -149,8 +161,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ``` """ - class_name = cls.__name__ - if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + mapping_class_name = _get_single_file_loadable_mapping_class(cls) + # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: + if mapping_class_name is None: raise ValueError( f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}" ) @@ -195,7 +208,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision=revision, ) - mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name] + mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name] checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"] if original_config: @@ -207,7 +220,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = if config_mapping_fn is None: raise ValueError( ( - f"`original_config` has been provided for {class_name} but no mapping function" + f"`original_config` has been provided for {mapping_class_name} but no mapping function" "was found to convert the original config to a Diffusers config in" "`diffusers.loaders.single_file_utils`" ) @@ -267,7 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ) if not diffusers_format_checkpoint: raise SingleFileComponentError( - f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint." + f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) ctx = init_empty_weights if is_accelerate_available() else nullcontext