Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,14 +905,14 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):

class FromOriginalUNetMixin:
"""
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`].
"""

@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.

Parameters:
Expand Down Expand Up @@ -951,6 +951,10 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
Can be used to overwrite load and saveable variables of the model.

"""
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
Comment on lines +954 to +956
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Errors should be caught as early as possible.


config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
Expand All @@ -961,10 +965,6 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)

class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")

checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
Expand Down