From 2a6a996b8a652bbe8ae638f863b268986cc2a18d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 18 Dec 2023 07:02:00 +0000 Subject: [PATCH 1/8] update --- .../convert_animatediff_loras_to_diffusers.py | 51 +++++++++++++++++++ ...convert_motion_module_ckpt_to_diffusers.py | 51 +++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 src/diffusers/pipelines/animatediff/convert_animatediff_loras_to_diffusers.py create mode 100644 src/diffusers/pipelines/animatediff/convert_motion_module_ckpt_to_diffusers.py diff --git a/src/diffusers/pipelines/animatediff/convert_animatediff_loras_to_diffusers.py b/src/diffusers/pipelines/animatediff/convert_animatediff_loras_to_diffusers.py new file mode 100644 index 000000000000..509a7345793c --- /dev/null +++ b/src/diffusers/pipelines/animatediff/convert_animatediff_loras_to_diffusers.py @@ -0,0 +1,51 @@ +import argparse + +import torch +from safetensors.torch import save_file + + +def convert_motion_module(original_state_dict): + converted_state_dict = {} + for k, v in original_state_dict.items(): + if "pos_encoder" in k: + continue + + else: + converted_state_dict[ + k.replace(".norms.0", ".norm1") + .replace(".norms.1", ".norm2") + .replace(".ff_norm", ".norm3") + .replace(".attention_blocks.0", ".attn1") + .replace(".attention_blocks.1", ".attn2") + .replace(".temporal_transformer", "") + ] = v + + return converted_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + state_dict = torch.load(args.ckpt_path, map_location="cpu") + + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + + conv_state_dict = convert_motion_module(state_dict) + + # convert to new format + output_dict = {} + for module_name, params in conv_state_dict.items(): + if type(params) is not torch.Tensor: + continue + output_dict.update({f"unet.{module_name}": params}) + + save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors") diff --git a/src/diffusers/pipelines/animatediff/convert_motion_module_ckpt_to_diffusers.py b/src/diffusers/pipelines/animatediff/convert_motion_module_ckpt_to_diffusers.py new file mode 100644 index 000000000000..9c06a48e08d6 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/convert_motion_module_ckpt_to_diffusers.py @@ -0,0 +1,51 @@ +import argparse + +import torch + +from diffusers import MotionAdapter + + +def convert_motion_module(original_state_dict): + converted_state_dict = {} + for k, v in original_state_dict.items(): + if "pos_encoder" in k: + continue + + else: + converted_state_dict[ + k.replace(".norms.0", ".norm1") + .replace(".norms.1", ".norm2") + .replace(".ff_norm", ".norm3") + .replace(".attention_blocks.0", ".attn1") + .replace(".attention_blocks.1", ".attn2") + .replace(".temporal_transformer", "") + ] = v + + return converted_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--use_motion_mid_block", action="store_true") + parser.add_argument("--motion_max_seq_length", type=int, default=32) + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + + conv_state_dict = convert_motion_module(state_dict) + adapter = MotionAdapter( + use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length + ) + # skip loading position embeddings + adapter.load_state_dict(conv_state_dict, strict=False) + adapter.save_pretrained(args.output_path) + adapter.save_pretrained(args.output_path, variant="fp16") From ed5360dfcae1d2a71009da2e33f683f983af5e01 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 19 Dec 2023 03:47:17 +0000 Subject: [PATCH 2/8] update --- src/diffusers/models/controlnet_sparsectrl.py | 821 ++++++++++++++++++ 1 file changed, 821 insertions(+) create mode 100644 src/diffusers/models/controlnet_sparsectrl.py diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py new file mode 100644 index 000000000000..a37e4debdfdc --- /dev/null +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -0,0 +1,821 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalControlnetMixin +from ..utils import BaseOutput, logging +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from .modeling_utils import ModelMixin +from .unet_2d_condition import UNet2DConditionModel +from .unet_3d_blocks import ( + CrossAttnDownBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UNetMidBlockMotion, + get_down_block, +) + + +class SparseControlNetConditioningEmbedding(nn.Module): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + batch_size, channels, num_frames, height, width = conditioning.shape + conditioning = conditioning.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + embedding = embedding.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + + return embedding + + +class SparseControlNetModel(ModelMixin, ConfigMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( + conditioning_embedding_channels=block_out_channels[0], + block_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + controlnet_block = zero_module(controlnet_block) + self.controlnet_mid_block = controlnet_block + + self.mid_block = UNetMidBlockCrossAttnMotion( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + mid_block_type=unet.config.mid_block_type, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + conditioning_mask: Optional[torch.FloatTensor] = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + controlnet_cond = torch.flip(controlnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + batch_size, channels, num_frames, height, width = sample.shape + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + sample = self.conv_in(sample) + + if self.concate_conditioning_mask: + controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) + + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + sample = sample + controlnet_cond + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 5. Control net blocks + + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(sample) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one + else: + down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample * conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + + + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module From 7d2498de87dca517cbb3a617a480135ebc548c13 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 21 Dec 2023 09:28:30 +0000 Subject: [PATCH 3/8] update --- src/diffusers/models/__init__.py | 1 + src/diffusers/models/controlnet_sparsectrl.py | 98 +- src/diffusers/models/unet_3d_blocks.py | 7 + ...y => convert_motion_loras_to_diffusers.py} | 0 .../convert_sparse_cntrl_to_diffusers.py | 49 + .../pipeline_animatediff_sparsecntrl.py | 1197 +++++++++++++++++ 6 files changed, 1322 insertions(+), 30 deletions(-) rename src/diffusers/pipelines/animatediff/{convert_animatediff_loras_to_diffusers.py => convert_motion_loras_to_diffusers.py} (100%) create mode 100644 src/diffusers/pipelines/animatediff/convert_sparse_cntrl_to_diffusers.py create mode 100644 src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7487bbf2f98e..e7169d31027d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -32,6 +32,7 @@ _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["controlnet"] = ["ControlNetModel"] + _import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnetxs"] = ["ControlNetXSModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["embeddings"] = ["ImageProjection"] diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index a37e4debdfdc..cd054fd16dfe 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -6,7 +6,6 @@ from torch.nn import functional as F from ..configuration_utils import ConfigMixin, register_to_config -from ..loaders import FromOriginalControlnetMixin from ..utils import BaseOutput, logging from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -17,16 +16,37 @@ ) from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin -from .unet_2d_condition import UNet2DConditionModel +from .unet_2d_condition import UNet2DConditionModel, UNetMidBlock2DCrossAttn from .unet_3d_blocks import ( CrossAttnDownBlockMotion, DownBlockMotion, - UNetMidBlockCrossAttnMotion, - UNetMidBlockMotion, get_down_block, ) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SparseControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + class SparseControlNetConditioningEmbedding(nn.Module): def __init__( self, @@ -141,7 +161,7 @@ class conditioning with `class_embed_type` equal to `None`. def __init__( self, in_channels: int = 4, - conditioning_channels: int = 3, + conditioning_channels: int = 4, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str, ...] = ( @@ -158,7 +178,7 @@ def __init__( act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, + cross_attention_dim: int = 768, transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, @@ -176,6 +196,11 @@ def __init__( conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), global_pool_conditions: bool = False, addition_embed_type_num_heads: int = 64, + motion_max_seq_length: int = 32, + motion_num_attention_heads: int = 8, + concate_conditioning_mask: bool = True, + use_simplified_condition_embedding: bool = True, + set_noisy_sample_input_to_zero: bool = False, ): super().__init__() @@ -299,10 +324,19 @@ def __init__( raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") # control net conditioning embedding - self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( - conditioning_embedding_channels=block_out_channels[0], - block_out_channels=conditioning_embedding_out_channels, - conditioning_channels=conditioning_channels, + # TODO: Use simple embeddings + + # self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( + # conditioning_embedding_channels=block_out_channels[0], + # block_out_channels=conditioning_embedding_out_channels, + # conditioning_channels=conditioning_channels, + # ) + if concate_conditioning_mask: + conditioning_channels = conditioning_channels + 1 + + self.concate_conditioning_mask = concate_conditioning_mask + self.controlnet_cond_embedding = nn.Conv2d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 ) self.down_blocks = nn.ModuleList([]) @@ -332,7 +366,6 @@ def __init__( down_block = get_down_block( down_block_type, num_layers=layers_per_block, - transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -342,12 +375,12 @@ def __init__( resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, + dual_cross_attention=False, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + temporal_double_self_attention=False, ) self.down_blocks.append(down_block) @@ -368,19 +401,16 @@ def __init__( controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block - self.mid_block = UNetMidBlockCrossAttnMotion( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=mid_block_channel, + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, + dual_cross_attention=False, ) @classmethod @@ -601,7 +631,7 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value: bool = False) -> None: - if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion)): module.gradient_checkpointing = value def forward( @@ -611,14 +641,15 @@ def forward( encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_scale: float = 1.0, - conditioning_mask: Optional[torch.FloatTensor] = None, class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + conditioning_mask: Optional[torch.FloatTensor] = None, guess_mode: bool = False, return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: + ) -> Union[SparseControlNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: """ The [`ControlNetModel`] forward method. @@ -658,6 +689,8 @@ def forward( If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ + batch_size, channels, num_generated_frames, height, width = sample.shape + # check channel order channel_order = self.config.controlnet_conditioning_channel_order @@ -737,16 +770,24 @@ def forward( # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape + emb = emb.repeat_interleave(repeats=num_generated_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_generated_frames, dim=0) sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) sample = self.conv_in(sample) - if self.concate_conditioning_mask: controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) + batch_size, channels, num_frames, height, width = controlnet_cond.shape + controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape( + batch_size * num_frames, channels, height, width + ) + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) sample = sample + controlnet_cond + batch_size, channels, num_frames, height, width + # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: @@ -757,6 +798,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + num_frames=num_generated_frames, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -785,7 +827,6 @@ def forward( controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples - mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling @@ -807,14 +848,11 @@ def forward( if not return_dict: return (down_block_res_samples, mid_block_res_sample) - return ControlNetOutput( + return SparseControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) - - - def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) diff --git a/src/diffusers/models/unet_3d_blocks.py b/src/diffusers/models/unet_3d_blocks.py index e9c505c347b0..c4df4ef361f0 100644 --- a/src/diffusers/models/unet_3d_blocks.py +++ b/src/diffusers/models/unet_3d_blocks.py @@ -54,6 +54,7 @@ def get_down_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", temporal_num_attention_heads: int = 8, + temporal_double_self_attention=True, temporal_max_seq_length: int = 32, transformer_layers_per_block: int = 1, ) -> Union[ @@ -112,6 +113,7 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, + temporal_double_self_attention=temporal_double_self_attention, ) elif down_block_type == "CrossAttnDownBlockMotion": if cross_attention_dim is None: @@ -135,6 +137,7 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, temporal_num_attention_heads=temporal_num_attention_heads, temporal_max_seq_length=temporal_max_seq_length, + temporal_double_self_attention=temporal_double_self_attention, ) elif down_block_type == "DownBlockSpatioTemporal": # added for SDV @@ -946,6 +949,7 @@ def __init__( temporal_num_attention_heads: int = 1, temporal_cross_attention_dim: Optional[int] = None, temporal_max_seq_length: int = 32, + temporal_double_self_attention: bool = True, ): super().__init__() resnets = [] @@ -978,6 +982,7 @@ def __init__( positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, + double_self_attention=temporal_double_self_attention, ) ) @@ -1080,6 +1085,7 @@ def __init__( temporal_cross_attention_dim: Optional[int] = None, temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, + temporal_double_self_attention: bool = True, ): super().__init__() resnets = [] @@ -1144,6 +1150,7 @@ def __init__( positional_embeddings="sinusoidal", num_positional_embeddings=temporal_max_seq_length, attention_head_dim=out_channels // temporal_num_attention_heads, + double_self_attention=temporal_double_self_attention, ) ) diff --git a/src/diffusers/pipelines/animatediff/convert_animatediff_loras_to_diffusers.py b/src/diffusers/pipelines/animatediff/convert_motion_loras_to_diffusers.py similarity index 100% rename from src/diffusers/pipelines/animatediff/convert_animatediff_loras_to_diffusers.py rename to src/diffusers/pipelines/animatediff/convert_motion_loras_to_diffusers.py diff --git a/src/diffusers/pipelines/animatediff/convert_sparse_cntrl_to_diffusers.py b/src/diffusers/pipelines/animatediff/convert_sparse_cntrl_to_diffusers.py new file mode 100644 index 000000000000..1bb794d31440 --- /dev/null +++ b/src/diffusers/pipelines/animatediff/convert_sparse_cntrl_to_diffusers.py @@ -0,0 +1,49 @@ +import argparse + +import torch + +from diffusers.models import SparseControlNetModel + + +def convert_sparse_cntrl_module(original_state_dict): + converted_state_dict = {} + for k, v in original_state_dict.items(): + if "pos_encoder" in k: + continue + + else: + converted_state_dict[ + k.replace(".norms.0", ".norm1") + .replace(".norms.1", ".norm2") + .replace(".ff_norm", ".norm3") + .replace(".attention_blocks.0", ".attn1") + .replace(".attention_blocks.1", ".attn2") + .replace(".temporal_transformer", "") + ] = v + + return converted_state_dict + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--motion_max_seq_length", type=int, default=32) + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + + conv_state_dict = convert_sparse_cntrl_module(state_dict) + controlnet = SparseControlNetModel() + + # skip loading position embeddings + controlnet.load_state_dict(conv_state_dict, strict=False) + controlnet.save_pretrained(args.output_path) + controlnet.save_pretrained(args.output_path, variant="fp16") diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py new file mode 100644 index 000000000000..63e48917f4eb --- /dev/null +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py @@ -0,0 +1,1197 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel +from diffusers.models.controlnet_sparsectrl import SparseControlNetModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.models.unet_motion_model import MotionAdapter +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import is_compiled_module, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter + >>> from diffusers.pipelines import DiffusionPipeline + >>> from diffusers.schedulers import DPMSolverMultistepScheduler + >>> from PIL import Image + + >>> motion_id = "guoyww/animatediff-motion-adapter-v1-5-2" + >>> adapter = MotionAdapter.from_pretrained(motion_id) + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16) + >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) + + >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" + >>> pipe = DiffusionPipeline.from_pretrained( + ... model_id, + ... motion_adapter=adapter, + ... controlnet=controlnet, + ... vae=vae, + ... custom_pipeline="pipeline_animatediff_controlnet", + ... ).to(device="cuda", dtype=torch.float16) + >>> pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained( + ... model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1 + ... ) + >>> pipe.enable_vae_slicing() + + >>> conditioning_frames = [] + >>> for i in range(1, 16 + 1): + ... conditioning_frames.append(Image.open(f"frame_{i}.png")) + + >>> prompt = "astronaut in space, dancing" + >>> negative_prompt = "bad quality, worst quality, jpeg artifacts, ugly" + >>> result = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=512, + ... height=768, + ... conditioning_frames=conditioning_frames, + ... num_inference_steps=12, + ... ).frames[0] + + >>> from diffusers.utils import export_to_gif + >>> export_to_gif(result.frames[0], "result.gif") + ``` +""" + + +def tensor2vid(video: torch.Tensor, processor, output_type="np"): + # Based on: + # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + return outputs + + +@dataclass +class AnimateDiffControlNetPipelineOutput(BaseOutput): + frames: Union[torch.Tensor, np.ndarray] + + +class AnimateDiffSparseControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (`CLIPTokenizer`): + A [`~transformers.CLIPTokenizer`] to tokenize text. + unet ([`UNet2DConditionModel`]): + A [`UNet2DConditionModel`] used to create a UNetMotionModel to denoise the encoded video latents. + motion_adapter ([`MotionAdapter`]): + A [`MotionAdapter`] to be used in combination with `unet` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["feature_extractor", "image_encoder"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + motion_adapter: MotionAdapter, + controlnet: Union[ControlNetModel, MultiControlNetModel], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], + feature_extractor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, + ): + super().__init__() + unet = UNetMotionModel.from_unet2d(unet, motion_adapter) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + motion_adapter=motion_adapter, + controlnet=controlnet, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + + batch_size, channels, num_frames, height, width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + + image = self.vae.decode(latents).sample + video = ( + image[None, :] + .reshape( + ( + batch_size, + num_frames, + -1, + ) + + image.shape[2:] + ) + .permute(0, 2, 1, 3, 4) + ) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + image=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, SparseControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, SparseControlNetModel) + ): + if isinstance(image, list): + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for control_ in image: + for image_ in control_: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, SparseControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, SparseControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + controlnet_images = image.unsqueeze(0) + + batch_size, num_frames, channels, height, width = controlnet_images.shape + controlnet_images = controlnet_images.reshape(batch_size * num_frames, channels, height, width) + controlnet_images = controlnet_images.to(device, dtype) + + with torch.no_grad(): + conditioning_frames = self.vae.encode(controlnet_images).latent_dist.sample() + + batch_frames, channels, height, width = conditioning_frames.shape + conditioning_frames = ( + conditioning_frames[:, None] + .reshape(batch_size, num_frames, channels, height, width) + .permute(0, 2, 1, 3, 4) + ) + + """ + batch_size, channels, _, height, width = conditioning_frames.shape + + controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width)) + + + controlnet_cond_shape = list(controlnet_images.shape) + controlnet_cond_shape[2] = video_length + controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device) + + controlnet_conditioning_mask_shape = list(controlnet_cond.shape) + controlnet_conditioning_mask_shape[1] = 1 + controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device) + + assert controlnet_images.shape[2] >= len(controlnet_image_index) + controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)] + controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 + """ + + """ + + rearrange(controlnet_images, "b f c h w -> b c f h w") + + + + + if controlnet.use_simplified_condition_embedding: + num_controlnet_images = controlnet_images.shape[2] + controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") + controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215 + controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + """ + + return conditioning_frames + + def prepare_sparse_control_conditioning( + self, conditioning_frames, num_frames, controlnet_image_index, device, dtype + ): + assert conditioning_frames.shape[2] >= len(controlnet_image_index) + + batch_size, channels, _, height, width = conditioning_frames.shape + controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width), dtype=dtype, device=device) + controlnet_cond[:, :, controlnet_image_index] = conditioning_frames[:, :, : len(controlnet_image_index)] + + batch_size, _, num_frames, height, width = controlnet_cond.shape + controlnet_cond_mask = torch.zeros((batch_size, 1, num_frames, height, width), dtype=dtype, device=device) + controlnet_cond_mask[:, :, controlnet_image_index] = 1 + + controlnet_cond = controlnet_cond.to(device, dtype) + controlnet_cond_mask = controlnet_cond_mask.to(device, dtype) + + return controlnet_cond, controlnet_cond_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_frames: Optional[int] = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + conditioning_frames: Optional[List[PipelineImageInput]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + controlnet_image_index: list = [0], + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds + amounts to 2 seconds of video. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. Latents should be of shape + `(batch_size, num_channel, num_frames, height, width)`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): + Optional image input to work with IP Adapters. + conditioning_frames (`List[PipelineImageInput]`, *optional*): + The ControlNet input condition to provide guidance to the `unet` for generation. If multiple ControlNets + are specified, images must be passed as a list such that each element of the list can be correctly + batched for input to a single ControlNet. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or + `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + allback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_steps=callback_steps, + negative_prompt=negative_prompt, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=conditioning_frames, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, SparseControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None: + image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + conditioning_frames = self.prepare_image( + image=conditioning_frames, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt * num_frames, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning( + conditioning_frames, batch_size, controlnet_image_index, dtype=controlnet.dtype, device=device + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if isinstance(controlnet, SparseControlNetModel) else keeps) + + # Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + """ + controlnet_images = conditioning_frames.to(latents.device) + controlnet_cond_shape = list(controlnet_images.shape) + controlnet_cond_shape[2] = num_frames + controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device) + controlnet_conditioning_mask_shape = list(controlnet_cond.shape) + controlnet_conditioning_mask_shape[1] = 1 + controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device) + assert controlnet_images.shape[2] >= len(controlnet_image_index) + controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)] + controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 + + if guess_mode and self.do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0) + + """ + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=controlnet_cond, + conditioning_mask=controlnet_cond_mask, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + return AnimateDiffControlNetPipelineOutput(frames=latents) + + # Post-processing + video_tensor = self.decode_latents(latents) + + if output_type == "pt": + video = video_tensor + else: + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AnimateDiffControlNetPipelineOutput(frames=video) From 7565feebd084c17fb862032fd009c6fb85f40eae Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 21 Dec 2023 10:58:40 +0000 Subject: [PATCH 4/8] update --- src/diffusers/models/controlnet_sparsectrl.py | 18 ++++++++++++------ .../pipeline_animatediff_sparsecntrl.py | 6 ++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index cd054fd16dfe..ef5494dd6048 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -689,7 +689,7 @@ def forward( If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ - batch_size, channels, num_generated_frames, height, width = sample.shape + sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape # check channel order channel_order = self.config.controlnet_conditioning_channel_order @@ -770,11 +770,14 @@ def forward( # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape - emb = emb.repeat_interleave(repeats=num_generated_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_generated_frames, dim=0) + emb = emb.repeat_interleave(repeats=sample_num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=sample_num_frames, dim=0) sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) sample = self.conv_in(sample) + batch_frames, channels, height, width = sample.shape + sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width) + if self.concate_conditioning_mask: controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1) @@ -782,11 +785,14 @@ def forward( controlnet_cond = controlnet_cond.permute(0, 2, 1, 3, 4).reshape( batch_size * num_frames, channels, height, width ) - controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + batch_frames, channels, height, width = controlnet_cond.shape + controlnet_cond = controlnet_cond[:, None].reshape(batch_size, num_frames, channels, height, width) + sample = sample + controlnet_cond - batch_size, channels, num_frames, height, width + batch_size, num_frames, channels, height, width = sample.shape + sample = sample.reshape(sample_batch_size * sample_num_frames, channels, height, width) # 3. down down_block_res_samples = (sample,) @@ -798,7 +804,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, - num_frames=num_generated_frames, + num_frames=sample_num_frames, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py index 63e48917f4eb..0022083aa9b6 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py @@ -112,7 +112,9 @@ class AnimateDiffControlNetPipelineOutput(BaseOutput): frames: Union[torch.Tensor, np.ndarray] -class AnimateDiffSparseControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin): +class AnimateDiffSparseControlNetPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin +): r""" Pipeline for text-to-video generation. @@ -1052,7 +1054,7 @@ def __call__( guess_mode=guess_mode, ) controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning( - conditioning_frames, batch_size, controlnet_image_index, dtype=controlnet.dtype, device=device + conditioning_frames, num_frames, controlnet_image_index, dtype=controlnet.dtype, device=device ) # 4. Prepare timesteps From 918a70bc42625fefd05f4f911686deb01d4452ef Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 21 Dec 2023 12:59:23 +0000 Subject: [PATCH 5/8] update --- src/diffusers/models/controlnet_sparsectrl.py | 48 ++----------------- 1 file changed, 5 insertions(+), 43 deletions(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index ef5494dd6048..8447fa9a6935 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -641,10 +641,8 @@ def forward( encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_scale: float = 1.0, - class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, conditioning_mask: Optional[torch.FloatTensor] = None, guess_mode: bool = False, @@ -690,6 +688,7 @@ def forward( returned where the first element is the sample tensor. """ sample_batch_size, sample_channels, sample_num_frames, sample_height, sample_width = sample.shape + sample = torch.zeros_like(sample).to(sample.device) # check channel order channel_order = self.config.controlnet_conditioning_channel_order @@ -723,58 +722,21 @@ def forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) - t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb - - if self.config.addition_embed_type is not None: - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - - elif self.config.addition_embed_type == "text_time": - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - - emb = emb + aug_emb if aug_emb is not None else emb # 2. pre-process batch_size, channels, num_frames, height, width = sample.shape - emb = emb.repeat_interleave(repeats=sample_num_frames, dim=0) - encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=sample_num_frames, dim=0) - sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0) + emb = emb.repeat_interleave(sample_num_frames, dim=0) + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width) sample = self.conv_in(sample) + batch_frames, channels, height, width = sample.shape sample = sample[:, None].reshape(sample_batch_size, sample_num_frames, channels, height, width) From e7894592eaf222cec64d50271d53c8386f7b7f7d Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 21 Dec 2023 14:58:08 +0000 Subject: [PATCH 6/8] update --- src/diffusers/models/controlnet_sparsectrl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index 8447fa9a6935..b29e46f3a56c 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -769,7 +769,7 @@ def forward( num_frames=sample_num_frames, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=sample_num_frames) down_block_res_samples += res_samples From 06e3dc0fdb7ee7d833b833f893765afd127b62bb Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 22 Dec 2023 09:00:46 +0000 Subject: [PATCH 7/8] update --- src/diffusers/models/controlnet_sparsectrl.py | 8 -- .../pipeline_animatediff_sparsecntrl.py | 89 ++----------------- 2 files changed, 5 insertions(+), 92 deletions(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index b29e46f3a56c..d71eaac95ef2 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -323,14 +323,6 @@ def __init__( elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - # control net conditioning embedding - # TODO: Use simple embeddings - - # self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding( - # conditioning_embedding_channels=block_out_channels[0], - # block_out_channels=conditioning_embedding_out_channels, - # conditioning_channels=conditioning_channels, - # ) if concate_conditioning_mask: conditioning_channels = conditioning_channels + 1 diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py index 0022083aa9b6..32903c999f5b 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py @@ -48,14 +48,14 @@ Examples: ```py >>> import torch - >>> from diffusers import AutoencoderKL, ControlNetModel, MotionAdapter + >>> from diffusers import AutoencoderKL, SparseControlNetModel, MotionAdapter >>> from diffusers.pipelines import DiffusionPipeline >>> from diffusers.schedulers import DPMSolverMultistepScheduler >>> from PIL import Image >>> motion_id = "guoyww/animatediff-motion-adapter-v1-5-2" >>> adapter = MotionAdapter.from_pretrained(motion_id) - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16) + >>> controlnet = SparseControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16) >>> vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16) >>> model_id = "SG161222/Realistic_Vision_V5.1_noVAE" @@ -183,7 +183,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=True ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt @@ -705,18 +705,13 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, image, width, height, - batch_size, - num_images_per_prompt, device, dtype, - do_classifier_free_guidance=False, - guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) controlnet_images = image.unsqueeze(0) @@ -726,7 +721,7 @@ def prepare_image( controlnet_images = controlnet_images.to(device, dtype) with torch.no_grad(): - conditioning_frames = self.vae.encode(controlnet_images).latent_dist.sample() + conditioning_frames = self.vae.encode(controlnet_images).latent_dist.sample() * 0.18215 batch_frames, channels, height, width = conditioning_frames.shape conditioning_frames = ( @@ -735,51 +730,6 @@ def prepare_image( .permute(0, 2, 1, 3, 4) ) - """ - batch_size, channels, _, height, width = conditioning_frames.shape - - controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width)) - - - controlnet_cond_shape = list(controlnet_images.shape) - controlnet_cond_shape[2] = video_length - controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device) - - controlnet_conditioning_mask_shape = list(controlnet_cond.shape) - controlnet_conditioning_mask_shape[1] = 1 - controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device) - - assert controlnet_images.shape[2] >= len(controlnet_image_index) - controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)] - controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 - """ - - """ - - rearrange(controlnet_images, "b f c h w -> b c f h w") - - - - - if controlnet.use_simplified_condition_embedding: - num_controlnet_images = controlnet_images.shape[2] - controlnet_images = rearrange(controlnet_images, "b c f h w -> (b f) c h w") - controlnet_images = vae.encode(controlnet_images * 2. - 1.).latent_dist.sample() * 0.18215 - controlnet_images = rearrange(controlnet_images, "(b f) c h w -> b c f h w", f=num_controlnet_images) - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - """ - return conditioning_frames def prepare_sparse_control_conditioning( @@ -789,10 +739,8 @@ def prepare_sparse_control_conditioning( batch_size, channels, _, height, width = conditioning_frames.shape controlnet_cond = torch.zeros((batch_size, channels, num_frames, height, width), dtype=dtype, device=device) - controlnet_cond[:, :, controlnet_image_index] = conditioning_frames[:, :, : len(controlnet_image_index)] - - batch_size, _, num_frames, height, width = controlnet_cond.shape controlnet_cond_mask = torch.zeros((batch_size, 1, num_frames, height, width), dtype=dtype, device=device) + controlnet_cond[:, :, controlnet_image_index] = conditioning_frames[:, :, : len(controlnet_image_index)] controlnet_cond_mask[:, :, controlnet_image_index] = 1 controlnet_cond = controlnet_cond.to(device, dtype) @@ -1046,12 +994,8 @@ def __call__( image=conditioning_frames, width=width, height=height, - batch_size=batch_size * num_videos_per_prompt * num_frames, - num_images_per_prompt=num_videos_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, ) controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning( conditioning_frames, num_frames, controlnet_image_index, dtype=controlnet.dtype, device=device @@ -1102,29 +1046,6 @@ def __call__( control_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - """ - controlnet_images = conditioning_frames.to(latents.device) - controlnet_cond_shape = list(controlnet_images.shape) - controlnet_cond_shape[2] = num_frames - controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device) - controlnet_conditioning_mask_shape = list(controlnet_cond.shape) - controlnet_conditioning_mask_shape[1] = 1 - controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device) - assert controlnet_images.shape[2] >= len(controlnet_image_index) - controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)] - controlnet_conditioning_mask[:,:,controlnet_image_index] = 1 - - if guess_mode and self.do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - else: - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0) - - """ if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: From b46e167be7054a2b28d516cc3c6e3510364ceb16 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 22 Dec 2023 09:55:52 +0000 Subject: [PATCH 8/8] update --- src/diffusers/models/controlnet_sparsectrl.py | 116 -------------- .../pipeline_animatediff_sparsecntrl.py | 150 ++---------------- 2 files changed, 17 insertions(+), 249 deletions(-) diff --git a/src/diffusers/models/controlnet_sparsectrl.py b/src/diffusers/models/controlnet_sparsectrl.py index d71eaac95ef2..787809fb47e2 100644 --- a/src/diffusers/models/controlnet_sparsectrl.py +++ b/src/diffusers/models/controlnet_sparsectrl.py @@ -277,52 +277,6 @@ def __init__( else: self.encoder_hid_proj = None - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim - - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - if concate_conditioning_mask: conditioning_channels = conditioning_channels + 1 @@ -470,9 +424,6 @@ def from_unet( controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) - if controlnet.class_embedding: - controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) - controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) @@ -556,72 +507,6 @@ def set_default_attn_processor(self): self.set_attn_processor(processor, _remove_lora=True) - # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice - def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: - r""" - Enable sliced attention computation. - - When this option is enabled, the attention module splits the input tensor in slices to compute attention in - several steps. This is useful for saving some memory in exchange for a small decrease in speed. - - Args: - slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If - `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is - provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` - must be a multiple of `slice_size`. - """ - sliceable_head_dims = [] - - def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): - if hasattr(module, "set_attention_slice"): - sliceable_head_dims.append(module.sliceable_head_dim) - - for child in module.children(): - fn_recursive_retrieve_sliceable_dims(child) - - # retrieve number of attention layers - for module in self.children(): - fn_recursive_retrieve_sliceable_dims(module) - - num_sliceable_layers = len(sliceable_head_dims) - - if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = [dim // 2 for dim in sliceable_head_dims] - elif slice_size == "max": - # make smallest slice possible - slice_size = num_sliceable_layers * [1] - - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size - - if len(slice_size) != len(sliceable_head_dims): - raise ValueError( - f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" - f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." - ) - - for i in range(len(slice_size)): - size = slice_size[i] - dim = sliceable_head_dims[i] - if size is not None and size > dim: - raise ValueError(f"size {size} has to be smaller or equal to {dim}.") - - # Recursively walk through all the children. - # Any children which exposes the set_attention_slice method - # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): - if hasattr(module, "set_attention_slice"): - module.set_attention_slice(slice_size.pop()) - - for child in module.children(): - fn_recursive_set_attention_slice(child, slice_size) - - reversed_slice_size = list(reversed(slice_size)) - for module in self.children(): - fn_recursive_set_attention_slice(module, reversed_slice_size) - def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion)): module.gradient_checkpointing = value @@ -779,7 +664,6 @@ def forward( sample = self.mid_block(sample, emb) # 5. Control net blocks - controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py index 32903c999f5b..1739e51236ab 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsecntrl.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -22,15 +21,14 @@ from PIL import Image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel, UNetMotionModel -from diffusers.models.controlnet_sparsectrl import SparseControlNetModel -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.models.unet_motion_model import MotionAdapter -from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.schedulers import ( +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel +from ...models.controlnet_sparsectrl import SparseControlNetModel +from ...models.lora import adjust_lora_scale_text_encoder +from ...models.unet_motion_model import MotionAdapter +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, @@ -38,8 +36,9 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers -from diffusers.utils.torch_utils import is_compiled_module, randn_tensor +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import is_compiled_module, randn_tensor +from .pipeline_animatediff import AnimateDiffPipelineOutput logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -107,11 +106,6 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): return outputs -@dataclass -class AnimateDiffControlNetPipelineOutput(BaseOutput): - frames: Union[torch.Tensor, np.ndarray] - - class AnimateDiffSparseControlNetPipeline( DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin ): @@ -154,7 +148,7 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, motion_adapter: MotionAdapter, - controlnet: Union[ControlNetModel, MultiControlNetModel], + controlnet: SparseControlNetModel, scheduler: Union[ DDIMScheduler, PNDMScheduler, @@ -498,8 +492,6 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, image=None, controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -542,15 +534,6 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # `prompt` needs more sophisticated handling when there are multiple - # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) - # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule @@ -565,26 +548,6 @@ def check_inputs( self.check_image(image_, prompt, prompt_embeds) else: self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif len(image) != len(self.controlnet.nets): - raise ValueError( - f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." - ) - - for control_ in image: - for image_ in control_: - self.check_image(image_, prompt, prompt_embeds) else: assert False @@ -596,51 +559,9 @@ def check_inputs( ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): - if isinstance(controlnet_conditioning_scale, list): - if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): - raise ValueError( - "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" - " the same length as the number of controlnets" - ) else: assert False - if not isinstance(control_guidance_start, (tuple, list)): - control_guidance_start = [control_guidance_start] - - if not isinstance(control_guidance_end, (tuple, list)): - control_guidance_end = [control_guidance_end] - - if len(control_guidance_start) != len(control_guidance_end): - raise ValueError( - f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." - ) - - if isinstance(self.controlnet, MultiControlNetModel): - if len(control_guidance_start) != len(self.controlnet.nets): - raise ValueError( - f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." - ) - - for start, end in zip(control_guidance_start, control_guidance_end): - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, Image.Image) @@ -795,8 +716,6 @@ def __call__( controlnet_conditioning_scale: Union[float, List[float]] = 1.0, controlnet_image_index: list = [0], guess_mode: bool = False, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], @@ -870,7 +789,7 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. - allback_on_step_end (`Callable`, *optional*): + callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by @@ -906,18 +825,6 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = ( - mult * [control_guidance_start], - mult * [control_guidance_end], - ) - # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor @@ -936,8 +843,6 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, image=conditioning_frames, controlnet_conditioning_scale=controlnet_conditioning_scale, - control_guidance_start=control_guidance_start, - control_guidance_end=control_guidance_end, ) self._guidance_scale = guidance_scale @@ -954,9 +859,6 @@ def __call__( device = self._execution_device - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, SparseControlNetModel) @@ -1026,15 +928,6 @@ def __call__( # 7. Add image embeds for IP-Adapter added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, SparseControlNetModel) else keeps) - # Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1044,23 +937,14 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - controlnet_cond_scale = controlnet_conditioning_scale - if isinstance(controlnet_cond_scale, list): - controlnet_cond_scale = controlnet_cond_scale[0] - cond_scale = controlnet_cond_scale * controlnet_keep[i] down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, t, - encoder_hidden_states=controlnet_prompt_embeds, + encoder_hidden_states=prompt_embeds, controlnet_cond=controlnet_cond, conditioning_mask=controlnet_cond_mask, - conditioning_scale=cond_scale, + conditioning_scale=controlnet_conditioning_scale, guess_mode=guess_mode, return_dict=False, ) @@ -1101,7 +985,7 @@ def __call__( callback(i, t, latents) if output_type == "latent": - return AnimateDiffControlNetPipelineOutput(frames=latents) + return AnimateDiffPipelineOutput(frames=latents) # Post-processing video_tensor = self.decode_latents(latents) @@ -1117,4 +1001,4 @@ def __call__( if not return_dict: return (video,) - return AnimateDiffControlNetPipelineOutput(frames=video) + return AnimateDiffPipelineOutput(frames=video)