From 5f1895fb830f2acebb3dee5798d426589ce9155f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 13 Mar 2024 09:02:39 +0530 Subject: [PATCH] clean residue from copy-pasting --- src/diffusers/loaders/unet.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index f89e004261f2..0a9544d0dbbe 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -905,14 +905,14 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): class FromOriginalUNetMixin: """ - Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`]. + Load pretrained UNet model weights saved in the `.ckpt` or `.safetensors` format into a [`StableCascadeUNet`]. """ @classmethod @validate_hf_hub_args def from_single_file(cls, pretrained_model_link_or_path, **kwargs): r""" - Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or + Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default. Parameters: @@ -951,6 +951,10 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): Can be used to overwrite load and saveable variables of the model. """ + class_name = cls.__name__ + if class_name != "StableCascadeUNet": + raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet") + config = kwargs.pop("config", None) resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) @@ -961,10 +965,6 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) - class_name = cls.__name__ - if class_name != "StableCascadeUNet": - raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet") - checkpoint = load_single_file_model_checkpoint( pretrained_model_link_or_path, resume_download=resume_download,