From f666908a90a337e067bb143561af438e07e9b257 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 12 Aug 2024 12:24:41 +0300 Subject: [PATCH 001/109] a From aabac0a2b20b0da868ea0b25a7ec42351a8af3a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 27 Aug 2024 15:25:22 +0300 Subject: [PATCH 002/109] refactor: add `ff_act_fn` parameter to `UNet2DConditionModel` and `get_down_block` for FF layers in attention --- src/diffusers/models/unets/unet_2d_blocks.py | 4 ++++ src/diffusers/models/unets/unet_2d_condition.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 93a0a82cdcff..4be938ae35d6 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -49,6 +49,7 @@ def get_down_block( add_downsample: bool, resnet_eps: float, resnet_act_fn: str, + ff_act_fn: str = "geglu", transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, resnet_groups: Optional[int] = None, @@ -136,6 +137,7 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + ff_act_fn=ff_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, @@ -1158,6 +1160,7 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads: int = 1, @@ -1209,6 +1212,7 @@ def __init__( only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, + activation_fn=ff_act_fn, ) ) else: diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 9a168bd22c93..8991dd2a9db0 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -191,6 +191,7 @@ def __init__( mid_block_scale_factor: float = 1, dropout: float = 0.0, act_fn: str = "silu", + ff_act_fn: str = "geglu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, @@ -363,6 +364,7 @@ def __init__( add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, + ff_act_fn=ff_act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], From 279d61387df1d219d6ccd81c7aee415370fe117d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Sep 2024 19:07:09 +0300 Subject: [PATCH 003/109] Study as an ordinary UNet model --- examples/community/matryoshka.py | 425 ++++++++++++++++++ src/diffusers/models/attention.py | 40 +- src/diffusers/models/attention_processor.py | 3 + .../models/transformers/transformer_2d.py | 29 +- src/diffusers/models/unets/unet_2d_blocks.py | 65 +++ .../models/unets/unet_2d_condition.py | 77 +++- 6 files changed, 617 insertions(+), 22 deletions(-) create mode 100644 examples/community/matryoshka.py diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py new file mode 100644 index 000000000000..2cf0f6ecacb2 --- /dev/null +++ b/examples/community/matryoshka.py @@ -0,0 +1,425 @@ +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All rights reserved. + +import inspect +from dataclasses import dataclass, field +from typing import List, Optional, Union + +import numpy as np +import torch +from packaging import version +from torch import nn +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import UNet2DConditionModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + logging, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import DiffusionPipeline + + >>> pipe = DiffusionPipeline.from_pretrained("A/B", torch_dtype=torch.float16, variant="fp16", + ... custom_pipeline="matryoshka",) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + >>> image + ``` +""" + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class NestedUNetConfig: # UNetConfig + inner_config = field( + default={"nesting": True}, # UNetConfig(nesting=True), + metadata={"help": "inner unet used as middle blocks"}, + ) + skip_mid_blocks: bool = field(default=True) + skip_cond_emb: bool = field(default=True) + skip_inner_unet_input: bool = field( + default=False, + metadata={"help": "If enabled, the inner unet only received the downsampled image, no features."}, + ) + skip_normalization: bool = field( + default=False, + ) + initialize_inner_with_pretrained: str = field( + default=None, + metadata={ + "help": ( + "Initialize the inner unet with pretrained vision model ", + "Provide the vision_model_path", + ) + }, + ) + freeze_inner_unet: bool = field(default=False) + interp_conditioning: bool = field( + default=False, + ) + + +@dataclass +class Nested2UNetConfig(NestedUNetConfig): + inner_config: NestedUNetConfig = field( + default=NestedUNetConfig(nesting=True, initialize_inner_with_pretrained=None) + ) + + +@dataclass +class Nested3UNetConfig(Nested2UNetConfig): + inner_config: Nested2UNetConfig = field( + default=Nested2UNetConfig(nesting=True, initialize_inner_with_pretrained=None) + ) + + +@dataclass +class Nested4UNetConfig(Nested3UNetConfig): + inner_config: Nested3UNetConfig = field( + default=Nested3UNetConfig(nesting=True, initialize_inner_with_pretrained=None) + ) + + +class NestedUNet( # UNet, + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + """ + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__(..., in_channels=3, out_channels=3) + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + unet.config.inner_config.conditioning_feature_dim = unet.config.conditioning_feature_dim + if getattr(unet.config.inner_config, "inner_config", None) is None: + self.inner_unet = UNet2DConditionModel( + in_channels=3, + out_channels=3, + block_out_channels=(256, 512, 768), + ff_act_fn="gelu", + cross_attention_dim=2048, + resnet_time_scale_shift="scale_shift", + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), + ) + self.inner_unet # Arrange inner attributes + else: + self.inner_unet = NestedUNet(3, 3, unet.config.inner_config) + + if not unet.config.skip_inner_unet_input: + self.in_adapter = nn.Conv2d( + unet.config.resolution_channels[-1], + unet.config.inner_config.resolution_channels[0], + kernel_size=3, + padding=1, + bias=True, + ) + else: + self.in_adapter = None + self.out_adapter = nn.Conv2d( + unet.config.inner_config.resolution_channels[0], + unet.config.resolution_channels[-1], + kernel_size=3, + padding=1, + bias=True, + ) + + self.is_temporal = [unet.config.temporal_mode and (not unet.config.temporal_spatial_ds)] + if hasattr(self.inner_unet, "is_temporal"): + self.is_temporal += self.inner_unet.is_temporal + + nest_ratio = int(2 ** (len(unet.config.resolution_channels) - 1)) + if self.is_temporal[0]: + nest_ratio = int(np.sqrt(nest_ratio)) + if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": + self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio + else: + self.nest_ratio = [nest_ratio] + + if self.config.interp_conditioning: + self.interp_layer1 = nn.Linear(self.temporal_dim // 4, self.temporal_dim) + self.interp_layer2 = nn.Linear(self.temporal_dim, self.temporal_dim) + + @property + def model_type(self): + return "nested_unet" + + def forward_conditioning(self, *args, **kwargs): + return self.inner_unet.forward_conditioning(*args, **kwargs) + + def forward_denoising(self, x_t, times, cond_emb=None, conditioning=None, cond_mask=None, micros={}): + # 1. time embedding + temb = self.create_temporal_embedding(times) + if cond_emb is not None: + temb = temb + cond_emb + if self.conditions is not None: + temb = temb + self.forward_micro_conditioning(times, micros) + + # 2. input layer (normalize the input) + if self._config.nesting: + x_t, x_feat = x_t + bsz = [x.size(0) for x in x_t] + bh, bl = bsz[0], bsz[1] + x_t_low, x_t = x_t[1:], x_t[0] + x = self.forward_input_layer(x_t, normalize=(not self.config.skip_normalization)) + if self._config.nesting: + x = x + x_feat + + # 3. downsample blocks in the outer layers + x, skip_activations = self.forward_downsample( + x, + temb[:bh], + conditioning[:bh], + cond_mask[:bh] if cond_mask is not None else cond_mask, + ) + + # 4. run inner unet + x_inner = self.in_adapter(x) if self.in_adapter is not None else None + x_inner = ( + torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner + ) # pad zeros for low-resolutions + x_low, x_inner = self.inner_unet.forward_denoising( + (x_t_low, x_inner), times, cond_emb, conditioning, cond_mask, micros + ) + x_inner = self.out_adapter(x_inner) + x = x + x_inner[:bh] if bh < bl else x + x_inner + + # 5. upsample blocks in the outer layers + x = self.forward_upsample( + x, + temb[:bh], + conditioning[:bh], + cond_mask[:bh] if cond_mask is not None else cond_mask, + skip_activations, + ) + + # 6. output layer + x_out = self.forward_output_layer(x) + + # 7. outpupt both low and high-res output + if isinstance(x_low, list): + out = [x_out] + x_low + else: + out = [x_out, x_low] + if self._config.nesting: + return out, x + return out diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 7766442f7133..fd8f2cdcf517 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -252,17 +252,21 @@ def __init__( attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, + cross_attention_norm: Optional[str] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, + norm_num_groups: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + ff_norm_type: str = "group_norm", norm_eps: float = 1e-5, final_dropout: bool = False, attention_type: str = "default", + attention_pre_only: bool = False, positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, @@ -326,6 +330,8 @@ def __init__( ada_norm_bias, "rms_norm", ) + elif norm_type == "layer_norm_matryoshka": + self.norm1 = None else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) @@ -337,7 +343,9 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups if norm_type == "layer_norm_matryoshka" else None, out_bias=attention_out_bias, + pre_only=attention_pre_only, ) # 2. Cross-Attn @@ -356,12 +364,15 @@ def __init__( ada_norm_bias, "rms_norm", ) + elif norm_type == "layer_norm_matryoshka": + self.norm2 = None else: self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, + cross_attention_norm=cross_attention_norm, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, @@ -389,7 +400,7 @@ def __init__( elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - elif norm_type == "layer_norm_i2vgen": + elif norm_type in ("layer_norm_i2vgen", "layer_norm_matryoshka"): self.norm3 = None self.ff = FeedForward( @@ -399,6 +410,7 @@ def __init__( final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, + norm_type=ff_norm_type, ) # 4. Fuser @@ -453,6 +465,8 @@ def forward( ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + elif self.norm_type == "layer_norm_matryoshka": + norm_hidden_states = hidden_states else: raise ValueError("Incorrect norm used") @@ -475,7 +489,10 @@ def forward( elif self.norm_type == "ada_norm_single": attn_output = gate_msa * attn_output - hidden_states = attn_output + hidden_states + if self.norm_type != "layer_norm_matryoshka": + hidden_states = attn_output + hidden_states + else: + cross_attention_kwargs["self_attn_output"] = attn_output if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) @@ -489,8 +506,8 @@ def forward( norm_hidden_states = self.norm2(hidden_states, timestep) elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type == "ada_norm_single": - # For PixArt norm2 isn't applied here: + elif self.norm_type in ("ada_norm_single", "layer_norm_matryoshka"): + # For PixArt and Matryoshka norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states elif self.norm_type == "ada_norm_continuous": @@ -507,14 +524,17 @@ def forward( attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) - hidden_states = attn_output + hidden_states + if self.norm_type == "layer_norm_matryoshka": + hidden_states = hidden_states + attn_output # 4. Feed-forward # i2vgen doesn't have this norm 🤷‍♂️ if self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif not self.norm_type == "ada_norm_single": + elif not self.norm_type in ("ada_norm_single", "layer_norm_matryoshka"): norm_hidden_states = self.norm3(hidden_states) + else: + norm_hidden_states = hidden_states if self.norm_type == "ada_norm_zero": norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] @@ -1147,6 +1167,7 @@ def __init__( final_dropout: bool = False, inner_dim=None, bias: bool = True, + norm_type: str = None, ): super().__init__() if inner_dim is None: @@ -1165,10 +1186,13 @@ def __init__( act_fn = SwiGLU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) + if norm_type == "group_norm_matryoshka": + self.net.append(nn.GroupNorm(32, dim)) # project in self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) + if norm_type != "group_norm_matryoshka": + # project dropout + self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 9f9bc5a46e10..618059c7b02e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2368,6 +2368,9 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) + if kwargs.get("self_attn_output", False): + hidden_states = kwargs.pop("self_attn_output") + hidden_states + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index c7c19e4582c6..6de7aa7bf3e8 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -78,6 +78,7 @@ def __init__( dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, + cross_attention_norm: Optional[str] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, @@ -89,9 +90,12 @@ def __init__( double_self_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' + ff_norm_type: str = None, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, attention_type: str = "default", + attention_context_pre_only: bool = None, + attention_pre_only: bool = False, caption_channels: int = None, interpolation_scale: float = None, use_additional_conditions: Optional[bool] = None, @@ -172,10 +176,14 @@ def __init__( self._init_patched_inputs(norm_type=norm_type) def _init_continuous_input(self, norm_type): - self.norm = torch.nn.GroupNorm( - num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True - ) - if self.use_linear_projection: + if self.use_linear_projection != "no_projection": + self.norm = torch.nn.GroupNorm( + num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True + ) + if self.use_linear_projection == "no_projection": + self.norm = None + self.proj_in = None + elif self.use_linear_projection: self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim) else: self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0) @@ -188,22 +196,27 @@ def _init_continuous_input(self, norm_type): self.config.attention_head_dim, dropout=self.config.dropout, cross_attention_dim=self.config.cross_attention_dim, + cross_attention_norm=self.config.cross_attention_norm, activation_fn=self.config.activation_fn, num_embeds_ada_norm=self.config.num_embeds_ada_norm, + norm_num_groups=self.config.norm_num_groups, attention_bias=self.config.attention_bias, only_cross_attention=self.config.only_cross_attention, double_self_attention=self.config.double_self_attention, upcast_attention=self.config.upcast_attention, norm_type=norm_type, + ff_norm_type=self.config.ff_norm_type, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, + attention_pre_only=self.config.attention_pre_only, ) for _ in range(self.config.num_layers) ] ) - - if self.use_linear_projection: + if self.use_linear_projection == "no_projection": + self.proj_out = None + elif self.use_linear_projection: self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) else: self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) @@ -480,7 +493,9 @@ def _operate_on_continuous_inputs(self, hidden_states): batch, _, height, width = hidden_states.shape hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: + if self.use_linear_projection == "no_projection": + inner_dim = hidden_states.shape[1] + elif not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 4be938ae35d6..2b221083d4c3 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -50,6 +50,8 @@ def get_down_block( resnet_eps: float, resnet_act_fn: str, ff_act_fn: str = "geglu", + norm_type: str = "layer_norm", + ff_norm_type: str = None, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, resnet_groups: Optional[int] = None, @@ -61,6 +63,8 @@ def get_down_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, @@ -138,9 +142,12 @@ def get_down_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, @@ -148,6 +155,8 @@ def get_down_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: @@ -258,6 +267,9 @@ def get_mid_block( resnet_eps: float, resnet_act_fn: str, resnet_groups: int, + ff_act_fn: str = "geglu", + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", output_scale_factor: float = 1.0, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, @@ -268,6 +280,8 @@ def get_mid_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, resnet_skip_time_act: bool = False, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = 1, @@ -281,15 +295,21 @@ def get_mid_block( dropout=dropout, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, output_scale_factor=output_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, resnet_groups=resnet_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": return UNetMidBlock2DSimpleCrossAttn( @@ -336,6 +356,9 @@ def get_up_block( add_upsample: bool, resnet_eps: float, resnet_act_fn: str, + ff_act_fn: str = "geglu", + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", resolution_idx: Optional[int] = None, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, @@ -347,6 +370,8 @@ def get_up_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, @@ -409,8 +434,12 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, @@ -418,6 +447,8 @@ def get_up_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: @@ -755,16 +786,22 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, ): super().__init__() @@ -809,10 +846,16 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, ) ) else: @@ -1163,8 +1206,11 @@ def __init__( ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, @@ -1173,6 +1219,8 @@ def __init__( only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, ): super().__init__() resnets = [] @@ -1207,11 +1255,16 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, activation_fn=ff_act_fn, ) ) @@ -2406,10 +2459,14 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, @@ -2417,6 +2474,8 @@ def __init__( only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, ): super().__init__() resnets = [] @@ -2454,11 +2513,17 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, ) ) else: diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 8991dd2a9db0..848c811611e0 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from calendar import c from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union +from cv2 import add import torch import torch.nn as nn import torch.utils.checkpoint @@ -53,7 +55,34 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): + def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim): + super().__init__() + self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(cross_attention_dim, time_embed_dim) + + def forward(self, emb, encoder_hidden_states, added_cond_kwargs): + conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) + masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) + if conditioning_mask is None or not masked_cross_attention: + y = encoder_hidden_states.mean(dim=1) + else: + y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( + dim=1, keepdim=True + ) + if not masked_cross_attention: + conditioning_mask = None + cond_emb = self.cond_emb(y) + cond_emb = cond_emb + emb + micro = added_cond_kwargs.get('micro_conditioning_scale', None) + if micro is not None: + temb = self.add_time_proj(micro) + temb_micro_conditioning = self.add_embedding(temb) + + cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning + return cond_emb, conditioning_mask @dataclass class UNet2DConditionOutput(BaseOutput): @@ -192,6 +221,8 @@ def __init__( dropout: float = 0.0, act_fn: str = "silu", ff_act_fn: str = "geglu", + norm_type: str = "layer_norm", + ff_norm_type: str = None, norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, @@ -220,6 +251,10 @@ def __init__( conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + masked_cross_attention: bool = False, + micro_conditioning_scale: int = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, @@ -279,6 +314,11 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) + self.cond_layers = nn.Sequential( + nn.Linear(timestep_input_dim, time_embed_dim), + nn.Linear(time_embed_dim, time_embed_dim), + ) + self._set_encoder_hid_proj( encoder_hid_dim_type, cross_attention_dim=cross_attention_dim, @@ -298,7 +338,7 @@ def __init__( self._set_add_embedding( addition_embed_type, addition_embed_type_num_heads=addition_embed_type_num_heads, - addition_time_embed_dim=addition_time_embed_dim, + addition_time_embed_dim=timestep_input_dim, cross_attention_dim=cross_attention_dim, encoder_hid_dim=encoder_hid_dim, flip_sin_to_cos=flip_sin_to_cos, @@ -365,6 +405,8 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], @@ -375,6 +417,8 @@ def __init__( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, @@ -390,9 +434,12 @@ def __init__( in_channels=block_out_channels[-1], resnet_eps=norm_eps, resnet_act_fn=act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, resnet_groups=norm_num_groups, output_scale_factor=mid_block_scale_factor, - transformer_layers_per_block=transformer_layers_per_block[-1], + transformer_layers_per_block=transformer_layers_per_block[-1] if norm_type != "layer_norm_matryoshka" else 1, num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], dual_cross_attention=dual_cross_attention, @@ -401,6 +448,8 @@ def __init__( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, resnet_skip_time_act=resnet_skip_time_act, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[-1], @@ -448,6 +497,9 @@ def __init__( add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, resolution_idx=i, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], @@ -458,6 +510,8 @@ def __init__( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, @@ -664,6 +718,10 @@ def _set_add_embedding( self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) + elif addition_embed_type == "matryoshka": + self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( + addition_time_embed_dim, cross_attention_dim, time_embed_dim + ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use @@ -956,6 +1014,8 @@ def get_aug_embed( aug_emb = None if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "matryoshka": + aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: @@ -1151,7 +1211,14 @@ def forward( else: emb = emb + class_emb - aug_emb = self.get_aug_embed( + added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention + added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + aug_emb, cond_mask = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) if self.config.addition_embed_type == "image_hint": @@ -1163,10 +1230,6 @@ def forward( if self.time_embed_act is not None: emb = self.time_embed_act(emb) - encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) - # 2. pre-process sample = self.conv_in(sample) From 5f5bd0815ca226258bee18f3170768c2854942dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Sep 2024 19:07:53 +0300 Subject: [PATCH 004/109] make style --- src/diffusers/models/attention.py | 2 +- src/diffusers/models/unets/unet_2d_condition.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index fd8f2cdcf517..ab5eff49d760 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -531,7 +531,7 @@ def forward( # i2vgen doesn't have this norm 🤷‍♂️ if self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif not self.norm_type in ("ada_norm_single", "layer_norm_matryoshka"): + elif self.norm_type not in ("ada_norm_single", "layer_norm_matryoshka"): norm_hidden_states = self.norm3(hidden_states) else: norm_hidden_states = hidden_states diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 848c811611e0..6895b109b689 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from calendar import c from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union -from cv2 import add import torch import torch.nn as nn import torch.utils.checkpoint @@ -55,6 +53,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim): super().__init__() @@ -76,7 +76,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): cond_emb = self.cond_emb(y) cond_emb = cond_emb + emb - micro = added_cond_kwargs.get('micro_conditioning_scale', None) + micro = added_cond_kwargs.get("micro_conditioning_scale", None) if micro is not None: temb = self.add_time_proj(micro) temb_micro_conditioning = self.add_embedding(temb) @@ -84,6 +84,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning return cond_emb, conditioning_mask + @dataclass class UNet2DConditionOutput(BaseOutput): """ @@ -439,7 +440,9 @@ def __init__( ff_norm_type=ff_norm_type, resnet_groups=norm_num_groups, output_scale_factor=mid_block_scale_factor, - transformer_layers_per_block=transformer_layers_per_block[-1] if norm_type != "layer_norm_matryoshka" else 1, + transformer_layers_per_block=transformer_layers_per_block[-1] + if norm_type != "layer_norm_matryoshka" + else 1, num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], dual_cross_attention=dual_cross_attention, From bfd8b9dcb3570f5d67b8f63c114d20c9bbe6fd5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Sep 2024 19:08:41 +0300 Subject: [PATCH 005/109] make fix-copies --- .../versatile_diffusion/modeling_text_unet.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 23dac5abd0c3..a122c6ea1138 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1915,10 +1915,14 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, @@ -1926,6 +1930,8 @@ def __init__( only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, ): super().__init__() resnets = [] @@ -1963,11 +1969,17 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, ) ) else: @@ -2246,16 +2258,22 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, ): super().__init__() @@ -2300,10 +2318,16 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, ) ) else: From eaef0377e4e401f024fea3e15c23c19482057642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Sep 2024 22:55:41 +0300 Subject: [PATCH 006/109] Up --- src/diffusers/models/attention.py | 3 ++- src/diffusers/models/attention_processor.py | 10 +++++++--- src/diffusers/models/embeddings.py | 5 ++++- src/diffusers/models/unets/unet_2d_condition.py | 9 ++------- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ab5eff49d760..17539a70a11e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -477,7 +477,7 @@ def forward( cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - attn_output = self.attn1( + attn_output, query = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, @@ -493,6 +493,7 @@ def forward( hidden_states = attn_output + hidden_states else: cross_attention_kwargs["self_attn_output"] = attn_output + cross_attention_kwargs["query"] = query if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 618059c7b02e..afc30d2a6346 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -217,7 +217,8 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + if not self.is_cross_attention: + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes @@ -2339,7 +2340,10 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - query = attn.to_q(hidden_states) + if kwargs.get("query", False): + query = kwargs.pop("query") + else: + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -2387,7 +2391,7 @@ def __call__( hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states + return hidden_states if not kwargs.get("self_attn_output", False) else hidden_states, query class StableAudioAttnProcessor2_0: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index eb5067c37700..1fd99ffbb1af 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -724,7 +724,10 @@ def __init__( else: self.cond_proj = None - self.act = get_activation(act_fn) + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index e895f647ff86..52a98fb03418 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -60,7 +60,7 @@ def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim) super().__init__() self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) - self.add_embedding = TimestepEmbedding(cross_attention_dim, time_embed_dim) + self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim, act_fn=None) def forward(self, emb, encoder_hidden_states, added_cond_kwargs): conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) @@ -79,7 +79,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): micro = added_cond_kwargs.get("micro_conditioning_scale", None) if micro is not None: temb = self.add_time_proj(micro) - temb_micro_conditioning = self.add_embedding(temb) + temb_micro_conditioning = self.add_timestep_embedder(temb) cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning return cond_emb, conditioning_mask @@ -315,11 +315,6 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) - self.cond_layers = nn.Sequential( - nn.Linear(timestep_input_dim, time_embed_dim), - nn.Linear(time_embed_dim, time_embed_dim), - ) - self._set_encoder_hid_proj( encoder_hid_dim_type, cross_attention_dim=cross_attention_dim, From 99d9099ae45045cf91ebdd1e24c95efafcfec50f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 6 Sep 2024 18:51:47 +0300 Subject: [PATCH 007/109] Up --- src/diffusers/models/activations.py | 11 ++++++-- src/diffusers/models/attention.py | 8 +++++- src/diffusers/models/attention_processor.py | 26 +++++++++++-------- .../models/transformers/transformer_2d.py | 9 +++++-- .../models/unets/unet_2d_condition.py | 5 ++-- 5 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index fb24a36bae75..3b56698565f8 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -85,8 +85,15 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) + if hidden_states.ndim == 4: + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(-1, channels) + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2) + else: + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) return hidden_states diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 17539a70a11e..3d30f1b5c3e2 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1205,5 +1205,11 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for module in self.net: - hidden_states = module(hidden_states) + if isinstance(module, nn.Linear) and hidden_states.ndim == 4: + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(-1, channels) + hidden_states = module(hidden_states) + hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2) + else: + hidden_states = module(hidden_states) return hidden_states diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index afc30d2a6346..1554a73de632 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -18,6 +18,7 @@ import torch import torch.nn.functional as F from torch import nn +from zmq import has from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging @@ -2307,6 +2308,8 @@ def __call__( self, attn: Attention, hidden_states: torch.Tensor, + self_attn_output: Optional[torch.Tensor] = None, + query: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, @@ -2340,9 +2343,7 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if kwargs.get("query", False): - query = kwargs.pop("query") - else: + if query is None: query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -2356,7 +2357,8 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if self_attn_output is None: + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -2372,16 +2374,15 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - if kwargs.get("self_attn_output", False): - hidden_states = kwargs.pop("self_attn_output") + hidden_states hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) + if hasattr(attn, "to_out"): + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) @@ -2389,9 +2390,12 @@ def __call__( if attn.residual_connection: hidden_states = hidden_states + residual + if self_attn_output is not None: + hidden_states = self_attn_output + hidden_states + hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states if not kwargs.get("self_attn_output", False) else hidden_states, query + return hidden_states if self_attn_output is not None else (hidden_states, query) class StableAudioAttnProcessor2_0: diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 6de7aa7bf3e8..7bcee91290ab 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -216,7 +216,7 @@ def _init_continuous_input(self, norm_type): ) if self.use_linear_projection == "no_projection": self.proj_out = None - elif self.use_linear_projection: + elif self.use_linear_projection is not None: self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) else: self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) @@ -491,7 +491,8 @@ def custom_forward(*inputs): def _operate_on_continuous_inputs(self, hidden_states): batch, _, height, width = hidden_states.shape - hidden_states = self.norm(hidden_states) + if self.norm is not None: + hidden_states = self.norm(hidden_states) if self.use_linear_projection == "no_projection": inner_dim = hidden_states.shape[1] @@ -527,6 +528,10 @@ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, times return hidden_states, encoder_hidden_states, timestep, embedded_timestep def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): + + if self.proj_out is None: + return hidden_states + residual + if not self.use_linear_projection: hidden_states = ( hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 52a98fb03418..6d45ff548683 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -78,8 +78,8 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): micro = added_cond_kwargs.get("micro_conditioning_scale", None) if micro is not None: - temb = self.add_time_proj(micro) - temb_micro_conditioning = self.add_timestep_embedder(temb) + temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) + temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning return cond_emb, conditioning_mask @@ -1209,6 +1209,7 @@ def forward( else: emb = emb + class_emb + added_cond_kwargs = added_cond_kwargs or {} added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale From 56e61f095e8cd2e682bb405f673057c8c01e1949 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 12 Sep 2024 20:56:25 +0300 Subject: [PATCH 008/109] Fix timestep embedding conditioning in `MatryoshkaCombinedTimestepTextEmbedding` --- src/diffusers/models/unets/unet_2d_condition.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 6d45ff548683..dbd7fafab084 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -60,7 +60,7 @@ def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim) super().__init__() self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) - self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim, act_fn=None) + self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) def forward(self, emb, encoder_hidden_states, added_cond_kwargs): conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) @@ -74,12 +74,11 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if not masked_cross_attention: conditioning_mask = None cond_emb = self.cond_emb(y) - cond_emb = cond_emb + emb micro = added_cond_kwargs.get("micro_conditioning_scale", None) if micro is not None: - temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) - temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) + temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) + temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning return cond_emb, conditioning_mask From 376500ab90e13d037a651d3b611a2bd7048a94a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 12 Sep 2024 20:57:48 +0300 Subject: [PATCH 009/109] make style --- src/diffusers/models/attention_processor.py | 2 -- src/diffusers/models/transformers/transformer_2d.py | 1 - 2 files changed, 3 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1554a73de632..edbdf71f6255 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -18,7 +18,6 @@ import torch import torch.nn.functional as F from torch import nn -from zmq import has from ..image_processor import IPAdapterMaskProcessor from ..utils import deprecate, logging @@ -2374,7 +2373,6 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index 7bcee91290ab..a19f2998e85e 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -528,7 +528,6 @@ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, times return hidden_states, encoder_hidden_states, timestep, embedded_timestep def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): - if self.proj_out is None: return hidden_states + residual From 8c4dcb3992a9232ec0d8e5007162f287d61dd3e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Sep 2024 16:55:03 +0300 Subject: [PATCH 010/109] Revert; cuz I should have created (probably) a new attention processor for Matryoshka models --- src/diffusers/models/attention_processor.py | 25 +++++++-------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index edbdf71f6255..9f9bc5a46e10 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -217,8 +217,7 @@ def __init__( f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" ) - if not self.is_cross_attention: - self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.only_cross_attention: # only relevant for the `AddedKVProcessor` classes @@ -2307,8 +2306,6 @@ def __call__( self, attn: Attention, hidden_states: torch.Tensor, - self_attn_output: Optional[torch.Tensor] = None, - query: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, @@ -2342,8 +2339,7 @@ def __call__( if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - if query is None: - query = attn.to_q(hidden_states) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states @@ -2356,8 +2352,7 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - if self_attn_output is None: - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) @@ -2376,11 +2371,10 @@ def __call__( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - if hasattr(attn, "to_out"): - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) @@ -2388,12 +2382,9 @@ def __call__( if attn.residual_connection: hidden_states = hidden_states + residual - if self_attn_output is not None: - hidden_states = self_attn_output + hidden_states - hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states if self_attn_output is not None else (hidden_states, query) + return hidden_states class StableAudioAttnProcessor2_0: From ef38541ff68c534a3fbb680ee722d0796ee26308 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 15 Sep 2024 11:38:55 +0300 Subject: [PATCH 011/109] Revert to create your own custom transformer block --- src/diffusers/models/attention.py | 51 ++++++------------------------- 1 file changed, 10 insertions(+), 41 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e0e00257f48c..84db0d061768 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -252,21 +252,17 @@ def __init__( attention_head_dim: int, dropout=0.0, cross_attention_dim: Optional[int] = None, - cross_attention_norm: Optional[str] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, - norm_num_groups: Optional[int] = None, attention_bias: bool = False, only_cross_attention: bool = False, double_self_attention: bool = False, upcast_attention: bool = False, norm_elementwise_affine: bool = True, norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' - ff_norm_type: str = "group_norm", norm_eps: float = 1e-5, final_dropout: bool = False, attention_type: str = "default", - attention_pre_only: bool = False, positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, @@ -330,8 +326,6 @@ def __init__( ada_norm_bias, "rms_norm", ) - elif norm_type == "layer_norm_matryoshka": - self.norm1 = None else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) @@ -343,9 +337,7 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, - norm_num_groups=norm_num_groups if norm_type == "layer_norm_matryoshka" else None, out_bias=attention_out_bias, - pre_only=attention_pre_only, ) # 2. Cross-Attn @@ -364,15 +356,12 @@ def __init__( ada_norm_bias, "rms_norm", ) - elif norm_type == "layer_norm_matryoshka": - self.norm2 = None else: self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, - cross_attention_norm=cross_attention_norm, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, @@ -400,7 +389,7 @@ def __init__( elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]: self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) - elif norm_type in ("layer_norm_i2vgen", "layer_norm_matryoshka"): + elif norm_type == "layer_norm_i2vgen": self.norm3 = None self.ff = FeedForward( @@ -410,7 +399,6 @@ def __init__( final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, - norm_type=ff_norm_type, ) # 4. Fuser @@ -465,8 +453,6 @@ def forward( ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa - elif self.norm_type == "layer_norm_matryoshka": - norm_hidden_states = hidden_states else: raise ValueError("Incorrect norm used") @@ -477,7 +463,7 @@ def forward( cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} gligen_kwargs = cross_attention_kwargs.pop("gligen", None) - attn_output, query = self.attn1( + attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, attention_mask=attention_mask, @@ -489,11 +475,7 @@ def forward( elif self.norm_type == "ada_norm_single": attn_output = gate_msa * attn_output - if self.norm_type != "layer_norm_matryoshka": - hidden_states = attn_output + hidden_states - else: - cross_attention_kwargs["self_attn_output"] = attn_output - cross_attention_kwargs["query"] = query + hidden_states = attn_output + hidden_states if hidden_states.ndim == 4: hidden_states = hidden_states.squeeze(1) @@ -507,8 +489,8 @@ def forward( norm_hidden_states = self.norm2(hidden_states, timestep) elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: norm_hidden_states = self.norm2(hidden_states) - elif self.norm_type in ("ada_norm_single", "layer_norm_matryoshka"): - # For PixArt and Matryoshka norm2 isn't applied here: + elif self.norm_type == "ada_norm_single": + # For PixArt norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states elif self.norm_type == "ada_norm_continuous": @@ -525,17 +507,14 @@ def forward( attention_mask=encoder_attention_mask, **cross_attention_kwargs, ) + hidden_states = attn_output + hidden_states - if self.norm_type == "layer_norm_matryoshka": - hidden_states = hidden_states + attn_output # 4. Feed-forward # i2vgen doesn't have this norm 🤷‍♂️ if self.norm_type == "ada_norm_continuous": norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) - elif self.norm_type not in ("ada_norm_single", "layer_norm_matryoshka"): + elif not self.norm_type == "ada_norm_single": norm_hidden_states = self.norm3(hidden_states) - else: - norm_hidden_states = hidden_states if self.norm_type == "ada_norm_zero": norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] @@ -1186,7 +1165,6 @@ def __init__( final_dropout: bool = False, inner_dim=None, bias: bool = True, - norm_type: str = None, ): super().__init__() if inner_dim is None: @@ -1205,13 +1183,10 @@ def __init__( act_fn = SwiGLU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) - if norm_type == "group_norm_matryoshka": - self.net.append(nn.GroupNorm(32, dim)) # project in self.net.append(act_fn) - if norm_type != "group_norm_matryoshka": - # project dropout - self.net.append(nn.Dropout(dropout)) + # project dropout + self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout @@ -1223,11 +1198,5 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for module in self.net: - if isinstance(module, nn.Linear) and hidden_states.ndim == 4: - batch_size, channels, height, width = hidden_states.shape - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(-1, channels) - hidden_states = module(hidden_states) - hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2) - else: - hidden_states = module(hidden_states) + hidden_states = module(hidden_states) return hidden_states From 19d6c178a0b646615c173297bf5e199534010442 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Sep 2024 21:02:17 +0300 Subject: [PATCH 012/109] Init template for the pipeline --- examples/community/matryoshka.py | 973 +++++++++++++++++++++++++------ 1 file changed, 806 insertions(+), 167 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 2cf0f6ecacb2..db416ee0837b 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1,45 +1,50 @@ -# For licensing see accompanying LICENSE file. -# Copyright (C) 2024 Apple Inc. All rights reserved. +# #TODO Licensed under the Apache License, Version 2.0 or MIT? import inspect -from dataclasses import dataclass, field -from typing import List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from packaging import version -from torch import nn from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.configuration_utils import FrozenDict -from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import ( - FromSingleFileMixin, - IPAdapterMixin, - StableDiffusionLoraLoaderMixin, - TextualInversionLoaderMixin, -) -from diffusers.models import UNet2DConditionModel -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import ImageProjection, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( + USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, ) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch - >>> from diffusers import DiffusionPipeline + >>> from diffusers import MatryoshkaPipeline - >>> pipe = DiffusionPipeline.from_pretrained("A/B", torch_dtype=torch.float16, variant="fp16", - ... custom_pipeline="matryoshka",) + >>> pipe = MatryoshkaPipeline.from_pretrained("A/B", torch_dtype=torch.float16, variant="fp16") >>> pipe = pipe.to("cuda") >>> prompt = "a photo of an astronaut riding a horse on mars" @@ -122,58 +127,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -@dataclass -class NestedUNetConfig: # UNetConfig - inner_config = field( - default={"nesting": True}, # UNetConfig(nesting=True), - metadata={"help": "inner unet used as middle blocks"}, - ) - skip_mid_blocks: bool = field(default=True) - skip_cond_emb: bool = field(default=True) - skip_inner_unet_input: bool = field( - default=False, - metadata={"help": "If enabled, the inner unet only received the downsampled image, no features."}, - ) - skip_normalization: bool = field( - default=False, - ) - initialize_inner_with_pretrained: str = field( - default=None, - metadata={ - "help": ( - "Initialize the inner unet with pretrained vision model ", - "Provide the vision_model_path", - ) - }, - ) - freeze_inner_unet: bool = field(default=False) - interp_conditioning: bool = field( - default=False, - ) - - -@dataclass -class Nested2UNetConfig(NestedUNetConfig): - inner_config: NestedUNetConfig = field( - default=NestedUNetConfig(nesting=True, initialize_inner_with_pretrained=None) - ) - - -@dataclass -class Nested3UNetConfig(Nested2UNetConfig): - inner_config: Nested2UNetConfig = field( - default=Nested2UNetConfig(nesting=True, initialize_inner_with_pretrained=None) - ) - - -@dataclass -class Nested4UNetConfig(Nested3UNetConfig): - inner_config: Nested3UNetConfig = field( - default=Nested3UNetConfig(nesting=True, initialize_inner_with_pretrained=None) - ) - - -class NestedUNet( # UNet, +class MatryoshkaPipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, @@ -181,7 +135,7 @@ class NestedUNet( # UNet, IPAdapterMixin, FromSingleFileMixin, ): - """ + r""" Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods @@ -309,117 +263,802 @@ def __init__( self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) - unet.config.inner_config.conditioning_feature_dim = unet.config.conditioning_feature_dim - if getattr(unet.config.inner_config, "inner_config", None) is None: - self.inner_unet = UNet2DConditionModel( - in_channels=3, - out_channels=3, - block_out_channels=(256, 512, 768), - ff_act_fn="gelu", - cross_attention_dim=2048, - resnet_time_scale_shift="scale_shift", - down_block_types=("DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"), - up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"), + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", ) - self.inner_unet # Arrange inner attributes + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype else: - self.inner_unet = NestedUNet(3, 3, unet.config.inner_config) - - if not unet.config.skip_inner_unet_input: - self.in_adapter = nn.Conv2d( - unet.config.resolution_channels[-1], - unet.config.inner_config.resolution_channels[0], - kernel_size=3, - padding=1, - bias=True, + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - self.in_adapter = None - self.out_adapter = nn.Conv2d( - unet.config.inner_config.resolution_channels[0], - unet.config.resolution_channels[-1], - kernel_size=3, - padding=1, - bias=True, - ) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) - self.is_temporal = [unet.config.temporal_mode and (not unet.config.temporal_spatial_ds)] - if hasattr(self.inner_unet, "is_temporal"): - self.is_temporal += self.inner_unet.is_temporal + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) - nest_ratio = int(2 ** (len(unet.config.resolution_channels) - 1)) - if self.is_temporal[0]: - nest_ratio = int(np.sqrt(nest_ratio)) - if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": - self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - self.nest_ratio = [nest_ratio] + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents - if self.config.interp_conditioning: - self.interp_layer1 = nn.Linear(self.temporal_dim // 4, self.temporal_dim) - self.interp_layer2 = nn.Linear(self.temporal_dim, self.temporal_dim) + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @property def model_type(self): return "nested_unet" - def forward_conditioning(self, *args, **kwargs): - return self.inner_unet.forward_conditioning(*args, **kwargs) - - def forward_denoising(self, x_t, times, cond_emb=None, conditioning=None, cond_mask=None, micros={}): - # 1. time embedding - temb = self.create_temporal_embedding(times) - if cond_emb is not None: - temb = temb + cond_emb - if self.conditions is not None: - temb = temb + self.forward_micro_conditioning(times, micros) - - # 2. input layer (normalize the input) - if self._config.nesting: - x_t, x_feat = x_t - bsz = [x.size(0) for x in x_t] - bh, bl = bsz[0], bsz[1] - x_t_low, x_t = x_t[1:], x_t[0] - x = self.forward_input_layer(x_t, normalize=(not self.config.skip_normalization)) - if self._config.nesting: - x = x + x_feat - - # 3. downsample blocks in the outer layers - x, skip_activations = self.forward_downsample( - x, - temb[:bh], - conditioning[:bh], - cond_mask[:bh] if cond_mask is not None else cond_mask, + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas ) - # 4. run inner unet - x_inner = self.in_adapter(x) if self.in_adapter is not None else None - x_inner = ( - torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner - ) # pad zeros for low-resolutions - x_low, x_inner = self.inner_unet.forward_denoising( - (x_t_low, x_inner), times, cond_emb, conditioning, cond_mask, micros + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, ) - x_inner = self.out_adapter(x_inner) - x = x + x_inner[:bh] if bh < bl else x + x_inner - - # 5. upsample blocks in the outer layers - x = self.forward_upsample( - x, - temb[:bh], - conditioning[:bh], - cond_mask[:bh] if cond_mask is not None else cond_mask, - skip_activations, + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) + else None ) - # 6. output layer - x_out = self.forward_output_layer(x) + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - # 7. outpupt both low and high-res output - if isinstance(x_low, list): - out = [x_out] + x_low + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] else: - out = [x_out, x_low] - if self._config.nesting: - return out, x - return out + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From 7d1a0ab31aaf5366fd4a56a94cbff41cb8158c0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Sep 2024 21:11:03 +0300 Subject: [PATCH 013/109] Add `MatryoshkaTransformerBlock` and `MatryoshkaFeedForward` classes --- examples/community/matryoshka.py | 164 +++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index db416ee0837b..cb0b89abfcc0 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch +from torch import nn from packaging import version from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast @@ -27,6 +28,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.models.attention_processor import Attention, FusedAttnProcessor2_0 +from diffusers.models.activations import GELU if is_torch_xla_available(): @@ -126,6 +129,167 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps +# Copied from diffusers.models.attention._chunked_feed_forward +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + +class MatryoshkaTransformerBlock(nn.Module): + r""" + Matryoshka Transformer block. + + Parameters: + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = True, + attention_type: str = "default", + attention_ff_inner_dim: Optional[int] = None, + ): + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.cross_attention_dim = cross_attention_dim + + # Define 3 blocks. + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + norm_num_groups=32 or None, + bias=True, + upcast_attention=upcast_attention, + pre_only=True, + processor=FusedAttnProcessor2_0(), + ) + self.attn1.fuse_projections() + + # 2. Cross-Attn + if cross_attention_dim is not None and cross_attention_dim > 0: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + cross_attention_norm="layer_norm", + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=True, + upcast_attention=upcast_attention, + pre_only=True, + processor=FusedAttnProcessor2_0(), + ) + self.attn2.fuse_projections() + # self.attn2.to_q = None + + self.proj_out = nn.Linear(dim, dim) + + if attention_ff_inner_dim is not None: + # 3. Feed-forward + self.ff = MatryoshkaFeedForward( + dim, + inner_dim=attention_ff_inner_dim, + ) + else: + self.ff = None + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # 1. Self-Attention + batch_size, channels, *spatial_dims = hidden_states.shape + + attn_output, query = self.attn1( + hidden_states, + **cross_attention_kwargs, + ) + cross_attention_kwargs["self_attention_output"] = attn_output + cross_attention_kwargs["self_attention_query"] = query + + # 2. Cross-Attention + if self.cross_attention_dim is not None and self.cross_attention_dim > 0: + attn_output_cond = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = attn_output + attn_output_cond + + attn_output = attn_output.reshape(batch_size, channels, *spatial_dims) + attn_output = self.proj_out(attn_output) + hidden_states = hidden_states + attn_output + + if self.ff is not None: + # 4. Feed-forward + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + +class MatryoshkaFeedForward(nn.Module): + r""" + A feed-forward layer for the Matryoshka models. + + Parameters:""" + + def __init__( + self, + dim: int, + ): + super().__init__() + + self.group_norm = nn.GroupNorm(32, dim) + self.linear_gelu = GELU(dim, dim * 4) + self.linear_out = nn.Linear(dim * 4, dim) + + def forward(self, x): + return self.linear_out(self.linear_gelu(self.group_norm(x))) + + class MatryoshkaPipeline( DiffusionPipeline, From 5754bc66d076d8fb7dd4e1502ac8e4c3fdb7184f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Sep 2024 21:18:33 +0300 Subject: [PATCH 014/109] Revert --- src/diffusers/models/activations.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 3b56698565f8..fb24a36bae75 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -85,15 +85,8 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): - if hidden_states.ndim == 4: - batch_size, channels, height, width = hidden_states.shape - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(-1, channels) - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2) - else: - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) return hidden_states From bc1f68ba596ce82f930c07235313b1682289078b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Sep 2024 21:19:42 +0300 Subject: [PATCH 015/109] Add `GELU` activation function module --- examples/community/matryoshka.py | 35 +++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index cb0b89abfcc0..fe84e904ed15 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -29,7 +29,6 @@ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.models.attention_processor import Attention, FusedAttnProcessor2_0 -from diffusers.models.activations import GELU if is_torch_xla_available(): @@ -270,6 +269,40 @@ def forward( return hidden_states +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + if hidden_states.ndim == 4: + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(-1, channels) + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2) + else: + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + class MatryoshkaFeedForward(nn.Module): r""" A feed-forward layer for the Matryoshka models. From 23f4ced0237638a5424d7af34844f109994d5bc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 10:42:53 +0300 Subject: [PATCH 016/109] Revert --- .../models/transformers/transformer_2d.py | 35 +++++-------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index a19f2998e85e..c7c19e4582c6 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -78,7 +78,6 @@ def __init__( dropout: float = 0.0, norm_num_groups: int = 32, cross_attention_dim: Optional[int] = None, - cross_attention_norm: Optional[str] = None, attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, @@ -90,12 +89,9 @@ def __init__( double_self_attention: bool = False, upcast_attention: bool = False, norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen' - ff_norm_type: str = None, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, attention_type: str = "default", - attention_context_pre_only: bool = None, - attention_pre_only: bool = False, caption_channels: int = None, interpolation_scale: float = None, use_additional_conditions: Optional[bool] = None, @@ -176,14 +172,10 @@ def __init__( self._init_patched_inputs(norm_type=norm_type) def _init_continuous_input(self, norm_type): - if self.use_linear_projection != "no_projection": - self.norm = torch.nn.GroupNorm( - num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True - ) - if self.use_linear_projection == "no_projection": - self.norm = None - self.proj_in = None - elif self.use_linear_projection: + self.norm = torch.nn.GroupNorm( + num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True + ) + if self.use_linear_projection: self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim) else: self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0) @@ -196,27 +188,22 @@ def _init_continuous_input(self, norm_type): self.config.attention_head_dim, dropout=self.config.dropout, cross_attention_dim=self.config.cross_attention_dim, - cross_attention_norm=self.config.cross_attention_norm, activation_fn=self.config.activation_fn, num_embeds_ada_norm=self.config.num_embeds_ada_norm, - norm_num_groups=self.config.norm_num_groups, attention_bias=self.config.attention_bias, only_cross_attention=self.config.only_cross_attention, double_self_attention=self.config.double_self_attention, upcast_attention=self.config.upcast_attention, norm_type=norm_type, - ff_norm_type=self.config.ff_norm_type, norm_elementwise_affine=self.config.norm_elementwise_affine, norm_eps=self.config.norm_eps, attention_type=self.config.attention_type, - attention_pre_only=self.config.attention_pre_only, ) for _ in range(self.config.num_layers) ] ) - if self.use_linear_projection == "no_projection": - self.proj_out = None - elif self.use_linear_projection is not None: + + if self.use_linear_projection: self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels) else: self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0) @@ -491,12 +478,9 @@ def custom_forward(*inputs): def _operate_on_continuous_inputs(self, hidden_states): batch, _, height, width = hidden_states.shape - if self.norm is not None: - hidden_states = self.norm(hidden_states) + hidden_states = self.norm(hidden_states) - if self.use_linear_projection == "no_projection": - inner_dim = hidden_states.shape[1] - elif not self.use_linear_projection: + if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) @@ -528,9 +512,6 @@ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, times return hidden_states, encoder_hidden_states, timestep, embedded_timestep def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): - if self.proj_out is None: - return hidden_states + residual - if not self.use_linear_projection: hidden_states = ( hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() From a2ca8efffb9c56fe75167bf1a48c6b691dde90c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 11:29:45 +0300 Subject: [PATCH 017/109] Revert --- src/diffusers/models/unets/unet_2d_blocks.py | 69 ----------------- .../models/unets/unet_2d_condition.py | 77 ++----------------- 2 files changed, 7 insertions(+), 139 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 2b221083d4c3..93a0a82cdcff 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -49,9 +49,6 @@ def get_down_block( add_downsample: bool, resnet_eps: float, resnet_act_fn: str, - ff_act_fn: str = "geglu", - norm_type: str = "layer_norm", - ff_norm_type: str = None, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, resnet_groups: Optional[int] = None, @@ -63,8 +60,6 @@ def get_down_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, @@ -141,13 +136,9 @@ def get_down_block( add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, @@ -155,8 +146,6 @@ def get_down_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: @@ -267,9 +256,6 @@ def get_mid_block( resnet_eps: float, resnet_act_fn: str, resnet_groups: int, - ff_act_fn: str = "geglu", - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", output_scale_factor: float = 1.0, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, @@ -280,8 +266,6 @@ def get_mid_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, resnet_skip_time_act: bool = False, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = 1, @@ -295,21 +279,15 @@ def get_mid_block( dropout=dropout, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, output_scale_factor=output_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, resnet_groups=resnet_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": return UNetMidBlock2DSimpleCrossAttn( @@ -356,9 +334,6 @@ def get_up_block( add_upsample: bool, resnet_eps: float, resnet_act_fn: str, - ff_act_fn: str = "geglu", - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", resolution_idx: Optional[int] = None, transformer_layers_per_block: int = 1, num_attention_heads: Optional[int] = None, @@ -370,8 +345,6 @@ def get_up_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, @@ -434,12 +407,8 @@ def get_up_block( add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, @@ -447,8 +416,6 @@ def get_up_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: @@ -786,22 +753,16 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, - cross_attention_norm: Optional[str] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, ): super().__init__() @@ -846,16 +807,10 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, ) ) else: @@ -1203,14 +1158,10 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_pre_norm: bool = True, - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, - cross_attention_norm: Optional[str] = None, output_scale_factor: float = 1.0, downsample_padding: int = 1, add_downsample: bool = True, @@ -1219,8 +1170,6 @@ def __init__( only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, ): super().__init__() resnets = [] @@ -1255,17 +1204,11 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, ) ) else: @@ -2459,14 +2402,10 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_pre_norm: bool = True, - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, - cross_attention_norm: Optional[str] = None, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, @@ -2474,8 +2413,6 @@ def __init__( only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, ): super().__init__() resnets = [] @@ -2513,17 +2450,11 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, ) ) else: diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index dbd7fafab084..4f55df32b738 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -55,35 +55,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): - def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim): - super().__init__() - self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) - self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) - - def forward(self, emb, encoder_hidden_states, added_cond_kwargs): - conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) - masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) - if conditioning_mask is None or not masked_cross_attention: - y = encoder_hidden_states.mean(dim=1) - else: - y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( - dim=1, keepdim=True - ) - if not masked_cross_attention: - conditioning_mask = None - cond_emb = self.cond_emb(y) - - micro = added_cond_kwargs.get("micro_conditioning_scale", None) - if micro is not None: - temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) - temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) - - cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning - return cond_emb, conditioning_mask - - @dataclass class UNet2DConditionOutput(BaseOutput): """ @@ -220,9 +191,6 @@ def __init__( mid_block_scale_factor: float = 1, dropout: float = 0.0, act_fn: str = "silu", - ff_act_fn: str = "geglu", - norm_type: str = "layer_norm", - ff_norm_type: str = None, norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, @@ -251,10 +219,6 @@ def __init__( conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, - masked_cross_attention: bool = False, - micro_conditioning_scale: int = None, class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, @@ -333,7 +297,7 @@ def __init__( self._set_add_embedding( addition_embed_type, addition_embed_type_num_heads=addition_embed_type_num_heads, - addition_time_embed_dim=timestep_input_dim, + addition_time_embed_dim=addition_time_embed_dim, cross_attention_dim=cross_attention_dim, encoder_hid_dim=encoder_hid_dim, flip_sin_to_cos=flip_sin_to_cos, @@ -399,9 +363,6 @@ def __init__( add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], @@ -412,8 +373,6 @@ def __init__( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, @@ -429,14 +388,9 @@ def __init__( in_channels=block_out_channels[-1], resnet_eps=norm_eps, resnet_act_fn=act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, resnet_groups=norm_num_groups, output_scale_factor=mid_block_scale_factor, - transformer_layers_per_block=transformer_layers_per_block[-1] - if norm_type != "layer_norm_matryoshka" - else 1, + transformer_layers_per_block=transformer_layers_per_block[-1], num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], dual_cross_attention=dual_cross_attention, @@ -445,8 +399,6 @@ def __init__( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, resnet_skip_time_act=resnet_skip_time_act, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[-1], @@ -494,9 +446,6 @@ def __init__( add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, resolution_idx=i, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], @@ -507,8 +456,6 @@ def __init__( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, @@ -714,10 +661,6 @@ def _set_add_embedding( self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) - elif addition_embed_type == "matryoshka": - self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( - addition_time_embed_dim, cross_attention_dim, time_embed_dim - ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use @@ -1012,8 +955,6 @@ def get_aug_embed( aug_emb = None if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) - elif self.config.addition_embed_type == "matryoshka": - aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: @@ -1208,15 +1149,7 @@ def forward( else: emb = emb + class_emb - added_cond_kwargs = added_cond_kwargs or {} - added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention - added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale - - encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) - - aug_emb, cond_mask = self.get_aug_embed( + aug_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) if self.config.addition_embed_type == "image_hint": @@ -1228,6 +1161,10 @@ def forward( if self.time_embed_act is not None: emb = self.time_embed_act(emb) + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + # 2. pre-process sample = self.conv_in(sample) From bcd8939902d9c59c69b4df5ad18b49fa0342becd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 11:32:52 +0300 Subject: [PATCH 018/109] make fix-copies --- .../versatile_diffusion/modeling_text_unet.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index af547d62963a..3937e87f63c9 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -1915,14 +1915,10 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_pre_norm: bool = True, - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, - cross_attention_norm: Optional[str] = None, output_scale_factor: float = 1.0, add_upsample: bool = True, dual_cross_attention: bool = False, @@ -1930,8 +1926,6 @@ def __init__( only_cross_attention: bool = False, upcast_attention: bool = False, attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, ): super().__init__() resnets = [] @@ -1969,17 +1963,11 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, ) ) else: @@ -2258,22 +2246,16 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, - cross_attention_norm: Optional[str] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, upcast_attention: bool = False, attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, ): super().__init__() @@ -2318,16 +2300,10 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, norm_num_groups=resnet_groups_out, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, ) ) else: From e014e3e077fd018170c4092be2bbdce88fd159d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 11:36:12 +0300 Subject: [PATCH 019/109] All in one file --- examples/community/matryoshka.py | 2108 +++++++++++++++++++++++++++++- 1 file changed, 2097 insertions(+), 11 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index fe84e904ed15..60f93e4e3bfe 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1,34 +1,72 @@ # #TODO Licensed under the Apache License, Version 2.0 or MIT? import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from torch import nn +import torch.nn as nn +import torch.utils.checkpoint from packaging import version +from torch.nn import functional as F from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.configuration_utils import FrozenDict +from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import ImageProjection, UNet2DConditionModel +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + PeftAdapterMixin, + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, + UNet2DConditionLoadersMixin, +) +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from diffusers.models.downsampling import Downsample2D +from diffusers.models.embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.resnet import ResnetBlock2D +from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D +from diffusers.models.upsampling import Upsample2D +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( USE_PEFT_BACKEND, + BaseOutput, deprecate, + is_torch_version, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) -from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.models.attention_processor import Attention, FusedAttnProcessor2_0 +from diffusers.utils.torch_utils import apply_freeu, randn_tensor if is_torch_xla_available(): @@ -128,6 +166,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps + # Copied from diffusers.models.attention._chunked_feed_forward def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): # "feed_forward_chunk_size" can be used to save memory @@ -143,6 +182,2052 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: ) return ff_output + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + ff_act_fn: str = "geglu", + norm_type: str = "layer_norm", + ff_norm_type: str = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + ) + + +def get_mid_block( + mid_block_type: str, + temb_channels: int, + in_channels: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + ff_act_fn: str = "geglu", + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", + output_scale_factor: float = 1.0, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + mid_block_only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + resnet_skip_time_act: bool = False, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = 1, + dropout: float = 0.0, +): + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + ) + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + ff_act_fn: str = "geglu", + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + ) + + +class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): + def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim): + super().__init__() + self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) + self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) + + def forward(self, emb, encoder_hidden_states, added_cond_kwargs): + conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) + masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) + if conditioning_mask is None or not masked_cross_attention: + y = encoder_hidden_states.mean(dim=1) + else: + y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( + dim=1, keepdim=True + ) + if not masked_cross_attention: + conditioning_mask = None + cond_emb = self.cond_emb(y) + + micro = added_cond_kwargs.get("micro_conditioning_scale", None) + if micro is not None: + temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) + temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) + + cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning + return cond_emb, conditioning_mask + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.Tensor = None + + +class UNet2DConditionModel( + ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin +): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + ff_act_fn: str = "geglu", + norm_type: str = "layer_norm", + ff_norm_type: str = None, + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + masked_cross_attention: bool = False, + micro_conditioning_scale: int = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + + # class embedding + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) + + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=timestep_input_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1] + if norm_type != "layer_norm_matryoshka" + else 1, + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + ff_act_fn=ff_act_fn, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + def _check_config( + self, + down_block_types: Tuple[str], + up_block_types: Tuple[str], + only_cross_attention: Union[bool, Tuple[bool]], + block_out_channels: Tuple[int], + layers_per_block: Union[int, Tuple[int]], + cross_attention_dim: Union[int, Tuple[int]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: Optional[Union[int, Tuple[int]]], + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: Optional[str], + cross_attention_dim: Union[int, Tuple[int]], + encoder_hid_dim: Optional[int], + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: Optional[str], + act_fn: str, + num_class_embeds: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: Optional[int], + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: Optional[int], + encoder_hid_dim: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "matryoshka": + self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( + addition_time_embed_dim, cross_attention_dim, time_embed_dim + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError( + f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." + ) + + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def get_time_embed( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] + ) -> Optional[torch.Tensor]: + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb + + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb + + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> Optional[torch.Tensor]: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "matryoshka": + aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: + encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) + + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + added_cond_kwargs = added_cond_kwargs or {} + added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention + added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + aug_emb, cond_mask = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + MatryoshkaTransformerBlock( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + additional_residuals: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", + resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, + resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + attentions.append( + MatryoshkaTransformerBlock( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + MatryoshkaTransformerBlock( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + class MatryoshkaTransformerBlock(nn.Module): r""" Matryoshka Transformer block. @@ -269,6 +2354,7 @@ def forward( return hidden_states + class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. @@ -303,6 +2389,7 @@ def forward(self, hidden_states): hidden_states = self.gelu(hidden_states) return hidden_states + class MatryoshkaFeedForward(nn.Module): r""" A feed-forward layer for the Matryoshka models. @@ -323,7 +2410,6 @@ def forward(self, x): return self.linear_out(self.linear_gelu(self.group_norm(x))) - class MatryoshkaPipeline( DiffusionPipeline, StableDiffusionMixin, From f264b9f3ecafb1af4ccde51706d31c8c6bee7424 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 15:35:33 +0300 Subject: [PATCH 020/109] Up --- examples/community/matryoshka.py | 3896 +++++++++++++++--------------- 1 file changed, 1932 insertions(+), 1964 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 60f93e4e3bfe..3b0103af3734 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -94,6 +94,7 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and @@ -108,6 +109,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, @@ -183,2231 +185,2197 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: return ff_output -def get_down_block( - down_block_type: str, - num_layers: int, - in_channels: int, - out_channels: int, - temb_channels: int, - add_downsample: bool, - resnet_eps: float, - resnet_act_fn: str, - ff_act_fn: str = "geglu", - norm_type: str = "layer_norm", - ff_norm_type: str = None, - transformer_layers_per_block: int = 1, - num_attention_heads: Optional[int] = None, - resnet_groups: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - downsample_padding: Optional[int] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: float = 1.0, - cross_attention_norm: Optional[str] = None, - attention_head_dim: Optional[int] = None, - downsample_type: Optional[str] = None, - dropout: float = 0.0, -): - # If attn head dim is not defined, we default it to the number of heads - if attention_head_dim is None: - logger.warning( - f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." - ) - attention_head_dim = num_attention_heads - - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type - if down_block_type == "DownBlock2D": - return DownBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif down_block_type == "CrossAttnDownBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") - return CrossAttnDownBlock2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - dropout=dropout, - add_downsample=add_downsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - resnet_groups=resnet_groups, - downsample_padding=downsample_padding, - cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - ) - - -def get_mid_block( - mid_block_type: str, - temb_channels: int, - in_channels: int, - resnet_eps: float, - resnet_act_fn: str, - resnet_groups: int, - ff_act_fn: str = "geglu", - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", - output_scale_factor: float = 1.0, - transformer_layers_per_block: int = 1, - num_attention_heads: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - mid_block_only_cross_attention: bool = False, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, - resnet_skip_time_act: bool = False, - cross_attention_norm: Optional[str] = None, - attention_head_dim: Optional[int] = 1, - dropout: float = 0.0, -): - if mid_block_type == "UNetMidBlock2DCrossAttn": - return UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - temb_channels=temb_channels, - dropout=dropout, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - output_scale_factor=output_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - num_attention_heads=num_attention_heads, - resnet_groups=resnet_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - ) - - -def get_up_block( - up_block_type: str, - num_layers: int, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - add_upsample: bool, - resnet_eps: float, - resnet_act_fn: str, - ff_act_fn: str = "geglu", - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", - resolution_idx: Optional[int] = None, - transformer_layers_per_block: int = 1, - num_attention_heads: Optional[int] = None, - resnet_groups: Optional[int] = None, - cross_attention_dim: Optional[int] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: float = 1.0, - cross_attention_norm: Optional[str] = None, - attention_head_dim: Optional[int] = None, - upsample_type: Optional[str] = None, - dropout: float = 0.0, -) -> nn.Module: - # If attn head dim is not defined, we default it to the number of heads - if attention_head_dim is None: - logger.warning( - f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." - ) - attention_head_dim = num_attention_heads - - up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type - if up_block_type == "UpBlock2D": - return UpBlock2D( - num_layers=num_layers, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - resnet_groups=resnet_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - ) - elif up_block_type == "CrossAttnUpBlock2D": - if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") - return CrossAttnUpBlock2D( - num_layers=num_layers, - transformer_layers_per_block=transformer_layers_per_block, - in_channels=in_channels, - out_channels=out_channels, - prev_output_channel=prev_output_channel, - temb_channels=temb_channels, - resolution_idx=resolution_idx, - dropout=dropout, - add_upsample=add_upsample, - resnet_eps=resnet_eps, - resnet_act_fn=resnet_act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - resnet_groups=resnet_groups, - cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - num_attention_heads=num_attention_heads, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - ) - - -class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): - def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim): - super().__init__() - self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) - self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) - - def forward(self, emb, encoder_hidden_states, added_cond_kwargs): - conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) - masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) - if conditioning_mask is None or not masked_cross_attention: - y = encoder_hidden_states.mean(dim=1) - else: - y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( - dim=1, keepdim=True - ) - if not masked_cross_attention: - conditioning_mask = None - cond_emb = self.cond_emb(y) - - micro = added_cond_kwargs.get("micro_conditioning_scale", None) - if micro is not None: - temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) - temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) - - cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning - return cond_emb, conditioning_mask - - -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - The output of [`UNet2DConditionModel`]. - - Args: - sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.Tensor = None - - -class UNet2DConditionModel( - ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin -): - r""" - A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample - shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `True`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or - `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, normalization and activation layers is skipped in post-processing. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling - blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for - [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - encoder_hid_dim (`int`, *optional*, defaults to None): - If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` - dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to `None`): - If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text - embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - num_attention_heads (`int`, *optional*): - The number of attention heads. If not defined, defaults to `attention_head_dim` - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to `None`): - Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or - "text". "text" will use the `TextTimeEmbedding` layer. - addition_time_embed_dim: (`int`, *optional*, defaults to `None`): - Dimension for the timestep embeddings. - num_class_embeds (`int`, *optional*, defaults to `None`): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, defaults to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - time_embedding_dim (`int`, *optional*, defaults to `None`): - An optional override for the dimension of the projected time embedding. - time_embedding_act_fn (`str`, *optional*, defaults to `None`): - Optional activation function to use only once on the time embeddings before they are passed to the rest of - the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. - timestep_post_act (`str`, *optional*, defaults to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, defaults to `None`): - The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. - conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. - projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - `class_embed_type="projection"`. Required when `class_embed_type="projection"`. - class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time - embeddings with the class embeddings. - mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): - Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If - `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the - `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` - otherwise. - """ - - _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] - - @register_to_config +class CrossAttnDownBlock2D(nn.Module): def __init__( self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: Union[int, Tuple[int]] = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, + in_channels: int, + out_channels: int, + temb_channels: int, dropout: float = 0.0, - act_fn: str = "silu", - ff_act_fn: str = "geglu", + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, norm_type: str = "layer_norm", - ff_norm_type: str = None, - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, dual_cross_attention: bool = False, use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, + only_cross_attention: bool = False, upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: float = 1.0, - time_embedding_type: str = "positional", - time_embedding_dim: Optional[int] = None, - time_embedding_act_fn: Optional[str] = None, - timestep_post_act: Optional[str] = None, - time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", attention_pre_only: bool = False, attention_bias: bool = False, - masked_cross_attention: bool = False, - micro_conditioning_scale: int = None, - class_embeddings_concat: bool = False, - mid_block_only_cross_attention: Optional[bool] = None, - cross_attention_norm: Optional[str] = None, - addition_embed_type_num_heads: int = 64, ): super().__init__() + resnets = [] + attentions = [] - self.sample_size = sample_size + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers - if num_attention_heads is not None: - raise ValueError( - "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + MatryoshkaTransformerBlock( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + ) ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None - # Check inputs - self._check_config( - down_block_types=down_block_types, - up_block_types=up_block_types, - only_cross_attention=only_cross_attention, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - cross_attention_dim=cross_attention_dim, - transformer_layers_per_block=transformer_layers_per_block, - reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, - attention_head_dim=attention_head_dim, - num_attention_heads=num_attention_heads, - ) + self.gradient_checkpointing = False - # input - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + additional_residuals: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () - # time - time_embed_dim, timestep_input_dim = self._set_time_proj( - time_embedding_type, - block_out_channels=block_out_channels, - flip_sin_to_cos=flip_sin_to_cos, - freq_shift=freq_shift, - time_embedding_dim=time_embedding_dim, - ) + blocks = list(zip(self.resnets, self.attentions)) - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: - self._set_encoder_hid_proj( - encoder_hid_dim_type, - cross_attention_dim=cross_attention_dim, - encoder_hid_dim=encoder_hid_dim, - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) - # class embedding - self._set_class_embedding( - class_embed_type, - act_fn=act_fn, - num_class_embeds=num_class_embeds, - projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, - time_embed_dim=time_embed_dim, - timestep_input_dim=timestep_input_dim, - ) + return custom_forward - self._set_add_embedding( - addition_embed_type, - addition_embed_type_num_heads=addition_embed_type_num_heads, - addition_time_embed_dim=timestep_input_dim, - cross_attention_dim=cross_attention_dim, - encoder_hid_dim=encoder_hid_dim, - flip_sin_to_cos=flip_sin_to_cos, - freq_shift=freq_shift, - projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, - time_embed_dim=time_embed_dim, - ) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - if time_embedding_act_fn is None: - self.time_embed_act = None - else: - self.time_embed_act = get_activation(time_embedding_act_fn) + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) + output_states = output_states + (hidden_states,) - if isinstance(only_cross_attention, bool): - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = only_cross_attention + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) - only_cross_attention = [only_cross_attention] * len(down_block_types) + output_states = output_states + (hidden_states,) - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = False + return hidden_states, output_states - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", + resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, + resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, + ): + super().__init__() - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels - if isinstance(layers_per_block, int): - layers_per_block = [layers_per_block] * len(down_block_types) + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + # support for variable transformer layers per block if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - if class_embeddings_concat: - # The time embeddings are concatenated with the class embeddings. The dimension of the - # time embeddings passed to the down, middle, and up blocks is twice the dimension of the - # regular time embeddings - blocks_time_embed_dim = time_embed_dim * 2 - else: - blocks_time_embed_dim = time_embed_dim + transformer_layers_per_block = [transformer_layers_per_block] * num_layers - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 + resnet_groups_out = resnet_groups_out or resnet_groups - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=blocks_time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, ) - self.down_blocks.append(down_block) - - # mid - self.mid_block = get_mid_block( - mid_block_type, - temb_channels=blocks_time_embed_dim, - in_channels=block_out_channels[-1], - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - resnet_groups=norm_num_groups, - output_scale_factor=mid_block_scale_factor, - transformer_layers_per_block=transformer_layers_per_block[-1] - if norm_type != "layer_norm_matryoshka" - else 1, - num_attention_heads=num_attention_heads[-1], - cross_attention_dim=cross_attention_dim[-1], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - mid_block_only_cross_attention=mid_block_only_cross_attention, - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - resnet_skip_time_act=resnet_skip_time_act, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[-1], - dropout=dropout, - ) - - # count how many layers upsample the images - self.num_upsamplers = 0 + ] + attentions = [] - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_num_attention_heads = list(reversed(num_attention_heads)) - reversed_layers_per_block = list(reversed(layers_per_block)) - reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = ( - list(reversed(transformer_layers_per_block)) - if reverse_transformer_layers_per_block is None - else reverse_transformer_layers_per_block - ) - only_cross_attention = list(reversed(only_cross_attention)) + for i in range(num_layers): + attentions.append( + MatryoshkaTransformerBlock( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + self.gradient_checkpointing = False - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - transformer_layers_per_block=reversed_transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=blocks_time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - ff_act_fn=ff_act_fn, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - resolution_idx=i, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - dropout=dropout, - ) - self.up_blocks.append(up_block) + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: - # out - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) - self.conv_act = get_activation(act_fn) + return custom_forward - else: - self.conv_norm_out = None - self.conv_act = None + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) + return hidden_states - self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) - def _check_config( +class CrossAttnUpBlock2D(nn.Module): + def __init__( self, - down_block_types: Tuple[str], - up_block_types: Tuple[str], - only_cross_attention: Union[bool, Tuple[bool]], - block_out_channels: Tuple[int], - layers_per_block: Union[int, Tuple[int]], - cross_attention_dim: Union[int, Tuple[int]], - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], - reverse_transformer_layers_per_block: bool, - attention_head_dim: int, - num_attention_heads: Optional[Union[int, Tuple[int]]], + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + ff_act_fn: str = "geglu", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + norm_type: str = "layer_norm", + ff_norm_type: str = "group_norm", + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + cross_attention_norm: Optional[str] = None, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + attention_pre_only: bool = False, + attention_bias: bool = False, ): - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) + super().__init__() + resnets = [] + attentions = [] - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: - for layer_number_per_block in transformer_layers_per_block: - if isinstance(layer_number_per_block, list): - raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") - - def _set_time_proj( - self, - time_embedding_type: str, - block_out_channels: int, - flip_sin_to_cos: bool, - freq_shift: float, - time_embedding_dim: int, - ) -> Tuple[int, int]: - if time_embedding_type == "fourier": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + attentions.append( + MatryoshkaTransformerBlock( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + ff_norm_type=ff_norm_type, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + attention_bias=attention_bias, + activation_fn=ff_act_fn, + ) ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: - raise ValueError( - f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." - ) + self.upsamplers = None - return time_embed_dim, timestep_input_dim + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx - def _set_encoder_hid_proj( + def forward( self, - encoder_hid_dim_type: Optional[str], - cross_attention_dim: Union[int, Tuple[int]], - encoder_hid_dim: Optional[int], - ): - if encoder_hid_dim_type is None and encoder_hid_dim is not None: - encoder_hid_dim_type = "text_proj" - self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + hidden_states: torch.Tensor, + res_hidden_states_tuple: Tuple[torch.Tensor, ...], + temb: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - if encoder_hid_dim is None and encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." - ) + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) - if encoder_hid_dim_type == "text_proj": - self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) - elif encoder_hid_dim_type == "text_image_proj": - # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` - self.encoder_hid_proj = TextImageProjection( - text_embed_dim=encoder_hid_dim, - image_embed_dim=cross_attention_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - self.encoder_hid_proj = ImageProjection( - image_embed_dim=encoder_hid_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." - ) - else: - self.encoder_hid_proj = None + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] - def _set_class_embedding( - self, - class_embed_type: Optional[str], - act_fn: str, - num_class_embeds: Optional[int], - projection_class_embeddings_input_dim: Optional[int], - time_embed_dim: int, - timestep_input_dim: int, - ): - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif class_embed_type == "simple_projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, ) - self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - def _set_add_embedding( + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class MatryoshkaTransformerBlock(nn.Module): + r""" + Matryoshka Transformer block. + + Parameters: + """ + + def __init__( self, - addition_embed_type: str, - addition_embed_type_num_heads: int, - addition_time_embed_dim: Optional[int], - flip_sin_to_cos: bool, - freq_shift: float, - cross_attention_dim: Optional[int], - encoder_hid_dim: Optional[int], - projection_class_embeddings_input_dim: Optional[int], - time_embed_dim: int, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = True, + attention_type: str = "default", + attention_ff_inner_dim: Optional[int] = None, ): - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim - - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "matryoshka": - self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( - addition_time_embed_dim, cross_attention_dim, time_embed_dim - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif addition_embed_type == "image": - # Kandinsky 2.2 - self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type == "image_hint": - # Kandinsky 2.2 ControlNet - self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type is not None: - raise ValueError( - f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." - ) + super().__init__() + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.cross_attention_dim = cross_attention_dim - def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): - if attention_type in ["gated", "gated-text-image"]: - positive_len = 768 - if isinstance(cross_attention_dim, int): - positive_len = cross_attention_dim - elif isinstance(cross_attention_dim, (list, tuple)): - positive_len = cross_attention_dim[0] + # Define 3 blocks. + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + norm_num_groups=32 or None, + bias=True, + upcast_attention=upcast_attention, + pre_only=True, + processor=FusedAttnProcessor2_0(), + ) + self.attn1.fuse_projections() - feature_type = "text-only" if attention_type == "gated" else "text-image" - self.position_net = GLIGENTextBoundingboxProjection( - positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + # 2. Cross-Attn + if cross_attention_dim is not None and cross_attention_dim > 0: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + cross_attention_norm="layer_norm", + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=True, + upcast_attention=upcast_attention, + pre_only=True, + processor=FusedAttnProcessor2_0(), ) + self.attn2.fuse_projections() + # self.attn2.to_q = None - @property - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor() + self.proj_out = nn.Linear(dim, dim) - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + if attention_ff_inner_dim is not None: + # 3. Feed-forward + self.ff = MatryoshkaFeedForward( + dim, + inner_dim=attention_ff_inner_dim, + ) + else: + self.ff = None - return processors + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) + # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim - return processors + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. + # 1. Self-Attention + batch_size, channels, *spatial_dims = hidden_states.shape - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. + attn_output, query = self.attn1( + hidden_states, + **cross_attention_kwargs, + ) + cross_attention_kwargs["self_attention_output"] = attn_output + cross_attention_kwargs["self_attention_query"] = query - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. + # 2. Cross-Attention + if self.cross_attention_dim is not None and self.cross_attention_dim > 0: + attn_output_cond = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = attn_output + attn_output_cond - """ - count = len(self.attn_processors.keys()) + attn_output = attn_output.reshape(batch_size, channels, *spatial_dims) + attn_output = self.proj_out(attn_output) + hidden_states = hidden_states + attn_output - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) + if self.ff is not None: + # 4. Feed-forward + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(hidden_states) - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) + hidden_states = ff_output + hidden_states - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + return hidden_states - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) +class GELU(nn.Module): + r""" + GELU activation function with tanh approximation support with `approximate="tanh"`. - self.set_attn_processor(processor) + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ - def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): - r""" - Enable sliced attention computation. + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] + def forward(self, hidden_states): + if hidden_states.ndim == 4: + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(-1, channels) + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2) + else: + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) +class MatryoshkaFeedForward(nn.Module): + r""" + A feed-forward layer for the Matryoshka models. - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) + Parameters:""" - num_sliceable_layers = len(sliceable_head_dims) + def __init__( + self, + dim: int, + ): + super().__init__() - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] + self.group_norm = nn.GroupNorm(32, dim) + self.linear_gelu = GELU(dim, dim * 4) + self.linear_out = nn.Linear(dim * 4, dim) - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + def forward(self, x): + return self.linear_out(self.linear_gelu(self.group_norm(x))) - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + norm_type: str = "layer_norm", + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + attention_pre_only: bool = False, + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + norm_type=norm_type, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + ) - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) +def get_mid_block( + mid_block_type: str, + temb_channels: int, + in_channels: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + norm_type: str = "layer_norm", + output_scale_factor: float = 1.0, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + mid_block_only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + attention_pre_only: bool = False, + resnet_skip_time_act: bool = False, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = 1, + dropout: float = 0.0, +): + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + norm_type=norm_type, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + ) - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + norm_type: str = "layer_norm", + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + attention_pre_only: bool = False, + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads - The suffixes after the scaling factors represent the stage blocks where they are being applied. + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + norm_type=norm_type, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + cross_attention_norm=cross_attention_norm, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + ) - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that - are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate the "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate the "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - for i, upsample_block in enumerate(self.up_blocks): - setattr(upsample_block, "s1", s1) - setattr(upsample_block, "s2", s2) - setattr(upsample_block, "b1", b1) - setattr(upsample_block, "b2", b2) +class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): + def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim): + super().__init__() + self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) + self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) - def disable_freeu(self): - """Disables the FreeU mechanism.""" - freeu_keys = {"s1", "s2", "b1", "b2"} - for i, upsample_block in enumerate(self.up_blocks): - for k in freeu_keys: - if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: - setattr(upsample_block, k, None) + def forward(self, emb, encoder_hidden_states, added_cond_kwargs): + conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) + masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) + if conditioning_mask is None or not masked_cross_attention: + y = encoder_hidden_states.mean(dim=1) + else: + y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( + dim=1, keepdim=True + ) + if not masked_cross_attention: + conditioning_mask = None + cond_emb = self.cond_emb(y) - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. + micro = added_cond_kwargs.get("micro_conditioning_scale", None) + if micro is not None: + temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) + temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) - + cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning + return cond_emb, conditioning_mask - This API is 🧪 experimental. - - """ - self.original_attn_processors = None +@dataclass +class NestedUNet2DConditionOutput(BaseOutput): + """ + The output of [`NestedUNet2DConditionOutput`]. - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ - self.original_attn_processors = self.attn_processors + sample: torch.Tensor = None - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - self.set_attn_processor(FusedAttnProcessor2_0()) +class NestedUNet2DConditionModel( + ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin +): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). - + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unets.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unets.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ - This API is 🧪 experimental. + _supports_gradient_checkpointing = True + _no_split_modules = ["MatryoshkaTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] - + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_type: str = "layer_norm", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + attention_pre_only: bool = False, + masked_cross_attention: bool = False, + micro_conditioning_scale: int = None, + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) + self.sample_size = sample_size - def get_time_embed( - self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] - ) -> Optional[torch.Tensor]: - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim - t_emb = self.time_proj(timesteps) - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - return t_emb + # Check inputs + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) - def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: - class_emb = None - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) - # `Timesteps` does not contain any weights and will always return f32 tensors - # there might be better ways to encapsulate this. - class_labels = class_labels.to(dtype=sample.dtype) + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) - class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) - return class_emb + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) - def get_aug_embed( - self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] - ) -> Optional[torch.Tensor]: - aug_emb = None - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - elif self.config.addition_embed_type == "matryoshka": - aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) - elif self.config.addition_embed_type == "text_image": - # Kandinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) + # class embedding + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) - image_embs = added_cond_kwargs.get("image_embeds") - text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) - aug_emb = self.add_embedding(text_embs, image_embs) - elif self.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - elif self.config.addition_embed_type == "image": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - aug_emb = self.add_embedding(image_embs) - elif self.config.addition_embed_type == "image_hint": - # Kandinsky 2.2 ControlNet - style - if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - hint = added_cond_kwargs.get("hint") - aug_emb = self.add_embedding(image_embs, hint) - return aug_emb + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=timestep_input_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) - def process_encoder_hidden_states( - self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] - ) -> torch.Tensor: - if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": - # Kandinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) - if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: - encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention - image_embeds = added_cond_kwargs.get("image_embeds") - image_embeds = self.encoder_hid_proj(image_embeds) - encoder_hidden_states = (encoder_hidden_states, image_embeds) - return encoder_hidden_states + only_cross_attention = [only_cross_attention] * len(down_block_types) - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - The [`UNet2DConditionModel`] forward method. + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False - Args: - sample (`torch.Tensor`): - The noisy input tensor with the following shape `(batch, channel, height, width)`. - timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. - encoder_hidden_states (`torch.Tensor`): - The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): - Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed - through the `self.time_embedding` layer to obtain the timestep embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): - A tuple of tensors that if specified are added to the residuals of down unet blocks. - mid_block_additional_residual: (`torch.Tensor`, *optional*): - A tensor that if specified is added to the residual of the middle unet block. - down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) - encoder_attention_mask (`torch.Tensor`): - A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If - `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, - which adds large negative values to the attention scores corresponding to "discard" tokens. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) - Returns: - [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, - otherwise a `tuple` is returned where the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - for dim in sample.shape[-2:]: - if dim % default_overall_up_factor != 0: - # Forward upsample size to force interpolation output size. - forward_upsample_size = True - break + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 - # 1. time - t_emb = self.get_time_embed(sample=sample, timestep=timestep) - emb = self.time_embedding(t_emb, timestep_cond) + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + norm_type=norm_type, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) - class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) - if class_emb is not None: - if self.config.class_embeddings_concat: - emb = torch.cat([emb, class_emb], dim=-1) - else: - emb = emb + class_emb + # mid + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + norm_type=norm_type, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1] + if norm_type != "layer_norm_matryoshka" + else 1, + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) - added_cond_kwargs = added_cond_kwargs or {} - added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention - added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + # count how many layers upsample the images + self.num_upsamplers = 0 - encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block ) + only_cross_attention = list(reversed(only_cross_attention)) - aug_emb, cond_mask = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) - if self.config.addition_embed_type == "image_hint": - aug_emb, hint = aug_emb - sample = torch.cat([sample, hint], dim=1) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 - emb = emb + aug_emb if aug_emb is not None else emb + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - if self.time_embed_act is not None: - emb = self.time_embed_act(emb) + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False - # 2. pre-process - sample = self.conv_in(sample) + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + norm_type=norm_type, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + attention_pre_only=attention_pre_only, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) - # 2.5 GLIGEN position net - if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: - cross_attention_kwargs = cross_attention_kwargs.copy() - gligen_args = cross_attention_kwargs.pop("gligen") - cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) - # 3. down - # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated - # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. - if cross_attention_kwargs is not None: - cross_attention_kwargs = cross_attention_kwargs.copy() - lora_scale = cross_attention_kwargs.pop("scale", 1.0) else: - lora_scale = 1.0 + self.conv_norm_out = None + self.conv_act = None - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets - is_adapter = down_intrablock_additional_residuals is not None - # maintain backward compatibility for legacy usage, where - # T2I-Adapter and ControlNet both use down_block_additional_residuals arg - # but can only use one or the other - if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: - deprecate( - "T2I should not use down_block_additional_residuals", - "1.3.0", - "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ - and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ - for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", - standard_warn=False, - ) - down_intrablock_additional_residuals = down_block_additional_residuals - is_adapter = True + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - # For t2i-adapter CrossAttnDownBlock2D - additional_residuals = {} - if is_adapter and len(down_intrablock_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + def _check_config( + self, + down_block_types: Tuple[str], + up_block_types: Tuple[str], + only_cross_attention: Union[bool, Tuple[bool]], + block_out_channels: Tuple[int], + layers_per_block: Union[int, Tuple[int]], + cross_attention_dim: Union[int, Tuple[int]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: Optional[Union[int, Tuple[int]]], + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - **additional_residuals, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - if is_adapter and len(down_intrablock_additional_residuals) > 0: - sample += down_intrablock_additional_residuals.pop(0) + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) - down_block_res_samples += res_samples + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) - if is_controlnet: - new_down_block_res_samples = () + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) - down_block_res_samples = new_down_block_res_samples + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - else: - sample = self.mid_block(sample, emb) + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") - # To support T2I-Adapter-XL - if ( - is_adapter - and len(down_intrablock_additional_residuals) > 0 - and sample.shape == down_intrablock_additional_residuals[0].shape - ): - sample += down_intrablock_additional_residuals.pop(0) + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - if is_controlnet: - sample = sample + mid_block_additional_residual + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 + return time_embed_dim, timestep_input_dim - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: Optional[str], + cross_attention_dim: Union[int, Tuple[int]], + encoder_hid_dim: Optional[int], + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: Optional[str], + act_fn: str, + num_class_embeds: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size, + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None - # 6. post-process - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: Optional[int], + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: Optional[int], + encoder_hid_dim: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "matryoshka": + self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( + addition_time_embed_dim, cross_attention_dim, time_embed_dim + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError( + f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." + ) - if not return_dict: - return (sample,) + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, (list, tuple)): + positive_len = cross_attention_dim[0] - return UNet2DConditionOutput(sample=sample) + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} -class CrossAttnDownBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - cross_attention_norm: Optional[str] = None, - output_scale_factor: float = 1.0, - downsample_padding: int = 1, - add_downsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, - ): - super().__init__() - resnets = [] - attentions = [] + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * num_layers + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) - attentions.append( - MatryoshkaTransformerBlock( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, - ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample2D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() else: - self.downsamplers = None + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) - self.gradient_checkpointing = False + self.set_attn_processor(processor) - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - additional_residuals: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): + r""" + Enable sliced attention computation. - output_states = () + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. - blocks = list(zip(self.resnets, self.attentions)) + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] - for i, (resnet, attn) in enumerate(blocks): - if self.training and self.gradient_checkpointing: + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) - return custom_forward + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + num_sliceable_layers = len(sliceable_head_dims) - # apply additional residuals to the output of the last pair of resnet and attention blocks - if i == len(blocks) - 1 and additional_residuals is not None: - hidden_states = hidden_states + additional_residuals + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] - output_states = output_states + (hidden_states,) + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) - output_states = output_states + (hidden_states,) + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - return hidden_states, output_states + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) -class UNetMidBlock2DCrossAttn(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - out_channels: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", - resnet_groups: int = 32, - resnet_groups_out: Optional[int] = None, - resnet_pre_norm: bool = True, - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", - num_attention_heads: int = 1, - output_scale_factor: float = 1.0, - cross_attention_dim: int = 1280, - cross_attention_norm: Optional[str] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, - ): - super().__init__() + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) - out_channels = out_channels or in_channels - self.in_channels = in_channels - self.out_channels = out_channels + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. - # support for variable transformer layers per block - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * num_layers + The suffixes after the scaling factors represent the stage blocks where they are being applied. - resnet_groups_out = resnet_groups_out or resnet_groups + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - # there is always at least one resnet - resnets = [ - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - groups_out=resnet_groups_out, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ] - attentions = [] + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. - for i in range(num_layers): - attentions.append( - MatryoshkaTransformerBlock( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - norm_num_groups=resnet_groups_out, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, - ) - ) - resnets.append( - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups_out, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) - ) + - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) + This API is 🧪 experimental. - self.gradient_checkpointing = False + + """ + self.original_attn_processors = None - def forward( - self, - hidden_states: torch.Tensor, - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if self.training and self.gradient_checkpointing: + self.original_attn_processors = self.attn_processors - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) - return custom_forward + self.set_attn_processor(FusedAttnProcessor2_0()) - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def get_time_embed( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] + ) -> Optional[torch.Tensor]: + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 else: - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - hidden_states = resnet(hidden_states, temb) + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) - return hidden_states + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb -class CrossAttnUpBlock2D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - resolution_idx: Optional[int] = None, - dropout: float = 0.0, - num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - cross_attention_norm: Optional[str] = None, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", - attention_pre_only: bool = False, - attention_bias: bool = False, - ): - super().__init__() - resnets = [] - attentions = [] + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * num_layers + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> Optional[torch.Tensor]: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "matryoshka": + aug_emb = self.add_embedding(emb, encoder_hidden_states, added_cond_kwargs) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) - ) - attentions.append( - MatryoshkaTransformerBlock( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) - ) - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None + if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: + encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) - self.gradient_checkpointing = False - self.resolution_idx = resolution_idx + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states def forward( self, - hidden_states: torch.Tensor, - res_hidden_states_tuple: Tuple[torch.Tensor, ...], - temb: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - upsample_size: Optional[int] = None, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) - - for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - hidden_states, res_hidden_states = apply_freeu( - self.resolution_idx, - hidden_states, - res_hidden_states, - s1=self.s1, - s2=self.s2, - b1=self.b1, - b2=self.b2, - ) - - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) + return_dict: bool = True, + ) -> Union[NestedUNet2DConditionOutput, Tuple]: + r""" + The [`NestedUNet2DConditionModel`] forward method. - return custom_forward + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain + tuple. - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(resnet), - hidden_states, - temb, - **ckpt_kwargs, - ) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + Returns: + [`~NestedUNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None - return hidden_states + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) -class MatryoshkaTransformerBlock(nn.Module): - r""" - Matryoshka Transformer block. + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - Parameters: - """ + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - cross_attention_dim: Optional[int] = None, - upcast_attention: bool = True, - attention_type: str = "default", - attention_ff_inner_dim: Optional[int] = None, - ): - super().__init__() - self.dim = dim - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - self.cross_attention_dim = cross_attention_dim + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) - # Define 3 blocks. - # 1. Self-Attn - self.attn1 = Attention( - query_dim=dim, - cross_attention_dim=None, - heads=num_attention_heads, - dim_head=attention_head_dim, - norm_num_groups=32 or None, - bias=True, - upcast_attention=upcast_attention, - pre_only=True, - processor=FusedAttnProcessor2_0(), + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + added_cond_kwargs = added_cond_kwargs or {} + added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention + added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - self.attn1.fuse_projections() - # 2. Cross-Attn - if cross_attention_dim is not None and cross_attention_dim > 0: - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - cross_attention_norm="layer_norm", - heads=num_attention_heads, - dim_head=attention_head_dim, - bias=True, - upcast_attention=upcast_attention, - pre_only=True, - processor=FusedAttnProcessor2_0(), - ) - self.attn2.fuse_projections() - # self.attn2.to_q = None + aug_emb, cond_mask = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) - self.proj_out = nn.Linear(dim, dim) + emb = emb + aug_emb if aug_emb is not None else emb - if attention_ff_inner_dim is not None: - # 3. Feed-forward - self.ff = MatryoshkaFeedForward( - dim, - inner_dim=attention_ff_inner_dim, - ) - else: - self.ff = None + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 + # 2. pre-process + sample = self.conv_in(sample) - # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - ) -> torch.Tensor: + # 3. down + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. if cross_attention_kwargs is not None: - if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") - - # 1. Self-Attention - batch_size, channels, *spatial_dims = hidden_states.shape + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 - attn_output, query = self.attn1( - hidden_states, - **cross_attention_kwargs, - ) - cross_attention_kwargs["self_attention_output"] = attn_output - cross_attention_kwargs["self_attention_query"] = query + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) - # 2. Cross-Attention - if self.cross_attention_dim is not None and self.cross_attention_dim > 0: - attn_output_cond = self.attn2( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, ) - attn_output = attn_output + attn_output_cond + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True - attn_output = attn_output.reshape(batch_size, channels, *spatial_dims) - attn_output = self.proj_out(attn_output) - hidden_states = hidden_states + attn_output + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) - if self.ff is not None: - # 4. Feed-forward - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) else: - ff_output = self.ff(hidden_states) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) - hidden_states = ff_output + hidden_states + down_block_res_samples += res_samples - return hidden_states + if is_controlnet: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) -class GELU(nn.Module): - r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. + down_block_res_samples = new_down_block_res_samples - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. - """ + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out, bias=bias) - self.approximate = approximate + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) - def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + if is_controlnet: + sample = sample + mid_block_additional_residual - def forward(self, hidden_states): - if hidden_states.ndim == 4: - batch_size, channels, height, width = hidden_states.shape - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(-1, channels) - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2) - else: - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] -class MatryoshkaFeedForward(nn.Module): - r""" - A feed-forward layer for the Matryoshka models. + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] - Parameters:""" + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) - def __init__( - self, - dim: int, - ): - super().__init__() + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) - self.group_norm = nn.GroupNorm(32, dim) - self.linear_gelu = GELU(dim, dim * 4) - self.linear_out = nn.Linear(dim * 4, dim) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) - def forward(self, x): - return self.linear_out(self.linear_gelu(self.group_norm(x))) + if not return_dict: + return (sample,) + + return NestedUNet2DConditionOutput(sample=sample) class MatryoshkaPipeline( @@ -2438,8 +2406,8 @@ class MatryoshkaPipeline( Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. + unet ([`NestedUNet2DConditionModel`]): + A `NestedUNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -2460,7 +2428,7 @@ def __init__( self, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, - unet: UNet2DConditionModel, + unet: NestedUNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, From 1a40f6881a17f03b51972f294422cbc776d823de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 17:41:14 +0300 Subject: [PATCH 021/109] Replace `MatryoshkaTransformerBlock` with `MatryoshkaTransformer2DModel` --- examples/community/matryoshka.py | 105 ++++++++++++++++--------------- 1 file changed, 53 insertions(+), 52 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 3b0103af3734..c3ba04a23233 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -12,7 +12,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config +from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config, LegacyConfigMixin from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( FromSingleFileMixin, @@ -47,7 +47,7 @@ Timesteps, ) from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.modeling_utils import ModelMixin, LegacyModelMixin from diffusers.models.resnet import ResnetBlock2D from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from diffusers.models.upsampling import Upsample2D @@ -213,6 +213,7 @@ def __init__( attention_type: str = "default", attention_pre_only: bool = False, attention_bias: bool = False, + use_attention_ffn: bool = True, ): super().__init__() resnets = [] @@ -240,21 +241,14 @@ def __init__( ) ) attentions.append( - MatryoshkaTransformerBlock( + MatryoshkaTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block[i], + num_layers=transformer_layers_per_block[i], # ???? cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, - norm_type=norm_type, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, + use_attention_ffn=use_attention_ffn, ) ) self.attentions = nn.ModuleList(attentions) @@ -356,12 +350,10 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_groups_out: Optional[int] = None, resnet_pre_norm: bool = True, norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", num_attention_heads: int = 1, output_scale_factor: float = 1.0, cross_attention_dim: int = 1280, @@ -372,6 +364,7 @@ def __init__( attention_type: str = "default", attention_pre_only: bool = False, attention_bias: bool = False, + use_attention_ffn: bool = True, ): super().__init__() @@ -409,22 +402,14 @@ def __init__( for i in range(num_layers): attentions.append( - MatryoshkaTransformerBlock( + MatryoshkaTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block[i], + num_layers=transformer_layers_per_block[i], # ???? cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - norm_num_groups=resnet_groups_out, - use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, + use_attention_ffn=use_attention_ffn, ) ) resnets.append( @@ -516,11 +501,9 @@ def __init__( resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", - ff_act_fn: str = "geglu", resnet_groups: int = 32, resnet_pre_norm: bool = True, norm_type: str = "layer_norm", - ff_norm_type: str = "group_norm", num_attention_heads: int = 1, cross_attention_dim: int = 1280, cross_attention_norm: Optional[str] = None, @@ -533,6 +516,7 @@ def __init__( attention_type: str = "default", attention_pre_only: bool = False, attention_bias: bool = False, + use_attention_ffn: bool = True, ): super().__init__() resnets = [] @@ -563,23 +547,14 @@ def __init__( ) ) attentions.append( - MatryoshkaTransformerBlock( + MatryoshkaTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block[i], + num_layers=transformer_layers_per_block[i], # ???? cross_attention_dim=cross_attention_dim, - cross_attention_norm=cross_attention_norm, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, - norm_type=norm_type, - ff_norm_type=ff_norm_type, - attention_type=attention_type, - attention_pre_only=attention_pre_only, - attention_bias=attention_bias, - activation_fn=ff_act_fn, + use_attention_ffn=use_attention_ffn, ) ) self.attentions = nn.ModuleList(attentions) @@ -678,6 +653,40 @@ def custom_forward(*inputs): return hidden_states +class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): + + _supports_gradient_checkpointing = True + _no_split_modules = ["MatryoshkaTransformerBlock"] + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, + use_attention_ffn: bool = True, + ): + super().__init__() + self.in_channels = self.config.num_attention_heads * self.config.attention_head_dim + + self.transformer_blocks = nn.ModuleList( + [ + MatryoshkaTransformerBlock( + self.in_channels, + self.config.num_attention_heads, + self.config.attention_head_dim, + cross_attention_dim=self.config.cross_attention_dim, + upcast_attention=self.config.upcast_attention, + use_attention_ffn=self.config.use_attention_ffn, + ) + for _ in range(self.config.num_layers) + ] + ) + + class MatryoshkaTransformerBlock(nn.Module): r""" Matryoshka Transformer block. @@ -691,9 +700,8 @@ def __init__( num_attention_heads: int, attention_head_dim: int, cross_attention_dim: Optional[int] = None, - upcast_attention: bool = True, - attention_type: str = "default", - attention_ff_inner_dim: Optional[int] = None, + upcast_attention: bool = False, + use_attention_ffn: bool = True, ): super().__init__() self.dim = dim @@ -734,12 +742,9 @@ def __init__( self.proj_out = nn.Linear(dim, dim) - if attention_ff_inner_dim is not None: + if use_attention_ffn: # 3. Feed-forward - self.ff = MatryoshkaFeedForward( - dim, - inner_dim=attention_ff_inner_dim, - ) + self.ff = MatryoshkaFeedForward(dim) else: self.ff = None @@ -939,7 +944,6 @@ def get_down_block( attention_pre_only=attention_pre_only, ) - def get_mid_block( mid_block_type: str, temb_channels: int, @@ -986,7 +990,6 @@ def get_mid_block( attention_pre_only=attention_pre_only, ) - def get_up_block( up_block_type: str, num_layers: int, @@ -1441,9 +1444,7 @@ def __init__( norm_type=norm_type, resnet_groups=norm_num_groups, output_scale_factor=mid_block_scale_factor, - transformer_layers_per_block=transformer_layers_per_block[-1] - if norm_type != "layer_norm_matryoshka" - else 1, + transformer_layers_per_block=1, num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], dual_cross_attention=dual_cross_attention, From 221c9541cb0fc0d7342d4bc43b6ee1e31e69f879 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 17:42:19 +0300 Subject: [PATCH 022/109] make style --- examples/community/matryoshka.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index c3ba04a23233..715f92e2365c 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -12,7 +12,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config, LegacyConfigMixin +from diffusers.configuration_utils import ConfigMixin, FrozenDict, LegacyConfigMixin, register_to_config from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import ( FromSingleFileMixin, @@ -47,7 +47,7 @@ Timesteps, ) from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.modeling_utils import ModelMixin, LegacyModelMixin +from diffusers.models.modeling_utils import LegacyModelMixin, ModelMixin from diffusers.models.resnet import ResnetBlock2D from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from diffusers.models.upsampling import Upsample2D @@ -654,7 +654,6 @@ def custom_forward(*inputs): class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): - _supports_gradient_checkpointing = True _no_split_modules = ["MatryoshkaTransformerBlock"] @@ -944,6 +943,7 @@ def get_down_block( attention_pre_only=attention_pre_only, ) + def get_mid_block( mid_block_type: str, temb_channels: int, @@ -990,6 +990,7 @@ def get_mid_block( attention_pre_only=attention_pre_only, ) + def get_up_block( up_block_type: str, num_layers: int, From c75e7237f93411aea97ded28e6a30cfda80e8baa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 18:02:10 +0300 Subject: [PATCH 023/109] Refactor `MatryoshkaTransformer2DModel` to add `forward()`and add `MatryoshkaTransformer2DModelOutput` --- examples/community/matryoshka.py | 145 +++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 715f92e2365c..1860a5b41e92 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -653,6 +653,19 @@ def custom_forward(*inputs): return hidden_states +@dataclass +class MatryoshkaTransformer2DModelOutput(BaseOutput): + """ + The output of [`MatryoshkaTransformer2DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`MatryoshkaTransformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: "torch.Tensor" # noqa: F821 + class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["MatryoshkaTransformerBlock"] @@ -670,6 +683,7 @@ def __init__( ): super().__init__() self.in_channels = self.config.num_attention_heads * self.config.attention_head_dim + self.gradient_checkpointing = False self.transformer_blocks = nn.ModuleList( [ @@ -685,6 +699,137 @@ def __init__( ] ) + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`MatryoshkaTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~MatryoshkaTransformer2DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. + """ + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size, _, height, width = hidden_states.shape + residual = hidden_states + # TODO: Do we need reshape here? + hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + # TODO: Do we need reshape here? + output = hidden_states + residual + + if not return_dict: + return (output,) + + return MatryoshkaTransformer2DModelOutput(sample=output) + class MatryoshkaTransformerBlock(nn.Module): r""" From e5db6e310b2a52e9cee3fa33c0ebdd21b629403b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 18:02:35 +0300 Subject: [PATCH 024/109] make style --- examples/community/matryoshka.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 1860a5b41e92..2010f6d673fb 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -666,6 +666,7 @@ class MatryoshkaTransformer2DModelOutput(BaseOutput): sample: "torch.Tensor" # noqa: F821 + class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin): _supports_gradient_checkpointing = True _no_split_modules = ["MatryoshkaTransformerBlock"] From 728fb42138e986e2114914af14ef64ffe1a546ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 17 Sep 2024 23:11:06 +0300 Subject: [PATCH 025/109] Up --- examples/community/matryoshka.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 2010f6d673fb..0966af89c7cc 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -245,7 +245,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block[i], # ???? + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_attention_ffn=use_attention_ffn, @@ -406,7 +406,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block[i], # ???? + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_attention_ffn=use_attention_ffn, @@ -551,7 +551,7 @@ def __init__( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, - num_layers=transformer_layers_per_block[i], # ???? + num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, upcast_attention=upcast_attention, use_attention_ffn=use_attention_ffn, From 464600dc529e95fbedf6d233ebcd75e4ceb7c1f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 18 Sep 2024 13:27:30 +0300 Subject: [PATCH 026/109] Remove redundant attention projections in `MatryoshkaTransformerBlock` --- examples/community/matryoshka.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0966af89c7cc..8235ead71380 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -868,6 +868,9 @@ def __init__( processor=FusedAttnProcessor2_0(), ) self.attn1.fuse_projections() + del self.attn1.to_q + del self.attn1.to_k + del self.attn1.to_v # 2. Cross-Attn if cross_attention_dim is not None and cross_attention_dim > 0: @@ -883,7 +886,9 @@ def __init__( processor=FusedAttnProcessor2_0(), ) self.attn2.fuse_projections() - # self.attn2.to_q = None + del self.attn2.to_q + del self.attn2.to_k + del self.attn2.to_v self.proj_out = nn.Linear(dim, dim) From 36d9d295fc08990b48bc98c2f66d21d274f57e8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 18 Sep 2024 19:08:01 +0300 Subject: [PATCH 027/109] Up --- examples/community/matryoshka.py | 193 ++++++++++++++++++++++++------- 1 file changed, 151 insertions(+), 42 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 8235ead71380..309d85a2f1eb 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -8,7 +8,6 @@ import torch.nn as nn import torch.utils.checkpoint from packaging import version -from torch.nn import functional as F from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback @@ -23,7 +22,7 @@ UNet2DConditionLoadersMixin, ) from diffusers.loaders.single_file_model import FromOriginalModelMixin -from diffusers.models.activations import get_activation +from diffusers.models.activations import GELU, get_activation from diffusers.models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -784,7 +783,7 @@ def forward( batch_size, _, height, width = hidden_states.shape residual = hidden_states # TODO: Do we need reshape here? - hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) + # hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) # 2. Blocks for block in self.transformer_blocks: @@ -865,7 +864,7 @@ def __init__( bias=True, upcast_attention=upcast_attention, pre_only=True, - processor=FusedAttnProcessor2_0(), + processor=MatryoshkaFusedAttnProcessor1_0_or_2_0(), ) self.attn1.fuse_projections() del self.attn1.to_q @@ -883,7 +882,7 @@ def __init__( bias=True, upcast_attention=upcast_attention, pre_only=True, - processor=FusedAttnProcessor2_0(), + processor=MatryoshkaFusedAttnProcessor1_0_or_2_0(), ) self.attn2.fuse_projections() del self.attn2.to_q @@ -928,10 +927,8 @@ def forward( attn_output, query = self.attn1( hidden_states, - **cross_attention_kwargs, + # **cross_attention_kwargs, ) - cross_attention_kwargs["self_attention_output"] = attn_output - cross_attention_kwargs["self_attention_query"] = query # 2. Cross-Attention if self.cross_attention_dim is not None and self.cross_attention_dim > 0: @@ -939,13 +936,15 @@ def forward( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - **cross_attention_kwargs, + self_attention_output=attn_output, + self_attention_query=query, + # **cross_attention_kwargs, ) - attn_output = attn_output + attn_output_cond - attn_output = attn_output.reshape(batch_size, channels, *spatial_dims) - attn_output = self.proj_out(attn_output) - hidden_states = hidden_states + attn_output + attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous() + attn_output_cond = self.proj_out(attn_output_cond) + attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims) + hidden_states = hidden_states + attn_output_cond if self.ff is not None: # 4. Feed-forward @@ -960,39 +959,145 @@ def forward( return hidden_states -class GELU(nn.Module): +class MatryoshkaFusedAttnProcessor1_0_or_2_0: r""" - GELU activation function with tanh approximation support with `approximate="tanh"`. + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses + fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. + For cross-attention modules, key and value projection matrices are fused. - Parameters: - dim_in (`int`): The number of channels in the input. - dim_out (`int`): The number of channels in the output. - approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. - bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + + + This API is currently 🧪 experimental in nature and can change in future. + + """ - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out, bias=bias) - self.approximate = approximate - - def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) - - def forward(self, hidden_states): - if hidden_states.ndim == 4: - batch_size, channels, height, width = hidden_states.shape - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(-1, channels) - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2) + # def __init__(self): + # if not hasattr(F, "scaled_dot_product_attention"): + # raise ImportError( + # "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x." + # ) + + # TODO: They seem to give different results; but nevertheless can I replace this with torch.nn.functional.scaled_dot_product_attention()? + def attention(self, q, k, v, num_heads, mask=None): + bs, width, length = q.shape + ch = width // num_heads + scale = 1 / torch.sqrt(torch.sqrt(torch.tensor(ch))) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).reshape(bs * num_heads, ch, length), + (k * scale).reshape(bs * num_heads, ch, -1), + ) # More stable with f16 than dividing afterwards + if mask is not None: + mask = mask.view(mask.size(0), 1, 1, mask.size(1)).repeat(1, num_heads, 1, 1).flatten(0, 1) + weight = weight.masked_fill(mask == 0, float("-inf")) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * num_heads, ch, -1)) + return a.reshape(bs, -1, length) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + self_attention_query: Optional[torch.Tensor] = None, + self_attention_output: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + # batch_size, sequence_length, _ = ( + # hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + # ) + + # if attention_mask is not None: + # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # # scaled_dot_product_attention expects attention_mask shape to be + # # (batch, heads, source_length, target_length) + # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2) + + # Reshape hidden_states to 2D tensor + hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1).contiguous() + # Now hidden_states.shape is [batch_size, height * width, channels] + + if encoder_hidden_states is None: + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) else: - hidden_states = self.proj(hidden_states) - hidden_states = self.gelu(hidden_states) - return hidden_states + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + if self_attention_query is not None: + query = self_attention_query + else: + query = attn.to_q(hidden_states) + + kv = attn.to_kv(encoder_hidden_states) + split_size = kv.shape[-1] // 2 + key, value = torch.split(kv, split_size, dim=-1) + + # inner_dim = key.shape[-1] + # head_dim = inner_dim // attn.heads + + # query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + # value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if self_attention_output is None: + query = query.permute(0, 2, 1) + key = key.permute(0, 2, 1) + value = value.permute(0, 2, 1) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available + hidden_states = self.attention( + query, + key, + value, + mask=attention_mask, + num_heads=attn.heads, # , dropout_p=0.0, is_causal=False + ) + + # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if self_attention_output is not None: + hidden_states = hidden_states + self_attention_output + + if not attn.pre_only: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states if self_attention_output is not None else (hidden_states, query) class MatryoshkaFeedForward(nn.Module): @@ -1012,7 +1117,11 @@ def __init__( self.linear_out = nn.Linear(dim * 4, dim) def forward(self, x): - return self.linear_out(self.linear_gelu(self.group_norm(x))) + batch_size, channels, *spatial_dims = x.shape + x = x.view(batch_size, channels, -1).permute(0, 2, 1) + x = self.linear_out(self.linear_gelu(self.group_norm(x))) + x = x.permute(0, 2, 1).view(batch_size, channels, *spatial_dims) + return x def get_down_block( From 9e37e00c1c07dec219fc6e51d6e7cf095107973a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 18 Sep 2024 19:56:12 +0300 Subject: [PATCH 028/109] Fix shape issue --- examples/community/matryoshka.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 309d85a2f1eb..60488506945b 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1118,8 +1118,9 @@ def __init__( def forward(self, x): batch_size, channels, *spatial_dims = x.shape + x = self.group_norm(x) x = x.view(batch_size, channels, -1).permute(0, 2, 1) - x = self.linear_out(self.linear_gelu(self.group_norm(x))) + x = self.linear_out(self.linear_gelu(x)) x = x.permute(0, 2, 1).view(batch_size, channels, *spatial_dims) return x From b57318219f347d587c74c15aa6f6ab6715150222 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 18 Sep 2024 19:56:30 +0300 Subject: [PATCH 029/109] Up --- examples/community/matryoshka.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 60488506945b..b39ea49667d6 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -823,7 +823,8 @@ def custom_forward(*inputs): # 3. Output # TODO: Do we need reshape here? - output = hidden_states + residual + # output = hidden_states + residual + output = hidden_states if not return_dict: return (output,) From f35a8f9f3950755e9ab4d5ad4f50e31823233f6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 18 Sep 2024 19:57:15 +0300 Subject: [PATCH 030/109] make style --- examples/community/matryoshka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index b39ea49667d6..aea8cfa443ec 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -781,7 +781,7 @@ def forward( # 1. Input batch_size, _, height, width = hidden_states.shape - residual = hidden_states + # residual = hidden_states # TODO: Do we need reshape here? # hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) From 0f6bce5b19a1599ad622551490d7a0f734880000 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 18 Sep 2024 22:25:35 +0300 Subject: [PATCH 031/109] Up --- examples/community/matryoshka.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index aea8cfa443ec..4f2d8cddc9d9 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -779,13 +779,7 @@ def forward( encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # 1. Input - batch_size, _, height, width = hidden_states.shape - # residual = hidden_states - # TODO: Do we need reshape here? - # hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states) - - # 2. Blocks + # Blocks for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: @@ -821,9 +815,7 @@ def custom_forward(*inputs): class_labels=class_labels, ) - # 3. Output - # TODO: Do we need reshape here? - # output = hidden_states + residual + # Output output = hidden_states if not return_dict: @@ -948,7 +940,7 @@ def forward( hidden_states = hidden_states + attn_output_cond if self.ff is not None: - # 4. Feed-forward + # 3. Feed-forward if self._chunk_size is not None: # "feed_forward_chunk_size" can be used to save memory ff_output = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size) @@ -1055,12 +1047,6 @@ def __call__( split_size = kv.shape[-1] // 2 key, value = torch.split(kv, split_size, dim=-1) - # inner_dim = key.shape[-1] - # head_dim = inner_dim // attn.heads - - # query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if self_attention_output is None: query = query.permute(0, 2, 1) key = key.permute(0, 2, 1) @@ -1078,10 +1064,9 @@ def __call__( key, value, mask=attention_mask, - num_heads=attn.heads, # , dropout_p=0.0, is_causal=False + num_heads=attn.heads, ) - # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if self_attention_output is not None: From 1d48420d454f0ae3d2cde46ce4fa86d67455a4aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 18 Sep 2024 22:31:28 +0300 Subject: [PATCH 032/109] Refactor condition embedding in `MatryoshkaCombinedTimestepTextEmbedding` --- examples/community/matryoshka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 4f2d8cddc9d9..61f967f98b82 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1345,8 +1345,8 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if micro is not None: temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) + cond_emb = cond_emb + temb_micro_conditioning - cond_emb = cond_emb if micro is None else cond_emb + temb_micro_conditioning return cond_emb, conditioning_mask From b476da952fe16e01a1acee8a776255f79be82bc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 19 Sep 2024 20:26:10 +0300 Subject: [PATCH 033/109] Adapt `DDIMScheduler` for `x_0` prediction by exploiting `gammas` --- examples/community/matryoshka.py | 506 +++++++++++++++++++++++++++++++ 1 file changed, 506 insertions(+) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 61f967f98b82..04f7fba3b0b2 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1,9 +1,11 @@ # #TODO Licensed under the Apache License, Version 2.0 or MIT? import inspect +import math from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import torch import torch.nn as nn import torch.utils.checkpoint @@ -54,6 +56,7 @@ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import ( USE_PEFT_BACKEND, BaseOutput, @@ -184,6 +187,509 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: return ff_output +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->MatryoshkaDDIM +class MatryoshkaDDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + if self.config.timestep_spacing == "matryoshka_style": + self.betas = torch.cat((torch.tensor([0]), betas_for_alpha_bar(num_train_timesteps))) + else: + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + log_alphas = torch.log(self.alphas) + self.gammas = torch.exp(torch.cumsum(log_alphas, dim=0)) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps) + .round()[::-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + elif self.config.timestep_spacing == "matryoshka_style": + step_ratio = (self.config.num_train_timesteps + 1) / (num_inference_steps + 1) + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1].copy().astype(np.int64) + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + + self.timesteps = torch.from_numpy(timesteps).to(device) + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[MatryoshkaDDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + gamma_t = self.gammas[timestep] + # gamma_last = self.gammas[prev_timestep] + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + if self.config.timestep_spacing == "matryoshka_style": + pred_original_sample = (gamma_t**0.5) * sample - ((1 - gamma_t) ** 0.5) * model_output + else: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return MatryoshkaDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement + # for the subsequent add_noise calls + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device) + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps + + class CrossAttnDownBlock2D(nn.Module): def __init__( self, From 6a978b26b32dedd5055a95d520d9ccb55a7a5352 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 20 Sep 2024 15:38:22 +0300 Subject: [PATCH 034/109] Fix `prev_timestep` index --- examples/community/matryoshka.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 04f7fba3b0b2..719518cb90f5 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -566,7 +566,10 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + if self.config.timestep_spacing != "matryoshka_style": + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + else: + prev_timestep = self.timesteps[torch.nonzero(self.timesteps == timestep).item() + 1] # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] From 368e044a341720cfd75ce038a502d76745cb1668 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 20 Sep 2024 15:39:14 +0300 Subject: [PATCH 035/109] Up --- examples/community/matryoshka.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 719518cb90f5..b7d25e1295bc 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -380,8 +380,6 @@ def __init__( self.betas = rescale_zero_terminal_snr(self.betas) self.alphas = 1.0 - self.betas - log_alphas = torch.log(self.alphas) - self.gammas = torch.exp(torch.cumsum(log_alphas, dim=0)) self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) # At every step in ddim, we are looking into the previous alphas_cumprod @@ -574,8 +572,6 @@ def step( # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - gamma_t = self.gammas[timestep] - # gamma_last = self.gammas[prev_timestep] beta_prod_t = 1 - alpha_prod_t @@ -588,10 +584,7 @@ def step( pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": - if self.config.timestep_spacing == "matryoshka_style": - pred_original_sample = (gamma_t**0.5) * sample - ((1 - gamma_t) ** 0.5) * model_output - else: - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( From a146ae4f20be1bee1c8ac70d0d34d309cae457aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 20 Sep 2024 16:47:12 +0300 Subject: [PATCH 036/109] Fix normalization group size in `MatryoshkaTransformerBlock` --- examples/community/matryoshka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index b7d25e1295bc..cdf54dff3b84 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1355,7 +1355,7 @@ def __init__( cross_attention_dim=None, heads=num_attention_heads, dim_head=attention_head_dim, - norm_num_groups=32 or None, + norm_num_groups=32, bias=True, upcast_attention=upcast_attention, pre_only=True, From abbb3d4d0a625956855037bb0e157100e6cdd66a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 20 Sep 2024 17:26:41 +0300 Subject: [PATCH 037/109] Refactor class names --- examples/community/matryoshka.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index cdf54dff3b84..99bc838b26b7 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1853,9 +1853,9 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): @dataclass -class NestedUNet2DConditionOutput(BaseOutput): +class MatryoshkaUNet2DConditionOutput(BaseOutput): """ - The output of [`NestedUNet2DConditionOutput`]. + The output of [`MatryoshkaUNet2DConditionOutput`]. Args: sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): @@ -1865,7 +1865,7 @@ class NestedUNet2DConditionOutput(BaseOutput): sample: torch.Tensor = None -class NestedUNet2DConditionModel( +class MatryoshkaUNet2DConditionModel( ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin ): r""" @@ -2865,7 +2865,7 @@ def forward( down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, - ) -> Union[NestedUNet2DConditionOutput, Tuple]: + ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: r""" The [`NestedUNet2DConditionModel`] forward method. @@ -3126,7 +3126,7 @@ def forward( if not return_dict: return (sample,) - return NestedUNet2DConditionOutput(sample=sample) + return MatryoshkaUNet2DConditionOutput(sample=sample) class MatryoshkaPipeline( @@ -3151,14 +3151,12 @@ class MatryoshkaPipeline( - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. - unet ([`NestedUNet2DConditionModel`]): - A `NestedUNet2DConditionModel` to denoise the encoded image latents. + unet ([`MatryoshkaUNet2DConditionModel`]): + A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -3179,7 +3177,7 @@ def __init__( self, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, - unet: NestedUNet2DConditionModel, + unet: MatryoshkaUNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, From a09266e5a2511eb4000bd8929405cf3f17fe6990 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 20 Sep 2024 17:41:02 +0300 Subject: [PATCH 038/109] Add `NestedUNet2DConditionModel` template --- examples/community/matryoshka.py | 121 ++++++++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 99bc838b26b7..04f51d703f0e 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -72,7 +72,7 @@ if is_torch_xla_available(): - import torch_xla.core.xla_model as xm + import torch_xla.core.xla_model as xm # type: ignore XLA_AVAILABLE = True else: @@ -3129,6 +3129,125 @@ def forward( return MatryoshkaUNet2DConditionOutput(sample=sample) +class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): + def __init__(self, input_channels, output_channels, config): + super().__init__(input_channels, output_channels, config) + config.inner_config.conditioning_feature_dim = config.conditioning_feature_dim + if getattr(config.inner_config, "inner_config", None) is None: + self.inner_unet = MatryoshkaUNet2DConditionModel(input_channels, output_channels, config.inner_config) + else: + self.inner_unet = NestedUNet2DConditionModel(input_channels, output_channels, config.inner_config) + + if not config.skip_inner_unet_input: + self.in_adapter = nn.Conv2d( + config.resolution_channels[-1], + config.inner_config.resolution_channels[0], + kernel_size=3, + padding=1, + bias=True, + ) + else: + self.in_adapter = None + self.out_adapter = nn.Conv2d( + config.inner_config.resolution_channels[0], + config.resolution_channels[-1], + kernel_size=3, + padding=1, + bias=True, + ) + + self.is_temporal = [config.temporal_mode and (not config.temporal_spatial_ds)] + if hasattr(self.inner_unet, "is_temporal"): + self.is_temporal += self.inner_unet.is_temporal + + nest_ratio = int(2 ** (len(config.resolution_channels) - 1)) + if self.is_temporal[0]: + nest_ratio = int(np.sqrt(nest_ratio)) + if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": + self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio + else: + self.nest_ratio = [nest_ratio] + + if config.initialize_inner_with_pretrained is not None: + try: + self.inner_unet.load(config.initialize_inner_with_pretrained) + except Exception as e: + print("<-- load pretrained checkpoint error -->") + print(f"{e}") + + if config.freeze_inner_unet: + for p in self.inner_unet.parameters(): + p.requires_grad = False + if config.interp_conditioning: + self.interp_layer1 = nn.Linear(self.temporal_dim // 4, self.temporal_dim) + self.interp_layer2 = nn.Linear(self.temporal_dim, self.temporal_dim) + + @property + def model_type(self): + return "nested_unet" + + def forward_conditioning(self, *args, **kwargs): + return self.inner_unet.forward_conditioning(*args, **kwargs) + + def forward_denoising(self, x_t, times, cond_emb=None, conditioning=None, cond_mask=None, micros={}): + # 1. time embedding + temb = self.create_temporal_embedding(times) + if cond_emb is not None: + temb = temb + cond_emb + if self.conditions is not None: + temb = temb + self.forward_micro_conditioning(times, micros) + + # 2. input layer (normalize the input) + if self._config.nesting: + x_t, x_feat = x_t + bsz = [x.size(0) for x in x_t] + bh, bl = bsz[0], bsz[1] + x_t_low, x_t = x_t[1:], x_t[0] + x = self.forward_input_layer(x_t, normalize=(not self.config.skip_normalization)) + if self._config.nesting: + x = x + x_feat + + # 3. downsample blocks in the outer layers + x, skip_activations = self.forward_downsample( + x, + temb[:bh], + conditioning[:bh], + cond_mask[:bh] if cond_mask is not None else cond_mask, + ) + + # 4. run inner unet + x_inner = self.in_adapter(x) if self.in_adapter is not None else None + x_inner = ( + torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner + ) # pad zeros for low-resolutions + x_low, x_inner = self.inner_unet.forward_denoising( + (x_t_low, x_inner), times, cond_emb, conditioning, cond_mask, micros + ) + x_inner = self.out_adapter(x_inner) + x = x + x_inner[:bh] if bh < bl else x + x_inner + + # 5. upsample blocks in the outer layers + x = self.forward_upsample( + x, + temb[:bh], + conditioning[:bh], + cond_mask[:bh] if cond_mask is not None else cond_mask, + skip_activations, + ) + + # 6. output layer + x_out = self.forward_output_layer(x) + + # 7. outpupt both low and high-res output + if isinstance(x_low, list): + out = [x_out] + x_low + else: + out = [x_out, x_low] + if self._config.nesting: + return out, x + return out + + class MatryoshkaPipeline( DiffusionPipeline, StableDiffusionMixin, From 85241b358a835b74ef6db4951c177351e7517664 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 20 Sep 2024 18:55:26 +0300 Subject: [PATCH 039/109] Adapt `NestedUNet2DConditionModel` initialization and configuration --- examples/community/matryoshka.py | 58 +++++++++++++++++--------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 04f51d703f0e..82f396ae9f01 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -2024,6 +2024,7 @@ def __init__( mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, + nesting: Optional[int] = False, ): super().__init__() @@ -2290,6 +2291,8 @@ def __init__( self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + self.register_to_config(is_temporal=[]) + def _check_config( self, down_block_types: Tuple[str], @@ -3130,57 +3133,58 @@ def forward( class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): - def __init__(self, input_channels, output_channels, config): - super().__init__(input_channels, output_channels, config) - config.inner_config.conditioning_feature_dim = config.conditioning_feature_dim - if getattr(config.inner_config, "inner_config", None) is None: - self.inner_unet = MatryoshkaUNet2DConditionModel(input_channels, output_channels, config.inner_config) + """ + Nested UNet model with condition for image denoising. + """ + @register_to_config + def __init__(self, input_channels=3, output_channels=3, *args, **kwargs): + super().__init__(input_channels=3, output_channels=3, *args, **kwargs) + self.register_to_config(inner_config_conditioning_feature_dim=self.config.conditioning_feature_dim) + #self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim + + if getattr(self.config.inner_config.inner_config, None) is None: + self.inner_unet = MatryoshkaUNet2DConditionModel(input_channels, output_channels, self.config.inner_config) else: - self.inner_unet = NestedUNet2DConditionModel(input_channels, output_channels, config.inner_config) + self.inner_unet = NestedUNet2DConditionModel(input_channels, output_channels, self.config.inner_config) - if not config.skip_inner_unet_input: + if not self.config.skip_inner_unet_input: self.in_adapter = nn.Conv2d( - config.resolution_channels[-1], - config.inner_config.resolution_channels[0], + self.config.resolution_channels[-1], + self.config.inner_config.resolution_channels[0], kernel_size=3, padding=1, - bias=True, ) else: self.in_adapter = None self.out_adapter = nn.Conv2d( - config.inner_config.resolution_channels[0], - config.resolution_channels[-1], + self.config.inner_config.resolution_channels[0], + self.config.resolution_channels[-1], kernel_size=3, padding=1, - bias=True, ) - self.is_temporal = [config.temporal_mode and (not config.temporal_spatial_ds)] - if hasattr(self.inner_unet, "is_temporal"): - self.is_temporal += self.inner_unet.is_temporal + self.register_to_config(is_temporal=[self.config.temporal_mode and (not self.config.temporal_spatial_ds)]) + if hasattr(self.inner_unet.config, "is_temporal"): + self.register_to_config(is_temporal = self.config.is_temporal + self.inner_unet.config.is_temporal) - nest_ratio = int(2 ** (len(config.resolution_channels) - 1)) + nest_ratio = int(2 ** (len(self.config.resolution_channels) - 1)) if self.is_temporal[0]: nest_ratio = int(np.sqrt(nest_ratio)) if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": - self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio + self.register_to_config(nest_ratio=[nest_ratio * self.inner_unet.config.nest_ratio[0]] + self.inner_unet.config.nest_ratio) else: - self.nest_ratio = [nest_ratio] + self.register_to_config(nest_ratio=[nest_ratio]) - if config.initialize_inner_with_pretrained is not None: + if self.config.initialize_inner_with_pretrained is not None: try: - self.inner_unet.load(config.initialize_inner_with_pretrained) + self.inner_unet.from_pretrained(self.config.initialize_inner_with_pretrained) except Exception as e: print("<-- load pretrained checkpoint error -->") print(f"{e}") - if config.freeze_inner_unet: - for p in self.inner_unet.parameters(): - p.requires_grad = False - if config.interp_conditioning: - self.interp_layer1 = nn.Linear(self.temporal_dim // 4, self.temporal_dim) - self.interp_layer2 = nn.Linear(self.temporal_dim, self.temporal_dim) + # if self.config.interp_conditioning: # Seems False for all cases + # self.interp_layer1 = nn.Linear(self.temporal_dim // 4, self.temporal_dim) + # self.interp_layer2 = nn.Linear(self.temporal_dim, self.temporal_dim) @property def model_type(self): From b7df3bb49055d3fa5dc5ea5e15b3b62273a39a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 20 Sep 2024 18:56:03 +0300 Subject: [PATCH 040/109] make style --- examples/community/matryoshka.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 82f396ae9f01..70b925a51d65 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3136,11 +3136,12 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): """ Nested UNet model with condition for image denoising. """ + @register_to_config def __init__(self, input_channels=3, output_channels=3, *args, **kwargs): super().__init__(input_channels=3, output_channels=3, *args, **kwargs) self.register_to_config(inner_config_conditioning_feature_dim=self.config.conditioning_feature_dim) - #self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim + # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim if getattr(self.config.inner_config.inner_config, None) is None: self.inner_unet = MatryoshkaUNet2DConditionModel(input_channels, output_channels, self.config.inner_config) @@ -3165,13 +3166,15 @@ def __init__(self, input_channels=3, output_channels=3, *args, **kwargs): self.register_to_config(is_temporal=[self.config.temporal_mode and (not self.config.temporal_spatial_ds)]) if hasattr(self.inner_unet.config, "is_temporal"): - self.register_to_config(is_temporal = self.config.is_temporal + self.inner_unet.config.is_temporal) + self.register_to_config(is_temporal=self.config.is_temporal + self.inner_unet.config.is_temporal) nest_ratio = int(2 ** (len(self.config.resolution_channels) - 1)) if self.is_temporal[0]: nest_ratio = int(np.sqrt(nest_ratio)) if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": - self.register_to_config(nest_ratio=[nest_ratio * self.inner_unet.config.nest_ratio[0]] + self.inner_unet.config.nest_ratio) + self.register_to_config( + nest_ratio=[nest_ratio * self.inner_unet.config.nest_ratio[0]] + self.inner_unet.config.nest_ratio + ) else: self.register_to_config(nest_ratio=[nest_ratio]) From 22c148f195bc44cc5d8fd1786741f3daa1853c12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 20 Sep 2024 22:10:09 +0300 Subject: [PATCH 041/109] Add template of `forward` for `NestedUNet2DConditionModel` --- examples/community/matryoshka.py | 211 +++++++++++++++++++++++++++---- 1 file changed, 189 insertions(+), 22 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 70b925a51d65..fda32d433523 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3132,6 +3132,14 @@ def forward( return MatryoshkaUNet2DConditionOutput(sample=sample) +class NestedUNet2DConditionOutput(BaseOutput): + """ + Output type for the [`NestedUNet2DConditionModel`] model. + """ + + sample_and_x_low: list = [] + x: torch.Tensor = None + class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): """ Nested UNet model with condition for image denoising. @@ -3196,37 +3204,187 @@ def model_type(self): def forward_conditioning(self, *args, **kwargs): return self.inner_unet.forward_conditioning(*args, **kwargs) - def forward_denoising(self, x_t, times, cond_emb=None, conditioning=None, cond_mask=None, micros={}): - # 1. time embedding - temb = self.create_temporal_embedding(times) - if cond_emb is not None: - temb = temb + cond_emb - if self.conditions is not None: - temb = temb + self.forward_micro_conditioning(times, micros) + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: + r""" + The [`NestedUNet2DConditionModel`] forward method. + + Args: + sample (`torch.Tensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~NestedUNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~NestedUNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~NestedUNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + added_cond_kwargs = added_cond_kwargs or {} + added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention + added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + aug_emb, cond_mask = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) # 2. input layer (normalize the input) - if self._config.nesting: - x_t, x_feat = x_t + # if self._config.nesting: + # x_t, x_feat = x_t bsz = [x.size(0) for x in x_t] bh, bl = bsz[0], bsz[1] x_t_low, x_t = x_t[1:], x_t[0] - x = self.forward_input_layer(x_t, normalize=(not self.config.skip_normalization)) - if self._config.nesting: - x = x + x_feat + if not self.config.skip_normalization: + x_t = x_t / x_t.std((1, 2, 3), keepdims=True) + x = self.conv_in(x_t) + # if self._config.nesting: + # x = x + x_feat # 3. downsample blocks in the outer layers x, skip_activations = self.forward_downsample( x, - temb[:bh], - conditioning[:bh], + emb[:bh], + encoder_hidden_states[:bh], cond_mask[:bh] if cond_mask is not None else cond_mask, ) + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples # 4. run inner unet x_inner = self.in_adapter(x) if self.in_adapter is not None else None x_inner = ( torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner ) # pad zeros for low-resolutions + # TODO: Add support for innerability for the inner unet x_low, x_inner = self.inner_unet.forward_denoising( (x_t_low, x_inner), times, cond_emb, conditioning, cond_mask, micros ) @@ -3242,17 +3400,26 @@ def forward_denoising(self, x_t, times, cond_emb=None, conditioning=None, cond_m skip_activations, ) - # 6. output layer - x_out = self.forward_output_layer(x) + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(x_low) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) - # 7. outpupt both low and high-res output + # 7. output both low and high-res output if isinstance(x_low, list): - out = [x_out] + x_low + out = [sample] + x_low else: - out = [x_out, x_low] - if self._config.nesting: - return out, x - return out + out = [sample, x_low] + if self.inner_unet.config.nesting: + return NestedUNet2DConditionOutput(sample_and_x_low=out, x=x) + if not return_dict: + return (out,) + return NestedUNet2DConditionOutput(sample_and_x_low=out) class MatryoshkaPipeline( From 651cd7687ecd9d865975ae7834bbbe5336f40d66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 21 Sep 2024 14:34:19 +0300 Subject: [PATCH 042/109] Refactor `NestedUNet2DConditionModel` forward method --- examples/community/matryoshka.py | 105 +++++++++++++++++++------------ 1 file changed, 66 insertions(+), 39 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index fda32d433523..45935efe0519 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1847,9 +1847,10 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if micro is not None: temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) - cond_emb = cond_emb + temb_micro_conditioning + cond_emb_micro = cond_emb + temb_micro_conditioning + return cond_emb_micro, conditioning_mask, cond_emb - return cond_emb, conditioning_mask + return cond_emb, conditioning_mask, cond_emb @dataclass @@ -2858,6 +2859,7 @@ def forward( sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, + aug_emb: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -2868,6 +2870,7 @@ def forward( down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, + from_nested: bool = False, ) -> Union[MatryoshkaUNet2DConditionOutput, Tuple]: r""" The [`NestedUNet2DConditionModel`] forward method. @@ -2969,13 +2972,14 @@ def forward( added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale - encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) + if not from_nested: + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) - aug_emb, cond_mask = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) + aug_emb, encoder_attention_mask, _ = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) @@ -3039,7 +3043,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, + encoder_attention_mask=encoder_attention_mask, # cond_mask? **additional_residuals, ) else: @@ -3069,7 +3073,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, + encoder_attention_mask=encoder_attention_mask, # cond_mask? ) else: sample = self.mid_block(sample, emb) @@ -3106,7 +3110,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, + encoder_attention_mask=encoder_attention_mask, # cond_mask? ) else: sample = upsample_block( @@ -3324,7 +3328,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb, cond_mask = self.get_aug_embed( + aug_emb, cond_mask, cond_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) if self.config.addition_embed_type == "image_hint": @@ -3338,23 +3342,47 @@ def forward( # 2. input layer (normalize the input) # if self._config.nesting: - # x_t, x_feat = x_t - bsz = [x.size(0) for x in x_t] + # sample, x_feat = sample + bsz = [x.size(0) for x in sample] bh, bl = bsz[0], bsz[1] - x_t_low, x_t = x_t[1:], x_t[0] + x_t_low, sample = sample[1:], sample[0] if not self.config.skip_normalization: - x_t = x_t / x_t.std((1, 2, 3), keepdims=True) - x = self.conv_in(x_t) + sample = sample / sample.std((1, 2, 3), keepdims=True) + sample = self.conv_in(sample) # if self._config.nesting: - # x = x + x_feat + # sample = sample + x_feat + + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True # 3. downsample blocks in the outer layers - x, skip_activations = self.forward_downsample( - x, - emb[:bh], - encoder_hidden_states[:bh], - cond_mask[:bh] if cond_mask is not None else cond_mask, - ) down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: @@ -3365,11 +3393,11 @@ def forward( sample, res_samples = downsample_block( hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, + temb=emb[:bh], + encoder_hidden_states=encoder_hidden_states[:bh], attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, + encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, **additional_residuals, ) else: @@ -3380,24 +3408,23 @@ def forward( down_block_res_samples += res_samples # 4. run inner unet - x_inner = self.in_adapter(x) if self.in_adapter is not None else None - x_inner = ( + x_inner = self.in_adapter(sample) if self.in_adapter is not None else None + x_inner = ( # TODO: What if x_inner is None? torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner ) # pad zeros for low-resolutions - # TODO: Add support for innerability for the inner unet - x_low, x_inner = self.inner_unet.forward_denoising( - (x_t_low, x_inner), times, cond_emb, conditioning, cond_mask, micros + x_low, x_inner = self.inner_unet( + (x_t_low, x_inner), timestep, aug_emb=cond_emb, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=cond_mask, from_nested=True ) x_inner = self.out_adapter(x_inner) - x = x + x_inner[:bh] if bh < bl else x + x_inner + sample = sample + x_inner[:bh] if bh < bl else sample + x_inner # 5. upsample blocks in the outer layers - x = self.forward_upsample( - x, - temb[:bh], - conditioning[:bh], + sample = self.forward_upsample( + sample, + emb[:bh], + encoder_hidden_states[:bh], cond_mask[:bh] if cond_mask is not None else cond_mask, - skip_activations, + down_block_res_samples, ) # 6. post-process @@ -3416,7 +3443,7 @@ def forward( else: out = [sample, x_low] if self.inner_unet.config.nesting: - return NestedUNet2DConditionOutput(sample_and_x_low=out, x=x) + return NestedUNet2DConditionOutput(sample_and_x_low=out, x=sample) if not return_dict: return (out,) return NestedUNet2DConditionOutput(sample_and_x_low=out) From ea60da33d5c2b3a3d452e4562f073fe7501183c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 21 Sep 2024 15:24:58 +0300 Subject: [PATCH 043/109] Refactor `NestedUNet2DConditionModel` forward method --- examples/community/matryoshka.py | 61 +++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 45935efe0519..a2c9bbd62fc6 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3141,9 +3141,10 @@ class NestedUNet2DConditionOutput(BaseOutput): Output type for the [`NestedUNet2DConditionModel`] model. """ - sample_and_x_low: list = [] + sample_out_x_low: list = [] x: torch.Tensor = None + class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): """ Nested UNet model with condition for image denoising. @@ -3364,7 +3365,6 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where @@ -3413,25 +3413,52 @@ def forward( torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner ) # pad zeros for low-resolutions x_low, x_inner = self.inner_unet( - (x_t_low, x_inner), timestep, aug_emb=cond_emb, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=cond_mask, from_nested=True + (x_t_low, x_inner), + timestep, + aug_emb=cond_emb, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=cond_mask, + from_nested=True, ) x_inner = self.out_adapter(x_inner) sample = sample + x_inner[:bh] if bh < bl else sample + x_inner # 5. upsample blocks in the outer layers - sample = self.forward_upsample( - sample, - emb[:bh], - encoder_hidden_states[:bh], - cond_mask[:bh] if cond_mask is not None else cond_mask, - down_block_res_samples, - ) + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb[:bh], + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states[:bh], + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, # cond_mask? + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) # 6. post-process if self.conv_norm_out: - sample = self.conv_norm_out(x_low) - sample = self.conv_act(sample) - sample = self.conv_out(sample) + sample_out = self.conv_norm_out(sample) + sample_out = self.conv_act(sample_out) + sample_out = self.conv_out(sample_out) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer @@ -3439,14 +3466,14 @@ def forward( # 7. output both low and high-res output if isinstance(x_low, list): - out = [sample] + x_low + out = [sample_out] + x_low else: - out = [sample, x_low] + out = [sample_out, x_low] if self.inner_unet.config.nesting: - return NestedUNet2DConditionOutput(sample_and_x_low=out, x=sample) + return NestedUNet2DConditionOutput(sample_out_x_low=out, x=sample) if not return_dict: return (out,) - return NestedUNet2DConditionOutput(sample_and_x_low=out) + return NestedUNet2DConditionOutput(sample_out_x_low=out) class MatryoshkaPipeline( From e01421cdf2b6a90227c0f409e8e2baf0d25ff9dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 21 Sep 2024 16:56:50 +0300 Subject: [PATCH 044/109] Fix `NestedUNet2DConditionModel` initialization --- examples/community/matryoshka.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index a2c9bbd62fc6..7d754a5b98c2 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -2025,7 +2025,10 @@ def __init__( mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, + temporal_mode: bool = False, + temporal_spatial_ds: bool = False, nesting: Optional[int] = False, + inner_config: Optional[Dict] = None, ): super().__init__() @@ -3151,28 +3154,29 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): """ @register_to_config - def __init__(self, input_channels=3, output_channels=3, *args, **kwargs): - super().__init__(input_channels=3, output_channels=3, *args, **kwargs) - self.register_to_config(inner_config_conditioning_feature_dim=self.config.conditioning_feature_dim) + def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, *args, **kwargs): + super().__init__(*args, **kwargs) + #self.skip_inner_unet_input = skip_inner_unet_input + #self.register_to_config(kwargs['inner_config']['cross_attention_dim']=self.config.cross_attention_dim) # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim - if getattr(self.config.inner_config.inner_config, None) is None: - self.inner_unet = MatryoshkaUNet2DConditionModel(input_channels, output_channels, self.config.inner_config) + if getattr(self.config.inner_config, "inner_config", None) is None: + self.inner_unet = MatryoshkaUNet2DConditionModel(**self.config.inner_config) else: - self.inner_unet = NestedUNet2DConditionModel(input_channels, output_channels, self.config.inner_config) + self.inner_unet = NestedUNet2DConditionModel(**self.config.inner_config) if not self.config.skip_inner_unet_input: self.in_adapter = nn.Conv2d( - self.config.resolution_channels[-1], - self.config.inner_config.resolution_channels[0], + self.config.block_out_channels[-1], + self.config.inner_config['block_out_channels'][0], kernel_size=3, padding=1, ) else: self.in_adapter = None self.out_adapter = nn.Conv2d( - self.config.inner_config.resolution_channels[0], - self.config.resolution_channels[-1], + self.config.inner_config['block_out_channels'][0], + self.config.block_out_channels[-1], kernel_size=3, padding=1, ) @@ -3181,8 +3185,8 @@ def __init__(self, input_channels=3, output_channels=3, *args, **kwargs): if hasattr(self.inner_unet.config, "is_temporal"): self.register_to_config(is_temporal=self.config.is_temporal + self.inner_unet.config.is_temporal) - nest_ratio = int(2 ** (len(self.config.resolution_channels) - 1)) - if self.is_temporal[0]: + nest_ratio = int(2 ** (len(self.config.block_out_channels) - 1)) + if self.config.is_temporal[0]: nest_ratio = int(np.sqrt(nest_ratio)) if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": self.register_to_config( From 4d06f295f7fd14bf07bd178b824a11816a81b5e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 21 Sep 2024 16:57:15 +0300 Subject: [PATCH 045/109] Up --- examples/community/matryoshka.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 7d754a5b98c2..0941015a3155 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3156,8 +3156,8 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): @register_to_config def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, *args, **kwargs): super().__init__(*args, **kwargs) - #self.skip_inner_unet_input = skip_inner_unet_input - #self.register_to_config(kwargs['inner_config']['cross_attention_dim']=self.config.cross_attention_dim) + # self.skip_inner_unet_input = skip_inner_unet_input + # self.register_to_config(kwargs['inner_config']['cross_attention_dim']=self.config.cross_attention_dim) # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim if getattr(self.config.inner_config, "inner_config", None) is None: @@ -3168,14 +3168,14 @@ def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, *arg if not self.config.skip_inner_unet_input: self.in_adapter = nn.Conv2d( self.config.block_out_channels[-1], - self.config.inner_config['block_out_channels'][0], + self.config.inner_config["block_out_channels"][0], kernel_size=3, padding=1, ) else: self.in_adapter = None self.out_adapter = nn.Conv2d( - self.config.inner_config['block_out_channels'][0], + self.config.inner_config["block_out_channels"][0], self.config.block_out_channels[-1], kernel_size=3, padding=1, From 29fa2572a8c89fbd2329de57a8b7c45e49d9dad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 22 Sep 2024 15:01:40 +0300 Subject: [PATCH 046/109] Generalize `MatryoshkaCombinedTimestepTextEmbedding` for nesting level 1 and 0 --- examples/community/matryoshka.py | 70 +++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0941015a3155..23d94a157c31 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1824,31 +1824,39 @@ def get_up_block( class MatryoshkaCombinedTimestepTextEmbedding(nn.Module): - def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim): + def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, type): super().__init__() - self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) + if type == "unet": + self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) + else: + self.cond_emb = None self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) def forward(self, emb, encoder_hidden_states, added_cond_kwargs): conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) - if conditioning_mask is None or not masked_cross_attention: - y = encoder_hidden_states.mean(dim=1) - else: - y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( - dim=1, keepdim=True - ) + if self.cond_emb is not None: + if conditioning_mask is None or not masked_cross_attention: + y = encoder_hidden_states.mean(dim=1) + else: + y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( + dim=1, keepdim=True + ) + cond_emb = self.cond_emb(y) + if not masked_cross_attention: conditioning_mask = None - cond_emb = self.cond_emb(y) micro = added_cond_kwargs.get("micro_conditioning_scale", None) if micro is not None: temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) - cond_emb_micro = cond_emb + temb_micro_conditioning - return cond_emb_micro, conditioning_mask, cond_emb + if self.cond_emb is not None: + cond_emb_micro = cond_emb + temb_micro_conditioning + return cond_emb_micro, conditioning_mask, cond_emb + else: + return temb_micro_conditioning, conditioning_mask, None return cond_emb, conditioning_mask, cond_emb @@ -2027,6 +2035,7 @@ def __init__( addition_embed_type_num_heads: int = 64, temporal_mode: bool = False, temporal_spatial_ds: bool = False, + skip_cond_emb: bool = False, nesting: Optional[int] = False, inner_config: Optional[Dict] = None, ): @@ -2077,7 +2086,7 @@ def __init__( ) self.time_embedding = TimestepEmbedding( - timestep_input_dim, + time_embedding_dim // 4 if time_embedding_dim is not None else timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, @@ -2477,7 +2486,8 @@ def _set_add_embedding( ) elif addition_embed_type == "matryoshka": self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( - addition_time_embed_dim, cross_attention_dim, time_embed_dim + self.config.time_embedding_dim // 4 if self.config.time_embedding_dim is not None else addition_time_embed_dim, + cross_attention_dim, time_embed_dim, self.model_type ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much @@ -2857,6 +2867,10 @@ def process_encoder_hidden_states( encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states + @property + def model_type(self) -> str: + return "unet" + def forward( self, sample: torch.Tensor, @@ -3329,13 +3343,23 @@ def forward( added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale - encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) + if isinstance(self.inner_unet, MatryoshkaUNet2DConditionModel): + encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + aug_emb, cond_mask, cond_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + elif isinstance(self.inner_unet, NestedUNet2DConditionModel): + encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + aug_emb, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) - aug_emb, cond_mask, cond_emb = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) @@ -3346,16 +3370,16 @@ def forward( emb = self.time_embed_act(emb) # 2. input layer (normalize the input) - # if self._config.nesting: - # sample, x_feat = sample + if self.config.nesting: + sample, x_feat = sample bsz = [x.size(0) for x in sample] bh, bl = bsz[0], bsz[1] x_t_low, sample = sample[1:], sample[0] if not self.config.skip_normalization: sample = sample / sample.std((1, 2, 3), keepdims=True) sample = self.conv_in(sample) - # if self._config.nesting: - # sample = sample + x_feat + if self.config.nesting: + sample = sample + x_feat # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. From 1a22767ce8b4ee76fcb45e158e2036e0823a8269 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 22 Sep 2024 15:02:07 +0300 Subject: [PATCH 047/109] make style --- examples/community/matryoshka.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 23d94a157c31..0861b354efc7 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -2486,8 +2486,12 @@ def _set_add_embedding( ) elif addition_embed_type == "matryoshka": self.add_embedding = MatryoshkaCombinedTimestepTextEmbedding( - self.config.time_embedding_dim // 4 if self.config.time_embedding_dim is not None else addition_time_embed_dim, - cross_attention_dim, time_embed_dim, self.model_type + self.config.time_embedding_dim // 4 + if self.config.time_embedding_dim is not None + else addition_time_embed_dim, + cross_attention_dim, + time_embed_dim, + self.model_type, ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much From d0fa5ca558c8429cc2887b9888c26f36e77b78b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 22 Sep 2024 21:15:39 +0300 Subject: [PATCH 048/109] Up --- examples/community/matryoshka.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0861b354efc7..4520f6c3aba0 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -7,9 +7,9 @@ import numpy as np import torch -import torch.nn as nn import torch.utils.checkpoint from packaging import version +from torch import nn from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback @@ -3174,8 +3174,6 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): @register_to_config def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, *args, **kwargs): super().__init__(*args, **kwargs) - # self.skip_inner_unet_input = skip_inner_unet_input - # self.register_to_config(kwargs['inner_config']['cross_attention_dim']=self.config.cross_attention_dim) # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim if getattr(self.config.inner_config, "inner_config", None) is None: @@ -3220,6 +3218,8 @@ def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, *arg print("<-- load pretrained checkpoint error -->") print(f"{e}") + # self.register_modules(inner_unet=self.inner_unet) + # if self.config.interp_conditioning: # Seems False for all cases # self.interp_layer1 = nn.Linear(self.temporal_dim // 4, self.temporal_dim) # self.interp_layer2 = nn.Linear(self.temporal_dim, self.temporal_dim) From 6b65f9fccfc1840fd4109b177a54b2e4f12c023d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 23 Sep 2024 22:36:15 +0300 Subject: [PATCH 049/109] Generalize time projection for different model types in `MatryoshkaUNet2DConditionModel` --- examples/community/matryoshka.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 4520f6c3aba0..e2877d46cc0f 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -2377,7 +2377,10 @@ def _set_time_proj( elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + if self.model_type == "unet": + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + elif self.model_type == "nested_unet": + self.time_proj = Timesteps(block_out_channels[0] * 4, flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( From 62db4b0d397acbfbcd5a59390b0e3cbcb3a30edc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 23 Sep 2024 22:37:03 +0300 Subject: [PATCH 050/109] Fix `cond_emb` usage --- examples/community/matryoshka.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index e2877d46cc0f..acbb4dfbb018 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1850,8 +1850,8 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): micro = added_cond_kwargs.get("micro_conditioning_scale", None) if micro is not None: - temb = self.add_time_proj(torch.tensor([micro], device=cond_emb.device, dtype=cond_emb.dtype)) - temb_micro_conditioning = self.add_timestep_embedder(temb.to(cond_emb.dtype)) + temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) + temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) if self.cond_emb is not None: cond_emb_micro = cond_emb + temb_micro_conditioning return cond_emb_micro, conditioning_mask, cond_emb From a57b5fcdecdbed3519eb7214f817a12ebb480e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 24 Sep 2024 20:13:56 +0300 Subject: [PATCH 051/109] Up --- examples/community/matryoshka.py | 106 +++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 28 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index acbb4dfbb018..9007e163c374 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -7,6 +7,7 @@ import numpy as np import torch +import torch.nn.functional as F import torch.utils.checkpoint from packaging import version from torch import nn @@ -504,6 +505,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps).to(device) + def get_schedule_shifted(self, gammas, scale_factor=None): + if (scale_factor is not None) and (scale_factor > 1): # rescale noise schecule + snr = gammas / (1 - gammas) + scaled_snr = snr / scale_factor + gammas = 1 / (1 + 1 / scaled_snr) + return gammas + def step( self, model_output: torch.Tensor, @@ -513,6 +521,7 @@ def step( use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.Tensor] = None, + scales: Optional[list] = None, return_dict: bool = True, ) -> Union[MatryoshkaDDIMSchedulerOutput, Tuple]: """ @@ -573,6 +582,16 @@ def step( alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + if self.config.timestep_spacing == "matryoshka_style" and len(sample) == 2: + alpha_prod_t = [self.get_schedule_shifted(alpha_prod_t, s) for s in scales] + alpha_prod_t_prev = [self.get_schedule_shifted(alpha_prod_t_prev, s) for s in scales] + if sample is not None and alpha_prod_t[0].size(-1) != 1: + alpha_prod_t = torch.tensor([F.interpolate(g, im.size(-1), mode="nearest") + for g, im in zip(alpha_prod_t, sample)]) + alpha_prod_t_prev = torch.tensor([F.interpolate(g, im.size(-1), mode="nearest") + for g, im in zip(alpha_prod_t_prev, sample)]) + + beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called @@ -594,7 +613,10 @@ def step( # 4. Clip or threshold "predicted x_0" if self.config.thresholding: - pred_original_sample = self._threshold_sample(pred_original_sample) + if len(sample) == 2: + pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample] + else: + pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range @@ -631,9 +653,9 @@ def step( prev_sample = prev_sample + variance if not return_dict: - return (prev_sample,) + return (list(prev_sample),) - return MatryoshkaDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + return MatryoshkaDDIMSchedulerOutput(prev_sample=list(prev_sample), pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( @@ -1484,7 +1506,7 @@ def attention(self, q, k, v, num_heads, mask=None): (k * scale).reshape(bs * num_heads, ch, -1), ) # More stable with f16 than dividing afterwards if mask is not None: - mask = mask.view(mask.size(0), 1, 1, mask.size(1)).repeat(1, num_heads, 1, 1).flatten(0, 1) + mask = mask.view(mask.size(0), 1, 1, mask.size(-1)).repeat(1, num_heads, 1, 1).flatten(0, 1) weight = weight.masked_fill(mask == 0, float("-inf")) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * num_heads, ch, -1)) @@ -1828,7 +1850,7 @@ def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, super().__init__() if type == "unet": self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) - else: + elif type in ("inner_unet", "nested_unet"): self.cond_emb = None self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) @@ -1837,7 +1859,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) if self.cond_emb is not None: - if conditioning_mask is None or not masked_cross_attention: + if conditioning_mask is None: y = encoder_hidden_states.mean(dim=1) else: y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( @@ -1872,6 +1894,7 @@ class MatryoshkaUNet2DConditionOutput(BaseOutput): """ sample: torch.Tensor = None + sample_inner: torch.Tensor = None class MatryoshkaUNet2DConditionModel( @@ -2494,7 +2517,7 @@ def _set_add_embedding( else addition_time_embed_dim, cross_attention_dim, time_embed_dim, - self.model_type, + self.model_type if not self.config.nesting else "inner_" + self.model_type, ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much @@ -2883,7 +2906,7 @@ def forward( sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, - aug_emb: Optional[torch.Tensor] = None, + cond_emb: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -2950,6 +2973,11 @@ def forward( forward_upsample_size = False upsample_size = None + if self.config.nesting: + sample, sample_feat = sample + if isinstance(sample, list) and len(sample) == 1: + sample = sample[0] + for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. @@ -2974,7 +3002,7 @@ def forward( # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary @@ -3001,20 +3029,22 @@ def forward( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb, encoder_attention_mask, _ = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) + aug_emb, encoder_attention_mask, _ = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) - emb = emb + aug_emb if aug_emb is not None else emb + emb = emb + aug_emb + cond_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) # 2. pre-process sample = self.conv_in(sample) + if self.config.nesting: + sample = sample + sample_feat # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: @@ -3144,9 +3174,11 @@ def forward( upsample_size=upsample_size, ) + sample_inner = sample + # 6. post-process if self.conv_norm_out: - sample = self.conv_norm_out(sample) + sample = self.conv_norm_out(sample_inner) sample = self.conv_act(sample) sample = self.conv_out(sample) @@ -3157,6 +3189,9 @@ def forward( if not return_dict: return (sample,) + if self.config.nesting: + return MatryoshkaUNet2DConditionOutput(sample=sample, sample_inner=sample_inner) + return MatryoshkaUNet2DConditionOutput(sample=sample) @@ -3165,8 +3200,9 @@ class NestedUNet2DConditionOutput(BaseOutput): Output type for the [`NestedUNet2DConditionModel`] model. """ - sample_out_x_low: list = [] - x: torch.Tensor = None + sample: list = [] + sample_inner: torch.Tensor = None + scales: list = [] class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): @@ -3175,7 +3211,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): """ @register_to_config - def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, *args, **kwargs): + def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, skip_normalization, *args, **kwargs): super().__init__(*args, **kwargs) # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim @@ -3346,16 +3382,16 @@ def forward( else: emb = emb + class_emb - added_cond_kwargs = added_cond_kwargs or {} - added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention - added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale - if isinstance(self.inner_unet, MatryoshkaUNet2DConditionModel): + added_cond_kwargs = added_cond_kwargs or {} + added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention + added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb, cond_mask, cond_emb = self.get_aug_embed( + aug_emb, cond_mask, cond_emb = self.inner_unet.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) elif isinstance(self.inner_unet, NestedUNet2DConditionModel): @@ -3376,6 +3412,17 @@ def forward( if self.time_embed_act is not None: emb = self.time_embed_act(emb) + scales = self.config.nest_ratio + [1] + if isinstance(sample, torch.Tensor): + out = [sample] + for s in scales[1:]: + ratio = scales[0] // s + sample_low = F.avg_pool2d(sample, ratio) * ratio + torch.manual_seed(0) + sample_low = sample_low.normal_() + out += [sample_low] + sample = out + # 2. input layer (normalize the input) if self.config.nesting: sample, x_feat = sample @@ -3384,6 +3431,8 @@ def forward( x_t_low, sample = sample[1:], sample[0] if not self.config.skip_normalization: sample = sample / sample.std((1, 2, 3), keepdims=True) + if isinstance(sample, list) and len(sample) == 1: + sample = sample[0] sample = self.conv_in(sample) if self.config.nesting: sample = sample + x_feat @@ -3447,14 +3496,15 @@ def forward( x_inner = ( # TODO: What if x_inner is None? torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner ) # pad zeros for low-resolutions - x_low, x_inner = self.inner_unet( + inner_unet_output = self.inner_unet( (x_t_low, x_inner), timestep, - aug_emb=cond_emb, + cond_emb=cond_emb, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=cond_mask, from_nested=True, ) + x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner x_inner = self.out_adapter(x_inner) sample = sample + x_inner[:bh] if bh < bl else sample + x_inner @@ -3504,11 +3554,11 @@ def forward( out = [sample_out] + x_low else: out = [sample_out, x_low] - if self.inner_unet.config.nesting: - return NestedUNet2DConditionOutput(sample_out_x_low=out, x=sample) + if self.config.nesting: + return NestedUNet2DConditionOutput(sample=out, sample_inner=sample, scales=scales) if not return_dict: - return (out,) - return NestedUNet2DConditionOutput(sample_out_x_low=out) + return (out, scales) + return NestedUNet2DConditionOutput(sample=out, scales=scales) class MatryoshkaPipeline( From db809dc9570bd23d1c57a0494cc429a39c8a2688 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 24 Sep 2024 20:15:59 +0300 Subject: [PATCH 052/109] style --- examples/community/matryoshka.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 9007e163c374..7cfb57f74524 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -586,11 +586,12 @@ def step( alpha_prod_t = [self.get_schedule_shifted(alpha_prod_t, s) for s in scales] alpha_prod_t_prev = [self.get_schedule_shifted(alpha_prod_t_prev, s) for s in scales] if sample is not None and alpha_prod_t[0].size(-1) != 1: - alpha_prod_t = torch.tensor([F.interpolate(g, im.size(-1), mode="nearest") - for g, im in zip(alpha_prod_t, sample)]) - alpha_prod_t_prev = torch.tensor([F.interpolate(g, im.size(-1), mode="nearest") - for g, im in zip(alpha_prod_t_prev, sample)]) - + alpha_prod_t = torch.tensor( + [F.interpolate(g, im.size(-1), mode="nearest") for g, im in zip(alpha_prod_t, sample)] + ) + alpha_prod_t_prev = torch.tensor( + [F.interpolate(g, im.size(-1), mode="nearest") for g, im in zip(alpha_prod_t_prev, sample)] + ) beta_prod_t = 1 - alpha_prod_t From ff301b61ace8da41db215204fc6b8a11109de5a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 25 Sep 2024 18:23:54 +0300 Subject: [PATCH 053/109] `Up` --- examples/community/matryoshka.py | 137 ++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 50 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 7cfb57f74524..6af83a180694 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -203,8 +203,8 @@ class MatryoshkaDDIMSchedulerOutput(BaseOutput): `pred_original_sample` can be used to preview progress or for guidance. """ - prev_sample: torch.Tensor - pred_original_sample: Optional[torch.Tensor] = None + prev_sample: list[torch.Tensor] + pred_original_sample: Optional[list[torch.Tensor]] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar @@ -582,16 +582,16 @@ def step( alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - if self.config.timestep_spacing == "matryoshka_style" and len(sample) == 2: - alpha_prod_t = [self.get_schedule_shifted(alpha_prod_t, s) for s in scales] - alpha_prod_t_prev = [self.get_schedule_shifted(alpha_prod_t_prev, s) for s in scales] - if sample is not None and alpha_prod_t[0].size(-1) != 1: - alpha_prod_t = torch.tensor( - [F.interpolate(g, im.size(-1), mode="nearest") for g, im in zip(alpha_prod_t, sample)] - ) - alpha_prod_t_prev = torch.tensor( - [F.interpolate(g, im.size(-1), mode="nearest") for g, im in zip(alpha_prod_t_prev, sample)] - ) + if self.config.timestep_spacing == "matryoshka_style" and len(model_output) == 2: + alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t.item(), s) for s in scales]) + alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev.item(), s) for s in scales]) + # if sample is not None:# and alpha_prod_t[0].size(-1) != 1: + # alpha_prod_t = torch.stack( + # [F.interpolate(g * torch.ones_like(im), im.size(-1), mode="nearest") for g, im in zip(alpha_prod_t, sample)] + # ) + # alpha_prod_t_prev = torch.stack( + # [F.interpolate(g, im.size(-1), mode="nearest") for g, im in zip(alpha_prod_t_prev, sample)] + # ) beta_prod_t = 1 - alpha_prod_t @@ -604,8 +604,15 @@ def step( pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + if len(model_output) == 2: + pred_original_sample = [] + pred_epsilon = [] + for m_o, s, a_p_t, b_p_t in zip(model_output, sample, alpha_prod_t, beta_prod_t): + pred_original_sample.append((a_p_t**0.5) * s - (b_p_t**0.5) * m_o) + pred_epsilon.append((a_p_t**0.5) * m_o + (b_p_t**0.5) * s) + else: + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" @@ -614,14 +621,17 @@ def step( # 4. Clip or threshold "predicted x_0" if self.config.thresholding: - if len(sample) == 2: + if len(model_output) == 2: pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample] else: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: - pred_original_sample = pred_original_sample.clamp( - -self.config.clip_sample_range, self.config.clip_sample_range - ) + if len(model_output) == 2: + pred_original_sample = [p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) for p_o_s in pred_original_sample] + else: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) # 5. compute variance: "sigma_t(η)" -> see formula (16) # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) @@ -633,10 +643,20 @@ def step( pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + if len(model_output) == 2: + pred_sample_direction = [] + for p_e, a_p_t_p in zip(pred_epsilon, alpha_prod_t_prev): + pred_sample_direction.append((1 - a_p_t_p - std_dev_t**2) ** (0.5) * p_e) + else: + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + if len(model_output) == 2: + prev_sample = [] + for p_o_s, p_s_d, a_p_t_p in zip(pred_original_sample, pred_sample_direction, alpha_prod_t_prev): + prev_sample.append(a_p_t_p ** (0.5) * p_o_s + p_s_d) + else: + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if eta > 0: if variance_noise is not None and generator is not None: @@ -654,9 +674,9 @@ def step( prev_sample = prev_sample + variance if not return_dict: - return (list(prev_sample),) + return (prev_sample,) - return MatryoshkaDDIMSchedulerOutput(prev_sample=list(prev_sample), pred_original_sample=pred_original_sample) + return MatryoshkaDDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( @@ -1662,6 +1682,7 @@ def get_down_block( resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = None, + use_attention_ffn: bool = True, downsample_type: Optional[str] = None, dropout: float = 0.0, ): @@ -1713,6 +1734,7 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, attention_pre_only=attention_pre_only, + use_attention_ffn=use_attention_ffn, ) @@ -1790,6 +1812,7 @@ def get_up_block( resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, attention_head_dim: Optional[int] = None, + use_attention_ffn: bool = True, upsample_type: Optional[str] = None, dropout: float = 0.0, ) -> nn.Module: @@ -1843,6 +1866,7 @@ def get_up_block( resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, attention_pre_only=attention_pre_only, + use_attention_ffn=use_attention_ffn, ) @@ -1851,7 +1875,7 @@ def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, super().__init__() if type == "unet": self.cond_emb = nn.Linear(cross_attention_dim, time_embed_dim, bias=False) - elif type in ("inner_unet", "nested_unet"): + elif type == "nested_unet": self.cond_emb = None self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=False, downscale_freq_shift=0) self.add_timestep_embedder = TimestepEmbedding(addition_time_embed_dim, time_embed_dim) @@ -1859,7 +1883,7 @@ def __init__(self, addition_time_embed_dim, cross_attention_dim, time_embed_dim, def forward(self, emb, encoder_hidden_states, added_cond_kwargs): conditioning_mask = added_cond_kwargs.get("conditioning_mask", None) masked_cross_attention = added_cond_kwargs.get("masked_cross_attention", False) - if self.cond_emb is not None: + if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): if conditioning_mask is None: y = encoder_hidden_states.mean(dim=1) else: @@ -1875,7 +1899,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if micro is not None: temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) - if self.cond_emb is not None: + if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): cond_emb_micro = cond_emb + temb_micro_conditioning return cond_emb_micro, conditioning_mask, cond_emb else: @@ -2032,6 +2056,7 @@ def __init__( attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, + use_attention_ffn: bool = True, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, @@ -2217,6 +2242,7 @@ def __init__( resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, + use_attention_ffn=use_attention_ffn, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) @@ -2304,6 +2330,7 @@ def __init__( resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, + use_attention_ffn=use_attention_ffn, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) @@ -2518,7 +2545,7 @@ def _set_add_embedding( else addition_time_embed_dim, cross_attention_dim, time_embed_dim, - self.model_type if not self.config.nesting else "inner_" + self.model_type, + self.model_type# if not self.config.nesting else "inner_" + self.model_type, ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much @@ -3024,6 +3051,7 @@ def forward( added_cond_kwargs = added_cond_kwargs or {} added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + added_cond_kwargs["from_nested"] = from_nested if not from_nested: encoder_hidden_states = self.process_encoder_hidden_states( @@ -3202,6 +3230,7 @@ class NestedUNet2DConditionOutput(BaseOutput): """ sample: list = [] + sample_low: torch.Tensor = None sample_inner: torch.Tensor = None scales: list = [] @@ -3341,6 +3370,27 @@ def forward( forward_upsample_size = False upsample_size = None + if self.config.nesting: + sample, sample_feat = sample + if isinstance(sample, list) and len(sample) == 1: + sample = sample[0] + + scales = self.config.nest_ratio + [1] + if isinstance(sample, torch.Tensor): + out = [sample] + for s in scales[1:]: + ratio = scales[0] // s + sample_low = F.avg_pool2d(sample, ratio) * ratio + torch.manual_seed(0) + sample_low = sample_low.normal_() + out += [sample_low] + sample = out + + # 2. input layer (normalize the input) + bsz = [x.size(0) for x in sample] + bh, bl = bsz[0], bsz[1] + x_t_low, sample = sample[1:], sample[0] + for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. @@ -3392,7 +3442,11 @@ def forward( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb, cond_mask, cond_emb = self.inner_unet.get_aug_embed( + aug_emb_inner_unet, cond_mask_inner_unet, cond_emb_inner_unet = self.inner_unet.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + aug_emb, cond_mask, cond_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) elif isinstance(self.inner_unet, NestedUNet2DConditionModel): @@ -3408,35 +3462,18 @@ def forward( aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) - emb = emb + aug_emb if aug_emb is not None else emb + emb = emb + aug_emb + cond_emb_inner_unet if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) - scales = self.config.nest_ratio + [1] - if isinstance(sample, torch.Tensor): - out = [sample] - for s in scales[1:]: - ratio = scales[0] // s - sample_low = F.avg_pool2d(sample, ratio) * ratio - torch.manual_seed(0) - sample_low = sample_low.normal_() - out += [sample_low] - sample = out - - # 2. input layer (normalize the input) - if self.config.nesting: - sample, x_feat = sample - bsz = [x.size(0) for x in sample] - bh, bl = bsz[0], bsz[1] - x_t_low, sample = sample[1:], sample[0] if not self.config.skip_normalization: sample = sample / sample.std((1, 2, 3), keepdims=True) if isinstance(sample, list) and len(sample) == 1: sample = sample[0] sample = self.conv_in(sample) if self.config.nesting: - sample = sample + x_feat + sample = sample + sample_feat # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. @@ -3500,9 +3537,9 @@ def forward( inner_unet_output = self.inner_unet( (x_t_low, x_inner), timestep, - cond_emb=cond_emb, + cond_emb=cond_emb_inner_unet, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=cond_mask, + encoder_attention_mask=cond_mask_inner_unet, from_nested=True, ) x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner @@ -3558,8 +3595,8 @@ def forward( if self.config.nesting: return NestedUNet2DConditionOutput(sample=out, sample_inner=sample, scales=scales) if not return_dict: - return (out, scales) - return NestedUNet2DConditionOutput(sample=out, scales=scales) + return (out, sample_low, scales) + return NestedUNet2DConditionOutput(sample=out, sample_low=sample_low, scales=scales) class MatryoshkaPipeline( From 77732bbcbc3746979414c624f2abe4a95bfde5bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 25 Sep 2024 18:24:30 +0300 Subject: [PATCH 054/109] style --- examples/community/matryoshka.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 6af83a180694..23afbde81ac5 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -627,7 +627,10 @@ def step( pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: if len(model_output) == 2: - pred_original_sample = [p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) for p_o_s in pred_original_sample] + pred_original_sample = [ + p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) + for p_o_s in pred_original_sample + ] else: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range @@ -2545,7 +2548,7 @@ def _set_add_embedding( else addition_time_embed_dim, cross_attention_dim, time_embed_dim, - self.model_type# if not self.config.nesting else "inner_" + self.model_type, + self.model_type, # if not self.config.nesting else "inner_" + self.model_type, ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much From 154c1be7207c691c93b1f52b6daf5350994a2ecc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 26 Sep 2024 10:43:14 +0300 Subject: [PATCH 055/109] Refactor `NestedUNet2DConditionModel` to handle `sample_low` conditionally --- examples/community/matryoshka.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 23afbde81ac5..c9efda09b57e 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3379,6 +3379,7 @@ def forward( sample = sample[0] scales = self.config.nest_ratio + [1] + is_sample_low = False if isinstance(sample, torch.Tensor): out = [sample] for s in scales[1:]: @@ -3387,6 +3388,7 @@ def forward( torch.manual_seed(0) sample_low = sample_low.normal_() out += [sample_low] + is_sample_low = True sample = out # 2. input layer (normalize the input) @@ -3599,7 +3601,10 @@ def forward( return NestedUNet2DConditionOutput(sample=out, sample_inner=sample, scales=scales) if not return_dict: return (out, sample_low, scales) - return NestedUNet2DConditionOutput(sample=out, sample_low=sample_low, scales=scales) + if is_sample_low: + return NestedUNet2DConditionOutput(sample=out, sample_low=sample_low, scales=scales) + else: + return NestedUNet2DConditionOutput(sample=out, sample_low=None, scales=scales) class MatryoshkaPipeline( From b363cc179daa95fca58a1b303252482c128618ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 26 Sep 2024 11:52:28 +0300 Subject: [PATCH 056/109] Simplify --- examples/community/matryoshka.py | 36 ++++++-------------------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index c9efda09b57e..0d41a372ff48 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -583,15 +583,8 @@ def step( alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod if self.config.timestep_spacing == "matryoshka_style" and len(model_output) == 2: - alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t.item(), s) for s in scales]) - alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev.item(), s) for s in scales]) - # if sample is not None:# and alpha_prod_t[0].size(-1) != 1: - # alpha_prod_t = torch.stack( - # [F.interpolate(g * torch.ones_like(im), im.size(-1), mode="nearest") for g, im in zip(alpha_prod_t, sample)] - # ) - # alpha_prod_t_prev = torch.stack( - # [F.interpolate(g, im.size(-1), mode="nearest") for g, im in zip(alpha_prod_t_prev, sample)] - # ) + alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in scales]) + alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in scales]) beta_prod_t = 1 - alpha_prod_t @@ -3232,10 +3225,8 @@ class NestedUNet2DConditionOutput(BaseOutput): Output type for the [`NestedUNet2DConditionModel`] model. """ - sample: list = [] - sample_low: torch.Tensor = None + sample: list = None sample_inner: torch.Tensor = None - scales: list = [] class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): @@ -3378,19 +3369,6 @@ def forward( if isinstance(sample, list) and len(sample) == 1: sample = sample[0] - scales = self.config.nest_ratio + [1] - is_sample_low = False - if isinstance(sample, torch.Tensor): - out = [sample] - for s in scales[1:]: - ratio = scales[0] // s - sample_low = F.avg_pool2d(sample, ratio) * ratio - torch.manual_seed(0) - sample_low = sample_low.normal_() - out += [sample_low] - is_sample_low = True - sample = out - # 2. input layer (normalize the input) bsz = [x.size(0) for x in sample] bh, bl = bsz[0], bsz[1] @@ -3598,13 +3576,11 @@ def forward( else: out = [sample_out, x_low] if self.config.nesting: - return NestedUNet2DConditionOutput(sample=out, sample_inner=sample, scales=scales) + return NestedUNet2DConditionOutput(sample=out, sample_inner=sample) if not return_dict: - return (out, sample_low, scales) - if is_sample_low: - return NestedUNet2DConditionOutput(sample=out, sample_low=sample_low, scales=scales) + return (out, ) else: - return NestedUNet2DConditionOutput(sample=out, sample_low=None, scales=scales) + return NestedUNet2DConditionOutput(sample=out) class MatryoshkaPipeline( From 028a685d2d5ee6bd67e5b472519b2366ea3a49fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 26 Sep 2024 12:11:48 +0300 Subject: [PATCH 057/109] Refactor `MatryoshkaDDIMScheduler` to use `alpha_prod` instead of `gammas` in `get_schedule_shifted` --- examples/community/matryoshka.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0d41a372ff48..1620f3e7d981 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -505,12 +505,12 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps).to(device) - def get_schedule_shifted(self, gammas, scale_factor=None): - if (scale_factor is not None) and (scale_factor > 1): # rescale noise schecule - snr = gammas / (1 - gammas) + def get_schedule_shifted(self, alpha_prod, scale_factor=None): + if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule + snr = alpha_prod / (1 - alpha_prod) scaled_snr = snr / scale_factor - gammas = 1 / (1 + 1 / scaled_snr) - return gammas + alpha_prod = 1 / (1 + 1 / scaled_snr) + return alpha_prod def step( self, From a5b3c37c008de214a525b453393a34cc84f6e53e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 26 Sep 2024 12:13:04 +0300 Subject: [PATCH 058/109] Refactor `MatryoshkaDDIMScheduler` to remove unused import and simplify return statement in `NestedUNet2DConditionModel` --- examples/community/matryoshka.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 1620f3e7d981..dee5197796fc 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -7,7 +7,6 @@ import numpy as np import torch -import torch.nn.functional as F import torch.utils.checkpoint from packaging import version from torch import nn @@ -3578,7 +3577,7 @@ def forward( if self.config.nesting: return NestedUNet2DConditionOutput(sample=out, sample_inner=sample) if not return_dict: - return (out, ) + return (out,) else: return NestedUNet2DConditionOutput(sample=out) From 30c6881a143716b3a2dbe003d18d05706b9a921a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 26 Sep 2024 19:24:48 +0300 Subject: [PATCH 059/109] Refactor `NestedUNet2DConditionModel` to handle `inner_config` conditionally --- examples/community/matryoshka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index dee5197796fc..0dcac4cb6640 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3238,7 +3238,7 @@ def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, skip super().__init__(*args, **kwargs) # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim - if getattr(self.config.inner_config, "inner_config", None) is None: + if "inner_config" not in self.config.inner_config: self.inner_unet = MatryoshkaUNet2DConditionModel(**self.config.inner_config) else: self.inner_unet = NestedUNet2DConditionModel(**self.config.inner_config) From e34fb48b6f7d603d564510ebd3897eb9266367b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 26 Sep 2024 19:58:59 +0300 Subject: [PATCH 060/109] Refactor `_set_time_proj` to handle with `micro_conditioning_scale` conditionally for nesting_levels 1 and 2 --- examples/community/matryoshka.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0dcac4cb6640..ac87b3702817 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -2425,8 +2425,10 @@ def _set_time_proj( if self.model_type == "unet": self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - elif self.model_type == "nested_unet": + elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 256: self.time_proj = Timesteps(block_out_channels[0] * 4, flip_sin_to_cos, freq_shift) + elif self.model_type == "nested_unet" and self.config.micro_conditioning_scale == 1024: + self.time_proj = Timesteps(block_out_channels[0] * 4 * 2, flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( @@ -3432,13 +3434,7 @@ def forward( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) elif isinstance(self.inner_unet, NestedUNet2DConditionModel): - encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) - - aug_emb, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) + # TODO: Implement for nesting_level=2 if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb From dd88c37452023480cab0ea6e12f4739dc581a5dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 27 Sep 2024 18:01:11 +0300 Subject: [PATCH 061/109] Generalize for `nesting_level=2` --- examples/community/matryoshka.py | 52 +++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index ac87b3702817..2fb0ab1d25fd 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -581,7 +581,7 @@ def step( alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod - if self.config.timestep_spacing == "matryoshka_style" and len(model_output) == 2: + if self.config.timestep_spacing == "matryoshka_style" and len(model_output) > 1: alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in scales]) alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in scales]) @@ -596,7 +596,7 @@ def step( pred_original_sample = model_output pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": - if len(model_output) == 2: + if len(model_output) > 1: pred_original_sample = [] pred_epsilon = [] for m_o, s, a_p_t, b_p_t in zip(model_output, sample, alpha_prod_t, beta_prod_t): @@ -613,12 +613,12 @@ def step( # 4. Clip or threshold "predicted x_0" if self.config.thresholding: - if len(model_output) == 2: + if len(model_output) > 1: pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample] else: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: - if len(model_output) == 2: + if len(model_output) > 1: pred_original_sample = [ p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) for p_o_s in pred_original_sample @@ -638,7 +638,7 @@ def step( pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if len(model_output) == 2: + if len(model_output) > 1: pred_sample_direction = [] for p_e, a_p_t_p in zip(pred_epsilon, alpha_prod_t_prev): pred_sample_direction.append((1 - a_p_t_p - std_dev_t**2) ** (0.5) * p_e) @@ -646,7 +646,7 @@ def step( pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - if len(model_output) == 2: + if len(model_output) > 1: prev_sample = [] for p_o_s, p_s_d, a_p_t_p in zip(pred_original_sample, pred_sample_direction, alpha_prod_t_prev): prev_sample.append(a_p_t_p ** (0.5) * p_o_s + p_s_d) @@ -3300,6 +3300,8 @@ def forward( sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, + cond_emb: Optional[torch.Tensor] = None, + from_nested: bool = False, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -3417,30 +3419,50 @@ def forward( else: emb = emb + class_emb - if isinstance(self.inner_unet, MatryoshkaUNet2DConditionModel): + if self.inner_unet.model_type == "unet": added_cond_kwargs = added_cond_kwargs or {} added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale - encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( + if not self.config.nesting: + encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + aug_emb, cond_mask, _ = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + else: + aug_emb, cond_mask_inner_unet, _ = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + elif self.inner_unet.model_type == "nested_unet": + added_cond_kwargs = added_cond_kwargs or {} + added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention + added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + + encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb_inner_unet, cond_mask_inner_unet, cond_emb_inner_unet = self.inner_unet.get_aug_embed( + aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.inner_unet.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb, cond_mask, cond_emb = self.get_aug_embed( + aug_emb, cond_mask, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - elif isinstance(self.inner_unet, NestedUNet2DConditionModel): - # TODO: Implement for nesting_level=2 if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) - emb = emb + aug_emb + cond_emb_inner_unet if aug_emb is not None else emb + emb = emb + aug_emb + cond_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) @@ -3509,13 +3531,13 @@ def forward( # 4. run inner unet x_inner = self.in_adapter(sample) if self.in_adapter is not None else None - x_inner = ( # TODO: What if x_inner is None? + x_inner = ( torch.cat([x_inner, x_inner.new_zeros(bl - bh, *x_inner.size()[1:])], 0) if bh < bl else x_inner ) # pad zeros for low-resolutions inner_unet_output = self.inner_unet( (x_t_low, x_inner), timestep, - cond_emb=cond_emb_inner_unet, + cond_emb=cond_emb, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=cond_mask_inner_unet, from_nested=True, From a7c7c9ab9c264b109f984e8f13f08c011bfe0414 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 28 Sep 2024 15:57:40 +0300 Subject: [PATCH 062/109] Refactor `MatryoshkaUNet2DConditionModel` and `NestedUNet2DConditionModel` constructors --- examples/community/matryoshka.py | 113 +++++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 19 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 2fb0ab1d25fd..85fbb8c356ad 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -2081,7 +2081,6 @@ def __init__( temporal_spatial_ds: bool = False, skip_cond_emb: bool = False, nesting: Optional[int] = False, - inner_config: Optional[Dict] = None, ): super().__init__() @@ -2350,7 +2349,7 @@ def __init__( self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) - self.register_to_config(is_temporal=[]) + self.is_temporal = [] def _check_config( self, @@ -3236,8 +3235,93 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel): """ @register_to_config - def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, skip_normalization, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + in_channels=3, + out_channels=3, + block_out_channels=(64, 128, 256), + cross_attention_dim=2048, + resnet_time_scale_shift="scale_shift", + down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"), + up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"), + mid_block_type=None, + nesting=False, + flip_sin_to_cos=False, + transformer_layers_per_block=[0, 0, 0], + layers_per_block=[2, 2, 1], + masked_cross_attention=True, + micro_conditioning_scale=256, + addition_embed_type="matryoshka", + skip_normalization=True, + time_embedding_dim=1024, + skip_inner_unet_input=False, + temporal_mode=False, + temporal_spatial_ds=False, + initialize_inner_with_pretrained=None, + use_attention_ffn=False, + inner_config={}, + act_fn="silu", + addition_embed_type_num_heads=64, + addition_time_embed_dim=None, + attention_head_dim=8, + attention_pre_only=False, + attention_type="default", + center_input_sample=False, + class_embed_type=None, + class_embeddings_concat=False, + conv_in_kernel=3, + conv_out_kernel=3, + cross_attention_norm=None, + downsample_padding=1, + dropout=0.0, + dual_cross_attention=False, + encoder_hid_dim=None, + encoder_hid_dim_type=None, + freq_shift=0, + mid_block_only_cross_attention=None, + mid_block_scale_factor=1, + norm_eps=1e-05, + norm_num_groups=32, + norm_type="layer_norm", + num_attention_heads=None, + num_class_embeds=None, + only_cross_attention=False, + projection_class_embeddings_input_dim=None, + resnet_out_scale_factor=1.0, + resnet_skip_time_act=False, + reverse_transformer_layers_per_block=None, + sample_size=None, + skip_cond_emb=False, + time_cond_proj_dim=None, + time_embedding_act_fn=None, + time_embedding_type="positional", + timestep_post_act=None, + upcast_attention=False, + use_linear_projection=False, + is_temporal=None, + nest_ratio=None, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + cross_attention_dim=cross_attention_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + down_block_types=down_block_types, + up_block_types=up_block_types, + mid_block_type=mid_block_type, + nesting=nesting, + flip_sin_to_cos=flip_sin_to_cos, + transformer_layers_per_block=transformer_layers_per_block, + layers_per_block=layers_per_block, + masked_cross_attention=masked_cross_attention, + micro_conditioning_scale=micro_conditioning_scale, + addition_embed_type=addition_embed_type, + time_embedding_dim=time_embedding_dim, + temporal_mode=temporal_mode, + temporal_spatial_ds=temporal_spatial_ds, + use_attention_ffn=use_attention_ffn, + ) # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim if "inner_config" not in self.config.inner_config: @@ -3261,26 +3345,17 @@ def __init__(self, skip_inner_unet_input, initialize_inner_with_pretrained, skip padding=1, ) - self.register_to_config(is_temporal=[self.config.temporal_mode and (not self.config.temporal_spatial_ds)]) - if hasattr(self.inner_unet.config, "is_temporal"): - self.register_to_config(is_temporal=self.config.is_temporal + self.inner_unet.config.is_temporal) + self.is_temporal = [self.config.temporal_mode and (not self.config.temporal_spatial_ds)] + if hasattr(self.inner_unet, "is_temporal"): + self.is_temporal = self.is_temporal + self.inner_unet.is_temporal nest_ratio = int(2 ** (len(self.config.block_out_channels) - 1)) - if self.config.is_temporal[0]: + if self.is_temporal[0]: nest_ratio = int(np.sqrt(nest_ratio)) if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": - self.register_to_config( - nest_ratio=[nest_ratio * self.inner_unet.config.nest_ratio[0]] + self.inner_unet.config.nest_ratio - ) + self.nest_ratio=[nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio else: - self.register_to_config(nest_ratio=[nest_ratio]) - - if self.config.initialize_inner_with_pretrained is not None: - try: - self.inner_unet.from_pretrained(self.config.initialize_inner_with_pretrained) - except Exception as e: - print("<-- load pretrained checkpoint error -->") - print(f"{e}") + self.nest_ratio=[nest_ratio] # self.register_modules(inner_unet=self.inner_unet) From 0144e140bac00e8499c24cb656daa77c17c8720d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 28 Sep 2024 15:58:40 +0300 Subject: [PATCH 063/109] Cleansing --- examples/community/matryoshka.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 85fbb8c356ad..a02880e92ca8 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3367,9 +3367,6 @@ def __init__( def model_type(self): return "nested_unet" - def forward_conditioning(self, *args, **kwargs): - return self.inner_unet.forward_conditioning(*args, **kwargs) - def forward( self, sample: torch.Tensor, From 5e2e939573fbd73f43ec642cbf024326b0e83dda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 28 Sep 2024 17:09:25 +0300 Subject: [PATCH 064/109] Clean up the `NestedUNet2DConditionModel` constructor --- examples/community/matryoshka.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index a02880e92ca8..2f9d1e0ecced 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3259,7 +3259,6 @@ def __init__( temporal_spatial_ds=False, initialize_inner_with_pretrained=None, use_attention_ffn=False, - inner_config={}, act_fn="silu", addition_embed_type_num_heads=64, addition_time_embed_dim=None, @@ -3300,7 +3299,8 @@ def __init__( use_linear_projection=False, is_temporal=None, nest_ratio=None, - ): + inner_config={}, + ): super().__init__( in_channels=in_channels, out_channels=out_channels, @@ -3353,16 +3353,12 @@ def __init__( if self.is_temporal[0]: nest_ratio = int(np.sqrt(nest_ratio)) if self.inner_unet.config.nesting and self.inner_unet.model_type == "nested_unet": - self.nest_ratio=[nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio + self.nest_ratio = [nest_ratio * self.inner_unet.nest_ratio[0]] + self.inner_unet.nest_ratio else: - self.nest_ratio=[nest_ratio] + self.nest_ratio = [nest_ratio] # self.register_modules(inner_unet=self.inner_unet) - # if self.config.interp_conditioning: # Seems False for all cases - # self.interp_layer1 = nn.Linear(self.temporal_dim // 4, self.temporal_dim) - # self.interp_layer2 = nn.Linear(self.temporal_dim, self.temporal_dim) - @property def model_type(self): return "nested_unet" From 5fbba0e22e23922cd8468c27af064327fd96406a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 28 Sep 2024 17:51:57 +0300 Subject: [PATCH 065/109] No need for VAE --- examples/community/matryoshka.py | 35 +++++++++----------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 2f9d1e0ecced..7643e8ef9092 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3707,7 +3707,7 @@ class MatryoshkaPipeline( A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ - model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] @@ -3717,13 +3717,13 @@ def __init__( text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, unet: MatryoshkaUNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, + scheduler: MatryoshkaDDIMScheduler, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): - super().__init__(..., in_channels=3, out_channels=3) + super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( @@ -3798,8 +3798,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor(do_resize=False) self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( @@ -4099,17 +4098,6 @@ def run_safety_checker(self, image, device, dtype): ) return image, has_nsfw_concept - def decode_latents(self, latents): - deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) - - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -4200,8 +4188,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype shape = ( batch_size, num_channels_latents, - int(height) // self.vae_scale_factor, - int(width) // self.vae_scale_factor, + int(height), + int(width), ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -4321,9 +4309,9 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + height (`int`, *optional*, defaults to `self.unet.config.sample_size`): The height in pixels of the generated image. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + width (`int`, *optional*, defaults to `self.unet.config.sample_size`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -4421,8 +4409,8 @@ def __call__( callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size # to deal with lora scaling and other possible forward hooks # 1. Check inputs. Raise error if not correct @@ -4579,9 +4567,6 @@ def __call__( xm.mark_step() if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents From 293457046b2df068b12c957c9df53dccfc4dfd34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 28 Sep 2024 17:52:10 +0300 Subject: [PATCH 066/109] Up --- examples/community/matryoshka.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 7643e8ef9092..8b45a74cf6ee 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -188,7 +188,6 @@ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: @dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->MatryoshkaDDIM class MatryoshkaDDIMSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. @@ -202,8 +201,8 @@ class MatryoshkaDDIMSchedulerOutput(BaseOutput): `pred_original_sample` can be used to preview progress or for guidance. """ - prev_sample: list[torch.Tensor] - pred_original_sample: Optional[list[torch.Tensor]] = None + prev_sample: Union[torch.Tensor, List[torch.Tensor]] + pred_original_sample: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar @@ -251,6 +250,7 @@ def alpha_bar_fn(t): return torch.tensor(betas, dtype=torch.float32) +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) @@ -337,7 +337,6 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin): [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ - _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config From f2f2f9c037de44413ca0441e8622fb2a096d3457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 28 Sep 2024 17:53:47 +0300 Subject: [PATCH 067/109] style --- examples/community/matryoshka.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 8b45a74cf6ee..9c16de5f6d44 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -55,7 +55,6 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import ( USE_PEFT_BACKEND, @@ -4566,7 +4565,7 @@ def __call__( xm.mark_step() if not output_type == "latent": - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image, has_nsfw_concept = self.run_safety_checker(latents, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None From 6ed3d6336d799639ad48f16ea02a9e9c8011bf23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 29 Sep 2024 16:50:04 +0300 Subject: [PATCH 068/109] Refactor `MatryoshkaUNet2DConditionModel` and `MatryoshkaPipeline` - Update `MatryoshkaUNet2DConditionModel` to include `cond_emb` in `get_aug_embed` method. - Remove the last timestep from `timesteps` in `MatryoshkaPipeline`. - Adjust the timestep index in `MatryoshkaPipeline` to `t - 1` instead of `t`. --- examples/community/matryoshka.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 9c16de5f6d44..e1994273ae15 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3052,7 +3052,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb, encoder_attention_mask, _ = self.get_aug_embed( + aug_emb, encoder_attention_mask, cond_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) if self.config.addition_embed_type == "image_hint": @@ -4477,6 +4477,7 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) + timesteps = timesteps[:-1] # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -4524,7 +4525,7 @@ def __call__( # predict the noise residual noise_pred = self.unet( latent_model_input, - t, + t - 1, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, From 1df22a64a870500cb3eea85b1988e50ebb4b9f9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 29 Sep 2024 19:03:12 +0300 Subject: [PATCH 069/109] Remove safety checker --- examples/community/matryoshka.py | 70 ++++++++++---------------------- 1 file changed, 21 insertions(+), 49 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index e1994273ae15..0747b2efcdcc 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -68,6 +68,7 @@ unscale_lora_layers, ) from diffusers.utils.torch_utils import apply_freeu, randn_tensor +from PIL import Image if is_torch_xla_available(): @@ -2018,8 +2019,8 @@ class conditioning with `class_embed_type` equal to `None`. def __init__( self, sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, + in_channels: int = 3, + out_channels: int = 3, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, @@ -3666,6 +3667,19 @@ def forward( return NestedUNet2DConditionOutput(sample=out) +@dataclass +class MatryoshkaPipelineOutput(BaseOutput): + """ + Output class for Matryoshka pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[Image.Image], List[List[Image.Image]], np.ndarray, List[np.ndarray]] + class MatryoshkaPipeline( DiffusionPipeline, StableDiffusionMixin, @@ -3716,10 +3730,8 @@ def __init__( tokenizer: T5TokenizerFast, unet: MatryoshkaUNet2DConditionModel, scheduler: MatryoshkaDDIMScheduler, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, + feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, - requires_safety_checker: bool = True, ): super().__init__() @@ -3750,21 +3762,6 @@ def __init__( new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version @@ -3792,12 +3789,10 @@ def __init__( tokenizer=tokenizer, unet=unet, scheduler=scheduler, - safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.image_processor = VaeImageProcessor(do_resize=False) - self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( self, @@ -4082,20 +4077,6 @@ def prepare_ip_adapter_image_embeds( return ip_adapter_image_embeds - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -4565,23 +4546,14 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - if not output_type == "latent": - image, has_nsfw_concept = self.run_safety_checker(latents, device, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + image = latents - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: - return (image, has_nsfw_concept) + return (image,) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return MatryoshkaPipelineOutput(images=image) From a4be940d7e41c1fe19b86c0f0ceb64d7e49e0630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 29 Sep 2024 19:03:57 +0300 Subject: [PATCH 070/109] style --- examples/community/matryoshka.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0747b2efcdcc..874c2e331533 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -9,6 +9,7 @@ import torch import torch.utils.checkpoint from packaging import version +from PIL import Image from torch import nn from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast @@ -53,8 +54,6 @@ from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D from diffusers.models.upsampling import Upsample2D from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import ( USE_PEFT_BACKEND, @@ -68,7 +67,6 @@ unscale_lora_layers, ) from diffusers.utils.torch_utils import apply_freeu, randn_tensor -from PIL import Image if is_torch_xla_available(): @@ -3680,6 +3678,7 @@ class MatryoshkaPipelineOutput(BaseOutput): images: Union[List[Image.Image], List[List[Image.Image]], np.ndarray, List[np.ndarray]] + class MatryoshkaPipeline( DiffusionPipeline, StableDiffusionMixin, @@ -3762,7 +3761,6 @@ def __init__( new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") From ba39b8d2b0543c19d5a0c076c11cbc7b56ac5d42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 29 Sep 2024 20:54:43 +0300 Subject: [PATCH 071/109] Refactor 'NestedUNet2DConditionModel' to add 'sample_size' parameter at initialization --- examples/community/matryoshka.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 874c2e331533..0a4d3a97e316 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3318,6 +3318,7 @@ def __init__( temporal_mode=temporal_mode, temporal_spatial_ds=temporal_spatial_ds, use_attention_ffn=use_attention_ffn, + sample_size=sample_size, ) # self.config.inner_config.conditioning_feature_dim = self.config.conditioning_feature_dim From 319a4d630915802a2fcba341ad339c378cc73555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 30 Sep 2024 14:38:53 +0300 Subject: [PATCH 072/109] Up --- examples/community/matryoshka.py | 60 +++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0a4d3a97e316..bd03afc4e4eb 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -7,6 +7,7 @@ import numpy as np import torch +import torch.nn.functional as F import torch.utils.checkpoint from packaging import version from PIL import Image @@ -392,6 +393,8 @@ def __init__( self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.scales = None + def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -517,7 +520,6 @@ def step( use_clipped_model_output: bool = False, generator=None, variance_noise: Optional[torch.Tensor] = None, - scales: Optional[list] = None, return_dict: bool = True, ) -> Union[MatryoshkaDDIMSchedulerOutput, Tuple]: """ @@ -579,8 +581,8 @@ def step( alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod if self.config.timestep_spacing == "matryoshka_style" and len(model_output) > 1: - alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in scales]) - alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in scales]) + alpha_prod_t = torch.tensor([self.get_schedule_shifted(alpha_prod_t, s) for s in self.scales]) + alpha_prod_t_prev = torch.tensor([self.get_schedule_shifted(alpha_prod_t_prev, s) for s in self.scales]) beta_prod_t = 1 - alpha_prod_t @@ -3051,9 +3053,14 @@ def forward( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb, encoder_attention_mask, cond_emb = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs - ) + aug_emb, encoder_attention_mask, cond_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + else: + aug_emb, encoder_attention_mask, _ = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) @@ -3295,7 +3302,6 @@ def __init__( upcast_attention=False, use_linear_projection=False, is_temporal=None, - nest_ratio=None, inner_config={}, ): super().__init__( @@ -3791,6 +3797,8 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) + if hasattr(unet, "nest_ratio"): + scheduler.scales = unet.nest_ratio + [1] self.image_processor = VaeImageProcessor(do_resize=False) def _encode_prompt( @@ -4162,7 +4170,9 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, batch_size, num_channels_latents, height, width, dtype, device, generator, scales, latents=None + ): shape = ( batch_size, num_channels_latents, @@ -4177,11 +4187,25 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if scales is not None: + out = [latents] + for s in scales[1:]: + ratio = scales[0] // s + sample_low = F.avg_pool2d(latents, ratio) * ratio + sample_low = sample_low.normal_(generator=generator) + out += [sample_low] + latents = out else: - latents = latents.to(device) + if scales is not None: + latents = [latent.to(device=device) for latent in latents] + else: + latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + if scales is not None: + latents = [latent * self.scheduler.init_noise_sigma for latent in latents] + else: + latents = latents * self.scheduler.init_noise_sigma return latents # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding @@ -4469,6 +4493,7 @@ def __call__( prompt_embeds.dtype, device, generator, + self.scheduler.scales, latents, ) @@ -4499,7 +4524,12 @@ def __call__( continue # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if self.do_classifier_free_guidance and isinstance(latents, list): + latent_model_input = [latent.repeat(2, 1, 1, 1) for latent in latents] + elif self.do_classifier_free_guidance: + latent_model_input = latents.repeat(2, 1, 1, 1) + else: + latent_model_input = latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual @@ -4514,7 +4544,10 @@ def __call__( )[0] # perform guidance - if self.do_classifier_free_guidance: + if isinstance(noise_pred, list) and self.do_classifier_free_guidance: + for i, (noise_pred_uncond, noise_pred_text) in enumerate(noise_pred): + noise_pred[i] = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) @@ -4547,6 +4580,9 @@ def __call__( image = latents + if self.scheduler.scales is not None: + image = image[0] + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models From 95a293c0c5e63eaccd07ddc247ee66f772a1cf84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 30 Sep 2024 19:05:09 +0300 Subject: [PATCH 073/109] Refactor 'MatryoshkaPipeline' to process multiple images for nesting_level > 0 --- examples/community/matryoshka.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index bd03afc4e4eb..0209e18ab83d 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -4580,10 +4580,12 @@ def __call__( image = latents - if self.scheduler.scales is not None: - image = image[0] - - image = self.image_processor.postprocess(image, output_type=output_type) + # if self.scheduler.scales is not None: + # image = image[0] + images = [] + for img in image: + images.append(self.image_processor.postprocess(img, output_type=output_type)) + # image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() @@ -4591,4 +4593,4 @@ def __call__( if not return_dict: return (image,) - return MatryoshkaPipelineOutput(images=image) + return MatryoshkaPipelineOutput(images=images) From b54d9ef720be612c58a18abffa448e4b402e7d28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 30 Sep 2024 19:29:04 +0300 Subject: [PATCH 074/109] revert the last --- examples/community/matryoshka.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0209e18ab83d..bd03afc4e4eb 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -4580,12 +4580,10 @@ def __call__( image = latents - # if self.scheduler.scales is not None: - # image = image[0] - images = [] - for img in image: - images.append(self.image_processor.postprocess(img, output_type=output_type)) - # image = self.image_processor.postprocess(image, output_type=output_type) + if self.scheduler.scales is not None: + image = image[0] + + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() @@ -4593,4 +4591,4 @@ def __call__( if not return_dict: return (image,) - return MatryoshkaPipelineOutput(images=images) + return MatryoshkaPipelineOutput(images=image) From c9f17bb7f38f1e31fc87bb8e1eaee929128897fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 30 Sep 2024 21:37:07 +0300 Subject: [PATCH 075/109] Refactor 'MatryoshkaDDIMScheduler' to handle multiple model outputs for variance calculation --- examples/community/matryoshka.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index bd03afc4e4eb..97dc3912b1f4 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -660,12 +660,24 @@ def step( ) if variance_noise is None: - variance_noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype - ) - variance = std_dev_t * variance_noise + if len(model_output) > 1: + variance_noise = [] + for m_o in model_output: + variance_noise.append( + randn_tensor( + m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype + ) + ) + else: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + if len(model_output) > 1: + prev_sample = [p_s + std_dev_t * v_n for v_n, p_s in zip(variance_noise, prev_sample)] + else: + variance = std_dev_t * variance_noise - prev_sample = prev_sample + variance + prev_sample = prev_sample + variance if not return_dict: return (prev_sample,) From 67a2917cafea469282c2373d43c0dbbd6cba3d31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Oct 2024 14:18:38 +0300 Subject: [PATCH 076/109] Refactor 'MatryoshkaPipeline' to remove unused 'model_type' property --- examples/community/matryoshka.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 97dc3912b1f4..900eb1647b14 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -4282,10 +4282,6 @@ def num_timesteps(self): def interrupt(self): return self._interrupt - @property - def model_type(self): - return "nested_unet" - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( From 60e9e773dce5c617e5d6f4dd62bec3299bac363d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Oct 2024 22:31:01 +0300 Subject: [PATCH 077/109] Fix masking --- examples/community/matryoshka.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 900eb1647b14..2507abff911c 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1534,7 +1534,7 @@ def attention(self, q, k, v, num_heads, mask=None): ) # More stable with f16 than dividing afterwards if mask is not None: mask = mask.view(mask.size(0), 1, 1, mask.size(-1)).repeat(1, num_heads, 1, 1).flatten(0, 1) - weight = weight.masked_fill(mask == 0, float("-inf")) + weight = weight.masked_fill(mask == -10_000, float("-inf")) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * num_heads, ch, -1)) return a.reshape(bs, -1, length) @@ -1893,7 +1893,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if conditioning_mask is None: y = encoder_hidden_states.mean(dim=1) else: - y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( + y = (conditioning_mask.unsqueeze(-1).squeeze(0) + encoder_hidden_states).sum(dim=1) / (conditioning_mask.squeeze(0)/10_000 + 1).sum( dim=1, keepdim=True ) cond_emb = self.cond_emb(y) @@ -3059,6 +3059,7 @@ def forward( added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale added_cond_kwargs["from_nested"] = from_nested + added_cond_kwargs["conditioning_mask"] = encoder_attention_mask if not from_nested: encoder_hidden_states = self.process_encoder_hidden_states( @@ -3507,6 +3508,7 @@ def forward( added_cond_kwargs = added_cond_kwargs or {} added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + added_cond_kwargs["conditioning_mask"] = encoder_attention_mask if not self.config.nesting: encoder_hidden_states = self.inner_unet.process_encoder_hidden_states( @@ -3516,7 +3518,7 @@ def forward( aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - + added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention aug_emb, cond_mask, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) @@ -3529,6 +3531,7 @@ def forward( added_cond_kwargs = added_cond_kwargs or {} added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale + added_cond_kwargs["conditioning_mask"] = encoder_attention_mask encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs @@ -4025,7 +4028,7 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) - return prompt_embeds, negative_prompt_embeds + return prompt_embeds, negative_prompt_embeds, attention_mask def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype @@ -4458,7 +4461,7 @@ def __call__( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt_embeds, negative_prompt_embeds, encoder_attention_mask = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -4548,6 +4551,7 @@ def __call__( timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, + encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] From 261f135fa8458493eb684bc9f3d0e1d6fa0586ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 2 Oct 2024 22:33:58 +0300 Subject: [PATCH 078/109] style --- examples/community/matryoshka.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 2507abff911c..1cb6d97b6d8c 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -664,9 +664,7 @@ def step( variance_noise = [] for m_o in model_output: variance_noise.append( - randn_tensor( - m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype - ) + randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype) ) else: variance_noise = randn_tensor( @@ -1893,9 +1891,9 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if conditioning_mask is None: y = encoder_hidden_states.mean(dim=1) else: - y = (conditioning_mask.unsqueeze(-1).squeeze(0) + encoder_hidden_states).sum(dim=1) / (conditioning_mask.squeeze(0)/10_000 + 1).sum( - dim=1, keepdim=True - ) + y = (conditioning_mask.unsqueeze(-1).squeeze(0) + encoder_hidden_states).sum(dim=1) / ( + conditioning_mask.squeeze(0) / 10_000 + 1 + ).sum(dim=1, keepdim=True) cond_emb = self.cond_emb(y) if not masked_cross_attention: From 25b56a87586f4194f8a1b34d82eb9473c37e9b90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 3 Oct 2024 21:02:21 +0300 Subject: [PATCH 079/109] Fix and improve mask handling --- examples/community/matryoshka.py | 44 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 1cb6d97b6d8c..4fa84350119e 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1532,7 +1532,7 @@ def attention(self, q, k, v, num_heads, mask=None): ) # More stable with f16 than dividing afterwards if mask is not None: mask = mask.view(mask.size(0), 1, 1, mask.size(-1)).repeat(1, num_heads, 1, 1).flatten(0, 1) - weight = weight.masked_fill(mask == -10_000, float("-inf")) + weight = weight.masked_fill(mask == 0, float("-inf")) weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * num_heads, ch, -1)) return a.reshape(bs, -1, length) @@ -1891,10 +1891,12 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if conditioning_mask is None: y = encoder_hidden_states.mean(dim=1) else: - y = (conditioning_mask.unsqueeze(-1).squeeze(0) + encoder_hidden_states).sum(dim=1) / ( - conditioning_mask.squeeze(0) / 10_000 + 1 + y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / ( + conditioning_mask ).sum(dim=1, keepdim=True) cond_emb = self.cond_emb(y) + else: + cond_emb = None if not masked_cross_attention: conditioning_mask = None @@ -1904,10 +1906,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): - cond_emb_micro = cond_emb + temb_micro_conditioning - return cond_emb_micro, conditioning_mask, cond_emb - else: - return temb_micro_conditioning, conditioning_mask, None + return temb_micro_conditioning, conditioning_mask, cond_emb return cond_emb, conditioning_mask, cond_emb @@ -3033,11 +3032,6 @@ def forward( attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -3072,6 +3066,11 @@ def forward( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) @@ -3932,16 +3931,16 @@ def encode_prompt( ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) + prompt_attention_mask = text_inputs.attention_mask.to(device) else: - attention_mask = None + prompt_attention_mask = None if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into @@ -4002,13 +4001,13 @@ def encode_prompt( ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) else: - attention_mask = None + negative_prompt_attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), - attention_mask=attention_mask, + attention_mask=negative_prompt_attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] @@ -4026,7 +4025,7 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) - return prompt_embeds, negative_prompt_embeds, attention_mask + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype @@ -4459,7 +4458,7 @@ def __call__( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds, encoder_attention_mask = self.encode_prompt( + prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask = self.encode_prompt( prompt, device, num_images_per_prompt, @@ -4476,6 +4475,7 @@ def __call__( # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( @@ -4549,7 +4549,7 @@ def __call__( timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, - encoder_attention_mask=encoder_attention_mask, + encoder_attention_mask=attention_masks, return_dict=False, )[0] From 38c5455de98bb8b59cb6246ade59453df9240eb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 3 Oct 2024 23:32:39 +0300 Subject: [PATCH 080/109] style --- examples/community/matryoshka.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 4fa84350119e..68d2e18328ab 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1891,9 +1891,9 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if conditioning_mask is None: y = encoder_hidden_states.mean(dim=1) else: - y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / ( - conditioning_mask - ).sum(dim=1, keepdim=True) + y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / (conditioning_mask).sum( + dim=1, keepdim=True + ) cond_emb = self.cond_emb(y) else: cond_emb = None @@ -4458,7 +4458,12 @@ def __call__( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) - prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask = self.encode_prompt( + ( + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) = self.encode_prompt( prompt, device, num_images_per_prompt, From 3ada184c69da4ece7fffcdd316f7a251568c7032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 4 Oct 2024 18:37:30 +0300 Subject: [PATCH 081/109] style --- examples/community/matryoshka.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 68d2e18328ab..92f9e46fff32 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1891,7 +1891,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if conditioning_mask is None: y = encoder_hidden_states.mean(dim=1) else: - y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / (conditioning_mask).sum( + y = (conditioning_mask.unsqueeze(-1) * encoder_hidden_states).sum(dim=1) / conditioning_mask.sum( dim=1, keepdim=True ) cond_emb = self.cond_emb(y) From 0efa0aec3d40c880352ca7dc5e946614c3697902 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 4 Oct 2024 18:37:46 +0300 Subject: [PATCH 082/109] Fix mask handling --- examples/community/matryoshka.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 92f9e46fff32..550065579808 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -4025,6 +4025,8 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) + if not do_classifier_free_guidance: + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, None return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): @@ -4481,6 +4483,9 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1) + else: + attention_masks = prompt_attention_mask if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( From 5737d95c4cb6883a3445fb72c22b74d521d10d3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 4 Oct 2024 19:43:27 +0300 Subject: [PATCH 083/109] Refactor attention mask handling in Matryoshka models --- examples/community/matryoshka.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 550065579808..3d64bc876ad7 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3481,11 +3481,6 @@ def forward( attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -3542,6 +3537,11 @@ def forward( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) @@ -4483,10 +4483,11 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) - prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1) else: attention_masks = prompt_attention_mask + prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, From b691f16e3780a6cb8c9f18681d760a7dca14c67f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 5 Oct 2024 00:15:06 +0300 Subject: [PATCH 084/109] Refactor attention mask handling in Matryoshka models --- examples/community/matryoshka.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 3d64bc876ad7..ac66f2bcd218 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1905,8 +1905,8 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): if micro is not None: temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype)) temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype)) - if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): - return temb_micro_conditioning, conditioning_mask, cond_emb + # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): + return temb_micro_conditioning, conditioning_mask, cond_emb return cond_emb, conditioning_mask, cond_emb @@ -3507,11 +3507,11 @@ def forward( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.get_aug_embed( + aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention - aug_emb, cond_mask, _ = self.get_aug_embed( + aug_emb, __, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) else: @@ -3623,7 +3623,7 @@ def forward( timestep, cond_emb=cond_emb, encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=cond_mask_inner_unet, + encoder_attention_mask=cond_mask, from_nested=True, ) x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner From ccdee3586a0ecf40535a8d317d341d3e4a15eb14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 5 Oct 2024 22:57:37 +0300 Subject: [PATCH 085/109] Fix mask handling for `nesting_level=2` --- examples/community/matryoshka.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index ac66f2bcd218..6b04043560ff 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3515,7 +3515,7 @@ def forward( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) else: - aug_emb, cond_mask_inner_unet, _ = self.get_aug_embed( + aug_emb, cond_mask, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) @@ -3529,11 +3529,11 @@ def forward( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.inner_unet.get_aug_embed( + aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) - aug_emb, cond_mask, _ = self.get_aug_embed( + aug_emb, __, _ = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) From 2efe7b0f7c15be4991d42e591d785ed8aeef47b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 5 Oct 2024 22:59:39 +0300 Subject: [PATCH 086/109] Attempt for scheduler usage generalization --- examples/community/matryoshka.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 6b04043560ff..e2937c13b7b6 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -4498,10 +4498,13 @@ def __call__( ) # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - timesteps = timesteps[:-1] + if isinstance(self.scheduler, MatryoshkaDDIMScheduler): + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + timesteps = timesteps[:-1] + else: + timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels @@ -4577,7 +4580,13 @@ def __call__( noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if self.scheduler.scales is not None and not isinstance(self.scheduler, MatryoshkaDDIMScheduler): + latents[0] = self.scheduler.step(noise_pred[0], t, latents[0], **extra_step_kwargs, return_dict=False)[0] + latents[1] = self.scheduler.inner_scheduler.step(noise_pred[1], t, latents[1], **extra_step_kwargs, return_dict=False)[0] + if len(latents) > 2: + latents[2] = self.scheduler.inner_scheduler.inner_scheduler.step(noise_pred[2], t, latents[2], **extra_step_kwargs, return_dict=False)[0] + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} From 5c90be9db6dbd26a5c285f37b0d8585b2a9422dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 5 Oct 2024 23:00:17 +0300 Subject: [PATCH 087/109] Equalize tokenizer usage fully --- examples/community/matryoshka.py | 99 +++++++++++++++++--------------- 1 file changed, 54 insertions(+), 45 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index e2937c13b7b6..b9ca3029423c 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3911,9 +3911,6 @@ def encode_prompt( text_inputs = self.tokenizer( prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids @@ -3935,23 +3932,6 @@ def encode_prompt( else: prompt_attention_mask = None - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: @@ -3959,13 +3939,6 @@ def encode_prompt( else: prompt_embeds_dtype = prompt_embeds.dtype - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] @@ -3991,34 +3964,70 @@ def encode_prompt( if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, return_tensors="pt", ) + uncond_input_ids = uncond_input.input_ids if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: negative_prompt_attention_mask = uncond_input.attention_mask.to(device) else: negative_prompt_attention_mask = None - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=negative_prompt_attention_mask, + if not do_classifier_free_guidance: + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + else: + max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0])) + if len(text_input_ids[0]) < max_len: + text_input_ids = torch.cat( + [text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)], dim=1 + ) + prompt_attention_mask = torch.cat( + [ + prompt_attention_mask, + torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long), + ], + dim=1, + ) + elif len(uncond_input_ids[0]) < max_len: + uncond_input_ids = torch.cat( + [uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)], + dim=1, + ) + negative_prompt_attention_mask = torch.cat( + [ + negative_prompt_attention_mask, + torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long), + ], + dim=1, + ) + cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0) + cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + prompt_embeds = self.text_encoder( + cfg_input_ids.to(device), + attention_mask=cfg_attention_mask, ) - negative_prompt_embeds = negative_prompt_embeds[0] + prompt_embeds = prompt_embeds[0] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] + seq_len = prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: @@ -4026,8 +4035,8 @@ def encode_prompt( unscale_lora_layers(self.text_encoder, lora_scale) if not do_classifier_free_guidance: - return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, None - return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask + return prompt_embeds, None, prompt_attention_mask, None + return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype @@ -4481,7 +4490,7 @@ def __call__( # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)]) attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) else: attention_masks = prompt_attention_mask From 31c73fa6e1d92ef4a89c3189deff57c67059e13d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 5 Oct 2024 23:01:22 +0300 Subject: [PATCH 088/109] style --- examples/community/matryoshka.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index b9ca3029423c..78565476cc96 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3996,7 +3996,8 @@ def encode_prompt( max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0])) if len(text_input_ids[0]) < max_len: text_input_ids = torch.cat( - [text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)], dim=1 + [text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)], + dim=1, ) prompt_attention_mask = torch.cat( [ @@ -4025,8 +4026,6 @@ def encode_prompt( ) prompt_embeds = prompt_embeds[0] - seq_len = prompt_embeds.shape[1] - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) if self.text_encoder is not None: @@ -4590,10 +4589,16 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 if self.scheduler.scales is not None and not isinstance(self.scheduler, MatryoshkaDDIMScheduler): - latents[0] = self.scheduler.step(noise_pred[0], t, latents[0], **extra_step_kwargs, return_dict=False)[0] - latents[1] = self.scheduler.inner_scheduler.step(noise_pred[1], t, latents[1], **extra_step_kwargs, return_dict=False)[0] + latents[0] = self.scheduler.step( + noise_pred[0], t, latents[0], **extra_step_kwargs, return_dict=False + )[0] + latents[1] = self.scheduler.inner_scheduler.step( + noise_pred[1], t, latents[1], **extra_step_kwargs, return_dict=False + )[0] if len(latents) > 2: - latents[2] = self.scheduler.inner_scheduler.inner_scheduler.step(noise_pred[2], t, latents[2], **extra_step_kwargs, return_dict=False)[0] + latents[2] = self.scheduler.inner_scheduler.inner_scheduler.step( + noise_pred[2], t, latents[2], **extra_step_kwargs, return_dict=False + )[0] else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] From a21e1104c8f67b79f3d7e5f03bbd933a858535ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 6 Oct 2024 14:14:23 +0300 Subject: [PATCH 089/109] Up --- examples/community/matryoshka.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 78565476cc96..c4137fc84eb5 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3750,6 +3750,7 @@ def __init__( scheduler: MatryoshkaDDIMScheduler, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, + trust_remote_code: bool = False, ): super().__init__() @@ -4002,7 +4003,7 @@ def encode_prompt( prompt_attention_mask = torch.cat( [ prompt_attention_mask, - torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long), + torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device), ], dim=1, ) @@ -4014,7 +4015,7 @@ def encode_prompt( negative_prompt_attention_mask = torch.cat( [ negative_prompt_attention_mask, - torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long), + torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long, device=device), ], dim=1, ) From bc073fcf13768778513c1212421cbabfc77d0690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 7 Oct 2024 11:28:50 +0300 Subject: [PATCH 090/109] Refactor `matryoshka.py` to include proper licensing and attribution --- examples/community/matryoshka.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index c4137fc84eb5..4ae1bfb73aa8 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1,4 +1,23 @@ -# #TODO Licensed under the Apache License, Version 2.0 or MIT? +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Based on [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111). +# Authors: Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind, Navdeep Jaitly +# Code: https://github.com/apple/ml-mdm with MIT license +# +# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). + import inspect import math From 33edbdd08616f89d9c59006051a918c0ba967b46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 7 Oct 2024 11:29:47 +0300 Subject: [PATCH 091/109] Refactor `matryoshka.py` to remove deprecated `_encode_prompt()` method --- examples/community/matryoshka.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 4ae1bfb73aa8..d0353a84f1a9 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3833,38 +3833,6 @@ def __init__( scheduler.scales = unet.nest_ratio + [1] self.image_processor = VaeImageProcessor(do_resize=False) - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - def encode_prompt( self, prompt, From 96a788cb716c46ae96583674123cce7aa6f8e034 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 7 Oct 2024 11:30:21 +0300 Subject: [PATCH 092/109] Refactor `matryoshka.py` to include nesting levels for the UNet model --- examples/community/matryoshka.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index d0353a84f1a9..c6fbed5a470e 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3765,11 +3765,12 @@ def __init__( self, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, - unet: MatryoshkaUNet2DConditionModel, scheduler: MatryoshkaDDIMScheduler, + unet: MatryoshkaUNet2DConditionModel = None, feature_extractor: CLIPImageProcessor = None, image_encoder: CLIPVisionModelWithProjection = None, trust_remote_code: bool = False, + nesting_level: int = 0, ): super().__init__() @@ -3800,10 +3801,10 @@ def __init__( new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version + is_unet_version_less_0_9_0 = hasattr(unet[0].config, "_diffusers_version") and version.parse( + version.parse(unet[0].config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_sample_size_less_64 = hasattr(unet[0].config, "sample_size") and unet[0].config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -3821,6 +3822,18 @@ def __init__( new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + if nesting_level == 0: + unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_0") + elif nesting_level == 1: + unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_1") + elif nesting_level == 2: + unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_2") + else: + raise ValueError("Nesting level should be 0, 1 or 2") + self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, From d4f2911157ee926ea5f506767818e1d497fecd99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 7 Oct 2024 12:13:38 +0300 Subject: [PATCH 093/109] Up --- examples/community/matryoshka.py | 66 ++++++++++++++++---------------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index c6fbed5a470e..e43abc239388 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3726,7 +3726,7 @@ class MatryoshkaPipeline( FromSingleFileMixin, ): r""" - Pipeline for text-to-image generation using Stable Diffusion. + Pipeline for text-to-image generation using Matryoshka Diffusion Models. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -3739,21 +3739,17 @@ class MatryoshkaPipeline( - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. + text_encoder ([`~transformers.T5EncoderModel`]): + Frozen text-encoder ([flan-t5-xl](https://huggingface.co/google/flan-t5-xl)). + tokenizer ([`~transformers.T5Tokenizer`]): + A `T5Tokenizer` to tokenize text. unet ([`MatryoshkaUNet2DConditionModel`]): A `MatryoshkaUNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details - about a model's potential harms. - feature_extractor ([`~transformers.CLIPImageProcessor`]): - A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + [`MatryoshkaDDIMScheduler`] and other schedulers with proper modifications, see an example usage in README.md. + feature_extractor ([`~transformers.`]): + A `AnImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet" @@ -3774,6 +3770,18 @@ def __init__( ): super().__init__() + if nesting_level == 0: + unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_0") + elif nesting_level == 1: + unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_1") + elif nesting_level == 2: + unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_2") + else: + raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" @@ -3801,10 +3809,10 @@ def __init__( new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) - is_unet_version_less_0_9_0 = hasattr(unet[0].config, "_diffusers_version") and version.parse( - version.parse(unet[0].config._diffusers_version).base_version + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet[0].config, "sample_size") and unet[0].config.sample_size < 64 + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -3822,18 +3830,6 @@ def __init__( new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) - if nesting_level == 0: - unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_0") - elif nesting_level == 1: - unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_1") - elif nesting_level == 2: - unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_2") - else: - raise ValueError("Nesting level should be 0, 1 or 2") - self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, @@ -3924,7 +3920,7 @@ def encode_prompt( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" + "The following part of your input was truncated because FLAN-T5-XL for this pipeline can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) @@ -4403,8 +4399,8 @@ def __call__( Examples: Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + [`~MatryoshkaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~MatryoshkaPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. @@ -4511,10 +4507,11 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) - timesteps = timesteps[:-1] else: timesteps = self.scheduler.timesteps + timesteps = timesteps[:-1] + # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( @@ -4626,9 +4623,10 @@ def __call__( image = latents if self.scheduler.scales is not None: - image = image[0] - - image = self.image_processor.postprocess(image, output_type=output_type) + for i in range(len(image)): + image[i] = self.image_processor.postprocess(image[i], output_type=output_type) + else: + image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From 3bc6f80fe9501238e7a5b4733bd4d4fcd6eb2bb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 8 Oct 2024 16:58:59 +0300 Subject: [PATCH 094/109] Fix scaling issue for high resolutions --- examples/community/matryoshka.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index e43abc239388..989f02ad75d8 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -632,14 +632,14 @@ def step( # 4. Clip or threshold "predicted x_0" if self.config.thresholding: if len(model_output) > 1: - pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample] + pred_original_sample = [self._threshold_sample(p_o_s * scale) / scale for p_o_s, scale in zip(pred_original_sample, self.scales)] else: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: if len(model_output) > 1: pred_original_sample = [ - p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) - for p_o_s in pred_original_sample + (p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale + for p_o_s, scale in zip(pred_original_sample, self.scales) ] else: pred_original_sample = pred_original_sample.clamp( @@ -4624,6 +4624,7 @@ def __call__( if self.scheduler.scales is not None: for i in range(len(image)): + image[i] = image[i] * self.scheduler.scales[i] image[i] = self.image_processor.postprocess(image[i], output_type=output_type) else: image = self.image_processor.postprocess(image, output_type=output_type) From e4259acc0227e096fbb7273bef9725dd27c5f5e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 8 Oct 2024 20:51:57 +0300 Subject: [PATCH 095/109] Add `self.change_nesting_level(int)` function --- examples/community/matryoshka.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 989f02ad75d8..b94f0c7ee786 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -20,6 +20,7 @@ import inspect +import gc import math from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -3830,6 +3831,9 @@ def __init__( new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + if hasattr(unet, "nest_ratio"): + scheduler.scales = unet.nest_ratio + [1] + self.register_modules( text_encoder=text_encoder, tokenizer=tokenizer, @@ -3838,10 +3842,32 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - if hasattr(unet, "nest_ratio"): - scheduler.scales = unet.nest_ratio + [1] + self.register_to_config(nesting_level=nesting_level) self.image_processor = VaeImageProcessor(do_resize=False) + def change_nesting_level(self, nesting_level: int): + if nesting_level == 0: + if hasattr(self.unet, "nest_ratio"): + self.scheduler.scales = None + self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_0").to(self.device) + self.config.nesting_level = 0 + elif nesting_level == 1: + self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_1").to(self.device) + self.config.nesting_level = 1 + self.scheduler.scales = self.unet.nest_ratio + [1] + elif nesting_level == 2: + self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + subfolder="unet/nesting_level_2").to(self.device) + self.config.nesting_level = 2 + self.scheduler.scales = self.unet.nest_ratio + [1] + else: + raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") + + gc.collect() + torch.cuda.empty_cache() + def encode_prompt( self, prompt, From 942c54afff21705694fd3ea1e7bbc23924655d75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 10 Oct 2024 21:29:18 +0300 Subject: [PATCH 096/109] Refactor `matryoshka.py` to handle multiple model outputs in `MatryoshkaDDIMScheduler` for `use_clipped_model_output` param --- examples/community/matryoshka.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index b94f0c7ee786..b55bedfcf829 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -654,7 +654,12 @@ def step( if use_clipped_model_output: # the pred_epsilon is always re-derived from the clipped x_0 in Glide - pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + if len(model_output) > 1: + pred_epsilon = [] + for s, a_p_t, p_o_s, b_p_t in zip(sample, alpha_prod_t, pred_original_sample, beta_prod_t): + pred_epsilon.append((s - a_p_t ** (0.5) * p_o_s) / b_p_t ** (0.5)) + else: + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if len(model_output) > 1: From 737bca026309555ab178ed8b707ffed91e2ccb15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 10 Oct 2024 21:30:09 +0300 Subject: [PATCH 097/109] This model uses this. --- examples/community/matryoshka.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index b55bedfcf829..77ea694d6a14 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3802,18 +3802,18 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) + # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + # deprecation_message = ( + # f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + # " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + # " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + # " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + # " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + # ) + # deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + # new_config = dict(scheduler.config) + # new_config["clip_sample"] = False + # scheduler._internal_dict = FrozenDict(new_config) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version From aabba0802b4bb8473bb28c1321809ec0274c3094 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 10 Oct 2024 21:31:23 +0300 Subject: [PATCH 098/109] Move `extra_step_kwargs` --- examples/community/matryoshka.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 77ea694d6a14..eae8c5f96a5a 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -4533,11 +4533,16 @@ def __call__( self.do_classifier_free_guidance, ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 4. Prepare timesteps if isinstance(self.scheduler, MatryoshkaDDIMScheduler): timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) + + extra_step_kwargs |= {"use_clipped_model_output": True} else: timesteps = self.scheduler.timesteps @@ -4557,9 +4562,6 @@ def __call__( latents, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} From 360f57ed1a13e68c64e6710c944ffffa18b1054a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 10 Oct 2024 21:36:46 +0300 Subject: [PATCH 099/109] style --- examples/community/matryoshka.py | 49 +++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index eae8c5f96a5a..d5ee39d4a440 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -19,8 +19,8 @@ # Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz). -import inspect import gc +import inspect import math from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -633,7 +633,10 @@ def step( # 4. Clip or threshold "predicted x_0" if self.config.thresholding: if len(model_output) > 1: - pred_original_sample = [self._threshold_sample(p_o_s * scale) / scale for p_o_s, scale in zip(pred_original_sample, self.scales)] + pred_original_sample = [ + self._threshold_sample(p_o_s * scale) / scale + for p_o_s, scale in zip(pred_original_sample, self.scales) + ] else: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: @@ -3777,14 +3780,17 @@ def __init__( super().__init__() if nesting_level == 0: - unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_0") + unet = MatryoshkaUNet2DConditionModel.from_pretrained( + "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" + ) elif nesting_level == 1: - unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_1") + unet = NestedUNet2DConditionModel.from_pretrained( + "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" + ) elif nesting_level == 2: - unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_2") + unet = NestedUNet2DConditionModel.from_pretrained( + "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" + ) else: raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") @@ -3854,17 +3860,20 @@ def change_nesting_level(self, nesting_level: int): if nesting_level == 0: if hasattr(self.unet, "nest_ratio"): self.scheduler.scales = None - self.unet = MatryoshkaUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_0").to(self.device) + self.unet = MatryoshkaUNet2DConditionModel.from_pretrained( + "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_0" + ).to(self.device) self.config.nesting_level = 0 elif nesting_level == 1: - self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_1").to(self.device) + self.unet = NestedUNet2DConditionModel.from_pretrained( + "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_1" + ).to(self.device) self.config.nesting_level = 1 self.scheduler.scales = self.unet.nest_ratio + [1] elif nesting_level == 2: - self.unet = NestedUNet2DConditionModel.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - subfolder="unet/nesting_level_2").to(self.device) + self.unet = NestedUNet2DConditionModel.from_pretrained( + "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" + ).to(self.device) self.config.nesting_level = 2 self.scheduler.scales = self.unet.nest_ratio + [1] else: @@ -4030,7 +4039,9 @@ def encode_prompt( prompt_attention_mask = torch.cat( [ prompt_attention_mask, - torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device), + torch.zeros( + batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long, device=device + ), ], dim=1, ) @@ -4042,7 +4053,12 @@ def encode_prompt( negative_prompt_attention_mask = torch.cat( [ negative_prompt_attention_mask, - torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long, device=device), + torch.zeros( + batch_size, + max_len - len(negative_prompt_attention_mask[0]), + dtype=torch.long, + device=device, + ), ], dim=1, ) @@ -4533,7 +4549,6 @@ def __call__( self.do_classifier_free_guidance, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 4. Prepare timesteps From 83262f83ef0666556bd9468933478739cd889ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 11 Oct 2024 15:59:44 +0300 Subject: [PATCH 100/109] Refactor optional components in `MatryoshkaPipeline` --- examples/community/matryoshka.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index d5ee39d4a440..860bc1f064f9 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3762,8 +3762,7 @@ class MatryoshkaPipeline( """ model_cpu_offload_seq = "text_encoder->image_encoder->unet" - _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] - _exclude_from_cpu_offload = ["safety_checker"] + _optional_components = ["unet", "feature_extractor", "image_encoder"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( From e5433792efa6f41f17b87cc8295d07e0cacbbf0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 11 Oct 2024 16:00:15 +0300 Subject: [PATCH 101/109] Simplify --- examples/community/matryoshka.py | 37 +++++++++----------------------- 1 file changed, 10 insertions(+), 27 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 860bc1f064f9..2eaf6e3c564a 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -4548,18 +4548,10 @@ def __call__( self.do_classifier_free_guidance, ) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 4. Prepare timesteps - if isinstance(self.scheduler, MatryoshkaDDIMScheduler): - timesteps, num_inference_steps = retrieve_timesteps( - self.scheduler, num_inference_steps, device, timesteps, sigmas - ) - - extra_step_kwargs |= {"use_clipped_model_output": True} - else: - timesteps = self.scheduler.timesteps - + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) timesteps = timesteps[:-1] # 5. Prepare latent variables @@ -4576,6 +4568,10 @@ def __call__( latents, ) + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + extra_step_kwargs |= {"use_clipped_model_output": True} + # 6.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} @@ -4633,19 +4629,7 @@ def __call__( noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 - if self.scheduler.scales is not None and not isinstance(self.scheduler, MatryoshkaDDIMScheduler): - latents[0] = self.scheduler.step( - noise_pred[0], t, latents[0], **extra_step_kwargs, return_dict=False - )[0] - latents[1] = self.scheduler.inner_scheduler.step( - noise_pred[1], t, latents[1], **extra_step_kwargs, return_dict=False - )[0] - if len(latents) > 2: - latents[2] = self.scheduler.inner_scheduler.inner_scheduler.step( - noise_pred[2], t, latents[2], **extra_step_kwargs, return_dict=False - )[0] - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} @@ -4670,9 +4654,8 @@ def __call__( image = latents if self.scheduler.scales is not None: - for i in range(len(image)): - image[i] = image[i] * self.scheduler.scales[i] - image[i] = self.image_processor.postprocess(image[i], output_type=output_type) + for i, (img, scale) in enumerate(zip(image, self.scheduler.scales)): + image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0] else: image = self.image_processor.postprocess(image, output_type=output_type) From bd915850374496e39e4f657dcfc9154cc437dc67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 11 Oct 2024 17:22:51 +0300 Subject: [PATCH 102/109] =?UTF-8?q?Add=20=F0=9F=AA=86Matryoshka=20Diffusio?= =?UTF-8?q?n=20Models=20to=20community=20pipelines=20in=20`Readme.md`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/community/README.md | 66 ++++++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 10 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index e51124e75956..ae4a855c26b4 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -73,7 +73,8 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) | | FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) | | AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | -| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffsuion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) | +| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) | +| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. @@ -85,17 +86,17 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion ### Flux with CFG -Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md). +Know more about Flux [here](https://blackforestlabs.ai/announcing-black-forest-labs/). Since Flux doesn't use CFG, this implementation provides one, inspired by the [PuLID Flux adaptation](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md). Example usage: ```py from diffusers import DiffusionPipeline -import torch +import torch pipeline = DiffusionPipeline.from_pretrained( - "black-forest-labs/FLUX.1-dev", - torch_dtype=torch.bfloat16, + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.bfloat16, custom_pipeline="pipeline_flux_with_cfg" ) pipeline.enable_model_cpu_offload() @@ -103,10 +104,10 @@ prompt = "a watercolor painting of a unicorn" negative_prompt = "pink" img = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - true_cfg=1.5, - guidance_scale=3.5, + prompt=prompt, + negative_prompt=negative_prompt, + true_cfg=1.5, + guidance_scale=3.5, num_images_per_prompt=1, generator=torch.manual_seed(0) ).images[0] @@ -2656,7 +2657,7 @@ image with mask mech_painted.png -result: +result: @@ -4324,6 +4325,51 @@ image = pipe( A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab. +### 🪆Matryoshka Diffusion Models + +![🪆Matryoshka Diffusion Models](https://github.com/user-attachments/assets/bf90b53b-48c3-4769-a805-d9dfe4a7c572) + +The Abstract of the paper: +>Diffusion models are the _de-facto_ approach for generating high-quality images and videos but learning high-dimensional models remains a formidable task due to computational and optimization challenges. Existing methods often resort to training cascaded models in pixel space, or using a downsampled latent space of a separately trained auto-encoder. In this paper, we introduce Matryoshka Diffusion (MDM), **a novel framework for high-resolution image and video synthesis**. We propose a diffusion process that denoises inputs at multiple resolutions jointly and uses a **NestedUNet** architecture where features and parameters for small scale inputs are nested within those of the large scales. In addition, MDM enables a progressive training schedule from lower to higher resolutions which leads to significant improvements in optimization for high-resolution generation. We demonstrate the effectiveness of our approach on various benchmarks, including class-conditioned image generation, high-resolution text-to-image, and text-to-video applications. Remarkably, we can train a **_single pixel-space model_ at resolutions of up to 1024 × 1024 pixels**, demonstrating strong zero shot generalization using the **CC12M dataset, which contains only 12 million images**. Code and pre-trained checkpoints are released at https://github.com/apple/ml-mdm. + +- `64×64, nesting_level=0`: 1.719 GiB. With `50` DDIM inference steps: + +**64x64** +:-------------------------: +| bird_64 | + +- `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps: + +**64x64** | **256x256** +:-------------------------:|:-------------------------: +| 64x64 | 256x256 | + +- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible. With `250` DDIM inference steps: + +**64x64** | **256x256** | **1024x1024** +:-------------------------:|:-------------------------:|:-------------------------: +| 64x64 | 256x256 | 1024x1024 | + +```py +from diffusers import DiffusionPipeline +from diffusers.utils import make_image_grid + +# nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64 +pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + custom_pipeline="matryoshka", + nesting_level=0, + ).to("cuda") + +prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" +prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed" +negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy" +image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images +make_image_grid(image, rows=1, cols=len(image)) + +# pipe.change_nesting_level() # 0, 1, or 2 +# 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively. +``` + # Perturbed-Attention Guidance [Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance) From 343034580c75b864ddd6509b016224c02ee2fafc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 11 Oct 2024 18:35:16 +0300 Subject: [PATCH 103/109] Update example usage --- examples/community/matryoshka.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 2eaf6e3c564a..0df78d8bbac0 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -102,15 +102,21 @@ EXAMPLE_DOC_STRING = """ Examples: ```py - >>> import torch - >>> from diffusers import MatryoshkaPipeline + >>> from diffusers import DiffusionPipeline + >>> from diffusers.utils import make_image_grid - >>> pipe = MatryoshkaPipeline.from_pretrained("A/B", torch_dtype=torch.float16, variant="fp16") - >>> pipe = pipe.to("cuda") + >>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64 + >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", + >>> custom_pipeline="matryoshka").to("cuda") - >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt).images[0] - >>> image + >>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" + >>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed" + >>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy" + >>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images + >>> make_image_grid(image, rows=1, cols=len(image)) + + >>> pipe.change_nesting_level() # 0, 1, or 2 + >>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively. ``` """ From 149e8b59183fa299bed456361969f1d1cd82a52f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 13 Oct 2024 13:15:34 +0300 Subject: [PATCH 104/109] Refactor `MatryoshkaTransformerBlock` to use `MatryoshkaFusedAttnProcessor2_0` --- examples/community/matryoshka.py | 88 ++++++++++---------------------- 1 file changed, 28 insertions(+), 60 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0df78d8bbac0..99eb39a23387 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1440,7 +1440,7 @@ def __init__( bias=True, upcast_attention=upcast_attention, pre_only=True, - processor=MatryoshkaFusedAttnProcessor1_0_or_2_0(), + processor=MatryoshkaFusedAttnProcessor2_0(), ) self.attn1.fuse_projections() del self.attn1.to_q @@ -1458,7 +1458,7 @@ def __init__( bias=True, upcast_attention=upcast_attention, pre_only=True, - processor=MatryoshkaFusedAttnProcessor1_0_or_2_0(), + processor=MatryoshkaFusedAttnProcessor2_0(), ) self.attn2.fuse_projections() del self.attn2.to_q @@ -1517,7 +1517,6 @@ def forward( # **cross_attention_kwargs, ) - attn_output_cond = attn_output_cond.permute(0, 2, 1).contiguous() attn_output_cond = self.proj_out(attn_output_cond) attn_output_cond = attn_output_cond.permute(0, 2, 1).reshape(batch_size, channels, *spatial_dims) hidden_states = hidden_states + attn_output_cond @@ -1535,7 +1534,7 @@ def forward( return hidden_states -class MatryoshkaFusedAttnProcessor1_0_or_2_0: +class MatryoshkaFusedAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused. @@ -1548,28 +1547,12 @@ class MatryoshkaFusedAttnProcessor1_0_or_2_0: """ - # def __init__(self): - # if not hasattr(F, "scaled_dot_product_attention"): - # raise ImportError( - # "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x." - # ) - - # TODO: They seem to give different results; but nevertheless can I replace this with torch.nn.functional.scaled_dot_product_attention()? - def attention(self, q, k, v, num_heads, mask=None): - bs, width, length = q.shape - ch = width // num_heads - scale = 1 / torch.sqrt(torch.sqrt(torch.tensor(ch))) - weight = torch.einsum( - "bct,bcs->bts", - (q * scale).reshape(bs * num_heads, ch, length), - (k * scale).reshape(bs * num_heads, ch, -1), - ) # More stable with f16 than dividing afterwards - if mask is not None: - mask = mask.view(mask.size(0), 1, 1, mask.size(-1)).repeat(1, num_heads, 1, 1).flatten(0, 1) - weight = weight.masked_fill(mask == 0, float("-inf")) - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * num_heads, ch, -1)) - return a.reshape(bs, -1, length) + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x." + ) + def __call__( self, @@ -1593,26 +1576,12 @@ def __call__( input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - # batch_size, sequence_length, _ = ( - # hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - # ) - - # if attention_mask is not None: - # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # # scaled_dot_product_attention expects attention_mask shape to be - # # (batch, heads, source_length, target_length) - # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states) # .transpose(1, 2)).transpose(1, 2) + hidden_states = attn.group_norm(hidden_states) - # Reshape hidden_states to 2D tensor - hidden_states = hidden_states.view(batch_size, channel, height * width).permute(0, 2, 1).contiguous() - # Now hidden_states.shape is [batch_size, height * width, channels] + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2).contiguous() if encoder_hidden_states is None: qkv = attn.to_qkv(hidden_states) @@ -1630,10 +1599,18 @@ def __call__( split_size = kv.shape[-1] // 2 key, value = torch.split(kv, split_size, dim=-1) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + if self_attention_output is None: - query = query.permute(0, 2, 1) - key = key.permute(0, 2, 1) - value = value.permute(0, 2, 1) + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) @@ -1641,25 +1618,16 @@ def __call__( key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 if F.scaled_dot_product_attention() is available - hidden_states = self.attention( - query, - key, - value, - mask=attention_mask, - num_heads=attn.heads, + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.to(query.dtype) if self_attention_output is not None: hidden_states = hidden_states + self_attention_output - - if not attn.pre_only: - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) if attn.residual_connection: hidden_states = hidden_states + residual From ecca7e33860261dffd9773e86bee2ebe1fe62d4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 13 Oct 2024 13:16:16 +0300 Subject: [PATCH 105/109] style --- examples/community/matryoshka.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 99eb39a23387..70f9c2c9da50 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1553,7 +1553,6 @@ def __init__(self): "MatryoshkaFusedAttnProcessor2_0 requires PyTorch 2.x, to use it. Please upgrade PyTorch to > 2.x." ) - def __call__( self, attn: Attention, From 6fd62e09c09a76f25eb6807f9a78bf59507031c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 13 Oct 2024 13:29:21 +0300 Subject: [PATCH 106/109] simplify --- examples/community/matryoshka.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 70f9c2c9da50..7ef1438f7204 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -1909,7 +1909,7 @@ def forward(self, emb, encoder_hidden_states, added_cond_kwargs): # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False): return temb_micro_conditioning, conditioning_mask, cond_emb - return cond_emb, conditioning_mask, cond_emb + return None, conditioning_mask, cond_emb @dataclass @@ -3137,7 +3137,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, # cond_mask? + encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: @@ -3167,7 +3167,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, # cond_mask? + encoder_attention_mask=encoder_attention_mask, ) else: sample = self.mid_block(sample, emb) @@ -3204,7 +3204,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, # cond_mask? + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( @@ -3652,7 +3652,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, - encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, # cond_mask? + encoder_attention_mask=cond_mask[:bh] if cond_mask is not None else cond_mask, ) else: sample = upsample_block( From 5009be12dccad653c72d05b5df98c7a0fcf1e593 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 13 Oct 2024 20:01:41 +0300 Subject: [PATCH 107/109] Add `trust_remote_code=True` requirement for custom components --- examples/community/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/community/README.md b/examples/community/README.md index ae4a855c26b4..0043084ba6c1 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -4358,6 +4358,7 @@ from diffusers.utils import make_image_grid pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", custom_pipeline="matryoshka", nesting_level=0, + trust_remote_code=False, # One needs to give permission for this code to run ).to("cuda") prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" From 4c3ba487370c80159945541305e504d85f784930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 13 Oct 2024 21:28:03 +0300 Subject: [PATCH 108/109] revert --- src/diffusers/models/embeddings.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index e3fd69031075..c250df29afbe 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -734,10 +734,7 @@ def __init__( else: self.cond_proj = None - if act_fn is None: - self.act = None - else: - self.act = get_activation(act_fn) + self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim From 1b756d188ad68cd1fa0eb60c298dd3e6d505b4c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Mon, 14 Oct 2024 12:04:20 +0300 Subject: [PATCH 109/109] Update README.md --- examples/community/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index 0043084ba6c1..267c8f4bb904 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -4356,7 +4356,6 @@ from diffusers.utils import make_image_grid # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64 pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - custom_pipeline="matryoshka", nesting_level=0, trust_remote_code=False, # One needs to give permission for this code to run ).to("cuda")