From 58a95343d90002a15db505bdba1dc85adb9ba113 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 17:22:04 +0530 Subject: [PATCH 01/21] add legacy behaviour flag for autoencoder temporal --- .../autoencoders/autoencoder_kl_temporal_decoder.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py index 68d5a31e43c7..9b476042b13a 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -205,6 +205,7 @@ def __init__( sample_size: int = 32, scaling_factor: float = 0.18215, force_upcast: float = True, + use_legacy: bool = True, ): super().__init__() @@ -226,7 +227,7 @@ def __init__( layers_per_block=layers_per_block, ) - self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_legacy else None sample_size = ( self.config.sample_size[0] @@ -330,8 +331,11 @@ def encode( [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) + + if self.quant_conv is not None: + h = self.quant_conv(h) + + posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) From 38f6b5357613e6d37102660be1da3d37d8f5c943 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 18:39:31 +0530 Subject: [PATCH 02/21] add camera proj and added_cond_kwargs to handle camera pose --- src/diffusers/models/attention.py | 23 +++++++++++++++++++ .../transformers/transformer_temporal.py | 4 ++++ src/diffusers/models/unets/unet_3d_blocks.py | 19 +++++++++++++++ .../unets/unet_spatio_temporal_condition.py | 10 +++++++- 4 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index afb022c8d612..7d8c5a0bdb10 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -425,6 +425,7 @@ def __init__( num_attention_heads: int, attention_head_dim: int, cross_attention_dim: Optional[int] = None, + motionctrl_kwargs: Dict[str, Any] = None, ): super().__init__() self.is_res = dim == time_mix_inner_dim @@ -468,6 +469,14 @@ def __init__( self.norm3 = nn.LayerNorm(time_mix_inner_dim) self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + self.use_camera_projection = motionctrl_kwargs is not None + if motionctrl_kwargs is not None: + camera_pose_embed_dim = motionctrl_kwargs.get("camera_pose_embed_dim") + camera_pose_dim = motionctrl_kwargs.get("camera_pose_dim") + self.cc_projection = nn.Linear( + time_mix_inner_dim + camera_pose_embed_dim * camera_pose_dim, time_mix_inner_dim + ) + # let chunk size default to None self._chunk_size = None self._chunk_dim = None @@ -483,7 +492,13 @@ def forward( hidden_states: torch.FloatTensor, num_frames: int, encoder_hidden_states: Optional[torch.FloatTensor] = None, + added_cond_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.FloatTensor: + if self.use_camera_projection and (added_cond_kwargs is None or added_cond_kwargs.get("camera_pose") is None): + raise ValueError( + "When using camera pose projection (for MotionCtrl), `added_cond_kwargs` must contain `camera_pose`" + ) + # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention batch_size = hidden_states.shape[0] @@ -510,6 +525,14 @@ def forward( attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) hidden_states = attn_output + hidden_states + # MotionCtrl specific + if self.use_camera_projection: + camera_pose: torch.FloatTensor = added_cond_kwargs.get("camera_pose") + camera_pose.repeat_interleave(seq_length, dim=0) # [batch_size * seq_length, num_frames, 12] + + hidden_states = torch.cat([hidden_states, camera_pose], dim=-1) + hidden_states = self.cc_projection(hidden_states) + # 3. Cross-Attention if self.attn2 is not None: norm_hidden_states = self.norm2(hidden_states) diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py index e5bc1226b4b5..671759a9ad6f 100644 --- a/src/diffusers/models/transformers/transformer_temporal.py +++ b/src/diffusers/models/transformers/transformer_temporal.py @@ -222,6 +222,7 @@ def __init__( out_channels: Optional[int] = None, num_layers: int = 1, cross_attention_dim: Optional[int] = None, + motionctrl_kwargs: Dict[str, Any] = None, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -257,6 +258,7 @@ def __init__( num_attention_heads, attention_head_dim, cross_attention_dim=cross_attention_dim, + motionctrl_kwargs=motionctrl_kwargs, ) for _ in range(num_layers) ] @@ -279,6 +281,7 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): """ @@ -360,6 +363,7 @@ def forward( hidden_states_mix, num_frames=num_frames, encoder_hidden_states=time_context, + added_cond_kwargs=added_cond_kwargs, ) hidden_states = self.time_mixer( x_spatial=hidden_states, diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index a1d9e848c230..54800242c483 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -56,6 +56,7 @@ def get_down_block( temporal_num_attention_heads: int = 8, temporal_max_seq_length: int = 32, transformer_layers_per_block: int = 1, + motionctrl_kwargs: Dict[str, Any] = None, ) -> Union[ "DownBlock3D", "CrossAttnDownBlock3D", @@ -158,6 +159,7 @@ def get_down_block( add_downsample=add_downsample, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, + motionctrl_kwargs=motionctrl_kwargs, ) raise ValueError(f"{down_block_type} does not exist.") @@ -187,6 +189,7 @@ def get_up_block( temporal_max_seq_length: int = 32, transformer_layers_per_block: int = 1, dropout: float = 0.0, + motionctrl_kwargs: Dict[str, Any] = None, ) -> Union[ "UpBlock3D", "CrossAttnUpBlock3D", @@ -297,6 +300,7 @@ def get_up_block( cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, resolution_idx=resolution_idx, + motionctrl_kwargs=motionctrl_kwargs, ) raise ValueError(f"{up_block_type} does not exist.") @@ -1875,6 +1879,7 @@ def __init__( transformer_layers_per_block: Union[int, Tuple[int]] = 1, num_attention_heads: int = 1, cross_attention_dim: int = 1280, + motionctrl_kwargs: Dict[str, Any] = None, ): super().__init__() @@ -1904,6 +1909,7 @@ def __init__( in_channels=in_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, + motionctrl_kwargs=motionctrl_kwargs, ) ) @@ -1927,6 +1933,7 @@ def forward( temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, image_only_indicator: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.FloatTensor: hidden_states = self.resnets[0]( hidden_states, @@ -1951,6 +1958,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( @@ -1965,6 +1973,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] hidden_states = resnet( @@ -2077,6 +2086,7 @@ def __init__( num_attention_heads: int = 1, cross_attention_dim: int = 1280, add_downsample: bool = True, + motionctrl_kwargs: Dict[str, Any] = None, ): super().__init__() resnets = [] @@ -2104,6 +2114,7 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, + motionctrl_kwargs=motionctrl_kwargs, ) ) @@ -2133,6 +2144,7 @@ def forward( temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, image_only_indicator: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: output_states = () @@ -2162,6 +2174,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] else: @@ -2174,6 +2187,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] @@ -2291,6 +2305,7 @@ def __init__( num_attention_heads: int = 1, cross_attention_dim: int = 1280, add_upsample: bool = True, + motionctrl_kwargs: Dict[str, Any] = None, ): super().__init__() resnets = [] @@ -2321,6 +2336,7 @@ def __init__( in_channels=out_channels, num_layers=transformer_layers_per_block[i], cross_attention_dim=cross_attention_dim, + motionctrl_kwargs=motionctrl_kwargs, ) ) @@ -2342,6 +2358,7 @@ def forward( temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, image_only_indicator: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.FloatTensor: for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -2373,6 +2390,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] else: @@ -2385,6 +2403,7 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 39a8009d5af9..f2c71d3231b5 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -92,6 +92,7 @@ def __init__( transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), num_frames: int = 25, + motionctrl_kwargs: Dict[str, Any] = None, ): super().__init__() @@ -178,6 +179,7 @@ def __init__( cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], resnet_act_fn="silu", + motionctrl_kwargs=motionctrl_kwargs, ) self.down_blocks.append(down_block) @@ -188,6 +190,7 @@ def __init__( transformer_layers_per_block=transformer_layers_per_block[-1], cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], + motionctrl_kwargs=motionctrl_kwargs, ) # count how many layers upsample the images @@ -229,6 +232,7 @@ def __init__( cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], resnet_act_fn="silu", + motionctrl_kwargs=motionctrl_kwargs, ) self.up_blocks.append(up_block) prev_output_channel = output_channel @@ -359,6 +363,7 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, added_time_ids: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: r""" @@ -436,6 +441,7 @@ def forward( temb=emb, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, ) else: sample, res_samples = downsample_block( @@ -452,6 +458,7 @@ def forward( temb=emb, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, ) # 5. up @@ -466,6 +473,7 @@ def forward( res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, image_only_indicator=image_only_indicator, + added_cond_kwargs=added_cond_kwargs, ) else: sample = upsample_block( From 84dd94b4b2cdf23064e1085885570179b31f7389 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 18:40:00 +0530 Subject: [PATCH 03/21] update conversion script to handle newer svd checkpoints --- scripts/convert_svd_to_diffusers.py | 233 ++++++++++++++++++++++------ 1 file changed, 187 insertions(+), 46 deletions(-) diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index 3243ce294b26..8f423d639239 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -1,3 +1,14 @@ +import argparse + +import torch +import yaml +from safetensors.torch import load_file +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from yaml.loader import FullLoader + +from diffusers import StableVideoDiffusionPipeline +from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from diffusers.schedulers import EulerDiscreteScheduler from diffusers.utils import is_accelerate_available, logging @@ -7,28 +18,52 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): +def create_vae_diffusers_config(original_config, image_size: int): + r""" + Creates a vae config for diffusers based on the config of the LDM. """ - Creates a config for the diffusers based on the config of the LDM model. + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["decoder_config"]["params"] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + + vae_config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + + return vae_config + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + r""" + Creates a unet config for diffusers based on the config of the LDM. """ if controlnet: - unet_params = original_config.model.params.control_stage_config.params + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] else: - if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: - unet_params = original_config.model.params.unet_config.params + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] else: - unet_params = original_config.model.params.network_config.params + unet_params = original_config["model"]["params"]["network_config"]["params"] - vae_params = original_config.model.params.first_stage_config.params.encoder_config.params + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["decoder_config"]["params"] - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = ( "CrossAttnDownBlockSpatioTemporal" - if resolution in unet_params.attention_resolutions + if resolution in unet_params["attention_resolutions"] else "DownBlockSpatioTemporal" ) down_block_types.append(block_type) @@ -39,32 +74,32 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa for i in range(len(block_out_channels)): block_type = ( "CrossAttnUpBlockSpatioTemporal" - if resolution in unet_params.attention_resolutions + if resolution in unet_params["attention_resolutions"] else "UpBlockSpatioTemporal" ) up_block_types.append(block_type) resolution //= 2 - if unet_params.transformer_depth is not None: + if unet_params["transformer_depth"] is not None: transformer_layers_per_block = ( - unet_params.transformer_depth - if isinstance(unet_params.transformer_depth, int) - else list(unet_params.transformer_depth) + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) ) else: transformer_layers_per_block = 1 - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - head_dim = unet_params.num_heads if "num_heads" in unet_params else None + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim_mult = unet_params.model_channels // unet_params.num_head_channels - head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] class_embed_type = None addition_embed_type = None @@ -72,23 +107,25 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa projection_class_embeddings_input_dim = None context_dim = None - if unet_params.context_dim is not None: + if unet_params["context_dim"] is not None: context_dim = ( - unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] ) if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": + if unet_params["num_classes"] == "sequential": addition_time_embed_dim = 256 assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, + "in_channels": unet_params["in_channels"], "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, + "layers_per_block": unet_params["num_res_blocks"], "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, @@ -100,15 +137,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa } if "disable_self_attentions" in unet_params: - config["only_cross_attention"] = unet_params.disable_self_attentions + config["only_cross_attention"] = unet_params["disable_self_attentions"] - if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): - config["num_class_embeds"] = unet_params.num_classes + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] if controlnet: - config["conditioning_channels"] = unet_params.hint_channels + config["conditioning_channels"] = unet_params["hint_channels"] else: - config["out_channels"] = unet_params.out_channels + config["out_channels"] = unet_params["out_channels"] config["up_block_types"] = tuple(up_block_types) return config @@ -169,9 +206,6 @@ def assign_to_checkpoint( for replacement in additional_replacements: new_path = new_path.replace(replacement["old"], replacement["new"]) - if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight": - print("yeyy") - # proj_attn.weight has to be converted from conv 1D to linear is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) shape = old_checkpoint[path["old"]].shape @@ -289,16 +323,16 @@ def convert_ldm_unet_checkpoint( new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - if config["class_embed_type"] is None: - # No parameters to port - ... - elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - else: - raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + # if config["class_embed_type"] is None: + # # No parameters to port + # ... + # elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + # new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + # new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + # new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + # new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + # else: + # raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") # if config["addition_embed_type"] == "text_time": new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] @@ -499,7 +533,6 @@ def convert_ldm_unet_checkpoint( attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key] if len(attentions): paths = renew_attention_paths(attentions) - # import ipdb; ipdb.set_trace() meta_path = { "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", @@ -644,8 +677,8 @@ def convert_ldm_vae_checkpoint(checkpoint, config): new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"] - new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"] + new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.conv_out.time_mix_conv.weight"] + new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.conv_out.time_mix_conv.bias"] # new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] # new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] @@ -728,3 +761,111 @@ def convert_ldm_vae_checkpoint(checkpoint, config): assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) return new_checkpoint + + +def read_config_file(filename): + # The yaml file contains annotations that certain values should + # loaded as tuples. + with open(filename) as f: + original_config = yaml.load(f, FullLoader) + + return original_config + + +def load_original_state_dict(filename: str): + if filename.endswith("safetensors"): + state_dict = load_file(filename) + elif filename.endswith("ckpt"): + state_dict = torch.load(filename, mmap=True, map_location="cpu") + else: + raise ValueError("File type is not supported") + + if isinstance(state_dict, dict) and "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + + return state_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint to convert.", required=True) + parser.add_argument( + "--config_file", type=str, help="The config json file corresponding to the architecture.", required=True + ) + parser.add_argument("--output_path", default=None, type=str, help="Path to the output model.", required=True) + parser.add_argument("--sample_size", type=int, default=768, help="VAE sample size") + parser.add_argument( + "--use_legacy_autoencoder", + action="store_true", + default=False, + help="Whether or not to use the `quant_conv` layers from the original implementation (which is now legacy behaviour)", + ) + args = parser.parse_args() + + original_config = read_config_file(args.config_file) + state_dict = load_original_state_dict(args.checkpoint_path) + + vae_config = create_vae_diffusers_config(original_config, args.sample_size) + vae = AutoencoderKLTemporalDecoder(**vae_config, use_legacy=args.use_legacy_autoencoder) + vae_state_dict = convert_ldm_vae_checkpoint(state_dict, vae_config) + + remove = [] + for key in vae_state_dict.keys(): + # i'm sorry to hurt your eyes + if ("encoder" in key) or ( + "decoder" in key and "resnets" and (("temporal_res_block" in key) or ("time_mixer" in key)) + ): + remove.append(key) + + for key in remove: + vae_state_dict[key.replace("spatial_res_block.", "")] = vae_state_dict.pop(key) + + missing_keys, unexpected_keys = vae.load_state_dict(vae_state_dict) + if missing_keys: + logger.error("VAE conversion failed") + raise ValueError( + f"VAE conversion failed due to missing keys: {missing_keys}, and unexpected keys: {unexpected_keys}" + ) + if unexpected_keys: + vae.load_state_dict(vae_state_dict, strict=False) + logger.info(f"VAE conversion occured successfully but some unexpected keys were found: {unexpected_keys}") + logger.info("VAE conversion succeeded") + + unet_config = create_unet_diffusers_config(original_config, args.sample_size) + unet_state_dict = convert_ldm_unet_checkpoint(state_dict, unet_config) + unet = UNetSpatioTemporalConditionModel.from_config(unet_config) + missing_keys, unexpected_keys = unet.load_state_dict(unet_state_dict) + + if missing_keys: + logger.error("UNet conversion failed") + raise ValueError( + f"UNet conversion failed due to missing keys: {missing_keys}, and unexpected keys: {unexpected_keys}" + ) + if unexpected_keys: + unet.load_state_dict(unet_state_dict, strict=False) + logger.info(f"UNet conversion occured successfully but some unexpected keys were found: {unexpected_keys}") + logger.info("UNet conversion succeeded") + + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + feature_extractor = CLIPImageProcessor() + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + use_karras_sigmas=True, + ) + + pipe = StableVideoDiffusionPipeline( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + + pipe.save_pretrained(args.output_path) + + pipe.to(dtype=torch.float16) + pipe.save_pretrained(args.output_path, variant="fp16") From e17e45888ac18a40f47766519e05b91fc99dc238 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 18:44:40 +0530 Subject: [PATCH 04/21] reorganize imports --- scripts/convert_svd_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index 8f423d639239..75d98ab543a3 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -1,10 +1,10 @@ import argparse +import yaml +from yaml.loader import FullLoader import torch -import yaml from safetensors.torch import load_file from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from yaml.loader import FullLoader from diffusers import StableVideoDiffusionPipeline from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel From b528f5571275eebdec5887af306bd257aa86b587 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 19:03:46 +0530 Subject: [PATCH 05/21] temporarily add motionctrl conversion to svd script --- scripts/convert_svd_to_diffusers.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index 75d98ab543a3..e86301d26e2f 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -1,10 +1,10 @@ import argparse -import yaml -from yaml.loader import FullLoader import torch +import yaml from safetensors.torch import load_file from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from yaml.loader import FullLoader from diffusers import StableVideoDiffusionPipeline from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel @@ -801,6 +801,12 @@ def load_original_state_dict(filename: str): default=False, help="Whether or not to use the `quant_conv` layers from the original implementation (which is now legacy behaviour)", ) + parser.add_argument( + "--convert_motionctrl", + type="store_true", + default=False, + help="Whether or not converting motionctrl svd checkpoint.", + ) args = parser.parse_args() original_config = read_config_file(args.config_file) @@ -833,6 +839,12 @@ def load_original_state_dict(filename: str): logger.info("VAE conversion succeeded") unet_config = create_unet_diffusers_config(original_config, args.sample_size) + # This is temporally added to handle motionctrl + if parser.convert_motionctrl: + unet_config["motionctrl_kwargs"] = { + "camera_pose_embed_dim": 1, + "camera_pose_dim": 12, + } unet_state_dict = convert_ldm_unet_checkpoint(state_dict, unet_config) unet = UNetSpatioTemporalConditionModel.from_config(unet_config) missing_keys, unexpected_keys = unet.load_state_dict(unet_state_dict) From 5e470cd743e3d264f9e7368b40479e65cfd067d1 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 19:36:37 +0530 Subject: [PATCH 06/21] push_to_hub --- scripts/convert_svd_to_diffusers.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index e86301d26e2f..8fda6c3cd659 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -799,13 +799,10 @@ def load_original_state_dict(filename: str): "--use_legacy_autoencoder", action="store_true", default=False, - help="Whether or not to use the `quant_conv` layers from the original implementation (which is now legacy behaviour)", + help="Whether or not to use the `quant_conv` layers from the original implementation (which is now legacy behaviour).", ) parser.add_argument( - "--convert_motionctrl", - type="store_true", - default=False, - help="Whether or not converting motionctrl svd checkpoint.", + "--push_to_hub", action="store_true", default=False, help="Whether to push to huggingface hub or not." ) args = parser.parse_args() @@ -839,12 +836,10 @@ def load_original_state_dict(filename: str): logger.info("VAE conversion succeeded") unet_config = create_unet_diffusers_config(original_config, args.sample_size) - # This is temporally added to handle motionctrl - if parser.convert_motionctrl: - unet_config["motionctrl_kwargs"] = { - "camera_pose_embed_dim": 1, - "camera_pose_dim": 12, - } + unet_config["motionctrl_kwargs"] = { + "camera_pose_embed_dim": 1, + "camera_pose_dim": 12, + } unet_state_dict = convert_ldm_unet_checkpoint(state_dict, unet_config) unet = UNetSpatioTemporalConditionModel.from_config(unet_config) missing_keys, unexpected_keys = unet.load_state_dict(unet_state_dict) @@ -881,3 +876,7 @@ def load_original_state_dict(filename: str): pipe.to(dtype=torch.float16) pipe.save_pretrained(args.output_path, variant="fp16") + + if args.push_to_hub: + pipe.push_to_hub(args.output_path) + pipe.push_to_hub(args.output_path, variant="fp16") From c6c499afa25a6d2e08bc1fe113cdbc61abe487f3 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 19:42:09 +0530 Subject: [PATCH 07/21] fix push_to_hub fp32 --- scripts/convert_svd_to_diffusers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index 8fda6c3cd659..859a91466c78 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -873,10 +873,12 @@ def load_original_state_dict(filename: str): ) pipe.save_pretrained(args.output_path) + if args.push_to_hub: + logger.info("Pushing float32 version to HF hub") + pipe.push_to_hub(args.output_path) pipe.to(dtype=torch.float16) pipe.save_pretrained(args.output_path, variant="fp16") - if args.push_to_hub: - pipe.push_to_hub(args.output_path) + logger.info("Pushing float16 version to HF hub") pipe.push_to_hub(args.output_path, variant="fp16") From 61454c0e824ad59fec200a9add954f912537f89d Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 20:14:06 +0530 Subject: [PATCH 08/21] begin pipeline --- ...eline_stable_video_motionctrl_diffusion.py | 671 ++++++++++++++++++ 1 file changed, 671 insertions(+) create mode 100644 src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py new file mode 100644 index 000000000000..966c4e6c8373 --- /dev/null +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py @@ -0,0 +1,671 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from ...schedulers import EulerDiscreteScheduler +from ...utils import BaseOutput, logging +from ...utils.torch_utils import is_compiled_module, randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid +def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + + return outputs + + +@dataclass +class StableVideoDiffusionPipelineOutput(BaseOutput): + r""" + Output class for zero-shot text-to-video pipeline. + + Args: + frames (`[List[PIL.Image.Image]`, `np.ndarray`]): + List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, + num_channels)`. + """ + + frames: Union[List[PIL.Image.Image], np.ndarray] + + +class StableVideoDiffusionMotionCtrlPipeline(DiffusionPipeline): + r""" + Pipeline to generate video from an input image using [MotionCtrl](https://github.com/TencentARC/MotionCtrl). + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderKLTemporalDecoder`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)). + unet ([`UNetSpatioTemporalConditionModel`]): + A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents. + scheduler ([`EulerDiscreteScheduler`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images. + """ + + model_cpu_offload_seq = "image_encoder->unet->vae" + _callback_tensor_inputs = ["latents"] + + def __init__( + self, + vae: AutoencoderKLTemporalDecoder, + image_encoder: CLIPVisionModelWithProjection, + unet: UNetSpatioTemporalConditionModel, + scheduler: EulerDiscreteScheduler, + feature_extractor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = _resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.feature_extractor( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeddings = self.image_encoder(image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image: torch.Tensor, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + image = image.to(device=device) + image_latents = self.vae.encode(image).latent_dist.mode() + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + def _get_add_time_ids( + self, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + + passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + def decode_latents(self, latents, num_frames, decode_chunk_size=14): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.vae.config.scaling_factor * latents + + forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward + accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys()) + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + decode_kwargs = {} + if accepts_num_frames: + # we only pass num_frames_in if it's expected + decode_kwargs["num_frames"] = num_frames_in + + frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def check_inputs(self, image, height, width, num_frames, camera_poses): + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is" + f" {type(image)}" + ) + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + camera_pose_lengths = [len(pose) for pose in camera_poses] + if len(camera_poses) != num_frames: + raise ValueError(f"length of `camera_poses` must be equal to {num_frames=} but got {len(camera_poses)=}") + if not all(x == 12 for x in camera_pose_lengths): + raise ValueError(f"All camera poses must have 12 values but got {camera_pose_lengths}") + + def prepare_latents( + self, + batch_size, + num_frames, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + if isinstance(self.guidance_scale, (int, float)): + return self.guidance_scale + return self.guidance_scale.max() > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + def __call__( + self, + image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], + camera_poses: List[List[float]], + height: int = 576, + width: int = 1024, + num_frames: Optional[int] = None, + num_inference_steps: int = 25, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: float = 0.02, + decode_chunk_size: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + return_dict: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): + Image or images to guide image generation. If you provide a tensor, it needs to be compatible with + [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_frames (`int`, *optional*): + The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt` + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. + noise_aug_strength (`float`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once + for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + + Returns: + [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list of list with the generated frames. + + Examples: + + ```py + from diffusers import StableVideoDiffusionPipeline + from diffusers.utils import load_image, export_to_video + + pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16") + pipe.to("cuda") + + image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200") + image = image.resize((1024, 576)) + + frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0] + export_to_video(frames, "generated.mp4", fps=7) + ``` + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + num_frames = num_frames if num_frames is not None else self.unet.config.num_frames + decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames + + # 1. Check inputs. Raise error if not correct + self.check_inputs(image, height, width, num_frames, camera_poses) + + # 2. Define call parameters + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list): + batch_size = len(image) + else: + batch_size = image.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + self._guidance_scale = max_guidance_scale + + # 3. Encode input image + image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance) + + # NOTE: Stable Diffusion Video was conditioned on fps - 1, which + # is why it is reduced here. + # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188 + fps = fps - 1 + + # 4. Encode input image using VAE + image = self.image_processor.preprocess(image, height=height, width=width).to(device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) + image = image + noise_aug_strength * noise + + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + self.vae.to(dtype=torch.float32) + + image_latents = self._encode_vae_image( + image, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + image_latents = image_latents.to(image_embeddings.dtype) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1) + + # 5. Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(device) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_frames, + num_channels_latents, + height, + width, + image_embeddings.dtype, + device, + generator, + latents, + ) + + # 7. Prepare guidance scale + guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0) + guidance_scale = guidance_scale.to(device, latents.dtype) + guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + self._guidance_scale = guidance_scale + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Concatenate image_latents over channels dimention + latent_model_input = torch.cat([latent_model_input, image_latents], dim=2) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + added_time_ids=added_time_ids, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + frames = self.decode_latents(latents, num_frames, decode_chunk_size) + frames = tensor2vid(frames, self.image_processor, output_type=output_type) + else: + frames = latents + + self.maybe_free_model_hooks() + + if not return_dict: + return frames + + return StableVideoDiffusionPipelineOutput(frames=frames) + + +# resizing utils +# TODO: clean up later +def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners) + return output + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out From 9670e57c5cfe9d446b355cc9924b0f90cd8417e5 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 20:15:52 +0530 Subject: [PATCH 09/21] added cond kwargs for camera pose --- .../pipeline_stable_video_motionctrl_diffusion.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py index 966c4e6c8373..556d8c89c109 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py @@ -241,7 +241,7 @@ def decode_latents(self, latents, num_frames, decode_chunk_size=14): frames = frames.float() return frames - def check_inputs(self, image, height, width, num_frames, camera_poses): + def check_inputs(self, image, height, width, num_frames, camera_pose): if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) @@ -255,9 +255,9 @@ def check_inputs(self, image, height, width, num_frames, camera_poses): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - camera_pose_lengths = [len(pose) for pose in camera_poses] - if len(camera_poses) != num_frames: - raise ValueError(f"length of `camera_poses` must be equal to {num_frames=} but got {len(camera_poses)=}") + camera_pose_lengths = [len(pose) for pose in camera_pose] + if len(camera_pose) != num_frames: + raise ValueError(f"length of `camera_poses` must be equal to {num_frames=} but got {len(camera_pose)=}") if not all(x == 12 for x in camera_pose_lengths): raise ValueError(f"All camera poses must have 12 values but got {camera_pose_lengths}") @@ -316,7 +316,7 @@ def num_timesteps(self): def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], - camera_poses: List[List[float]], + camera_pose: List[List[float]], height: int = 576, width: int = 1024, num_frames: Optional[int] = None, @@ -419,7 +419,7 @@ def __call__( decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames # 1. Check inputs. Raise error if not correct - self.check_inputs(image, height, width, num_frames, camera_poses) + self.check_inputs(image, height, width, num_frames, camera_pose) # 2. Define call parameters if isinstance(image, PIL.Image.Image): @@ -505,6 +505,8 @@ def __call__( self._guidance_scale = guidance_scale + added_cond_kwargs = {"camera_pose": camera_pose} + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -523,6 +525,7 @@ def __call__( t, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] From eee6471f90e1bb624fba08055cc09a92f62f4622 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 20:30:58 +0530 Subject: [PATCH 10/21] handle relative camera pose --- ...eline_stable_video_motionctrl_diffusion.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py index 556d8c89c109..af9794d2618f 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py @@ -295,6 +295,28 @@ def prepare_latents( latents = latents * self.scheduler.init_noise_sigma return latents + def _to_relative_camera_pose( + self, camera_pose: np.ndarray, keyframe_index: int = 0, keyframe_zero: bool = False + ) -> np.ndarray: + camera_pose = camera_pose.reshape(-1, 3, 4) + rotation_dst = camera_pose[:, :, :3] + translation_dst = camera_pose[:, :, 3:] + + rotation_src = rotation_dst[keyframe_index : keyframe_index + 1].repeat(camera_pose.shape[0], axis=0) + translation_src = translation_dst[keyframe_index : keyframe_index + 1].repeat(camera_pose.shape[0], axis=0) + + rotation_src_inv = rotation_src.transpose(0, 2, 1) + rotation_rel = rotation_dst @ rotation_src_inv + translation_rel = translation_dst - rotation_rel @ translation_src + + rt_rel = np.concatenate([rotation_rel, translation_rel], axis=-1) + rt_rel = rt_rel.reshape(-1, 12) + + if keyframe_zero: + rt_rel[keyframe_index] = np.zeros_like(rt_rel[keyframe_index]) + + return rt_rel + @property def guidance_scale(self): return self._guidance_scale @@ -505,6 +527,9 @@ def __call__( self._guidance_scale = guidance_scale + camera_pose = np.array(camera_pose) + camera_pose = self._to_relative_camera_pose(camera_pose) + camera_pose = torch.FloatTensor(camera_pose).to(device=device, dtype=image_embeddings.dtype) added_cond_kwargs = {"camera_pose": camera_pose} # 8. Denoising loop From f3bc672051863ec25368f7435e2945992522ba5b Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 20:54:55 +0530 Subject: [PATCH 11/21] unsqueeze and repeat --- .../pipeline_stable_video_motionctrl_diffusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py index af9794d2618f..3800a6fc1ac0 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py @@ -530,6 +530,7 @@ def __call__( camera_pose = np.array(camera_pose) camera_pose = self._to_relative_camera_pose(camera_pose) camera_pose = torch.FloatTensor(camera_pose).to(device=device, dtype=image_embeddings.dtype) + camera_pose = camera_pose.unsqueeze(0).repeat(2, 1, 1) added_cond_kwargs = {"camera_pose": camera_pose} # 8. Denoising loop From bbafde8a5c591721eefb91dfcf56fb7d5092c8bf Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 20:55:30 +0530 Subject: [PATCH 12/21] rename class --- .../pipeline_stable_video_motionctrl_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py index 3800a6fc1ac0..c93520598bf1 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py @@ -76,7 +76,7 @@ class StableVideoDiffusionPipelineOutput(BaseOutput): frames: Union[List[PIL.Image.Image], np.ndarray] -class StableVideoDiffusionMotionCtrlPipeline(DiffusionPipeline): +class StableVideoMotionCtrlDiffusionPipeline(DiffusionPipeline): r""" Pipeline to generate video from an input image using [MotionCtrl](https://github.com/TencentARC/MotionCtrl). From 2c8bf06e27ab8f9022d9855581c470a1dd2662e5 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Sun, 4 Feb 2024 23:40:57 +0530 Subject: [PATCH 13/21] fix --- scripts/convert_svd_to_diffusers.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index 859a91466c78..053726222f0e 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -801,6 +801,12 @@ def load_original_state_dict(filename: str): default=False, help="Whether or not to use the `quant_conv` layers from the original implementation (which is now legacy behaviour).", ) + parser.add_argument( + "--convert_motionctrl", + type="store_true", + default=False, + help="Whether or not converting motionctrl svd checkpoint.", + ) parser.add_argument( "--push_to_hub", action="store_true", default=False, help="Whether to push to huggingface hub or not." ) @@ -836,10 +842,12 @@ def load_original_state_dict(filename: str): logger.info("VAE conversion succeeded") unet_config = create_unet_diffusers_config(original_config, args.sample_size) - unet_config["motionctrl_kwargs"] = { - "camera_pose_embed_dim": 1, - "camera_pose_dim": 12, - } + # This is temporally added to handle motionctrl + if parser.convert_motionctrl: + unet_config["motionctrl_kwargs"] = { + "camera_pose_embed_dim": 1, + "camera_pose_dim": 12, + } unet_state_dict = convert_ldm_unet_checkpoint(state_dict, unet_config) unet = UNetSpatioTemporalConditionModel.from_config(unet_config) missing_keys, unexpected_keys = unet.load_state_dict(unet_state_dict) From 4018daef2974cba7b300ee3164ebc4c9f2c0733c Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 6 Feb 2024 23:15:42 +0530 Subject: [PATCH 14/21] revert changes to scripts/convert_svd_to_diffusers.py --- scripts/convert_svd_to_diffusers.py | 254 +++++----------------------- 1 file changed, 46 insertions(+), 208 deletions(-) diff --git a/scripts/convert_svd_to_diffusers.py b/scripts/convert_svd_to_diffusers.py index 053726222f0e..3243ce294b26 100644 --- a/scripts/convert_svd_to_diffusers.py +++ b/scripts/convert_svd_to_diffusers.py @@ -1,14 +1,3 @@ -import argparse - -import torch -import yaml -from safetensors.torch import load_file -from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from yaml.loader import FullLoader - -from diffusers import StableVideoDiffusionPipeline -from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel -from diffusers.schedulers import EulerDiscreteScheduler from diffusers.utils import is_accelerate_available, logging @@ -18,52 +7,28 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def create_vae_diffusers_config(original_config, image_size: int): - r""" - Creates a vae config for diffusers based on the config of the LDM. - """ - vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["decoder_config"]["params"] - block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - - vae_config = { - "sample_size": image_size, - "in_channels": vae_params["in_channels"], - "out_channels": vae_params["out_ch"], - "down_block_types": tuple(down_block_types), - "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params["z_channels"], - "layers_per_block": vae_params["num_res_blocks"], - } - - return vae_config - - def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): - r""" - Creates a unet config for diffusers based on the config of the LDM. + """ + Creates a config for the diffusers based on the config of the LDM model. """ if controlnet: - unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + unet_params = original_config.model.params.control_stage_config.params else: - if ( - "unet_config" in original_config["model"]["params"] - and original_config["model"]["params"]["unet_config"] is not None - ): - unet_params = original_config["model"]["params"]["unet_config"]["params"] + if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: + unet_params = original_config.model.params.unet_config.params else: - unet_params = original_config["model"]["params"]["network_config"]["params"] + unet_params = original_config.model.params.network_config.params - vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["decoder_config"]["params"] + vae_params = original_config.model.params.first_stage_config.params.encoder_config.params - block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): block_type = ( "CrossAttnDownBlockSpatioTemporal" - if resolution in unet_params["attention_resolutions"] + if resolution in unet_params.attention_resolutions else "DownBlockSpatioTemporal" ) down_block_types.append(block_type) @@ -74,32 +39,32 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa for i in range(len(block_out_channels)): block_type = ( "CrossAttnUpBlockSpatioTemporal" - if resolution in unet_params["attention_resolutions"] + if resolution in unet_params.attention_resolutions else "UpBlockSpatioTemporal" ) up_block_types.append(block_type) resolution //= 2 - if unet_params["transformer_depth"] is not None: + if unet_params.transformer_depth is not None: transformer_layers_per_block = ( - unet_params["transformer_depth"] - if isinstance(unet_params["transformer_depth"], int) - else list(unet_params["transformer_depth"]) + unet_params.transformer_depth + if isinstance(unet_params.transformer_depth, int) + else list(unet_params.transformer_depth) ) else: transformer_layers_per_block = 1 - vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) - head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + head_dim = unet_params.num_heads if "num_heads" in unet_params else None use_linear_projection = ( - unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] - head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] class_embed_type = None addition_embed_type = None @@ -107,25 +72,23 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa projection_class_embeddings_input_dim = None context_dim = None - if unet_params["context_dim"] is not None: + if unet_params.context_dim is not None: context_dim = ( - unet_params["context_dim"] - if isinstance(unet_params["context_dim"], int) - else unet_params["context_dim"][0] + unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] ) if "num_classes" in unet_params: - if unet_params["num_classes"] == "sequential": + if unet_params.num_classes == "sequential": addition_time_embed_dim = 256 assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + projection_class_embeddings_input_dim = unet_params.adm_in_channels config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params["in_channels"], + "in_channels": unet_params.in_channels, "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params["num_res_blocks"], + "layers_per_block": unet_params.num_res_blocks, "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, @@ -137,15 +100,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa } if "disable_self_attentions" in unet_params: - config["only_cross_attention"] = unet_params["disable_self_attentions"] + config["only_cross_attention"] = unet_params.disable_self_attentions - if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): - config["num_class_embeds"] = unet_params["num_classes"] + if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): + config["num_class_embeds"] = unet_params.num_classes if controlnet: - config["conditioning_channels"] = unet_params["hint_channels"] + config["conditioning_channels"] = unet_params.hint_channels else: - config["out_channels"] = unet_params["out_channels"] + config["out_channels"] = unet_params.out_channels config["up_block_types"] = tuple(up_block_types) return config @@ -206,6 +169,9 @@ def assign_to_checkpoint( for replacement in additional_replacements: new_path = new_path.replace(replacement["old"], replacement["new"]) + if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight": + print("yeyy") + # proj_attn.weight has to be converted from conv 1D to linear is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) shape = old_checkpoint[path["old"]].shape @@ -323,16 +289,16 @@ def convert_ldm_unet_checkpoint( new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - # if config["class_embed_type"] is None: - # # No parameters to port - # ... - # elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - # new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - # new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - # new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - # new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - # else: - # raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") # if config["addition_embed_type"] == "text_time": new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] @@ -533,6 +499,7 @@ def convert_ldm_unet_checkpoint( attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key] if len(attentions): paths = renew_attention_paths(attentions) + # import ipdb; ipdb.set_trace() meta_path = { "old": f"output_blocks.{i}.1", "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", @@ -677,8 +644,8 @@ def convert_ldm_vae_checkpoint(checkpoint, config): new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.conv_out.time_mix_conv.weight"] - new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.conv_out.time_mix_conv.bias"] + new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"] + new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"] # new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] # new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] @@ -761,132 +728,3 @@ def convert_ldm_vae_checkpoint(checkpoint, config): assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) conv_attn_to_linear(new_checkpoint) return new_checkpoint - - -def read_config_file(filename): - # The yaml file contains annotations that certain values should - # loaded as tuples. - with open(filename) as f: - original_config = yaml.load(f, FullLoader) - - return original_config - - -def load_original_state_dict(filename: str): - if filename.endswith("safetensors"): - state_dict = load_file(filename) - elif filename.endswith("ckpt"): - state_dict = torch.load(filename, mmap=True, map_location="cpu") - else: - raise ValueError("File type is not supported") - - if isinstance(state_dict, dict) and "state_dict" in state_dict.keys(): - state_dict = state_dict["state_dict"] - - return state_dict - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint to convert.", required=True) - parser.add_argument( - "--config_file", type=str, help="The config json file corresponding to the architecture.", required=True - ) - parser.add_argument("--output_path", default=None, type=str, help="Path to the output model.", required=True) - parser.add_argument("--sample_size", type=int, default=768, help="VAE sample size") - parser.add_argument( - "--use_legacy_autoencoder", - action="store_true", - default=False, - help="Whether or not to use the `quant_conv` layers from the original implementation (which is now legacy behaviour).", - ) - parser.add_argument( - "--convert_motionctrl", - type="store_true", - default=False, - help="Whether or not converting motionctrl svd checkpoint.", - ) - parser.add_argument( - "--push_to_hub", action="store_true", default=False, help="Whether to push to huggingface hub or not." - ) - args = parser.parse_args() - - original_config = read_config_file(args.config_file) - state_dict = load_original_state_dict(args.checkpoint_path) - - vae_config = create_vae_diffusers_config(original_config, args.sample_size) - vae = AutoencoderKLTemporalDecoder(**vae_config, use_legacy=args.use_legacy_autoencoder) - vae_state_dict = convert_ldm_vae_checkpoint(state_dict, vae_config) - - remove = [] - for key in vae_state_dict.keys(): - # i'm sorry to hurt your eyes - if ("encoder" in key) or ( - "decoder" in key and "resnets" and (("temporal_res_block" in key) or ("time_mixer" in key)) - ): - remove.append(key) - - for key in remove: - vae_state_dict[key.replace("spatial_res_block.", "")] = vae_state_dict.pop(key) - - missing_keys, unexpected_keys = vae.load_state_dict(vae_state_dict) - if missing_keys: - logger.error("VAE conversion failed") - raise ValueError( - f"VAE conversion failed due to missing keys: {missing_keys}, and unexpected keys: {unexpected_keys}" - ) - if unexpected_keys: - vae.load_state_dict(vae_state_dict, strict=False) - logger.info(f"VAE conversion occured successfully but some unexpected keys were found: {unexpected_keys}") - logger.info("VAE conversion succeeded") - - unet_config = create_unet_diffusers_config(original_config, args.sample_size) - # This is temporally added to handle motionctrl - if parser.convert_motionctrl: - unet_config["motionctrl_kwargs"] = { - "camera_pose_embed_dim": 1, - "camera_pose_dim": 12, - } - unet_state_dict = convert_ldm_unet_checkpoint(state_dict, unet_config) - unet = UNetSpatioTemporalConditionModel.from_config(unet_config) - missing_keys, unexpected_keys = unet.load_state_dict(unet_state_dict) - - if missing_keys: - logger.error("UNet conversion failed") - raise ValueError( - f"UNet conversion failed due to missing keys: {missing_keys}, and unexpected keys: {unexpected_keys}" - ) - if unexpected_keys: - unet.load_state_dict(unet_state_dict, strict=False) - logger.info(f"UNet conversion occured successfully but some unexpected keys were found: {unexpected_keys}") - logger.info("UNet conversion succeeded") - - image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - feature_extractor = CLIPImageProcessor() - scheduler = EulerDiscreteScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - use_karras_sigmas=True, - ) - - pipe = StableVideoDiffusionPipeline( - vae=vae, - image_encoder=image_encoder, - unet=unet, - scheduler=scheduler, - feature_extractor=feature_extractor, - ) - - pipe.save_pretrained(args.output_path) - if args.push_to_hub: - logger.info("Pushing float32 version to HF hub") - pipe.push_to_hub(args.output_path) - - pipe.to(dtype=torch.float16) - pipe.save_pretrained(args.output_path, variant="fp16") - if args.push_to_hub: - logger.info("Pushing float16 version to HF hub") - pipe.push_to_hub(args.output_path, variant="fp16") From 1db684ceab97ebb5dd1f3e2978fee0330489f189 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Tue, 6 Feb 2024 23:17:44 +0530 Subject: [PATCH 15/21] add conversion script for motionctrl --- .../convert_motionctrl_svd_to_diffusers.py | 884 ++++++++++++++++++ 1 file changed, 884 insertions(+) create mode 100644 scripts/convert_motionctrl_svd_to_diffusers.py diff --git a/scripts/convert_motionctrl_svd_to_diffusers.py b/scripts/convert_motionctrl_svd_to_diffusers.py new file mode 100644 index 000000000000..859a91466c78 --- /dev/null +++ b/scripts/convert_motionctrl_svd_to_diffusers.py @@ -0,0 +1,884 @@ +import argparse + +import torch +import yaml +from safetensors.torch import load_file +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from yaml.loader import FullLoader + +from diffusers import StableVideoDiffusionPipeline +from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel +from diffusers.schedulers import EulerDiscreteScheduler +from diffusers.utils import is_accelerate_available, logging + + +if is_accelerate_available(): + pass + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def create_vae_diffusers_config(original_config, image_size: int): + r""" + Creates a vae config for diffusers based on the config of the LDM. + """ + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["decoder_config"]["params"] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + + vae_config = { + "sample_size": image_size, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], + } + + return vae_config + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + r""" + Creates a unet config for diffusers based on the config of the LDM. + """ + if controlnet: + unet_params = original_config["model"]["params"]["control_stage_config"]["params"] + else: + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] + else: + unet_params = original_config["model"]["params"]["network_config"]["params"] + + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["decoder_config"]["params"] + + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnDownBlockSpatioTemporal" + if resolution in unet_params["attention_resolutions"] + else "DownBlockSpatioTemporal" + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnUpBlockSpatioTemporal" + if resolution in unet_params["attention_resolutions"] + else "UpBlockSpatioTemporal" + ) + up_block_types.append(block_type) + resolution //= 2 + + if unet_params["transformer_depth"] is not None: + transformer_layers_per_block = ( + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) + ) + else: + transformer_layers_per_block = 1 + + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) + + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] + + class_embed_type = None + addition_embed_type = None + addition_time_embed_dim = None + projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params["context_dim"] is not None: + context_dim = ( + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] + ) + + if "num_classes" in unet_params: + if unet_params["num_classes"] == "sequential": + addition_time_embed_dim = 256 + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params["in_channels"], + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params["num_res_blocks"], + "cross_attention_dim": context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "addition_embed_type": addition_embed_type, + "addition_time_embed_dim": addition_time_embed_dim, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + "transformer_layers_per_block": transformer_layers_per_block, + } + + if "disable_self_attentions" in unet_params: + config["only_cross_attention"] = unet_params["disable_self_attentions"] + + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] + + if controlnet: + config["conditioning_channels"] = unet_params["hint_channels"] + else: + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) + + return config + + +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + config=None, + mid_block_suffix="", +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + if mid_block_suffix is not None: + mid_block_suffix = f".{mid_block_suffix}" + else: + mid_block_suffix = "" + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + new_item = new_item.replace("time_stack", "temporal_transformer_blocks") + + new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias") + new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight") + new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias") + new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight") + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = new_item.replace("time_stack.", "") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + if skip_extract_state_dict: + unet_state_dict = checkpoint + else: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + logger.warning( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + # if config["class_embed_type"] is None: + # # No parameters to port + # ... + # elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + # new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + # new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + # new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + # new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + # else: + # raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + # if config["addition_embed_type"] == "text_time": + new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + spatial_resnets = [ + key + for key in input_blocks[i] + if f"input_blocks.{i}.0" in key + and ( + f"input_blocks.{i}.0.op" not in key + and f"input_blocks.{i}.0.time_stack" not in key + and f"input_blocks.{i}.0.time_mixer" not in key + ) + ] + temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key] + # import ipdb; ipdb.set_trace() + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(spatial_resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + paths = renew_resnet_paths(temporal_resnets) + meta_path = { + "old": f"input_blocks.{i}.0", + "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + # TODO resnet time_mixer.mix_factor + if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: + new_checkpoint[ + f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + # import ipdb; ipdb.set_trace() + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key] + resnet_0_paths = renew_resnet_paths(resnet_0_spatial) + # import ipdb; ipdb.set_trace() + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block" + ) + + resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key] + resnet_0_paths = renew_resnet_paths(resnet_0_temporal) + assign_to_checkpoint( + resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block" + ) + + resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key] + resnet_1_paths = renew_resnet_paths(resnet_1_spatial) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block" + ) + + resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key] + resnet_1_paths = renew_resnet_paths(resnet_1_temporal) + assign_to_checkpoint( + resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block" + ) + + new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[ + "middle_block.0.time_mixer.mix_factor" + ] + new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[ + "middle_block.2.time_mixer.mix_factor" + ] + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + spatial_resnets = [ + key + for key in output_blocks[i] + if f"output_blocks.{i}.0" in key + and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key) + ] + # import ipdb; ipdb.set_trace() + + temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key] + + paths = renew_resnet_paths(spatial_resnets) + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + paths = renew_resnet_paths(temporal_resnets) + meta_path = { + "old": f"output_blocks.{i}.0", + "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict: + new_checkpoint[ + f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor" + ] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"] + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key] + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + spatial_layers = [ + layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer + ] + resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1) + # import ipdb; ipdb.set_trace() + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + temporal_layers = [ + layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key + ] + resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1) + # import ipdb; ipdb.set_trace() + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join( + ["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]] + ) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[ + f"output_blocks.{str(i)}.0.time_mixer.mix_factor" + ] + + return new_checkpoint + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # Temporal resnet + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = new_item.replace("time_stack.", "temporal_res_block.") + + # Spatial resnet + new_item = new_item.replace("conv1", "spatial_res_block.conv1") + new_item = new_item.replace("norm1", "spatial_res_block.norm1") + + new_item = new_item.replace("conv2", "spatial_res_block.conv2") + new_item = new_item.replace("norm2", "spatial_res_block.norm2") + + new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut") + + new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + keys = list(checkpoint.keys()) + vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else "" + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.conv_out.time_mix_conv.weight"] + new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.conv_out.time_mix_conv.bias"] + + # new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + # new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + # new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + # new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def read_config_file(filename): + # The yaml file contains annotations that certain values should + # loaded as tuples. + with open(filename) as f: + original_config = yaml.load(f, FullLoader) + + return original_config + + +def load_original_state_dict(filename: str): + if filename.endswith("safetensors"): + state_dict = load_file(filename) + elif filename.endswith("ckpt"): + state_dict = torch.load(filename, mmap=True, map_location="cpu") + else: + raise ValueError("File type is not supported") + + if isinstance(state_dict, dict) and "state_dict" in state_dict.keys(): + state_dict = state_dict["state_dict"] + + return state_dict + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint to convert.", required=True) + parser.add_argument( + "--config_file", type=str, help="The config json file corresponding to the architecture.", required=True + ) + parser.add_argument("--output_path", default=None, type=str, help="Path to the output model.", required=True) + parser.add_argument("--sample_size", type=int, default=768, help="VAE sample size") + parser.add_argument( + "--use_legacy_autoencoder", + action="store_true", + default=False, + help="Whether or not to use the `quant_conv` layers from the original implementation (which is now legacy behaviour).", + ) + parser.add_argument( + "--push_to_hub", action="store_true", default=False, help="Whether to push to huggingface hub or not." + ) + args = parser.parse_args() + + original_config = read_config_file(args.config_file) + state_dict = load_original_state_dict(args.checkpoint_path) + + vae_config = create_vae_diffusers_config(original_config, args.sample_size) + vae = AutoencoderKLTemporalDecoder(**vae_config, use_legacy=args.use_legacy_autoencoder) + vae_state_dict = convert_ldm_vae_checkpoint(state_dict, vae_config) + + remove = [] + for key in vae_state_dict.keys(): + # i'm sorry to hurt your eyes + if ("encoder" in key) or ( + "decoder" in key and "resnets" and (("temporal_res_block" in key) or ("time_mixer" in key)) + ): + remove.append(key) + + for key in remove: + vae_state_dict[key.replace("spatial_res_block.", "")] = vae_state_dict.pop(key) + + missing_keys, unexpected_keys = vae.load_state_dict(vae_state_dict) + if missing_keys: + logger.error("VAE conversion failed") + raise ValueError( + f"VAE conversion failed due to missing keys: {missing_keys}, and unexpected keys: {unexpected_keys}" + ) + if unexpected_keys: + vae.load_state_dict(vae_state_dict, strict=False) + logger.info(f"VAE conversion occured successfully but some unexpected keys were found: {unexpected_keys}") + logger.info("VAE conversion succeeded") + + unet_config = create_unet_diffusers_config(original_config, args.sample_size) + unet_config["motionctrl_kwargs"] = { + "camera_pose_embed_dim": 1, + "camera_pose_dim": 12, + } + unet_state_dict = convert_ldm_unet_checkpoint(state_dict, unet_config) + unet = UNetSpatioTemporalConditionModel.from_config(unet_config) + missing_keys, unexpected_keys = unet.load_state_dict(unet_state_dict) + + if missing_keys: + logger.error("UNet conversion failed") + raise ValueError( + f"UNet conversion failed due to missing keys: {missing_keys}, and unexpected keys: {unexpected_keys}" + ) + if unexpected_keys: + unet.load_state_dict(unet_state_dict, strict=False) + logger.info(f"UNet conversion occured successfully but some unexpected keys were found: {unexpected_keys}") + logger.info("UNet conversion succeeded") + + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + feature_extractor = CLIPImageProcessor() + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + use_karras_sigmas=True, + ) + + pipe = StableVideoDiffusionPipeline( + vae=vae, + image_encoder=image_encoder, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + + pipe.save_pretrained(args.output_path) + if args.push_to_hub: + logger.info("Pushing float32 version to HF hub") + pipe.push_to_hub(args.output_path) + + pipe.to(dtype=torch.float16) + pipe.save_pretrained(args.output_path, variant="fp16") + if args.push_to_hub: + logger.info("Pushing float16 version to HF hub") + pipe.push_to_hub(args.output_path, variant="fp16") From e1e2beba78dcc9313c80b54e4631db1b86ca76fd Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Wed, 7 Feb 2024 22:01:00 +0530 Subject: [PATCH 16/21] make model loading strict --- .../convert_motionctrl_svd_to_diffusers.py | 33 ++++--------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/scripts/convert_motionctrl_svd_to_diffusers.py b/scripts/convert_motionctrl_svd_to_diffusers.py index 859a91466c78..d3fc3b82dda1 100644 --- a/scripts/convert_motionctrl_svd_to_diffusers.py +++ b/scripts/convert_motionctrl_svd_to_diffusers.py @@ -825,15 +825,8 @@ def load_original_state_dict(filename: str): vae_state_dict[key.replace("spatial_res_block.", "")] = vae_state_dict.pop(key) missing_keys, unexpected_keys = vae.load_state_dict(vae_state_dict) - if missing_keys: - logger.error("VAE conversion failed") - raise ValueError( - f"VAE conversion failed due to missing keys: {missing_keys}, and unexpected keys: {unexpected_keys}" - ) - if unexpected_keys: - vae.load_state_dict(vae_state_dict, strict=False) - logger.info(f"VAE conversion occured successfully but some unexpected keys were found: {unexpected_keys}") - logger.info("VAE conversion succeeded") + logger.info(f"[VAE] missing_keys: {missing_keys}") + logger.info(f"[VAE] unexpected_keys: {unexpected_keys}") unet_config = create_unet_diffusers_config(original_config, args.sample_size) unet_config["motionctrl_kwargs"] = { @@ -843,26 +836,14 @@ def load_original_state_dict(filename: str): unet_state_dict = convert_ldm_unet_checkpoint(state_dict, unet_config) unet = UNetSpatioTemporalConditionModel.from_config(unet_config) missing_keys, unexpected_keys = unet.load_state_dict(unet_state_dict) - - if missing_keys: - logger.error("UNet conversion failed") - raise ValueError( - f"UNet conversion failed due to missing keys: {missing_keys}, and unexpected keys: {unexpected_keys}" - ) - if unexpected_keys: - unet.load_state_dict(unet_state_dict, strict=False) - logger.info(f"UNet conversion occured successfully but some unexpected keys were found: {unexpected_keys}") + logger.info(f"[UNet] missing_keys: {missing_keys}") + logger.info(f"[UNet] unexpected_keys: {unexpected_keys}") logger.info("UNet conversion succeeded") - image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + original_svd_model_id = "stabilityai/stable-video-diffusion-img2vid-xt" + image_encoder = CLIPVisionModelWithProjection.from_pretrained(original_svd_model_id, subfolder="image_encoder") feature_extractor = CLIPImageProcessor() - scheduler = EulerDiscreteScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - use_karras_sigmas=True, - ) + scheduler = EulerDiscreteScheduler.from_pretrained(original_svd_model_id) pipe = StableVideoDiffusionPipeline( vae=vae, From d21bce653516d8aaa610d3f51d4712ebb7d29509 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Wed, 7 Feb 2024 22:01:32 +0530 Subject: [PATCH 17/21] rename script --- ...trl_svd_to_diffusers.py => convert_motionctrl_to_diffusers.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{convert_motionctrl_svd_to_diffusers.py => convert_motionctrl_to_diffusers.py} (100%) diff --git a/scripts/convert_motionctrl_svd_to_diffusers.py b/scripts/convert_motionctrl_to_diffusers.py similarity index 100% rename from scripts/convert_motionctrl_svd_to_diffusers.py rename to scripts/convert_motionctrl_to_diffusers.py From 4b66b4fe7f662d3959c05b56a4eb6ca48b1a9c74 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Wed, 7 Feb 2024 22:13:59 +0530 Subject: [PATCH 18/21] fix script --- scripts/convert_motionctrl_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_motionctrl_to_diffusers.py b/scripts/convert_motionctrl_to_diffusers.py index d3fc3b82dda1..794c3752ec62 100644 --- a/scripts/convert_motionctrl_to_diffusers.py +++ b/scripts/convert_motionctrl_to_diffusers.py @@ -843,7 +843,7 @@ def load_original_state_dict(filename: str): original_svd_model_id = "stabilityai/stable-video-diffusion-img2vid-xt" image_encoder = CLIPVisionModelWithProjection.from_pretrained(original_svd_model_id, subfolder="image_encoder") feature_extractor = CLIPImageProcessor() - scheduler = EulerDiscreteScheduler.from_pretrained(original_svd_model_id) + scheduler = EulerDiscreteScheduler.from_pretrained(original_svd_model_id, subfolder="scheduler") pipe = StableVideoDiffusionPipeline( vae=vae, From adb2cafffe0b2da825d90944c235f477ffc2db8d Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Wed, 7 Feb 2024 22:15:03 +0530 Subject: [PATCH 19/21] remove unnecessary lines in script --- scripts/convert_motionctrl_to_diffusers.py | 24 ---------------------- 1 file changed, 24 deletions(-) diff --git a/scripts/convert_motionctrl_to_diffusers.py b/scripts/convert_motionctrl_to_diffusers.py index 794c3752ec62..5469be021fe2 100644 --- a/scripts/convert_motionctrl_to_diffusers.py +++ b/scripts/convert_motionctrl_to_diffusers.py @@ -225,13 +225,6 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0): for old_item in old_list: new_item = old_item - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) new_item = new_item.replace("time_stack", "temporal_transformer_blocks") new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias") @@ -323,18 +316,6 @@ def convert_ldm_unet_checkpoint( new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - # if config["class_embed_type"] is None: - # # No parameters to port - # ... - # elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - # new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - # new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - # new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - # new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - # else: - # raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") - - # if config["addition_embed_type"] == "text_time": new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] @@ -680,11 +661,6 @@ def convert_ldm_vae_checkpoint(checkpoint, config): new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.conv_out.time_mix_conv.weight"] new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.conv_out.time_mix_conv.bias"] - # new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - # new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - # new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - # new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - # Retrieves the keys for the encoder down blocks only num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) down_blocks = { From a5a45b04d516627bd9c89be2d8ca3831b6cbeea2 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Wed, 7 Feb 2024 22:56:37 +0530 Subject: [PATCH 20/21] repeat_interleave is not an inplace operation dummy --- src/diffusers/models/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5280d586bc14..6e6192e44437 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -534,7 +534,7 @@ def forward( # MotionCtrl specific if self.use_camera_projection: camera_pose: torch.FloatTensor = added_cond_kwargs.get("camera_pose") - camera_pose.repeat_interleave(seq_length, dim=0) # [batch_size * seq_length, num_frames, 12] + camera_pose = camera_pose.repeat_interleave(seq_length, dim=0) # [batch_size * seq_length, num_frames, 12] hidden_states = torch.cat([hidden_states, camera_pose], dim=-1) hidden_states = self.cc_projection(hidden_states) From 7769f0e14c3321a1d9c0629ffd2bc3eae0245e24 Mon Sep 17 00:00:00 2001 From: a-r-r-o-w Date: Thu, 8 Feb 2024 00:13:13 +0530 Subject: [PATCH 21/21] add camera speed --- .../pipeline_stable_video_motionctrl_diffusion.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py index c93520598bf1..d0fa713f2456 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_motionctrl_diffusion.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -339,6 +339,7 @@ def __call__( self, image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor], camera_pose: List[List[float]], + camera_speed: float = 1.0, height: int = 576, width: int = 1024, num_frames: Optional[int] = None, @@ -527,10 +528,11 @@ def __call__( self._guidance_scale = guidance_scale - camera_pose = np.array(camera_pose) + camera_pose = np.array(camera_pose).reshape(-1, 3, 4) + camera_pose[:, :, -1] = camera_pose[:, :, -1] * np.array([3, 1, 4]) * camera_speed # rescale camera_pose = self._to_relative_camera_pose(camera_pose) camera_pose = torch.FloatTensor(camera_pose).to(device=device, dtype=image_embeddings.dtype) - camera_pose = camera_pose.unsqueeze(0).repeat(2, 1, 1) + camera_pose = camera_pose.repeat(2, 1, 1) added_cond_kwargs = {"camera_pose": camera_pose} # 8. Denoising loop