Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnetxs"] = ["ControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
Expand Down
703 changes: 703 additions & 0 deletions src/diffusers/models/controlnet_sparsectrl.py

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/diffusers/models/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def get_down_block(
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
temporal_num_attention_heads: int = 8,
temporal_double_self_attention=True,
temporal_max_seq_length: int = 32,
transformer_layers_per_block: int = 1,
) -> Union[
Expand Down Expand Up @@ -112,6 +113,7 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_double_self_attention=temporal_double_self_attention,
)
elif down_block_type == "CrossAttnDownBlockMotion":
if cross_attention_dim is None:
Expand All @@ -135,6 +137,7 @@ def get_down_block(
resnet_time_scale_shift=resnet_time_scale_shift,
temporal_num_attention_heads=temporal_num_attention_heads,
temporal_max_seq_length=temporal_max_seq_length,
temporal_double_self_attention=temporal_double_self_attention,
)
elif down_block_type == "DownBlockSpatioTemporal":
# added for SDV
Expand Down Expand Up @@ -946,6 +949,7 @@ def __init__(
temporal_num_attention_heads: int = 1,
temporal_cross_attention_dim: Optional[int] = None,
temporal_max_seq_length: int = 32,
temporal_double_self_attention: bool = True,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -978,6 +982,7 @@ def __init__(
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
double_self_attention=temporal_double_self_attention,
)
)

Expand Down Expand Up @@ -1080,6 +1085,7 @@ def __init__(
temporal_cross_attention_dim: Optional[int] = None,
temporal_num_attention_heads: int = 8,
temporal_max_seq_length: int = 32,
temporal_double_self_attention: bool = True,
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -1144,6 +1150,7 @@ def __init__(
positional_embeddings="sinusoidal",
num_positional_embeddings=temporal_max_seq_length,
attention_head_dim=out_channels // temporal_num_attention_heads,
double_self_attention=temporal_double_self_attention,
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse

import torch
from safetensors.torch import save_file


def convert_motion_module(original_state_dict):
converted_state_dict = {}
for k, v in original_state_dict.items():
if "pos_encoder" in k:
continue

else:
converted_state_dict[
k.replace(".norms.0", ".norm1")
.replace(".norms.1", ".norm2")
.replace(".ff_norm", ".norm3")
.replace(".attention_blocks.0", ".attn1")
.replace(".attention_blocks.1", ".attn2")
.replace(".temporal_transformer", "")
] = v

return converted_state_dict


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)

return parser.parse_args()


if __name__ == "__main__":
args = get_args()

state_dict = torch.load(args.ckpt_path, map_location="cpu")

if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]

conv_state_dict = convert_motion_module(state_dict)

# convert to new format
output_dict = {}
for module_name, params in conv_state_dict.items():
if type(params) is not torch.Tensor:
continue
output_dict.update({f"unet.{module_name}": params})

save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import argparse

import torch

from diffusers import MotionAdapter


def convert_motion_module(original_state_dict):
converted_state_dict = {}
for k, v in original_state_dict.items():
if "pos_encoder" in k:
continue

else:
converted_state_dict[
k.replace(".norms.0", ".norm1")
.replace(".norms.1", ".norm2")
.replace(".ff_norm", ".norm3")
.replace(".attention_blocks.0", ".attn1")
.replace(".attention_blocks.1", ".attn2")
.replace(".temporal_transformer", "")
] = v

return converted_state_dict


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--use_motion_mid_block", action="store_true")
parser.add_argument("--motion_max_seq_length", type=int, default=32)

return parser.parse_args()


if __name__ == "__main__":
args = get_args()

state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]

conv_state_dict = convert_motion_module(state_dict)
adapter = MotionAdapter(
use_motion_mid_block=args.use_motion_mid_block, motion_max_seq_length=args.motion_max_seq_length
)
# skip loading position embeddings
adapter.load_state_dict(conv_state_dict, strict=False)
adapter.save_pretrained(args.output_path)
adapter.save_pretrained(args.output_path, variant="fp16")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp16 variant, we should also set the dtype to torch.float16. I don't see that being used. This issue persists here too: https://huggingface.co/guoyww/animatediff-motion-adapter-v1-5-2/tree/main.

Both the fp16 and non-fp16 variants have the same size which shouldn't be the case.

Am I missing out on anything?

Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse

import torch

from diffusers.models import SparseControlNetModel


def convert_sparse_cntrl_module(original_state_dict):
converted_state_dict = {}
for k, v in original_state_dict.items():
if "pos_encoder" in k:
continue

else:
converted_state_dict[
k.replace(".norms.0", ".norm1")
.replace(".norms.1", ".norm2")
.replace(".ff_norm", ".norm3")
.replace(".attention_blocks.0", ".attn1")
.replace(".attention_blocks.1", ".attn2")
.replace(".temporal_transformer", "")
] = v

return converted_state_dict


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--motion_max_seq_length", type=int, default=32)

return parser.parse_args()


if __name__ == "__main__":
args = get_args()

state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]

conv_state_dict = convert_sparse_cntrl_module(state_dict)
controlnet = SparseControlNetModel()

# skip loading position embeddings
controlnet.load_state_dict(conv_state_dict, strict=False)
controlnet.save_pretrained(args.output_path)
controlnet.save_pretrained(args.output_path, variant="fp16")
Loading