Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import re
from contextlib import nullcontext
Expand Down Expand Up @@ -72,6 +73,17 @@
}


def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module(__name__.split(".")[0])
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
loadable_class = getattr(diffusers_module, loadable_class_str)

if issubclass(cls, loadable_class):
return loadable_class_str

return None


def _get_mapping_function_kwargs(mapping_fn, **kwargs):
parameters = inspect.signature(mapping_fn).parameters

Expand Down Expand Up @@ -149,8 +161,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
```
"""

class_name = cls.__name__
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
if mapping_class_name is None:
raise ValueError(
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
)
Expand Down Expand Up @@ -195,7 +208,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
revision=revision,
)

mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]

checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
if original_config:
Expand All @@ -207,7 +220,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
if config_mapping_fn is None:
raise ValueError(
(
f"`original_config` has been provided for {class_name} but no mapping function"
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
"was found to convert the original config to a Diffusers config in"
"`diffusers.loaders.single_file_utils`"
)
Expand Down Expand Up @@ -267,7 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
)
if not diffusers_format_checkpoint:
raise SingleFileComponentError(
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
)

ctx = init_empty_weights if is_accelerate_available() else nullcontext
Expand Down