Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def text_encoder_attn_modules(text_encoder):
_import_structure["autoencoder"] = ["FromOriginalVAEMixin"]

_import_structure["controlnet"] = ["FromOriginalControlNetMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["unet"] = ["FromOriginalUNetMixin", "UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"]
if is_transformers_available():
_import_structure["single_file"] = ["FromSingleFileMixin"]
Expand All @@ -72,7 +72,7 @@ def text_encoder_attn_modules(text_encoder):
if is_torch_available():
from .autoencoder import FromOriginalVAEMixin
from .controlnet import FromOriginalControlNetMixin
from .unet import UNet2DConditionLoadersMixin
from .unet import FromOriginalUNetMixin, UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers

if is_transformers_available():
Expand Down
55 changes: 55 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
if is_accelerate_available():
from accelerate import init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

CONFIG_URLS = {
Expand Down Expand Up @@ -1274,6 +1276,59 @@ def create_text_encoder_from_open_clip_checkpoint(
return text_model


def create_diffusers_unet_from_stable_cascade(
cls,
pretrained_model_link_or_path,
config,
resume_download,
force_download,
proxies,
token,
cache_dir,
local_files_only,
revision,
torch_dtype,
**kwargs,
):
checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)

if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)

diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

else:
model.load_state_dict(diffusers_format_checkpoint)

if torch_dtype is not None:
model.to(torch_dtype)

return model


def create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
Expand Down
92 changes: 50 additions & 42 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
set_weights_and_activate_adapters,
)
from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
create_diffusers_unet_from_stable_cascade,
create_diffusers_unet_model_from_ldm,
fetch_ldm_config_and_checkpoint,
)
from .utils import AttnProcsLayers

Expand All @@ -66,6 +66,8 @@
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"

COMPATIBLE_SINGLE_FILE_CLASSES = ["StableCascadeUNet", "UNet2DConditionModel"]


class UNet2DConditionLoadersMixin:
"""
Expand Down Expand Up @@ -912,8 +914,9 @@ class FromOriginalUNetMixin:
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`StableCascadeUNet`] from pretrained StableCascadeUNet weights saved in the original `.ckpt` or
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
Instantiate a UNet from pretrained weights saved in the original `.ckpt`, `.bin`, or
`.safetensors` format. The model is set in evaluation mode (`model.eval()`) by default.
Currently supported checkpoints: StableCascade, SDXL, SD, Playground v2.5, etc.

Parameters:
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
Expand Down Expand Up @@ -952,8 +955,10 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):

"""
class_name = cls.__name__
if class_name != "StableCascadeUNet":
raise ValueError("FromOriginalUNetMixin is currently only compatible with StableCascadeUNet")
if class_name not in COMPATIBLE_SINGLE_FILE_CLASSES:
raise ValueError(
f"FromOriginalUNetMixin is currently only compatible with {', '.join(COMPATIBLE_SINGLE_FILE_CLASSES)}"
)

config = kwargs.pop("config", None)
resume_download = kwargs.pop("resume_download", False)
Expand All @@ -965,39 +970,42 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)

checkpoint = load_single_file_model_checkpoint(
pretrained_model_link_or_path,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
)

if config is None:
config = infer_stable_cascade_single_file_config(checkpoint)
model_config = cls.load_config(**config, **kwargs)
else:
model_config = config

ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model = cls.from_config(model_config, **kwargs)

diffusers_format_checkpoint = convert_stable_cascade_unet_single_file_to_diffusers(checkpoint)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

if class_name == "StableCascadeUNet":
return create_diffusers_unet_from_stable_cascade(
cls,
pretrained_model_link_or_path,
config,
resume_download,
force_download,
proxies,
token,
cache_dir,
local_files_only,
revision,
torch_dtype,
**kwargs,
)
else:
model.load_state_dict(diffusers_format_checkpoint)

if torch_dtype is not None:
model.to(torch_dtype)

return model
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=kwargs.get("pipeline_class_name", None),
original_config_file=kwargs.get("original_config_file", None),
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
cache_dir=cache_dir,
)
return create_diffusers_unet_model_from_ldm(
pipeline_class_name=kwargs.get("pipeline_class_name", None),
original_config=original_config,
checkpoint=checkpoint,
num_in_channels=kwargs.get("num_in_channels", 4),
upcast_attention=kwargs.get("upcast_attention", None),
extract_ema=kwargs.get("upcast_attention", False),
image_size=kwargs.get("image_size", None),
torch_dtype=torch_dtype,
model_type=kwargs.pop("model_type", None),
)["unet"]
6 changes: 4 additions & 2 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.utils.checkpoint

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders import FromOriginalUNetMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention_processor import (
Expand Down Expand Up @@ -66,7 +66,9 @@ class UNet2DConditionOutput(BaseOutput):
sample: torch.FloatTensor = None


class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
class UNet2DConditionModel(
ModelMixin, ConfigMixin, FromOriginalUNetMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
Expand Down