From 8b7ea8110bfdfcea4097596ca34ab849340a4d8a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 20 Nov 2025 22:13:15 +0100 Subject: [PATCH 01/34] add --- .../convert_hunyuan_video1_5_to_diffusers.py | 382 +++++++ src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoder_kl_hunyuanvideo15.py | 968 ++++++++++++++++++ src/diffusers/models/transformers/__init__.py | 1 + .../transformer_hunyuan_video15.py | 937 +++++++++++++++++ 7 files changed, 2297 insertions(+) create mode 100644 scripts/convert_hunyuan_video1_5_to_diffusers.py create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py create mode 100644 src/diffusers/models/transformers/transformer_hunyuan_video15.py diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py new file mode 100644 index 000000000000..59df02d1901f --- /dev/null +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -0,0 +1,382 @@ +""" +python scripts/convert_hunyuan_video1_5_to_diffusers.py \ + --original_state_dict_folder /raid/yiyi/new-model-vid \ + --output_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \ + --transformer_type 480p_i2v \ + --dtype fp32 +""" + +import argparse +from typing import Any, Dict + +import torch +from accelerate import init_empty_weights +from safetensors.torch import load_file +from huggingface_hub import snapshot_download + +import pathlib +from diffusers import HunyuanVideo15Transformer3DModel + +TRANSFORMER_CONFIGS = { + "480p_i2v": { + "in_channels": 65, + "out_channels": 32, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 54, + "num_refiner_layers": 2, + "mlp_ratio": 4.0, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm", + "text_embed_dim": 3584, + "text_embed_2_dim": 1472, + "image_embed_dim": 1152, + "rope_theta": 256.0, + "rope_axes_dim": (16, 56, 56), + "use_meanflow": False, + }, +} + +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_hyvideo15_transformer_to_diffusers(original_state_dict): + """ + Convert HunyuanVideo 1.5 original checkpoint to Diffusers format. + """ + converted_state_dict = {} + + # 1. time_embed.timestep_embedder <- time_in + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( + "time_in.mlp.0.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "time_in.mlp.0.bias" + ) + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( + "time_in.mlp.2.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "time_in.mlp.2.bias" + ) + + # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.0.weight") + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.0.bias") + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.2.weight") + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.2.bias") + ) + + # 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = ( + original_state_dict.pop("txt_in.c_embedder.linear_1.weight") + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = ( + original_state_dict.pop("txt_in.c_embedder.linear_1.bias") + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = ( + original_state_dict.pop("txt_in.c_embedder.linear_2.weight") + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = ( + original_state_dict.pop("txt_in.c_embedder.linear_2.bias") + ) + + # 4. context_embedder.proj_in <- txt_in.input_embedder + converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop( + "txt_in.input_embedder.weight" + ) + converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias") + + # 5. context_embedder.token_refiner <- txt_in.individual_token_refiner + num_refiner_blocks = 2 + for i in range(num_refiner_blocks): + block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}." + orig_prefix = f"txt_in.individual_token_refiner.blocks.{i}." + + # norm1 + converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(f"{orig_prefix}norm1.weight") + converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(f"{orig_prefix}norm1.bias") + + # Split self_attn_qkv into to_q, to_k, to_v + qkv_weight = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.weight") + qkv_bias = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.bias") + q, k, v = torch.chunk(qkv_weight, 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0) + + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias + + # self_attn_proj -> attn.to_out.0 + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"{orig_prefix}self_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"{orig_prefix}self_attn_proj.bias" + ) + + # norm2 + converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(f"{orig_prefix}norm2.weight") + converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(f"{orig_prefix}norm2.bias") + + # mlp -> ff + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(f"{orig_prefix}mlp.fc2.bias") + + # adaLN_modulation -> norm_out + converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}adaLN_modulation.1.weight" + ) + converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}adaLN_modulation.1.bias" + ) + + # 6. context_embedder_2 <- byt5_in + converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight") + converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias") + converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight") + converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias") + converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight") + converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias") + converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight") + converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias") + + # 7. image_embedder <- vision_in + converted_state_dict["image_embedder.norm_in.weight"] = original_state_dict.pop("vision_in.proj.0.weight") + converted_state_dict["image_embedder.norm_in.bias"] = original_state_dict.pop("vision_in.proj.0.bias") + converted_state_dict["image_embedder.linear_1.weight"] = original_state_dict.pop("vision_in.proj.1.weight") + converted_state_dict["image_embedder.linear_1.bias"] = original_state_dict.pop("vision_in.proj.1.bias") + converted_state_dict["image_embedder.linear_2.weight"] = original_state_dict.pop("vision_in.proj.3.weight") + converted_state_dict["image_embedder.linear_2.bias"] = original_state_dict.pop("vision_in.proj.3.bias") + converted_state_dict["image_embedder.norm_out.weight"] = original_state_dict.pop("vision_in.proj.4.weight") + converted_state_dict["image_embedder.norm_out.bias"] = original_state_dict.pop("vision_in.proj.4.bias") + + # 8. x_embedder <- img_in + converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight") + converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias") + + # 9. cond_type_embed <- cond_type_embedding + converted_state_dict["cond_type_embed.weight"] = original_state_dict.pop("cond_type_embedding.weight") + + # 10. transformer_blocks <- double_blocks + num_layers = 54 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + orig_prefix = f"double_blocks.{i}." + + # norm1 (img_mod) + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mod.linear.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mod.linear.bias" + ) + + # norm1_context (txt_mod) + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mod.linear.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mod.linear.bias" + ) + + # img attention (to_q, to_k, to_v) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_v.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_v.bias" + ) + + # img attention qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k_norm.weight" + ) + + # img attention output projection + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_proj.bias" + ) + + # txt attention (add_q_proj, add_k_proj, add_v_proj) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q.bias" + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k.bias" + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_v.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_v.bias" + ) + + # txt attention qk norm + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k_norm.weight" + ) + + # txt attention output projection + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_proj.bias" + ) + + # norm2 and norm2_context (these don't have weights in the original, they're LayerNorm with elementwise_affine=False) + # So we skip them + + # img_mlp -> ff + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc2.bias" + ) + + # txt_mlp -> ff_context + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc2.bias" + ) + + # 11. norm_out and proj_out <- final_layer + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop( + "final_layer.adaLN_modulation.1.weight" + )) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.adaLN_modulation.1.bias")) + converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + + return converted_state_dict + + +def load_sharded_safetensors(dir: pathlib.Path): + file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) + state_dict = {} + for path in file_paths: + state_dict.update(load_file(path)) + return state_dict + + +def load_original_state_dict(args): + if args.original_state_dict_repo_id is not None: + model_dir = snapshot_download( + args.original_state_dict_repo_id, + repo_type="model", + allow_patterns="transformer/" + args.transformer_type + "/*" + ) + elif args.original_state_dict_folder is not None: + model_dir = pathlib.Path(args.original_state_dict_folder) + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`") + model_dir = pathlib.Path(model_dir) + model_dir = model_dir / "transformer" / args.transformer_type + return load_sharded_safetensors(model_dir) + +def convert_transformer(args): + original_state_dict = load_original_state_dict(args) + + config = TRANSFORMER_CONFIGS[args.transformer_type] + with init_empty_weights(): + transformer = HunyuanVideo15Transformer3DModel(**config) + state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict) + transformer.load_state_dict(state_dict, strict=True, assign=True) + + return transformer + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model" + ) + parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Folder name of the original state dict") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") + parser.add_argument( + "--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()) + ) + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + transformer = convert_transformer(args) + transformer = transformer.to(dtype=dtype) + transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index cd7a2cb581b7..eb9929cf2c99 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -189,6 +189,7 @@ "AutoencoderKLHunyuanImage", "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", + "AutoencoderKLHunyuanVideo15", "AutoencoderKLLTXVideo", "AutoencoderKLMagvit", "AutoencoderKLMochi", @@ -223,6 +224,7 @@ "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", "HunyuanImageTransformer2DModel", + "HunyuanVideo15Transformer3DModel", "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", @@ -903,6 +905,7 @@ AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, + AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -939,6 +942,7 @@ HunyuanImageTransformer2DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, + HunyuanVideo15Transformer3DModel, I2VGenXLUNet, Kandinsky3UNet, Kandinsky5Transformer3DModel, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index b42e981f71a9..aa80c93f2b10 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -38,6 +38,7 @@ _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] + _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] @@ -83,6 +84,7 @@ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] + _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] @@ -143,6 +145,7 @@ AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, + AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -191,6 +194,7 @@ DualTransformer2DModel, EasyAnimateTransformer3DModel, FluxTransformer2DModel, + HunyuanVideo15Transformer3DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index edfaabb070c5..6d39785a86b4 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -7,6 +7,7 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner +from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py new file mode 100644 index 000000000000..2f05172a97d3 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -0,0 +1,968 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanVideo15CausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanVideo15RMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class HunyuanVideo15AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = HunyuanVideo15RMS_norm(in_channels, images=False) + + self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1) + + @staticmethod + def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): + """Prepare a causal attention mask for 3D videos. + + Args: + n_frame (int): Number of frames (temporal length). + n_hw (int): Product of height and width. + dtype: Desired mask dtype. + device: Device for the mask. + batch_size (int, optional): If set, expands for batch. + + Returns: + torch.Tensor: Causal attention mask. + """ + seq_len = n_frame * n_hw + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // n_hw + mask[i, : (i_frame + 1) * n_hw] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + x = self.norm(x) + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + batch_size, channels, frames, height, width = query.shape + + query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + + attention_mask = self.prepare_causal_attention_mask(frames, height * width, query.dtype, query.device, batch_size=batch_size) + + x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + # batch_size, 1, frames * height * width, channels + + x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3) + x = self.proj_out(x) + + return x + identity + + +class HunyuanVideo15Upsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2 + self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels * factor, kernel_size=3) + + self.add_temporal_upsample = add_temporal_upsample + self.repeats = factor * out_channels // in_channels + + @staticmethod + def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w) + + Args: + tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w) + r1: temporal upsampling factor + r2: height upsampling factor + r3: width upsampling factor + """ + b, packed_c, f, h, w = tensor.shape + factor = r1 * r2 * r3 + c = packed_c // factor + + tensor = tensor.view(b, r1, r2, r3, c, f, h, w) + tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3) + return tensor.reshape(b, c, f * r1, h * r2, w * r3) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_upsample else 1 + h = self.conv(x) + if self.add_temporal_upsample: + h_first = h[:, :, :1, :, :] + h_first = self._dcae_upsample_rearrange(h_first, r1=1, r2=2, r3=2) + h_first = h_first[:, : h_first.shape[1] // 2] + h_next = h[:, :, 1:, :, :] + h_next = self._dcae_upsample_rearrange(h_next, r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2) + x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1) + + x_next = x[:, :, 1:, :, :] + x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2) + x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = torch.cat([x_first, x_next], dim=2) + + else: + h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2) + return h + shortcut + + +class HunyuanVideo15Downsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2 + assert out_channels % factor == 0 + # self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3) + + self.add_temporal_downsample = add_temporal_downsample + self.group_size = factor * in_channels // out_channels + + @staticmethod + def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w) + + This packs spatial/temporal dimensions into channels (opposite of upsample) + """ + b, c, packed_f, packed_h, packed_w = tensor.shape + f, h, w = packed_f // r1, packed_h // r2, packed_w // r3 + + tensor = tensor.view(b, c, f, r1, h, r2, w, r3) + tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6) + return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_downsample else 1 + h = self.conv(x) + if self.add_temporal_downsample: + h_first = h[:, :, :1, :, :] + h_first = self._dcae_downsample_rearrange(h_first, r1=1, r2=2, r3=2) + h_first = torch.cat([h_first, h_first], dim=1) + h_next = h[:, :, 1:, :, :] + h_next = self._dcae_downsample_rearrange(h_next, r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = self._dcae_downsample_rearrange(x_first, r1=1, r2=2, r3=2) + B, C, T, H, W = x_first.shape + x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2) + x_next = x[:, :, 1:, :, :] + x_next = self._dcae_downsample_rearrange(x_next, r1=r1, r2=2, r3=2) + B, C, T, H, W = x_next.shape + x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + shortcut = torch.cat([x_first, x_next], dim=2) + else: + h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2) + B, C, T, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + + return h + shortcut + + +class HunyuanVideo15ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = HunyuanVideo15RMS_norm(in_channels, images=False) + self.conv1 = HunyuanVideo15CausalConv3d(in_channels, out_channels, kernel_size=3) + + self.norm2 = HunyuanVideo15RMS_norm(out_channels, images=False) + self.conv2 = HunyuanVideo15CausalConv3d(out_channels, out_channels, kernel_size=3) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual + + +class HunyuanVideo15MidBlock(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int = 1, + add_attention: bool = True, + ) -> None: + super().__init__() + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append(HunyuanVideo15AttnBlock(in_channels)) + else: + attentions.append(None) + + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideo15DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + downsample_out_channels: Optional[int] = None, + add_temporal_downsample: int = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if downsample_out_channels is not None: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideo15Downsample( + out_channels, + out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideo15UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + upsample_out_channels: Optional[int] = None, + add_temporal_upsample: bool = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=input_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if upsample_out_channels is not None: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideo15Upsample( + out_channels, + out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideo15Encoder3D(nn.Module): + r""" + 3D vae encoder for HunyuanImageRefiner. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 64, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 16, + downsample_match_channel: bool = True, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.group_size = block_out_channels[-1] // self.out_channels + + self.conv_in = HunyuanVideo15CausalConv3d(in_channels, block_out_channels[0], kernel_size=3) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + add_spatial_downsample = i < np.log2(spatial_compression_ratio) + output_channel = block_out_channels[i] + if not add_spatial_downsample: + down_block = HunyuanVideo15DownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=None, + add_temporal_downsample=False, + ) + input_channel = output_channel + else: + add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio) + downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel + down_block = HunyuanVideo15DownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + input_channel = downsample_out_channels + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[-1]) + + self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + # short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2) + batch_size, _, frame, height, width = hidden_states.shape + short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states += short_cut + + return hidden_states + + +class HunyuanVideo15Decoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner. + """ + + def __init__( + self, + in_channels: int = 32, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + upsample_match_channel: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.in_channels = in_channels + self.out_channels = out_channels + self.repeat = block_out_channels[0] // self.in_channels + + self.conv_in = HunyuanVideo15CausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[0]) + + # up + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + output_channel = block_out_channels[i] + + add_spatial_upsample = i < np.log2(spatial_compression_ratio) + add_temporal_upsample = i < np.log2(temporal_compression_ratio) + if add_spatial_upsample or add_temporal_upsample: + upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel + up_block = HunyuanVideo15UpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + input_channel = upsample_out_channels + else: + up_block = HunyuanVideo15UpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=None, + add_temporal_upsample=False, + ) + input_channel = output_channel + + self.up_blocks.append(up_block) + + # out + self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for + HunyuanVideo-1.5. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 32, + block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + downsample_match_channel: bool = True, + upsample_match_channel: bool = True, + scaling_factor: float = 1.03682, + ) -> None: + super().__init__() + + self.encoder = HunyuanVideo15Encoder3D( + in_channels=in_channels, + out_channels=latent_channels * 2, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + downsample_match_channel=downsample_match_channel, + ) + + self.decoder = HunyuanVideo15Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + upsample_match_channel=upsample_match_channel, + ) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal tile height and width in latent space + self.tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio + self.tile_overlap_factor = 0.25 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_latent_min_height: Optional[int] = None, + tile_latent_min_width: Optional[int] = None, + tile_overlap_factor: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_latent_min_height (`int`, *optional*): + The minimum height required for a latent to be separated into tiles across the height dimension. + tile_latent_min_width (`int`, *optional*): + The minimum width required for a latent to be separated into tiles across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = tile_latent_min_height or self.tile_latent_min_height + self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width + self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z) + + dec = self.decoder(z) + + return dec + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, _, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2 + blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2 + row_limit_height = self.tile_latent_min_height - blend_height # 8 - 2 = 6 + row_limit_width = self.tile_latent_min_width - blend_width # 8 - 2 = 6 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + moments = torch.cat(result_rows, dim=-2) + + return moments + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + _, _, _, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64 + blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64 + row_limit_height = self.tile_sample_min_height - blend_height # 256 - 64 = 192 + row_limit_width = self.tile_sample_min_width - blend_width # 256 - 64 = 192 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=-2) + + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 826469237fb1..ee408acd2368 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -44,3 +44,4 @@ from .transformer_wan import WanTransformer3DModel from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel + from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py new file mode 100644 index 000000000000..8fb2ea451a7b --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -0,0 +1,937 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.loaders import FromOriginalModelMixin + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention, AttentionProcessor +from ..cache_utils import CacheMixin +from ..embeddings import ( + CombinedTimestepTextProjEmbeddings, + TimestepEmbedding, + Timesteps, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanVideo15AttnProcessor2_0: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HunyuanVideo15AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if attn.add_q_proj is None and encoder_hidden_states is not None: + assert False # YiYi Notes: remove this condition if this code path is never used + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + else: + assert False + # YiYi Notes: remove this condition if this code path is never used + if attn.norm_k is not None: + key = attn.norm_k(key) + else: + assert False + # YiYi Notes: remove this condition if this code path is never used + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + if attn.add_q_proj is None and encoder_hidden_states is not None: + assert False # YiYi Notes: remove this condition if this code path is never used + query = torch.cat( + [ + apply_rotary_emb( + query[:, : -encoder_hidden_states.shape[1]], + image_rotary_emb, + sequence_dim=1, + ), + query[:, -encoder_hidden_states.shape[1] :], + ], + dim=1, + ) + key = torch.cat( + [ + apply_rotary_emb( + key[:, : -encoder_hidden_states.shape[1]], + image_rotary_emb, + sequence_dim=1, + ), + key[:, -encoder_hidden_states.shape[1] :], + ], + dim=1, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=1) + key = torch.cat([key, encoder_key], dim=1) + value = torch.cat([value, encoder_value], dim=1) + + else: + assert False # YiYi Notes: remove this condition if this code path is never used + + + batch_size, seq_len, heads, dim = query.shape + print(f" query.shape: {query.shape}") + print(f" attention_mask.shape: {attention_mask.shape}") + attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True) + print(f" attention_mask.shape: {attention_mask.shape}") + attention_mask = attention_mask.bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + attention_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + + # 5. Attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + print(f" hidden_states.shape: {hidden_states.shape}") + print(f" hidden_states[0,:10,:3]: {hidden_states[0,:10,:3]}") + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanVideoPatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class HunyuanVideoAdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideo15TimeEmbedding(nn.Module): + r""" + Time embedding for HunyuanVideo 1.5. + + Supports standard timestep embedding and optional reference timestep embedding + for MeanFlow-based super-resolution models. + + Args: + embedding_dim (`int`): + The dimension of the output embedding. + use_meanflow (`bool`, defaults to `False`): + Whether to support reference timestep embeddings for temporal consistency. + Set to `True` for super-resolution models. + """ + def __init__( + self, + embedding_dim: int, + use_meanflow: bool = False, + ): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + self.use_meanflow = use_meanflow + + self.time_proj_r = None + self.timestep_embedder_r = None + if use_meanflow: + self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + + def forward( + self, + timestep: torch.Tensor, + timestep_r: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype)) + + if timestep_r is not None: + timesteps_proj_r = self.time_proj_r(timestep_r) + timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype)) + timesteps_emb = timesteps_emb + timesteps_emb_r + + return timesteps_emb + + +class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideoIndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideoIndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + # YiYi TODO convert 1D mask to 4d Bx1xLxL + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideoTokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideoIndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + print(f" temb(time_text_embed).shape: {temb.shape}, {temb[0,:10]}") + hidden_states = self.proj_in(hidden_states) + print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,:3,:3]}") + print(f" temb: {temb.shape}, {temb[0,:10]}") + print(f" attention_mask: {attention_mask.shape}, {attention_mask[0,:3]}, {attention_mask.abs().sum()}") + print(f" -> token_refiner") + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + print(f" hidden_states(token_refiner) {hidden_states.shape}, {hidden_states[0,:3,:3]}") + + return hidden_states + + +class HunyuanVideoRotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(3): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + +# Copied from diffusers.models.transformers.transformer_hunyuanimage.HunyuanImageByT5TextProjection +class HunyuanVideo15ByT5TextProjection(nn.Module): + def __init__(self, in_features: int, hidden_size: int, out_features: int): + super().__init__() + self.norm = nn.LayerNorm(in_features) + self.linear_1 = nn.Linear(in_features, hidden_size) + self.linear_2 = nn.Linear(hidden_size, hidden_size) + self.linear_3 = nn.Linear(hidden_size, out_features) + self.act_fn = nn.GELU() + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(encoder_hidden_states) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_3(hidden_states) + return hidden_states + + +class HunyuanVideo15ImageProjection(nn.Module): + def __init__(self, in_channels: int, hidden_size: int): + super().__init__() + self.norm_in = nn.LayerNorm(in_channels) + self.linear_1 = nn.Linear(in_channels, in_channels) + self.act_fn = nn.GELU() + self.linear_2 = nn.Linear(in_channels, hidden_size) + self.norm_out = nn.LayerNorm(hidden_size) + + def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm_in(image_embeds) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.norm_out(hidden_states) + return hidden_states + + +class HunyuanVideoTransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideoAttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + print(f" norm_hidden_states(norm1).shape: {norm_hidden_states.shape}, {norm_hidden_states[0,:10,:3]}") + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + print(f" norm_encoder_hidden_states(norm1_context).shape: {norm_encoder_hidden_states.shape}, {norm_encoder_hidden_states[0,:10,:3]}") + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + print(f" attn_output.shape: {attn_output.shape}, {attn_output[0,:10,:3]}") + print(f" context_attn_output.shape: {context_attn_output.shape}, {context_attn_output[0,:10,:3]}") + + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + print(f" hidden_states(ff): {hidden_states.shape}, {hidden_states[0,:10,:3]}") + print(f" encoder_hidden_states(ff): {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + + return hidden_states, encoder_hidden_states + + +class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): + r""" + A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). + + Args: + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of dual-stream blocks to use. + num_refiner_layers (`int`, defaults to `2`): + The number of layers of refiner blocks to use. + mlp_ratio (`float`, defaults to `4.0`): + The ratio of the hidden layer size to the input size in the feedforward network. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + qk_norm (`str`, defaults to `rms_norm`): + The normalization to use for the query and key projections in the attention layers. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings in the model. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + pooled_projection_dim (`int`, defaults to `768`): + The dimension of the pooled projection of the text embeddings. + rope_theta (`float`, defaults to `256.0`): + The value of theta to use in the RoPE layer. + rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions of the axes to use in the RoPE layer. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] + _no_split_modules = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] + _repeated_blocks = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + num_attention_heads: int = 24, + attention_head_dim: int = 128, + num_layers: int = 20, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 1, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + text_embed_dim: int = 3584, + text_embed_2_dim: int = 1472, + image_embed_dim: int = 1152, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int, ...] = (16, 56, 56), + use_meanflow: bool = False, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim) + + self.context_embedder = HunyuanVideoTokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim) + + self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow) + + self.cond_type_embed = nn.Embedding(3, inner_dim) + + # 2. RoPE + self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideoTransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + encoder_attention_mask_2: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + timestep_r: Optional[torch.LongTensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_embed(timestep, timestep_r=timestep_r) + + hidden_states = self.x_embedder(hidden_states) + + # qwen text embedding + print(f" encoder_hidden_states(qwen).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + print(f" timestep: {timestep}, {timestep[:10]}") + print(f" encoder_attention_mask: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.abs().sum()}") + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + print(f" encoder_hidden_states(token_refiner).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + + encoder_hidden_states_cond_emb = self.cond_type_embed( + torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long) + ) + encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb + print(f" encoder_hidden_states(+ cond_emb).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + + # byt5 text embedding + encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2) + print(f" encoder_hidden_states_2(byt5).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}") + + encoder_hidden_states_2_cond_emb = self.cond_type_embed( + torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long) + ) + encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb + print(f" encoder_hidden_states_2(+ cond_emb).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}") + + # image embed + encoder_hidden_states_3 = self.image_embedder(image_embeds) + print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") + is_t2v = torch.all(image_embeds == 0) + if is_t2v: + encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0 + encoder_attention_mask_3 = torch.zeros( + (batch_size, encoder_hidden_states_3.shape[1]), + dtype=encoder_attention_mask.dtype, + device=encoder_attention_mask.device, + ) + print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") + print(f" encoder_attention_mask_3: {encoder_attention_mask_3.shape}, {encoder_attention_mask_3[0,:10]}, {encoder_attention_mask_3.abs().sum()}") + else: + encoder_attention_mask_3 = torch.ones( + (batch_size, encoder_hidden_states_3.shape[1]), + dtype=encoder_attention_mask.dtype, + device=encoder_attention_mask.device, + ) + encoder_hidden_states_3_cond_emb = self.cond_type_embed( + 2 * torch.ones_like( + encoder_hidden_states_3[:, :, 0], + dtype=torch.long, + ) + ) + encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb + + print(f" encoder_hidden_states_3(+ cond_emb).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") + + + # reorder and combine text tokens: combine valid tokens first, then padding + encoder_attention_mask = encoder_attention_mask.bool() + encoder_attention_mask_2 = encoder_attention_mask_2.bool() + encoder_attention_mask_3 = encoder_attention_mask_3.bool() + new_encoder_hidden_states = [] + new_encoder_attention_mask = [] + + for text, text_mask, text_2, text_mask_2, image, image_mask in zip( + encoder_hidden_states, + encoder_attention_mask, + encoder_hidden_states_2, + encoder_attention_mask_2, + encoder_hidden_states_3, + encoder_attention_mask_3, + ): + # Concatenate: [valid_image, valid_byt5, valid_mllm, invalid_image, invalid_byt5, invalid_mllm] + new_encoder_hidden_states.append( + torch.cat( + [ + image[image_mask], # valid image + text_2[text_mask_2], # valid byt5 + text[text_mask], # valid mllm + image[~image_mask], # invalid image + torch.zeros_like(text_2[~text_mask_2]), # invalid byt5 (zeroed) + torch.zeros_like(text[~text_mask]), # invalid mllm (zeroed) + ], + dim=0, + ) + ) + + # Apply same reordering to attention masks + new_encoder_attention_mask.append( + torch.cat( + [ + image_mask[image_mask], + text_mask_2[text_mask_2], + text_mask[text_mask], + image_mask[~image_mask], + text_mask_2[~text_mask_2], + text_mask[~text_mask], + ], + dim=0, + ) + ) + + encoder_hidden_states = torch.stack(new_encoder_hidden_states) + encoder_attention_mask = torch.stack(new_encoder_attention_mask) + + print(f" hidden_states.shape: {hidden_states.shape}, {hidden_states[0,:3,:3]}") + print(f" encoder_hidden_states.shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") + print(f" encoder_attention_mask.shape: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.dtype}, {encoder_attention_mask.sum()}") + print(f" image_rotary_emb: {image_rotary_emb[0].shape}, {image_rotary_emb[1].shape}, {image_rotary_emb[0][:3,:10]}, {image_rotary_emb[1][:3,:10]}") + print(f" temb.shape: {temb.shape}, {temb[0,:10]}") + + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p_h, p_w + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) From 56d57c3e9b0f9b074005d2d5448d4ae04b78b5fb Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 22 Nov 2025 02:06:17 +0100 Subject: [PATCH 02/34] add first pipeline draft --- .../transformer_hunyuan_video15.py | 5 +- .../pipelines/hunyuan_video1_5/__init__.py | 48 + .../hunyuan_video1_5/image_processor.py | 97 ++ .../pipeline_hunyuan_video1_5.py | 879 ++++++++++++++++++ .../hunyuan_video1_5/pipeline_output.py | 23 + 5 files changed, 1051 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/hunyuan_video1_5/__init__.py create mode 100644 src/diffusers/pipelines/hunyuan_video1_5/image_processor.py create mode 100644 src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py create mode 100644 src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 8fb2ea451a7b..86c3a3565f57 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -514,7 +514,7 @@ def __init__( out_dim=hidden_size, context_pre_only=False, bias=True, - processor=HunyuanVideoAttnProcessor2_0(), + processor=HunyuanVideo15AttnProcessor2_0(), qk_norm=qk_norm, eps=1e-6, ) @@ -645,6 +645,9 @@ def __init__( rope_theta: float = 256.0, rope_axes_dim: Tuple[int, ...] = (16, 56, 56), use_meanflow: bool = False, + # YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205 + target_size: int = 640, # did not name sample_size since it is in pixel spaces + task_type: str = "i2v", ) -> None: super().__init__() diff --git a/src/diffusers/pipelines/hunyuan_video1_5/__init__.py b/src/diffusers/pipelines/hunyuan_video1_5/__init__.py new file mode 100644 index 000000000000..09bffb88353c --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py new file mode 100644 index 000000000000..5963dd43bd2e --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -0,0 +1,97 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...video_processor import VideoProcessor +from ...configuration_utils import register_to_config + +# Copied from hyvideo/utils/data_utils.py +def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): + num_patches = round((base_size / patch_size) ** 2) + assert max_ratio >= 1.0 + crop_size_list = [] + wp, hp = num_patches, 1 + while wp > 0: + if max(wp, hp) / min(wp, hp) <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + if (hp + 1) * wp <= num_patches: + hp += 1 + else: + wp -= 1 + return crop_size_list + +# Copied from hyvideo/utils/data_utils.py +def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): + """ + Get the closest ratio in the buckets. + + Args: + height (float): video height + width (float): video width + ratios (list): video aspect ratio + buckets (list): buckets generated by `generate_crop_size_list` + + Returns: + the closest size in the buckets and the corresponding ratio + """ + aspect_ratio = float(height) / float(width) + diff_ratios = ratios - aspect_ratio + + if aspect_ratio >= 1: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0] + else: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0] + + closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0] + closest_size = buckets[closest_ratio_id] + closest_ratio = ratios[closest_ratio_id] + + return closest_size, closest_ratio + +class HunyuanVideo15ImageProcessor(VideoProcessor): + r""" + Image/video processor to preproces/postprocess the reference image/generatedvideo for the HunyuanVideo1.5 model. + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `16`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `32`): + VAE latent channels. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels + ) + + + def calculate_default_height_width(self, height: int, width: int, target_size: int): + + crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.vae_scale_factor) + aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) + height, width = get_closest_ratio(height, width, aspect_ratios, crop_size_list)[0] + + return height, width \ No newline at end of file diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py new file mode 100644 index 000000000000..1a2aa517921b --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -0,0 +1,879 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import re + +import numpy as np +import torch +from transformers import Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel, ByT5Tokenizer + +from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from .image_processor import HunyuanVideo15ImageProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideo15PipelineOutput +from ...guiders import ClassifierFreeGuidance + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideo15Pipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo15" + >>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +def format_text_input(prompt: List[str], system_message: str + ) -> List[Dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (List[str]): Input text. + system_message (str): System message. + + Returns: + List[Dict[str, Any]]: List of chat conversation. + """ + + template = [ + [ + { + 'role': 'system', + 'content': system_message}, + {'role': 'user', 'content': p if p else " "} + ] + for p in prompt] + + return template + + +def extract_glyph_texts(prompt: str) -> List[str]: + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r'\"(.*?)\"|“(.*?)”' + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + + return formatted_result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo1.5. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`HunyuanVideo15Transformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo15`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`ClassifierFreeGuidance`]): + [ClassifierFreeGuidance]for classifier free guidance. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanVideo15Transformer3DModel, + vae: AutoencoderKLHunyuanVideo15, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + guider: ClassifierFreeGuidance, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + guider=guider, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 + self.vision_states_dim = self.transformer.config.vision_states_dim if getattr(self, "transformer", None) else 729 + # fmt: off + self.system_message ="You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on + self.prompt_template_encode_start_idx = 108 + self.tokenizer_max_length = 1000 + self.text_encoder_2_max_length = 256 + self.vision_num_semantic_tokens = 729 + + + @staticmethod + def _get_mllm_prompt_embeds( + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + tokenizer_max_length: int = 1000, + num_hidden_layers_to_skip: int = 2, + # fmt: off + system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start: int = 108, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=False, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + + @staticmethod + def _get_byt5_prompt_embeds( + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + tokenizer_max_length: int = 256, + ): + + prompt = [prompt] if isinstance(prompt, str) else prompt + + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=dtype + ) + glyph_text_embeds_mask = torch.zeros( + (1, tokenizer_max_length), device=device, dtype=torch.int64 + ) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(dtype=dtype, device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0) + + return prompt_embeds, prompt_embeds_mask + + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + batch_size: int = 1, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_embeds_mask_2: Optional[torch.Tensor] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + dtype=dtype, + tokenizer_max_length=self.tokenizer_max_length, + system_message=self.system_message, + crop_start=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=prompt, + device=device, + dtype=dtype, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + ): + + if height is None and width is not None: + raise ValueError( + "If `width` is provided, `height` also have to be provided." + ) + elif width is None and height is not None: + raise ValueError( + "If `height` is provided, `width` also have to be provided." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + + def prepare_cond_latents_and_mask(self, latents): + """ + Prepare conditional latents and mask for t2v generation. + + Args: + latents: Main latents tensor (B, C, F, H, W) + + Returns: + tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v + """ + batch, channels, frames, height, width = latents.shape + + cond_latents_concat = torch.zeros( + batch, channels, frames, height, width, + device=latents.device, + dtype=latents.dtype + ) + + mask_concat = torch.zeros( + batch, 1, frames, height, width, + device=latents.device + ) + + return cond_latents_concat, mask_concat + + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_embeds_mask_2: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + guidance_scale (`float`, defaults to `6.0`): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + if height is None and width is None: + height, width = self.video_processor.calculate_default_height_width(height, width, self.target_size) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + if self.guider._enabled and self.guider.num_conditions >1 : + negative_prompt_embeds, negative_prompt_embeds_mask, negative_prompt_embeds_2, negative_prompt_embeds_mask_2 = self.encode_prompt( + prompt=negative_prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents) + vision_states = torch.zeros(batch_size, self.vision_num_semantic_tokens, self.vision_states_dim).to(latents.device) + + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = self.guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latents, + image_embeds=vision_states, + timestep=timestep, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + self.guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + self.vae.enable_tiling() + video = self.vae.decode(latents, return_dict=False, generator=generator)[0] + self.vae.disable_tiling() + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideo15PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py new file mode 100644 index 000000000000..3adb54e1fbed --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideo15PipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo1.5 pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor From b282ac1510029b1f3604ee569517f6c624d2a1ad Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 22 Nov 2025 02:06:42 +0100 Subject: [PATCH 03/34] add text encoders to conversion script --- .../convert_hunyuan_video1_5_to_diffusers.py | 379 +++++++++++++++++- 1 file changed, 369 insertions(+), 10 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 59df02d1901f..35b76fcf4663 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -1,21 +1,32 @@ """ python scripts/convert_hunyuan_video1_5_to_diffusers.py \ --original_state_dict_folder /raid/yiyi/new-model-vid \ - --output_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \ + --output_transformer_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \ --transformer_type 480p_i2v \ --dtype fp32 """ +""" +python scripts/convert_hunyuan_video1_5_to_diffusers.py \ + --original_state_dict_folder /raid/yiyi/new-model-vid \ + --output_vae_path /raid/yiyi/hunyuanvideo15-vae \ + --dtype fp32 +""" + import argparse from typing import Any, Dict import torch from accelerate import init_empty_weights from safetensors.torch import load_file -from huggingface_hub import snapshot_download +from huggingface_hub import snapshot_download, hf_hub_download import pathlib -from diffusers import HunyuanVideo15Transformer3DModel +from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15 +from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer + +import json +import argparse TRANSFORMER_CONFIGS = { "480p_i2v": { @@ -316,6 +327,193 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict): return converted_state_dict +def convert_hunyuan_video_15_vae_checkpoint_to_diffusers( + original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2 +): + converted = {} + + # 1. Encoder + # 1.1 conv_in + converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight") + converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias") + + # 1.2 Down blocks + for down_block_index in range(len(block_out_channels)): # 0 to 4 + # ResNet blocks + for resnet_block_index in range(layers_per_block): # 0 to 1 + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma") + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = ( + original_state_dict.pop( + f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight" + ) + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias") + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma") + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = ( + original_state_dict.pop( + f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight" + ) + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias") + ) + + # Downsample (if exists) + if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict: + converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight") + ) + converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias") + ) + + # 1.3 Mid block + converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma") + converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop( + "encoder.mid.block_1.conv1.conv.weight" + ) + converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop( + "encoder.mid.block_1.conv1.conv.bias" + ) + converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma") + converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop( + "encoder.mid.block_1.conv2.conv.weight" + ) + converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop( + "encoder.mid.block_1.conv2.conv.bias" + ) + + converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma") + converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop( + "encoder.mid.block_2.conv1.conv.weight" + ) + converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop( + "encoder.mid.block_2.conv1.conv.bias" + ) + converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma") + converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop( + "encoder.mid.block_2.conv2.conv.weight" + ) + converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop( + "encoder.mid.block_2.conv2.conv.bias" + ) + + # Attention block + converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma") + converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight") + converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias") + converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight") + converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias") + converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight") + converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias") + converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop( + "encoder.mid.attn_1.proj_out.weight" + ) + converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop( + "encoder.mid.attn_1.proj_out.bias" + ) + + # 1.4 Encoder output + converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma") + converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight") + converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias") + + # 2. Decoder + # 2.1 conv_in + converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight") + converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias") + + # 2.2 Mid block + converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma") + converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop( + "decoder.mid.block_1.conv1.conv.weight" + ) + converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop( + "decoder.mid.block_1.conv1.conv.bias" + ) + converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma") + converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop( + "decoder.mid.block_1.conv2.conv.weight" + ) + converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop( + "decoder.mid.block_1.conv2.conv.bias" + ) + + converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma") + converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop( + "decoder.mid.block_2.conv1.conv.weight" + ) + converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop( + "decoder.mid.block_2.conv1.conv.bias" + ) + converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma") + converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop( + "decoder.mid.block_2.conv2.conv.weight" + ) + converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop( + "decoder.mid.block_2.conv2.conv.bias" + ) + + # Decoder attention block + converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma") + converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight") + converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias") + converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight") + converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias") + converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight") + converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias") + converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop( + "decoder.mid.attn_1.proj_out.weight" + ) + converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop( + "decoder.mid.attn_1.proj_out.bias" + ) + + # 2.3 Up blocks + for up_block_index in range(len(block_out_channels)): # 0 to 5 + # ResNet blocks + for resnet_block_index in range(layers_per_block + 1): # 0 to 2 (decoder has 3 resnets per level) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias") + ) + + # Upsample (if exists) + if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict: + converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop( + f"decoder.up.{up_block_index}.upsample.conv.conv.weight" + ) + converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop( + f"decoder.up.{up_block_index}.upsample.conv.conv.bias" + ) + + # 2.4 Decoder output + converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma") + converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight") + converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias") + + return converted + def load_sharded_safetensors(dir: pathlib.Path): file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) state_dict = {} @@ -324,7 +522,7 @@ def load_sharded_safetensors(dir: pathlib.Path): return state_dict -def load_original_state_dict(args): +def load_original_transformer_state_dict(args): if args.original_state_dict_repo_id is not None: model_dir = snapshot_download( args.original_state_dict_repo_id, @@ -339,8 +537,23 @@ def load_original_state_dict(args): model_dir = model_dir / "transformer" / args.transformer_type return load_sharded_safetensors(model_dir) +def load_original_vae_state_dict(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download( + repo_id=args.original_state_dict_repo_id, + filename= "vae/diffusion_pytorch_model.safetensors" + ) + elif args.original_state_dict_folder is not None: + model_dir = pathlib.Path(args.original_state_dict_folder) + ckpt_path = model_dir / "vae/diffusion_pytorch_model.safetensors" + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`") + + original_state_dict = load_file(ckpt_path) + return original_state_dict + def convert_transformer(args): - original_state_dict = load_original_state_dict(args) + original_state_dict = load_original_transformer_state_dict(args) config = TRANSFORMER_CONFIGS[args.transformer_type] with init_empty_weights(): @@ -350,13 +563,152 @@ def convert_transformer(args): return transformer +def convert_vae(args): + original_state_dict = load_original_vae_state_dict(args) + with init_empty_weights(): + vae = AutoencoderKLHunyuanVideo15() + state_dict = convert_hunyuan_video_15_vae_checkpoint_to_diffusers(original_state_dict) + vae.load_state_dict(state_dict, strict=True, assign=True) + return vae + +def save_text_encoder(output_path): + text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", low_cpu_mem_usage=True) + if hasattr(text_encoder, 'language_model'): + text_encoder = text_encoder.language_model + + + text_encoder.save_pretrained(output_path + "/text_encoder") + + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right") + tokenizer.save_pretrained(output_path + "/tokenizer") + + +def add_special_token( + tokenizer, + text_encoder, + add_color=True, + add_font=True, + multilingual=True, + color_ann_path='assets/color_idx.json', + font_ann_path='assets/multilingual_10-lang_idx.json', +): + """ + Add special tokens for color and font to tokenizer and text encoder. + + Args: + tokenizer: Huggingface tokenizer. + text_encoder: Huggingface T5 encoder. + add_color (bool): Whether to add color tokens. + add_font (bool): Whether to add font tokens. + color_ann_path (str): Path to color annotation JSON. + font_ann_path (str): Path to font annotation JSON. + multilingual (bool): Whether to use multilingual font tokens. + """ + with open(font_ann_path, 'r') as f: + idx_font_dict = json.load(f) + with open(color_ann_path, 'r') as f: + idx_color_dict = json.load(f) + + if multilingual: + font_token = [f'<{font_code[:2]}-font-{idx_font_dict[font_code]}>' for font_code in idx_font_dict] + else: + font_token = [f'' for i in range(len(idx_font_dict))] + color_token = [f'' for i in range(len(idx_color_dict))] + additional_special_tokens = [] + if add_color: + additional_special_tokens += color_token + if add_font: + additional_special_tokens += font_token + + tokenizer.add_tokens(additional_special_tokens, special_tokens=True) + # Set mean_resizing=False to avoid PyTorch LAPACK dependency + text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) + + +def save_text_encoder_2( + byt5_base_path, + byt5_checkpoint_path, + color_ann_path, + font_ann_path, + output_path, + multilingual=True +): + """ + Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format. + + Args: + byt5_base_path: Path to base byt5-small model (e.g., "google/byt5-small") + byt5_checkpoint_path: Path to Glyph-SDXL-v2 checkpoint (byt5_model.pt) + color_ann_path: Path to color_idx.json + font_ann_path: Path to multilingual_10-lang_idx.json + output_path: Where to save the converted model + multilingual: Whether to use multilingual font tokens + """ + + + # 1. Load base tokenizer and encoder + tokenizer = AutoTokenizer.from_pretrained(byt5_base_path) + + # Load as T5EncoderModel + encoder = T5EncoderModel.from_pretrained(byt5_base_path) + + # 2. Add special tokens + add_special_token( + tokenizer, + encoder, + color_ann_path=color_ann_path, + font_ann_path=font_ann_path, + multilingual=multilingual + ) + + # 3. Load Glyph-SDXL-v2 checkpoint + print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}") + checkpoint = torch.load(byt5_checkpoint_path, map_location='cpu') + + # Handle different checkpoint formats + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + # add 'encoder.' prefix to the keys + # Remove 'module.text_tower.encoder.' prefix if present + cleaned_state_dict = {} + for key, value in state_dict.items(): + if key.startswith('module.text_tower.encoder.'): + new_key = 'encoder.' + key[len('module.text_tower.encoder.'):] + cleaned_state_dict[new_key] = value + else: + new_key = 'encoder.' + key + cleaned_state_dict[new_key] = value + + + # 4. Load weights + missing_keys, unexpected_keys = encoder.load_state_dict(cleaned_state_dict, strict=False) + if unexpected_keys: + raise ValueError(f"Unexpected keys: {unexpected_keys}") + if "shared.weight" in missing_keys: + print(f" Missing shared.weight as expected") + missing_keys.remove("shared.weight") + if missing_keys: + raise ValueError(f"Missing keys: {missing_keys}") + + + # Save encoder + encoder.save_pretrained(output_path + "/text_encoder_2") + + # Save tokenizer + tokenizer.save_pretrained(output_path + "/tokenizer_2") + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model" ) - parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Folder name of the original state dict") - parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict") + parser.add_argument("--output_vae_path", type=str, default=None, help="Path where converted VAE should be saved") + parser.add_argument("--output_transformer_path", type=str, default=None, help="Path where converted transformer should be saved") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") parser.add_argument( "--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()) @@ -377,6 +729,13 @@ def get_args(): transformer = None dtype = DTYPE_MAPPING[args.dtype] - transformer = convert_transformer(args) - transformer = transformer.to(dtype=dtype) - transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if args.output_transformer_path is not None: + transformer = convert_transformer(args) + transformer = transformer.to(dtype=dtype) + transformer.save_pretrained(args.output_transformer_path, safe_serialization=True) + + if args.output_vae_path is not None: + vae = convert_vae(args) + vae = vae.to(dtype=dtype) + vae.save_pretrained(args.output_vae_path, safe_serialization=True) + From 5732d60db366a29c7364c0ef136d8ac42137468a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 22 Nov 2025 06:37:03 +0100 Subject: [PATCH 04/34] fix a bit more, remove print lines --- src/diffusers/__init__.py | 2 + .../transformer_hunyuan_video15.py | 36 ----------------- src/diffusers/pipelines/__init__.py | 2 + .../pipeline_hunyuan_video1_5.py | 40 ++++++++++++------- 4 files changed, 30 insertions(+), 50 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eb9929cf2c99..eaaac1838d90 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -482,6 +482,7 @@ "HunyuanVideoFramepackPipeline", "HunyuanVideoImageToVideoPipeline", "HunyuanVideoPipeline", + "HunyuanVideo15Pipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -1168,6 +1169,7 @@ HunyuanVideoFramepackPipeline, HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, + HunyuanVideo15Pipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 86c3a3565f57..c26b43e19ce4 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -140,10 +140,7 @@ def __call__( batch_size, seq_len, heads, dim = query.shape - print(f" query.shape: {query.shape}") - print(f" attention_mask.shape: {attention_mask.shape}") attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True) - print(f" attention_mask.shape: {attention_mask.shape}") attention_mask = attention_mask.bool() self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) @@ -160,8 +157,6 @@ def __call__( backend=self._attention_backend, parallel_config=self._parallel_config, ) - print(f" hidden_states.shape: {hidden_states.shape}") - print(f" hidden_states[0,:10,:3]: {hidden_states[0,:10,:3]}") hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -407,14 +402,8 @@ def forward( pooled_projections = pooled_projections.to(original_dtype) temb = self.time_text_embed(timestep, pooled_projections) - print(f" temb(time_text_embed).shape: {temb.shape}, {temb[0,:10]}") hidden_states = self.proj_in(hidden_states) - print(f" hidden_states: {hidden_states.shape}, {hidden_states[0,:3,:3]}") - print(f" temb: {temb.shape}, {temb[0,:10]}") - print(f" attention_mask: {attention_mask.shape}, {attention_mask[0,:3]}, {attention_mask.abs().sum()}") - print(f" -> token_refiner") hidden_states = self.token_refiner(hidden_states, temb, attention_mask) - print(f" hidden_states(token_refiner) {hidden_states.shape}, {hidden_states[0,:3,:3]}") return hidden_states @@ -537,11 +526,9 @@ def forward( ) -> Tuple[torch.Tensor, torch.Tensor]: # 1. Input normalization norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) - print(f" norm_hidden_states(norm1).shape: {norm_hidden_states.shape}, {norm_hidden_states[0,:10,:3]}") norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - print(f" norm_encoder_hidden_states(norm1_context).shape: {norm_encoder_hidden_states.shape}, {norm_encoder_hidden_states[0,:10,:3]}") # 2. Joint attention attn_output, context_attn_output = self.attn( @@ -550,8 +537,6 @@ def forward( attention_mask=attention_mask, image_rotary_emb=freqs_cis, ) - print(f" attn_output.shape: {attn_output.shape}, {attn_output[0,:10,:3]}") - print(f" context_attn_output.shape: {context_attn_output.shape}, {context_attn_output[0,:10,:3]}") # 3. Modulation and residual connection @@ -570,8 +555,6 @@ def forward( hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output - print(f" hidden_states(ff): {hidden_states.shape}, {hidden_states[0,:10,:3]}") - print(f" encoder_hidden_states(ff): {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") return hidden_states, encoder_hidden_states @@ -791,31 +774,23 @@ def forward( hidden_states = self.x_embedder(hidden_states) # qwen text embedding - print(f" encoder_hidden_states(qwen).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") - print(f" timestep: {timestep}, {timestep[:10]}") - print(f" encoder_attention_mask: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.abs().sum()}") encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) - print(f" encoder_hidden_states(token_refiner).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") encoder_hidden_states_cond_emb = self.cond_type_embed( torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long) ) encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb - print(f" encoder_hidden_states(+ cond_emb).shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") # byt5 text embedding encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2) - print(f" encoder_hidden_states_2(byt5).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}") encoder_hidden_states_2_cond_emb = self.cond_type_embed( torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long) ) encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb - print(f" encoder_hidden_states_2(+ cond_emb).shape: {encoder_hidden_states_2.shape}, {encoder_hidden_states_2[0,:10,:3]}") # image embed encoder_hidden_states_3 = self.image_embedder(image_embeds) - print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") is_t2v = torch.all(image_embeds == 0) if is_t2v: encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0 @@ -824,8 +799,6 @@ def forward( dtype=encoder_attention_mask.dtype, device=encoder_attention_mask.device, ) - print(f" encoder_hidden_states_3(image).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") - print(f" encoder_attention_mask_3: {encoder_attention_mask_3.shape}, {encoder_attention_mask_3[0,:10]}, {encoder_attention_mask_3.abs().sum()}") else: encoder_attention_mask_3 = torch.ones( (batch_size, encoder_hidden_states_3.shape[1]), @@ -840,9 +813,6 @@ def forward( ) encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb - print(f" encoder_hidden_states_3(+ cond_emb).shape: {encoder_hidden_states_3.shape}, {encoder_hidden_states_3[0,:10,:3]}") - - # reorder and combine text tokens: combine valid tokens first, then padding encoder_attention_mask = encoder_attention_mask.bool() encoder_attention_mask_2 = encoder_attention_mask_2.bool() @@ -891,12 +861,6 @@ def forward( encoder_hidden_states = torch.stack(new_encoder_hidden_states) encoder_attention_mask = torch.stack(new_encoder_attention_mask) - print(f" hidden_states.shape: {hidden_states.shape}, {hidden_states[0,:3,:3]}") - print(f" encoder_hidden_states.shape: {encoder_hidden_states.shape}, {encoder_hidden_states[0,:10,:3]}") - print(f" encoder_attention_mask.shape: {encoder_attention_mask.shape}, {encoder_attention_mask[0,:10]}, {encoder_attention_mask.dtype}, {encoder_attention_mask.sum()}") - print(f" image_rotary_emb: {image_rotary_emb[0].shape}, {image_rotary_emb[1].shape}, {image_rotary_emb[0][:3,:10]}, {image_rotary_emb[1][:3,:10]}") - print(f" temb.shape: {temb.shape}, {temb[0,:10]}") - # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 69bb14b98edc..fe84f5c7ca85 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -242,6 +242,7 @@ "HunyuanVideoImageToVideoPipeline", "HunyuanVideoFramepackPipeline", ] + _import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"] _import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", @@ -662,6 +663,7 @@ HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, ) + from .hunyuan_video1_5 import HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index 1a2aa517921b..378f5570230b 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -27,6 +27,7 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import HunyuanVideo15PipelineOutput from ...guiders import ClassifierFreeGuidance +from ...utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -225,7 +226,7 @@ def __init__( self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 - self.vision_states_dim = self.transformer.config.vision_states_dim if getattr(self, "transformer", None) else 729 + self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 # fmt: off self.system_message ="You are a helpful assistant. Describe the video by detailing the following aspects: \ 1. The main content and theme of the video. \ @@ -236,8 +237,9 @@ def __init__( # fmt: on self.prompt_template_encode_start_idx = 108 self.tokenizer_max_length = 1000 - self.text_encoder_2_max_length = 256 + self.tokenizer_2_max_length = 256 self.vision_num_semantic_tokens = 729 + self.default_aspect_ratio = (16, 9) # (width: height) @staticmethod @@ -282,7 +284,7 @@ def _get_mllm_prompt_embeds( prompt_embeds = text_encoder( input_ids=text_input_ids, attention_mask=prompt_attention_mask, - output_hidden_states=False, + output_hidden_states=True, ).hidden_states[-(num_hidden_layers_to_skip + 1)] prompt_embeds = prompt_embeds.to(dtype=dtype) @@ -521,7 +523,7 @@ def prepare_latents( return latents - def prepare_cond_latents_and_mask(self, latents): + def prepare_cond_latents_and_mask(self, latents, dtype: Optional[torch.dtype], device: Optional[torch.device]): """ Prepare conditional latents and mask for t2v generation. @@ -535,13 +537,14 @@ def prepare_cond_latents_and_mask(self, latents): cond_latents_concat = torch.zeros( batch, channels, frames, height, width, - device=latents.device, - dtype=latents.dtype + dtype=dtype, + device=device ) mask_concat = torch.zeros( batch, 1, frames, height, width, - device=latents.device + dtype=dtype, + device=device ) return cond_latents_concat, mask_concat @@ -702,7 +705,7 @@ def __call__( ) if height is None and width is None: - height, width = self.video_processor.calculate_default_height_width(height, width, self.target_size) + height, width = self.video_processor.calculate_default_height_width(self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size) self._attention_kwargs = attention_kwargs self._current_timestep = None @@ -761,8 +764,19 @@ def __call__( generator, latents, ) - cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents) - vision_states = torch.zeros(batch_size, self.vision_num_semantic_tokens, self.vision_states_dim).to(latents.device) + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, torch.float32, device) + image_embeds = torch.zeros( + batch_size, + self.vision_num_semantic_tokens, + self.vision_states_dim, + dtype=torch.float32, + device=device + ) + + image_embeds = image_embeds.to(self.transformer.dtype) + latents=latents.to(self.transformer.dtype) + cond_latents_concat=cond_latents_concat.to(self.transformer.dtype) + mask_concat=mask_concat.to(self.transformer.dtype) # 7. Denoising loop @@ -817,8 +831,8 @@ def __call__( with self.transformer.cache_context(context_name): # Run denoiser and store noise prediction in this batch guider_state_batch.noise_pred = self.transformer( - hidden_states=latents, - image_embeds=vision_states, + hidden_states=latent_model_input, + image_embeds=image_embeds, timestep=timestep, attention_kwargs=self.attention_kwargs, return_dict=False, @@ -863,9 +877,7 @@ def __call__( if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor - self.vae.enable_tiling() video = self.vae.decode(latents, return_dict=False, generator=generator)[0] - self.vae.disable_tiling() video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From 76bb607bc062b3ec02224e95c7d644d0148a3530 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 23 Nov 2025 05:43:01 +0100 Subject: [PATCH 05/34] fix more, system prompt etc --- .../pipeline_hunyuan_video1_5.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index 378f5570230b..3464853add8f 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -227,13 +227,14 @@ def __init__( self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + self.num_channels_latents = self.vae.latent_channels if hasattr(self, "vae") else 32 # fmt: off - self.system_message ="You are a helpful assistant. Describe the video by detailing the following aspects: \ - 1. The main content and theme of the video. \ - 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ - 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ - 4. background environment, light, style and atmosphere. \ - 5. camera angles, movements, and transitions used in the video." + self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." # fmt: on self.prompt_template_encode_start_idx = 108 self.tokenizer_max_length = 1000 @@ -253,11 +254,11 @@ def _get_mllm_prompt_embeds( num_hidden_layers_to_skip: int = 2, # fmt: off system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ - 1. The main content and theme of the video. \ - 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ - 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ - 4. background environment, light, style and atmosphere. \ - 5. camera angles, movements, and transitions used in the video.", + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", # fmt: on crop_start: int = 108, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -286,12 +287,13 @@ def _get_mllm_prompt_embeds( attention_mask=prompt_attention_mask, output_hidden_states=True, ).hidden_states[-(num_hidden_layers_to_skip + 1)] - prompt_embeds = prompt_embeds.to(dtype=dtype) if crop_start is not None and crop_start > 0: prompt_embeds = prompt_embeds[:, crop_start:] prompt_attention_mask = prompt_attention_mask[:, crop_start:] + prompt_embeds = prompt_embeds.to(dtype=dtype) + return prompt_embeds, prompt_attention_mask @@ -578,7 +580,7 @@ def __call__( negative_prompt: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, - num_frames: int = 129, + num_frames: int = 121, num_inference_steps: int = 50, sigmas: List[float] = None, num_videos_per_prompt: Optional[int] = 1, @@ -752,10 +754,9 @@ def __call__( timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, - num_channels_latents, + self.num_channels_latents, height, width, num_frames, @@ -877,7 +878,7 @@ def __call__( if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor - video = self.vae.decode(latents, return_dict=False, generator=generator)[0] + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From c739ee9cedbea0ecdd00eed6d8b90af51d577aa1 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Wed, 26 Nov 2025 07:38:16 +0000 Subject: [PATCH 06/34] update conversion script --- .../convert_hunyuan_video1_5_to_diffusers.py | 142 ++++++++++++------ 1 file changed, 95 insertions(+), 47 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 35b76fcf4663..2694ac2834fe 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -1,16 +1,19 @@ """ python scripts/convert_hunyuan_video1_5_to_diffusers.py \ - --original_state_dict_folder /raid/yiyi/new-model-vid \ - --output_transformer_path /raid/yiyi/hunyuanvideo15-480p_i2v-diffusers \ + --original_state_dict_repo_id tencent/HunyuanVideo-1.5\ + --output_path /fsx/yiyi/hy15/480p_i2v\ --transformer_type 480p_i2v \ --dtype fp32 """ """ python scripts/convert_hunyuan_video1_5_to_diffusers.py \ - --original_state_dict_folder /raid/yiyi/new-model-vid \ - --output_vae_path /raid/yiyi/hunyuanvideo15-vae \ - --dtype fp32 + --original_state_dict_repo_id tencent/HunyuanVideo-1.5\ + --output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \ + --dtype bf16 \ + --save_pipeline \ + --byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\ + --transformer_type 480p_i2v """ import argparse @@ -22,11 +25,12 @@ from huggingface_hub import snapshot_download, hf_hub_download import pathlib -from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15 +from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer import json import argparse +import os TRANSFORMER_CONFIGS = { "480p_i2v": { @@ -49,6 +53,20 @@ }, } +SCHEDULER_CONFIGS = { + "480p_i2v": { + "shift": 5.0, + }, +} + +GUIDANCE_CONFIGS = { + "480p_i2v": { + "guidance_scale": 6.0, + "embedded_guidance_scale": None, + }, + + } + def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) @@ -571,18 +589,16 @@ def convert_vae(args): vae.load_state_dict(state_dict, strict=True, assign=True) return vae -def save_text_encoder(output_path): +def load_mllm(): + print(f" loading from Qwen/Qwen2.5-VL-7B-Instruct") text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", low_cpu_mem_usage=True) if hasattr(text_encoder, 'language_model'): text_encoder = text_encoder.language_model - - - text_encoder.save_pretrained(output_path + "/text_encoder") - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right") - tokenizer.save_pretrained(output_path + "/tokenizer") + return text_encoder, tokenizer +#copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89 def add_special_token( tokenizer, text_encoder, @@ -625,42 +641,36 @@ def add_special_token( text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) -def save_text_encoder_2( - byt5_base_path, - byt5_checkpoint_path, - color_ann_path, - font_ann_path, - output_path, - multilingual=True -): + + +def load_byt5(args): """ Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format. - - Args: - byt5_base_path: Path to base byt5-small model (e.g., "google/byt5-small") - byt5_checkpoint_path: Path to Glyph-SDXL-v2 checkpoint (byt5_model.pt) - color_ann_path: Path to color_idx.json - font_ann_path: Path to multilingual_10-lang_idx.json - output_path: Where to save the converted model - multilingual: Whether to use multilingual font tokens """ - + # 1. Load base tokenizer and encoder - tokenizer = AutoTokenizer.from_pretrained(byt5_base_path) + tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") # Load as T5EncoderModel - encoder = T5EncoderModel.from_pretrained(byt5_base_path) + encoder = T5EncoderModel.from_pretrained("google/byt5-small") + byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt") + color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json") + font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json") + # 2. Add special tokens add_special_token( - tokenizer, - encoder, + tokenizer=tokenizer, + text_encoder=encoder, + add_color=True, + add_font=True, color_ann_path=color_ann_path, font_ann_path=font_ann_path, - multilingual=multilingual + multilingual=True, ) + # 3. Load Glyph-SDXL-v2 checkpoint print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}") checkpoint = torch.load(byt5_checkpoint_path, map_location='cpu') @@ -694,11 +704,7 @@ def save_text_encoder_2( raise ValueError(f"Missing keys: {missing_keys}") - # Save encoder - encoder.save_pretrained(output_path + "/text_encoder_2") - - # Save tokenizer - tokenizer.save_pretrained(output_path + "/tokenizer_2") + return encoder, tokenizer def get_args(): @@ -707,12 +713,26 @@ def get_args(): "--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model" ) parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict") - parser.add_argument("--output_vae_path", type=str, default=None, help="Path where converted VAE should be saved") - parser.add_argument("--output_transformer_path", type=str, default=None, help="Path where converted transformer should be saved") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved") parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") parser.add_argument( "--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()) ) + parser.add_argument( + "--byt5_path", + type=str, + default=None, + help=( + "path to the downloaded byt5 checkpoint & assets. " + "Note: They use Glyph-SDXL-v2 as byt5 encoder. You can download from modelscope like: " + "`modelscope download --model AI-ModelScope/Glyph-SDXL-v2 --local_dir ./ckpts/text_encoder/Glyph-SDXL-v2` " + "or manually download following the instructions on " + "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/checkpoints-download.md. " + "The path should point to the Glyph-SDXL-v2 folder which should contain an `assets` folder and a `checkpoints` folder, " + "like: Glyph-SDXL-v2/assets/... and Glyph-SDXL-v2/checkpoints/byt5_model.pt" + ), + ) + parser.add_argument("--save_pipeline", action="store_true") return parser.parse_args() @@ -726,16 +746,44 @@ def get_args(): if __name__ == "__main__": args = get_args() + if args.save_pipeline and args.byt5_path is None: + raise ValueError("Please provide --byt5_path when saving pipeline") + transformer = None dtype = DTYPE_MAPPING[args.dtype] - if args.output_transformer_path is not None: - transformer = convert_transformer(args) - transformer = transformer.to(dtype=dtype) - transformer.save_pretrained(args.output_transformer_path, safe_serialization=True) + transformer = convert_transformer(args) + transformer = transformer.to(dtype=dtype) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True) + else: - if args.output_vae_path is not None: vae = convert_vae(args) vae = vae.to(dtype=dtype) - vae.save_pretrained(args.output_vae_path, safe_serialization=True) + + + text_encoder, tokenizer = load_mllm() + text_encoder_2, tokenizer_2 = load_byt5(args) + text_encoder = text_encoder.to(dtype=dtype) + text_encoder_2 = text_encoder_2.to(dtype=dtype) + + flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"] + scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) + + guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"] + guider = ClassifierFreeGuidance(guidance_scale=guidance_scale) + + pipeline = HunyuanVideo15Pipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + guider=guider, + scheduler=scheduler, + ) + pipeline.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + + From 2f6914d57aea1ee6bdc6aef6d7cd55ba4022b67c Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Wed, 26 Nov 2025 07:38:30 +0000 Subject: [PATCH 07/34] up up --- .../pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index 3464853add8f..97c4bd3c3df8 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -227,7 +227,7 @@ def __init__( self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 - self.num_channels_latents = self.vae.latent_channels if hasattr(self, "vae") else 32 + self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 # fmt: off self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ 1. The main content and theme of the video. \ @@ -594,7 +594,7 @@ def __call__( prompt_embeds_mask_2: Optional[torch.Tensor] = None, negative_prompt_embeds_2: Optional[torch.Tensor] = None, negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None, - output_type: Optional[str] = "pil", + output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ): From a0b2fe02b05fcda990740eda76240383870ea14f Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Thu, 27 Nov 2025 05:24:49 +0000 Subject: [PATCH 08/34] update conversion script: remove dtype, always keep same precision as original checkpoint --- .../convert_hunyuan_video1_5_to_diffusers.py | 54 +++++++++++-------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 2694ac2834fe..c5f9515c6be7 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -1,19 +1,17 @@ """ python scripts/convert_hunyuan_video1_5_to_diffusers.py \ --original_state_dict_repo_id tencent/HunyuanVideo-1.5\ - --output_path /fsx/yiyi/hy15/480p_i2v\ - --transformer_type 480p_i2v \ - --dtype fp32 + --output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers/transformer\ + --transformer_type 480p_t2v """ """ python scripts/convert_hunyuan_video1_5_to_diffusers.py \ --original_state_dict_repo_id tencent/HunyuanVideo-1.5\ --output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \ - --dtype bf16 \ --save_pipeline \ --byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\ - --transformer_type 480p_i2v + --transformer_type 480p_t2v """ import argparse @@ -51,12 +49,33 @@ "rope_axes_dim": (16, 56, 56), "use_meanflow": False, }, + "480p_t2v": { + "in_channels": 65, + "out_channels": 32, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 54, + "num_refiner_layers": 2, + "mlp_ratio": 4.0, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm", + "text_embed_dim": 3584, + "text_embed_2_dim": 1472, + "image_embed_dim": 1152, + "rope_theta": 256.0, + "rope_axes_dim": (16, 56, 56), + "use_meanflow": False, + }, } SCHEDULER_CONFIGS = { "480p_i2v": { "shift": 5.0, }, + "480p_t2v": { + "shift": 5.0, + }, } GUIDANCE_CONFIGS = { @@ -64,6 +83,10 @@ "guidance_scale": 6.0, "embedded_guidance_scale": None, }, + "480p_t2v": { + "guidance_scale": 6.0, + "embedded_guidance_scale": None, + }, } @@ -555,6 +578,7 @@ def load_original_transformer_state_dict(args): model_dir = model_dir / "transformer" / args.transformer_type return load_sharded_safetensors(model_dir) + def load_original_vae_state_dict(args): if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download( @@ -570,6 +594,7 @@ def load_original_vae_state_dict(args): original_state_dict = load_file(ckpt_path) return original_state_dict + def convert_transformer(args): original_state_dict = load_original_transformer_state_dict(args) @@ -581,6 +606,7 @@ def convert_transformer(args): return transformer + def convert_vae(args): original_state_dict = load_original_vae_state_dict(args) with init_empty_weights(): @@ -591,7 +617,7 @@ def convert_vae(args): def load_mllm(): print(f" loading from Qwen/Qwen2.5-VL-7B-Instruct") - text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", low_cpu_mem_usage=True) + text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16,low_cpu_mem_usage=True) if hasattr(text_encoder, 'language_model'): text_encoder = text_encoder.language_model tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right") @@ -641,8 +667,6 @@ def add_special_token( text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) - - def load_byt5(args): """ Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format. @@ -714,7 +738,6 @@ def get_args(): ) parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved") - parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.") parser.add_argument( "--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()) ) @@ -736,13 +759,6 @@ def get_args(): return parser.parse_args() -DTYPE_MAPPING = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, -} - - if __name__ == "__main__": args = get_args() @@ -750,22 +766,16 @@ def get_args(): raise ValueError("Please provide --byt5_path when saving pipeline") transformer = None - dtype = DTYPE_MAPPING[args.dtype] transformer = convert_transformer(args) - transformer = transformer.to(dtype=dtype) if not args.save_pipeline: transformer.save_pretrained(args.output_path, safe_serialization=True) else: - vae = convert_vae(args) - vae = vae.to(dtype=dtype) text_encoder, tokenizer = load_mllm() text_encoder_2, tokenizer_2 = load_byt5(args) - text_encoder = text_encoder.to(dtype=dtype) - text_encoder_2 = text_encoder_2.to(dtype=dtype) flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"] scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) From db0127cb9dcc901942884d72dc489188d4852ec8 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Thu, 27 Nov 2025 05:25:28 +0000 Subject: [PATCH 09/34] fix --- src/diffusers/pipelines/hunyuan_video1_5/image_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py index 5963dd43bd2e..29a0f065fa90 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -90,7 +90,7 @@ def __init__( def calculate_default_height_width(self, height: int, width: int, target_size: int): - crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.vae_scale_factor) + crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.config.vae_scale_factor) aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) height, width = get_closest_ratio(height, width, aspect_ratios, crop_size_list)[0] From 38c42b4de1b98bee31cb35d98aa0bf3398a960cb Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Thu, 27 Nov 2025 22:15:08 +0000 Subject: [PATCH 10/34] conversion scripts --- .../convert_hunyuan_video1_5_to_diffusers.py | 95 +++++++++++++++++-- 1 file changed, 88 insertions(+), 7 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index c5f9515c6be7..7546a909dfdd 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -31,6 +31,26 @@ import os TRANSFORMER_CONFIGS = { + "480p_t2v": { + "in_channels": 65, + "out_channels": 32, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 54, + "num_refiner_layers": 2, + "mlp_ratio": 4.0, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm", + "text_embed_dim": 3584, + "text_embed_2_dim": 1472, + "image_embed_dim": 1152, + "rope_theta": 256.0, + "rope_axes_dim": (16, 56, 56), + "use_meanflow": False, + "target_size": 640, + "task_type": "t2v", + }, "480p_i2v": { "in_channels": 65, "out_channels": 32, @@ -48,8 +68,31 @@ "rope_theta": 256.0, "rope_axes_dim": (16, 56, 56), "use_meanflow": False, + "target_size": 640, + "task_type": "i2v", }, - "480p_t2v": { + "720p_t2v": { + "in_channels": 65, + "out_channels": 32, + "num_attention_heads": 16, + "attention_head_dim": 128, + "num_layers": 54, + "num_refiner_layers": 2, + "mlp_ratio": 4.0, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm", + "text_embed_dim": 3584, + "text_embed_2_dim": 1472, + "image_embed_dim": 1152, + "rope_theta": 256.0, + "rope_axes_dim": (16, 56, 56), + "use_meanflow": False, + "target_size": 960, + "task_type": "t2v", + }, + "720p_i2v": {}, + "480p_t2v_distilled": { "in_channels": 65, "out_channels": 32, "num_attention_heads": 16, @@ -66,29 +109,67 @@ "rope_theta": 256.0, "rope_axes_dim": (16, 56, 56), "use_meanflow": False, + "target_size": 640, + "task_type": "t2v", }, + "480p_i2v_distilled": {}, + "720p_t2v_distilled": {}, + "720p_i2v_distilled": {}, } SCHEDULER_CONFIGS = { + "480p_t2v": { + "shift": 5.0, + }, "480p_i2v": { "shift": 5.0, }, - "480p_t2v": { + "720p_t2v": { + "shift": 9.0, + }, + "720p_i2v": { + "shift": 7.0, + }, + "480p_t2v_distilled": { + "shift": 5.0, + }, + "480p_i2v_distilled": { "shift": 5.0, }, + "720p_t2v_distilled": { + "shift": 9.0, + }, + "720p_i2v_distilled": { + "shift": 7.0, + }, } GUIDANCE_CONFIGS = { + "480p_t2v": { + "guidance_scale": 6.0, + }, "480p_i2v": { "guidance_scale": 6.0, - "embedded_guidance_scale": None, }, - "480p_t2v": { + "720p_t2v": { "guidance_scale": 6.0, - "embedded_guidance_scale": None, }, - - } + "720p_i2v": { + "guidance_scale": 6.0, + }, + "480p_t2v_distilled": { + "guidance_scale": 1.0, + }, + "480p_i2v_distilled": { + "guidance_scale": 1.0, + }, + "720p_t2v_distilled": { + "guidance_scale": 1.0, + }, + "720p_i2v_distilled": { + "guidance_scale": 1.0, + }, +} def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) From 090ceb5d4f3204e412fbe927b3ced9f63c215988 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sat, 29 Nov 2025 00:33:58 +0000 Subject: [PATCH 11/34] remove dtype from the _get_ encodeing methods --- .../pipeline_hunyuan_video1_5.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index 97c4bd3c3df8..7e56616b4a47 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -248,8 +248,7 @@ def _get_mllm_prompt_embeds( text_encoder: Qwen2_5_VLTextModel, tokenizer: Qwen2Tokenizer, prompt: Union[str, List[str]], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device, tokenizer_max_length: int = 1000, num_hidden_layers_to_skip: int = 2, # fmt: off @@ -292,8 +291,6 @@ def _get_mllm_prompt_embeds( prompt_embeds = prompt_embeds[:, crop_start:] prompt_attention_mask = prompt_attention_mask[:, crop_start:] - prompt_embeds = prompt_embeds.to(dtype=dtype) - return prompt_embeds, prompt_attention_mask @@ -302,8 +299,7 @@ def _get_byt5_prompt_embeds( tokenizer: ByT5Tokenizer, text_encoder: T5EncoderModel, prompt: Union[str, List[str]], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + device: torch.device, tokenizer_max_length: int = 256, ): @@ -317,7 +313,7 @@ def _get_byt5_prompt_embeds( for glyph_text in glyph_texts: if glyph_text is None: glyph_text_embeds = torch.zeros( - (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=dtype + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, text_encoder.dtype ) glyph_text_embeds_mask = torch.zeros( (1, tokenizer_max_length), device=device, dtype=torch.int64 @@ -336,7 +332,7 @@ def _get_byt5_prompt_embeds( input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask.float(), )[0] - glyph_text_embeds = glyph_text_embeds.to(dtype=dtype, device=device) + glyph_text_embeds = glyph_text_embeds.to(device=device) glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) prompt_embeds_list.append(glyph_text_embeds) @@ -397,7 +393,6 @@ def encode_prompt( text_encoder=self.text_encoder, prompt=prompt, device=device, - dtype=dtype, tokenizer_max_length=self.tokenizer_max_length, system_message=self.system_message, crop_start=self.prompt_template_encode_start_idx, @@ -409,7 +404,6 @@ def encode_prompt( text_encoder=self.text_encoder_2, prompt=prompt, device=device, - dtype=dtype, tokenizer_max_length=self.tokenizer_2_max_length, ) @@ -425,6 +419,11 @@ def encode_prompt( prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 def check_inputs( From e3301cbda40f7dca7f51e116d3c68b653466dafd Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sat, 29 Nov 2025 00:49:18 +0000 Subject: [PATCH 12/34] add i2v pipeline --- .../convert_hunyuan_video1_5_to_diffusers.py | 59 +- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 4 +- .../pipelines/hunyuan_video1_5/__init__.py | 2 + .../hunyuan_video1_5/image_processor.py | 4 +- .../pipeline_hunyuan_video1_5.py | 12 +- .../pipeline_hunyuan_video1_5_image2video.py | 975 ++++++++++++++++++ 7 files changed, 1032 insertions(+), 26 deletions(-) create mode 100644 src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 7546a909dfdd..1fabb62922db 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -1,3 +1,4 @@ +# to convert only transformer """ python scripts/convert_hunyuan_video1_5_to_diffusers.py \ --original_state_dict_repo_id tencent/HunyuanVideo-1.5\ @@ -5,6 +6,7 @@ --transformer_type 480p_t2v """ +# to convert full pipeline """ python scripts/convert_hunyuan_video1_5_to_diffusers.py \ --original_state_dict_repo_id tencent/HunyuanVideo-1.5\ @@ -23,8 +25,8 @@ from huggingface_hub import snapshot_download, hf_hub_download import pathlib -from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline -from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer +from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline, HunyuanVideo15Image2VideoPipeline, HunyuanVideo15Text2VideoPipeline +from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor import json import argparse @@ -812,6 +814,16 @@ def load_byt5(args): return encoder, tokenizer +def load_siglip(): + image_encoder = SiglipVisionModel.from_pretrained( + "black-forest-labs/FLUX.1-Redux-dev", subfolder="image_encoder", torch_dtype=torch.bfloat16 + ) + feature_extractor = SiglipImageProcessor.from_pretrained( + "black-forest-labs/FLUX.1-Redux-dev", subfolder="feature_extractor" + ) + return image_encoder, feature_extractor + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -852,8 +864,9 @@ def get_args(): if not args.save_pipeline: transformer.save_pretrained(args.output_path, safe_serialization=True) else: - vae = convert_vae(args) + task_type = transformer.config.task_type + vae = convert_vae(args) text_encoder, tokenizer = load_mllm() text_encoder_2, tokenizer_2 = load_byt5(args) @@ -864,17 +877,35 @@ def get_args(): guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"] guider = ClassifierFreeGuidance(guidance_scale=guidance_scale) - pipeline = HunyuanVideo15Pipeline( - vae=vae, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - guider=guider, - scheduler=scheduler, - ) - pipeline.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB") + if task_type == "i2v": + image_encoder, feature_extractor = load_siglip() + pipeline = HunyuanVideo15Image2VideoPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + guider=guider, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + elif task_type == "t2v": + pipeline = HunyuanVideo15Text2VideoPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + guider=guider, + scheduler=scheduler, + ) + else: + raise ValueError(f"Task type {task_type} is not supported") + + pipeline.save_pretrained(args.output_path, safe_serialization=True) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index eaaac1838d90..3533bb9c516d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -483,6 +483,7 @@ "HunyuanVideoImageToVideoPipeline", "HunyuanVideoPipeline", "HunyuanVideo15Pipeline", + "HunyuanVideo15ImageToVideoPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -1170,6 +1171,7 @@ HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, HunyuanVideo15Pipeline, + HunyuanVideo15ImageToVideoPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index fe84f5c7ca85..a6a85ea03eb5 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -242,7 +242,7 @@ "HunyuanVideoImageToVideoPipeline", "HunyuanVideoFramepackPipeline", ] - _import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"] + _import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline", "HunyuanVideo15ImageToVideoPipeline"] _import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", @@ -663,7 +663,7 @@ HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, ) - from .hunyuan_video1_5 import HunyuanVideo15Pipeline + from .hunyuan_video1_5 import HunyuanVideo15Pipeline, HunyuanVideo15ImageToVideoPipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video1_5/__init__.py b/src/diffusers/pipelines/hunyuan_video1_5/__init__.py index 09bffb88353c..846320f4ace0 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/__init__.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"] + _import_structure["pipeline_hunyuan_video1_5_image2video"] = ["HunyuanVideo15ImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -33,6 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline + from .pipeline_hunyuan_video1_5_image2video import HunyuanVideo15ImageToVideoPipeline else: import sys diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py index 29a0f065fa90..6e3e818c7d83 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -80,11 +80,13 @@ def __init__( do_resize: bool = True, vae_scale_factor: int = 16, vae_latent_channels: int = 32, + do_convert_rgb: bool = True, ): super().__init__( do_resize=do_resize, vae_scale_factor=vae_scale_factor, - vae_latent_channels=vae_latent_channels + vae_latent_channels=vae_latent_channels, + do_convert_rgb=do_convert_rgb, ) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index 7e56616b4a47..4650b911406b 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -759,26 +759,20 @@ def __call__( height, width, num_frames, - torch.float32, + self.transformer.dtype, device, generator, latents, ) - cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, torch.float32, device) + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, self.transformer.dtype, device) image_embeds = torch.zeros( batch_size, self.vision_num_semantic_tokens, self.vision_states_dim, - dtype=torch.float32, + dtype=self.transformer.dtype, device=device ) - image_embeds = image_embeds.to(self.transformer.dtype) - latents=latents.to(self.transformer.dtype) - cond_latents_concat=cond_latents_concat.to(self.transformer.dtype) - mask_concat=mask_concat.to(self.transformer.dtype) - - # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py new file mode 100644 index 000000000000..361a465b8640 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -0,0 +1,975 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import re + +import numpy as np +import torch +from transformers import Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor + +from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from .image_processor import HunyuanVideo15ImageProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import HunyuanVideo15PipelineOutput +from ...guiders import ClassifierFreeGuidance +from ...utils.torch_utils import randn_tensor + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideo15Pipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo15" + >>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +def format_text_input(prompt: List[str], system_message: str + ) -> List[Dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (List[str]): Input text. + system_message (str): System message. + + Returns: + List[Dict[str, Any]]: List of chat conversation. + """ + + template = [ + [ + { + 'role': 'system', + 'content': system_message}, + {'role': 'user', 'content': p if p else " "} + ] + for p in prompt] + + return template + + +def extract_glyph_texts(prompt: str) -> List[str]: + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r'\"(.*?)\"|“(.*?)”' + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + + return formatted_result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15Image2VideoPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using HunyuanVideo1.5. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`HunyuanVideo15Transformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo15`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`ClassifierFreeGuidance`]): + [ClassifierFreeGuidance]for classifier free guidance. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanVideo15Transformer3DModel, + vae: AutoencoderKLHunyuanVideo15, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + guider: ClassifierFreeGuidance, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + guider=guider, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True) + self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 + self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 + # fmt: off + self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on + self.prompt_template_encode_start_idx = 108 + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 256 + self.vision_num_semantic_tokens = 729 + + + @staticmethod + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds + def _get_mllm_prompt_embeds( + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + tokenizer_max_length: int = 1000, + num_hidden_layers_to_skip: int = 2, + # fmt: off + system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start: int = 108, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + + return prompt_embeds, prompt_attention_mask + + + @staticmethod + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds + def _get_byt5_prompt_embeds( + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + tokenizer_max_length: int = 256, + ): + + prompt = [prompt] if isinstance(prompt, str) else prompt + + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros( + (1, tokenizer_max_length), device=device, dtype=torch.int64 + ) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0) + + return prompt_embeds, prompt_embeds_mask + + + @staticmethod + def _get_vae_image_latents( + vae: AutoencoderKLHunyuanVideo15, + image_processor: HunyuanVideo15ImageProcessor, + image: PIL.Image.Image, + height: int, + width: int, + device: torch.device, + ) -> torch.Tensor: + + vae_dtype = self.vae.dtype + image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype) + image_latents = retrieve_latents(vae.encode(image_tensor), sample_mode="argmax") + image_latents = image_latents * vae.config.scaling_factor + return image_latents + + + @staticmethod + def _get_image_embeds( + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + image: PIL.Image.Image, + device: torch.device, + ) -> torch.Tensor: + + image_encoder_dtype = next(image_encoder.parameters()).dtype + image = feature_extractor.preprocess( + images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True + ) + image = image.to(device=device, dtype=image_encoder_dtype) + image_enc_hidden_states = image_encoder(**image).last_hidden_state + + return image_enc_hidden_states + + def encode_image( + self, + image: PIL.Image.Image, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + + image_embeds = self._get_image_embeds( + image_encoder=self.image_encoder, + feature_extractor=self.feature_extractor, + image=image, + device=device, + ) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(device=device, dtype=dtype) + return image_embeds + + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + batch_size: int = 1, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_embeds_mask_2: Optional[torch.Tensor] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + system_message=self.system_message, + crop_start=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + + prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) + prompt_embeds_mask = prompt_embeds_mask.to(device=device, dtype=dtype) + prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dtype) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(device=device, dtype=dtype) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + image: PIL.Image.Image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + ): + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `PIL.Image.Image` but is {type(image)}") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + + def prepare_cond_latents_and_mask( + self, + latents: torch.Tensor, + image: PIL.Image.Image, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + ): + """ + Prepare conditional latents and mask for t2v generation. + + Args: + latents: Main latents tensor (B, C, F, H, W) + + Returns: + tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v + """ + + batch, channels, frames, height, width = latents.shape + + image_latents = self._get_vae_image_latents( + vae=self.vae, + image_processor=self.video_processor, + image=image, + height=height, + width=width, + device=device, + ) + + latent_condition = image_latents.repeat(batch_size, 1, frames, 1, 1) + latent_condition[:,:,1:, :, :] = 0 + latent_condition = latent_condition.to(device=device, dtype=dtype) + + latent_mask = torch.zeros( + batch, 1, frames, height, width, + dtype=dtype, + device=device + ) + latent_mask[:,:, 0, :, :] = 1.0 + + return latent_condition, latent_mask + + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + num_frames: int = 121, + num_inference_steps: int = 50, + sigmas: List[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_embeds_mask_2: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + guidance_scale (`float`, defaults to `6.0`): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + + height, width = self.video_processor.calculate_default_height_width(height=image.size[1], width=image.size[0], target_size=self.target_size) + image = self.video_processor.resize(image, height=height, width=width, resize_mode="crop") + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + if self.guider._enabled and self.guider.num_conditions >1 : + negative_prompt_embeds, negative_prompt_embeds_mask, negative_prompt_embeds_2, negative_prompt_embeds_mask_2 = self.encode_prompt( + prompt=negative_prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=self.transformer.dtype, + device=device, + generator=generator, + latents=latents, + ) + + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask( + latents =latenets, + image=image, + batch_size=batch_size * num_videos_per_prompt, + height=height, + width=width, + dtype=self.transformer.dtype, + device=device + ) + image_embeds = self.encode_image( + image=image, + batch_size=batch_size * num_videos_per_prompt, + device=device, + dtype=self.transformer.dtype, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = self.guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + image_embeds=image_embeds, + timestep=timestep, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + self.guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideo15PipelineOutput(frames=video) From 753d4075f9020074ec4562cad3afd732fcb3663d Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sat, 29 Nov 2025 01:59:44 +0000 Subject: [PATCH 13/34] add image to video pipeline --- scripts/convert_hunyuan_video1_5_to_diffusers.py | 4 ++-- .../hunyuan_video1_5/pipeline_hunyuan_video1_5.py | 3 +-- .../pipeline_hunyuan_video1_5_image2video.py | 14 ++++++++------ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 1fabb62922db..cb2fb64c76d1 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -25,7 +25,7 @@ from huggingface_hub import snapshot_download, hf_hub_download import pathlib -from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline, HunyuanVideo15Image2VideoPipeline, HunyuanVideo15Text2VideoPipeline +from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline, HunyuanVideo15ImageToVideoPipeline from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor import json @@ -879,7 +879,7 @@ def get_args(): if task_type == "i2v": image_encoder, feature_extractor = load_siglip() - pipeline = HunyuanVideo15Image2VideoPipeline( + pipeline = HunyuanVideo15ImageToVideoPipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index 4650b911406b..ee25031f60c4 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -313,8 +313,7 @@ def _get_byt5_prompt_embeds( for glyph_text in glyph_texts: if glyph_text is None: glyph_text_embeds = torch.zeros( - (1, tokenizer_max_length, text_encoder.config.d_model), device=device, text_encoder.dtype - ) + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype) glyph_text_embeds_mask = torch.zeros( (1, tokenizer_max_length), device=device, dtype=torch.int64 ) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py index 361a465b8640..1bce43f086d2 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -15,8 +15,9 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union import re - +import PIL import numpy as np + import torch from transformers import Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor @@ -182,7 +183,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class HunyuanVideo15Image2VideoPipeline(DiffusionPipeline): +class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline): r""" Pipeline for image-to-video generation using HunyuanVideo1.5. @@ -318,7 +319,7 @@ def _get_byt5_prompt_embeds( tokenizer: ByT5Tokenizer, text_encoder: T5EncoderModel, prompt: Union[str, List[str]], - device: Optional[torch.device] = None, + device: torch.device, tokenizer_max_length: int = 256, ): @@ -332,7 +333,7 @@ def _get_byt5_prompt_embeds( for glyph_text in glyph_texts: if glyph_text is None: glyph_text_embeds = torch.zeros( - (1, tokenizer_max_length, text_encoder.config.d_model), device=device, text_encoder.dtype + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype ) glyph_text_embeds_mask = torch.zeros( (1, tokenizer_max_length), device=device, dtype=torch.int64 @@ -373,8 +374,9 @@ def _get_vae_image_latents( device: torch.device, ) -> torch.Tensor: - vae_dtype = self.vae.dtype + vae_dtype = vae.dtype image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype) + image_tensor = image_tensor.unsqueeze(2) image_latents = retrieve_latents(vae.encode(image_tensor), sample_mode="argmax") image_latents = image_latents * vae.config.scaling_factor return image_latents @@ -848,7 +850,7 @@ def __call__( ) cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask( - latents =latenets, + latents =latents, image=image, batch_size=batch_size * num_videos_per_prompt, height=height, From 0687a407687265879f5f5b25c40fc3a3f31bcbea Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sat, 29 Nov 2025 05:36:22 +0000 Subject: [PATCH 14/34] remove use_meanflow --- .../transformer_hunyuan_video15.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index c26b43e19ce4..41aa7ca45d47 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -230,21 +230,12 @@ class HunyuanVideo15TimeEmbedding(nn.Module): def __init__( self, embedding_dim: int, - use_meanflow: bool = False, ): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - self.use_meanflow = use_meanflow - - self.time_proj_r = None - self.timestep_embedder_r = None - if use_meanflow: - self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - def forward( self, @@ -612,11 +603,11 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin @register_to_config def __init__( self, - in_channels: int = 16, - out_channels: int = 16, - num_attention_heads: int = 24, + in_channels: int = 65, + out_channels: int = 32, + num_attention_heads: int = 16, attention_head_dim: int = 128, - num_layers: int = 20, + num_layers: int = 54, num_refiner_layers: int = 2, mlp_ratio: float = 4.0, patch_size: int = 1, @@ -627,7 +618,6 @@ def __init__( image_embed_dim: int = 1152, rope_theta: float = 256.0, rope_axes_dim: Tuple[int, ...] = (16, 56, 56), - use_meanflow: bool = False, # YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205 target_size: int = 640, # did not name sample_size since it is in pixel spaces task_type: str = "i2v", @@ -646,7 +636,7 @@ def __init__( ) self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim) - self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow) + self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim) self.cond_type_embed = nn.Embedding(3, inner_dim) From c22915d6c43cab5d5d62f22fdc9b310b4fae6074 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sat, 29 Nov 2025 05:36:33 +0000 Subject: [PATCH 15/34] up up --- .../convert_hunyuan_video1_5_to_diffusers.py | 90 +++---------------- 1 file changed, 12 insertions(+), 78 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index cb2fb64c76d1..0e493ae99e98 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -34,89 +34,29 @@ TRANSFORMER_CONFIGS = { "480p_t2v": { - "in_channels": 65, - "out_channels": 32, - "num_attention_heads": 16, - "attention_head_dim": 128, - "num_layers": 54, - "num_refiner_layers": 2, - "mlp_ratio": 4.0, - "patch_size": 1, - "patch_size_t": 1, - "qk_norm": "rms_norm", - "text_embed_dim": 3584, - "text_embed_2_dim": 1472, - "image_embed_dim": 1152, - "rope_theta": 256.0, - "rope_axes_dim": (16, 56, 56), - "use_meanflow": False, - "target_size": 640, - "task_type": "t2v", - }, - "480p_i2v": { - "in_channels": 65, - "out_channels": 32, - "num_attention_heads": 16, - "attention_head_dim": 128, - "num_layers": 54, - "num_refiner_layers": 2, - "mlp_ratio": 4.0, - "patch_size": 1, - "patch_size_t": 1, - "qk_norm": "rms_norm", - "text_embed_dim": 3584, - "text_embed_2_dim": 1472, - "image_embed_dim": 1152, - "rope_theta": 256.0, - "rope_axes_dim": (16, 56, 56), - "use_meanflow": False, "target_size": 640, "task_type": "i2v", }, "720p_t2v": { - "in_channels": 65, - "out_channels": 32, - "num_attention_heads": 16, - "attention_head_dim": 128, - "num_layers": 54, - "num_refiner_layers": 2, - "mlp_ratio": 4.0, - "patch_size": 1, - "patch_size_t": 1, - "qk_norm": "rms_norm", - "text_embed_dim": 3584, - "text_embed_2_dim": 1472, - "image_embed_dim": 1152, - "rope_theta": 256.0, - "rope_axes_dim": (16, 56, 56), - "use_meanflow": False, "target_size": 960, "task_type": "t2v", }, - "720p_i2v": {}, + "720p_i2v": { + "target_size": 960, + "task_type": "i2v", + }, "480p_t2v_distilled": { - "in_channels": 65, - "out_channels": 32, - "num_attention_heads": 16, - "attention_head_dim": 128, - "num_layers": 54, - "num_refiner_layers": 2, - "mlp_ratio": 4.0, - "patch_size": 1, - "patch_size_t": 1, - "qk_norm": "rms_norm", - "text_embed_dim": 3584, - "text_embed_2_dim": 1472, - "image_embed_dim": 1152, - "rope_theta": 256.0, - "rope_axes_dim": (16, 56, 56), - "use_meanflow": False, "target_size": 640, "task_type": "t2v", }, - "480p_i2v_distilled": {}, - "720p_t2v_distilled": {}, - "720p_i2v_distilled": {}, + "480p_i2v_distilled": { + "target_size": 640, + "task_type": "i2v", + }, + "720p_i2v_distilled": { + "target_size": 960, + "task_type": "i2v", + }, } SCHEDULER_CONFIGS = { @@ -138,9 +78,6 @@ "480p_i2v_distilled": { "shift": 5.0, }, - "720p_t2v_distilled": { - "shift": 9.0, - }, "720p_i2v_distilled": { "shift": 7.0, }, @@ -165,9 +102,6 @@ "480p_i2v_distilled": { "guidance_scale": 1.0, }, - "720p_t2v_distilled": { - "guidance_scale": 1.0, - }, "720p_i2v_distilled": { "guidance_scale": 1.0, }, From f9cb82b64f031eb50e5d49c9833c0aced203f2fb Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 30 Nov 2025 18:38:33 +0000 Subject: [PATCH 16/34] a few small fix: proprocess, cpu_offloading, attention backend --- src/diffusers/models/attention_dispatch.py | 1 + src/diffusers/pipelines/hunyuan_video1_5/image_processor.py | 2 +- .../pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 8504504981a3..7face061b2ff 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -278,6 +278,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke backend = AttentionBackendName(backend) _check_attention_backend_requirements(backend) + _maybe_download_kernel_for_backend(backend) old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry._active_backend = backend diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py index 6e3e818c7d83..efeb0e2a5fdd 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -52,7 +52,7 @@ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): if aspect_ratio >= 1: indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0] else: - indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0] + indices = [(index, x) for index, x in enumerate(diff_ratios) if x >= 0] closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0] closest_size = buckets[closest_ratio_id] diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index ee25031f60c4..7e23d2da3e82 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -195,7 +195,7 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): [ClassifierFreeGuidance]for classifier free guidance. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( From e319d7207a70818f56c172245171b70e9f5b1780 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 30 Nov 2025 19:24:00 +0000 Subject: [PATCH 17/34] simplify transformer --- .../transformer_hunyuan_video15.py | 106 +++++------------- 1 file changed, 27 insertions(+), 79 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 41aa7ca45d47..86ee2104475b 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -59,9 +59,6 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if attn.add_q_proj is None and encoder_hidden_states is not None: - assert False # YiYi Notes: remove this condition if this code path is never used - hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) # 1. QKV projections query = attn.to_q(hidden_states) @@ -73,51 +70,17 @@ def __call__( value = value.unflatten(2, (attn.heads, -1)) # 2. QK normalization - if attn.norm_q is not None: - query = attn.norm_q(query) - else: - assert False - # YiYi Notes: remove this condition if this code path is never used - if attn.norm_k is not None: - key = attn.norm_k(key) - else: - assert False - # YiYi Notes: remove this condition if this code path is never used + query = attn.norm_q(query) + key = attn.norm_k(key) # 3. Rotational positional embeddings applied to latent stream if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb - - if attn.add_q_proj is None and encoder_hidden_states is not None: - assert False # YiYi Notes: remove this condition if this code path is never used - query = torch.cat( - [ - apply_rotary_emb( - query[:, : -encoder_hidden_states.shape[1]], - image_rotary_emb, - sequence_dim=1, - ), - query[:, -encoder_hidden_states.shape[1] :], - ], - dim=1, - ) - key = torch.cat( - [ - apply_rotary_emb( - key[:, : -encoder_hidden_states.shape[1]], - image_rotary_emb, - sequence_dim=1, - ), - key[:, -encoder_hidden_states.shape[1] :], - ], - dim=1, - ) - else: - query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) - key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) # 4. Encoder condition QKV projection and normalization - if attn.add_q_proj is not None and encoder_hidden_states is not None: + if encoder_hidden_states is not None: encoder_query = attn.add_q_proj(encoder_hidden_states) encoder_key = attn.add_k_proj(encoder_hidden_states) encoder_value = attn.add_v_proj(encoder_hidden_states) @@ -134,10 +97,6 @@ def __call__( query = torch.cat([query, encoder_query], dim=1) key = torch.cat([key, encoder_key], dim=1) value = torch.cat([value, encoder_value], dim=1) - - else: - assert False # YiYi Notes: remove this condition if this code path is never used - batch_size, seq_len, heads, dim = query.shape attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True) @@ -178,7 +137,7 @@ def __call__( return hidden_states, encoder_hidden_states -class HunyuanVideoPatchEmbed(nn.Module): +class HunyuanVideo15PatchEmbed(nn.Module): def __init__( self, patch_size: Union[int, Tuple[int, int, int]] = 16, @@ -196,7 +155,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -class HunyuanVideoAdaNorm(nn.Module): +class HunyuanVideo15AdaNorm(nn.Module): def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: super().__init__() @@ -223,9 +182,6 @@ class HunyuanVideo15TimeEmbedding(nn.Module): Args: embedding_dim (`int`): The dimension of the output embedding. - use_meanflow (`bool`, defaults to `False`): - Whether to support reference timestep embeddings for temporal consistency. - Set to `True` for super-resolution models. """ def __init__( self, @@ -240,20 +196,15 @@ def __init__( def forward( self, timestep: torch.Tensor, - timestep_r: Optional[torch.Tensor] = None, ) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype)) - if timestep_r is not None: - timesteps_proj_r = self.time_proj_r(timestep_r) - timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype)) - timesteps_emb = timesteps_emb + timesteps_emb_r return timesteps_emb -class HunyuanVideoIndividualTokenRefinerBlock(nn.Module): +class HunyuanVideo15IndividualTokenRefinerBlock(nn.Module): def __init__( self, num_attention_heads: int, @@ -278,7 +229,7 @@ def __init__( self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) - self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size) + self.norm_out = HunyuanVideo15AdaNorm(hidden_size, 2 * hidden_size) def forward( self, @@ -303,7 +254,7 @@ def forward( return hidden_states -class HunyuanVideoIndividualTokenRefiner(nn.Module): +class HunyuanVideo15IndividualTokenRefiner(nn.Module): def __init__( self, num_attention_heads: int, @@ -317,7 +268,7 @@ def __init__( self.refiner_blocks = nn.ModuleList( [ - HunyuanVideoIndividualTokenRefinerBlock( + HunyuanVideo15IndividualTokenRefinerBlock( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, mlp_width_ratio=mlp_width_ratio, @@ -336,7 +287,6 @@ def forward( ) -> None: self_attn_mask = None if attention_mask is not None: - # YiYi TODO convert 1D mask to 4d Bx1xLxL batch_size = attention_mask.shape[0] seq_len = attention_mask.shape[1] attention_mask = attention_mask.to(hidden_states.device).bool() @@ -350,7 +300,7 @@ def forward( return hidden_states -class HunyuanVideoTokenRefiner(nn.Module): +class HunyuanVideo15TokenRefiner(nn.Module): def __init__( self, in_channels: int, @@ -369,7 +319,7 @@ def __init__( embedding_dim=hidden_size, pooled_projection_dim=in_channels ) self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) - self.token_refiner = HunyuanVideoIndividualTokenRefiner( + self.token_refiner = HunyuanVideo15IndividualTokenRefiner( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, num_layers=num_layers, @@ -399,7 +349,7 @@ def forward( return hidden_states -class HunyuanVideoRotaryPosEmbed(nn.Module): +class HunyuanVideo15RotaryPosEmbed(nn.Module): def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: super().__init__() @@ -432,7 +382,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return freqs_cos, freqs_sin -# Copied from diffusers.models.transformers.transformer_hunyuanimage.HunyuanImageByT5TextProjection class HunyuanVideo15ByT5TextProjection(nn.Module): def __init__(self, in_features: int, hidden_size: int, out_features: int): super().__init__() @@ -470,7 +419,7 @@ def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: return hidden_states -class HunyuanVideoTransformerBlock(nn.Module): +class HunyuanVideo15TransformerBlock(nn.Module): def __init__( self, num_attention_heads: int, @@ -552,7 +501,7 @@ def forward( class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" - A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo). + A Transformer model for video-like data used in [HunyuanVideo1.5](https://huggingface.co/tencent/HunyuanVideo1.5). Args: in_channels (`int`, defaults to `16`): @@ -590,14 +539,14 @@ class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] _no_split_modules = [ - "HunyuanVideoTransformerBlock", - "HunyuanVideoPatchEmbed", - "HunyuanVideoTokenRefiner", + "HunyuanVideo15TransformerBlock", + "HunyuanVideo15PatchEmbed", + "HunyuanVideo15TokenRefiner", ] _repeated_blocks = [ - "HunyuanVideoTransformerBlock", - "HunyuanVideoPatchEmbed", - "HunyuanVideoTokenRefiner", + "HunyuanVideo15TransformerBlock", + "HunyuanVideo15PatchEmbed", + "HunyuanVideo15TokenRefiner", ] @register_to_config @@ -628,10 +577,10 @@ def __init__( out_channels = out_channels or in_channels # 1. Latent and condition embedders - self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.x_embedder = HunyuanVideo15PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim) - self.context_embedder = HunyuanVideoTokenRefiner( + self.context_embedder = HunyuanVideo15TokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim) @@ -641,13 +590,13 @@ def __init__( self.cond_type_embed = nn.Embedding(3, inner_dim) # 2. RoPE - self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + self.rope = HunyuanVideo15RotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) # 3. Dual stream transformer blocks self.transformer_blocks = nn.ModuleList( [ - HunyuanVideoTransformerBlock( + HunyuanVideo15TransformerBlock( num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm ) for _ in range(num_layers) @@ -730,7 +679,6 @@ def forward( encoder_hidden_states_2: Optional[torch.Tensor] = None, encoder_attention_mask_2: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, - timestep_r: Optional[torch.LongTensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: @@ -759,7 +707,7 @@ def forward( image_rotary_emb = self.rope(hidden_states) # 2. Conditional embeddings - temb = self.time_embed(timestep, timestep_r=timestep_r) + temb = self.time_embed(timestep) hidden_states = self.x_embedder(hidden_states) From e1940341ffc75a210c1a212f9d05678d1f25e74f Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 30 Nov 2025 19:46:38 +0000 Subject: [PATCH 18/34] clean up a bit more pipelines --- .../hunyuan_video1_5/image_processor.py | 6 +- .../pipeline_hunyuan_video1_5.py | 92 ++++-------- .../pipeline_hunyuan_video1_5_image2video.py | 138 ++++++++---------- 3 files changed, 97 insertions(+), 139 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py index efeb0e2a5fdd..d6fe62c6ff41 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -17,7 +17,7 @@ from ...video_processor import VideoProcessor from ...configuration_utils import register_to_config -# Copied from hyvideo/utils/data_utils.py +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L20 def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): num_patches = round((base_size / patch_size) ** 2) assert max_ratio >= 1.0 @@ -32,7 +32,7 @@ def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): wp -= 1 return crop_size_list -# Copied from hyvideo/utils/data_utils.py +# copied fromhttps://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38 def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): """ Get the closest ratio in the buckets. @@ -72,6 +72,8 @@ class HunyuanVideo15ImageProcessor(VideoProcessor): this factor. vae_latent_channels (`int`, *optional*, defaults to `32`): VAE latent channels. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. """ @register_to_config diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index 7e23d2da3e82..d9d7fc5a37fc 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -47,7 +47,7 @@ >>> from diffusers import HunyuanVideo15Pipeline >>> from diffusers.utils import export_to_video - >>> model_id = "hunyuanvideo-community/HunyuanVideo15" + >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_t2v" >>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) >>> pipe.vae.enable_tiling() >>> pipe.to("cuda") @@ -196,7 +196,6 @@ class HunyuanVideo15Pipeline(DiffusionPipeline): """ model_cpu_offload_seq = "text_encoder->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, @@ -550,10 +549,6 @@ def prepare_cond_latents_and_mask(self, latents, dtype: Optional[torch.dtype], d return cond_latents_concat, mask_concat - @property - def guidance_scale(self): - return self._guidance_scale - @property def num_timesteps(self): return self._num_timesteps @@ -601,91 +596,67 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is - not greater than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - height (`int`, defaults to `720`): - The height in pixels of the generated image. - width (`int`, defaults to `1280`): - The width in pixels of the generated image. - num_frames (`int`, defaults to `129`): + `negative_prompt_embeds` instead. + height (`int`, *optional*): + The height in pixels of the generated video. + width (`int`, *optional*): + The width in pixels of the generated video. + num_frames (`int`, defaults to `121`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the + The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. - true_cfg_scale (`float`, *optional*, defaults to 1.0): - True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and - `negative_prompt` is provided. - guidance_scale (`float`, defaults to `6.0`): - Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages - a model to generate images more aligned with `prompt` at the expense of lower image quality. - - Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to - the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. + The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings from the second text encoder. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the second text encoder. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings from the second text encoder. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between "np", "pt", or "latent". return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): - A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of - each denoising step during the inference. with the following arguments: `callback_on_step_end(self: - DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a - list of all tensors as specified by `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: - [`~HunyuanVideoPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned - where the first element is a list with the generated images and the second element is a list of `bool`s - indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + [`~HunyuanVideo15PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated videos. """ # 1. Check inputs. Raise error if not correct @@ -867,7 +838,8 @@ def __call__( xm.mark_step() self._current_timestep = None - + + # 8. decode the latents to video and postprocess if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py index 1bce43f086d2..18bd590cfc83 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -45,23 +45,28 @@ Examples: ```python >>> import torch - >>> from diffusers import HunyuanVideo15Pipeline + >>> from diffusers import HunyuanVideo15ImageToVideoPipeline >>> from diffusers.utils import export_to_video - >>> model_id = "hunyuanvideo-community/HunyuanVideo15" - >>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_i2v" + >>> pipe = HunyuanVideo15ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=torch.float16) >>> pipe.vae.enable_tiling() >>> pipe.to("cuda") + >>> image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG" + ... ) + >>> output = pipe( - ... prompt="A cat walks on the grass, realistic", + ... prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + ... image=image, ... num_inference_steps=50, ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=15) + >>> export_to_video(output, "output.mp4", fps=24) ``` """ - +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input def format_text_input(prompt: List[str], system_message: str ) -> List[Dict[str, Any]]: """ @@ -87,6 +92,7 @@ def format_text_input(prompt: List[str], system_message: str return template +# Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.extract_glyph_text def extract_glyph_texts(prompt: str) -> List[str]: """ Extract glyph texts from prompt using regex pattern. @@ -207,10 +213,15 @@ class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline): tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] guider ([`ClassifierFreeGuidance`]): [ClassifierFreeGuidance]for classifier free guidance. + image_encoder ([`SiglipVisionModel`]): + [SiglipVisionModel](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipVisionModel) + variant. + feature_extractor ([`SiglipImageProcessor`]): + [SiglipImageProcessor](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipImageProcessor) + variant. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds"] + model_cpu_offload_seq = "image_encoder->text_encoder->transformer->vae" def __init__( self, @@ -365,7 +376,7 @@ def _get_byt5_prompt_embeds( @staticmethod - def _get_vae_image_latents( + def _get_image_latents( vae: AutoencoderKLHunyuanVideo15, image_processor: HunyuanVideo15ImageProcessor, image: PIL.Image.Image, @@ -613,7 +624,7 @@ def prepare_cond_latents_and_mask( batch, channels, frames, height, width = latents.shape - image_latents = self._get_vae_image_latents( + image_latents = self._get_image_latents( vae=self.vae, image_processor=self.video_processor, image=image, @@ -636,10 +647,6 @@ def prepare_cond_latents_and_mask( return latent_condition, latent_mask - @property - def guidance_scale(self): - return self._guidance_scale - @property def num_timesteps(self): return self._num_timesteps @@ -685,92 +692,66 @@ def __call__( The call function to the pipeline for generation. Args: + image (`PIL.Image.Image`): + The input image to condition video generation on. prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds` instead. - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is - not greater than `1`). - negative_prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and - `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - height (`int`, defaults to `720`): - The height in pixels of the generated image. - width (`int`, defaults to `1280`): - The width in pixels of the generated image. - num_frames (`int`, defaults to `129`): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + num_frames (`int`, defaults to `121`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the + The number of denoising steps. More denoising steps usually lead to a higher quality video at the expense of slower inference. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. - true_cfg_scale (`float`, *optional*, defaults to 1.0): - True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and - `negative_prompt` is provided. - guidance_scale (`float`, defaults to `6.0`): - Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages - a model to generate images more aligned with `prompt` at the expense of lower image quality. - - Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to - the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. + The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` - input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings from the second text encoder. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the second text encoder. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings from the second text encoder. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between "np", "pt", or "latent". return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): - A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of - each denoising step during the inference. with the following arguments: `callback_on_step_end(self: - DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a - list of all tensors as specified by `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: - [`~HunyuanVideoPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned - where the first element is a list with the generated images and the second element is a list of `bool`s - indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + [`~HunyuanVideo15PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated videos. """ # 1. Check inputs. Raise error if not correct @@ -806,7 +787,15 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # 3. Encode input prompt + # 3. Encode image + image_embeds = self.encode_image( + image=image, + batch_size=batch_size * num_videos_per_prompt, + device=device, + dtype=self.transformer.dtype, + ) + + # 4. Encode input prompt prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( prompt=prompt, device=device, @@ -832,11 +821,11 @@ def __call__( prompt_embeds_mask_2=negative_prompt_embeds_mask_2, ) - # 4. Prepare timesteps + # 5. Prepare timesteps sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) - # 5. Prepare latent variables + # 6. Prepare latent variables latents = self.prepare_latents( batch_size=batch_size * num_videos_per_prompt, num_channels_latents=self.num_channels_latents, @@ -858,12 +847,6 @@ def __call__( dtype=self.transformer.dtype, device=device ) - image_embeds = self.encode_image( - image=image, - batch_size=batch_size * num_videos_per_prompt, - device=device, - dtype=self.transformer.dtype, - ) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -961,6 +944,7 @@ def __call__( self._current_timestep = None + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] From 5029dbf763d841b6f08d152b6d94b21329dd26b4 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 30 Nov 2025 20:14:36 +0000 Subject: [PATCH 19/34] style --- .../convert_hunyuan_video1_5_to_diffusers.py | 169 +++++++++--------- src/diffusers/__init__.py | 10 +- src/diffusers/models/__init__.py | 4 +- .../autoencoder_kl_hunyuanvideo15.py | 6 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformer_hunyuan_video15.py | 47 +++-- src/diffusers/pipelines/__init__.py | 2 +- .../hunyuan_video1_5/image_processor.py | 12 +- .../pipeline_hunyuan_video1_5.py | 105 +++++------ .../pipeline_hunyuan_video1_5_image2video.py | 133 +++++++------- .../hunyuan_video1_5/pipeline_output.py | 3 - 11 files changed, 233 insertions(+), 260 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 0e493ae99e98..38226f684a6d 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -1,3 +1,30 @@ +import argparse +import json +import os +import pathlib + +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors.torch import load_file +from transformers import ( + AutoModel, + AutoTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, +) + +from diffusers import ( + AutoencoderKLHunyuanVideo15, + ClassifierFreeGuidance, + FlowMatchEulerDiscreteScheduler, + HunyuanVideo15ImageToVideoPipeline, + HunyuanVideo15Pipeline, + HunyuanVideo15Transformer3DModel, +) + + # to convert only transformer """ python scripts/convert_hunyuan_video1_5_to_diffusers.py \ @@ -16,21 +43,6 @@ --transformer_type 480p_t2v """ -import argparse -from typing import Any, Dict - -import torch -from accelerate import init_empty_weights -from safetensors.torch import load_file -from huggingface_hub import snapshot_download, hf_hub_download - -import pathlib -from diffusers import HunyuanVideo15Transformer3DModel, AutoencoderKLHunyuanVideo15, FlowMatchEulerDiscreteScheduler, ClassifierFreeGuidance, HunyuanVideo15Pipeline, HunyuanVideo15ImageToVideoPipeline -from transformers import AutoModel, AutoTokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor - -import json -import argparse -import os TRANSFORMER_CONFIGS = { "480p_t2v": { @@ -107,6 +119,7 @@ }, } + def swap_scale_shift(weight): shift, scale = weight.chunk(2, dim=0) new_weight = torch.cat([scale, shift], dim=0) @@ -123,48 +136,42 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict): converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( "time_in.mlp.0.weight" ) - converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( - "time_in.mlp.0.bias" - ) + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.mlp.0.bias") converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( "time_in.mlp.2.weight" ) - converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( - "time_in.mlp.2.bias" - ) + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias") # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = ( original_state_dict.pop("txt_in.t_embedder.mlp.0.weight") ) - converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = ( - original_state_dict.pop("txt_in.t_embedder.mlp.0.bias") + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "txt_in.t_embedder.mlp.0.bias" ) converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = ( original_state_dict.pop("txt_in.t_embedder.mlp.2.weight") ) - converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = ( - original_state_dict.pop("txt_in.t_embedder.mlp.2.bias") + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "txt_in.t_embedder.mlp.2.bias" ) # 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder - converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = ( - original_state_dict.pop("txt_in.c_embedder.linear_1.weight") + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop( + "txt_in.c_embedder.linear_1.weight" ) - converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = ( - original_state_dict.pop("txt_in.c_embedder.linear_1.bias") + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop( + "txt_in.c_embedder.linear_1.bias" ) - converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = ( - original_state_dict.pop("txt_in.c_embedder.linear_2.weight") + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop( + "txt_in.c_embedder.linear_2.weight" ) - converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = ( - original_state_dict.pop("txt_in.c_embedder.linear_2.bias") + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop( + "txt_in.c_embedder.linear_2.bias" ) # 4. context_embedder.proj_in <- txt_in.input_embedder - converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop( - "txt_in.input_embedder.weight" - ) + converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight") converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias") # 5. context_embedder.token_refiner <- txt_in.individual_token_refiner @@ -375,10 +382,12 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict): ) # 11. norm_out and proj_out <- final_layer - converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop( - "final_layer.adaLN_modulation.1.weight" - )) - converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.adaLN_modulation.1.bias")) + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.bias") + ) converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") @@ -572,6 +581,7 @@ def convert_hunyuan_video_15_vae_checkpoint_to_diffusers( return converted + def load_sharded_safetensors(dir: pathlib.Path): file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) state_dict = {} @@ -583,9 +593,9 @@ def load_sharded_safetensors(dir: pathlib.Path): def load_original_transformer_state_dict(args): if args.original_state_dict_repo_id is not None: model_dir = snapshot_download( - args.original_state_dict_repo_id, + args.original_state_dict_repo_id, repo_type="model", - allow_patterns="transformer/" + args.transformer_type + "/*" + allow_patterns="transformer/" + args.transformer_type + "/*", ) elif args.original_state_dict_folder is not None: model_dir = pathlib.Path(args.original_state_dict_folder) @@ -599,8 +609,7 @@ def load_original_transformer_state_dict(args): def load_original_vae_state_dict(args): if args.original_state_dict_repo_id is not None: ckpt_path = hf_hub_download( - repo_id=args.original_state_dict_repo_id, - filename= "vae/diffusion_pytorch_model.safetensors" + repo_id=args.original_state_dict_repo_id, filename="vae/diffusion_pytorch_model.safetensors" ) elif args.original_state_dict_folder is not None: model_dir = pathlib.Path(args.original_state_dict_folder) @@ -632,24 +641,27 @@ def convert_vae(args): vae.load_state_dict(state_dict, strict=True, assign=True) return vae + def load_mllm(): - print(f" loading from Qwen/Qwen2.5-VL-7B-Instruct") - text_encoder = AutoModel.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16,low_cpu_mem_usage=True) - if hasattr(text_encoder, 'language_model'): + print(" loading from Qwen/Qwen2.5-VL-7B-Instruct") + text_encoder = AutoModel.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + if hasattr(text_encoder, "language_model"): text_encoder = text_encoder.language_model tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right") return text_encoder, tokenizer -#copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89 +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89 def add_special_token( tokenizer, text_encoder, add_color=True, add_font=True, multilingual=True, - color_ann_path='assets/color_idx.json', - font_ann_path='assets/multilingual_10-lang_idx.json', + color_ann_path="assets/color_idx.json", + font_ann_path="assets/multilingual_10-lang_idx.json", ): """ Add special tokens for color and font to tokenizer and text encoder. @@ -663,16 +675,16 @@ def add_special_token( font_ann_path (str): Path to font annotation JSON. multilingual (bool): Whether to use multilingual font tokens. """ - with open(font_ann_path, 'r') as f: + with open(font_ann_path, "r") as f: idx_font_dict = json.load(f) - with open(color_ann_path, 'r') as f: + with open(color_ann_path, "r") as f: idx_color_dict = json.load(f) if multilingual: - font_token = [f'<{font_code[:2]}-font-{idx_font_dict[font_code]}>' for font_code in idx_font_dict] + font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict] else: - font_token = [f'' for i in range(len(idx_font_dict))] - color_token = [f'' for i in range(len(idx_color_dict))] + font_token = [f"" for i in range(len(idx_font_dict))] + color_token = [f"" for i in range(len(idx_color_dict))] additional_special_tokens = [] if add_color: additional_special_tokens += color_token @@ -688,14 +700,13 @@ def load_byt5(args): """ Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format. """ - # 1. Load base tokenizer and encoder tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") - + # Load as T5EncoderModel encoder = T5EncoderModel.from_pretrained("google/byt5-small") - + byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt") color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json") font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json") @@ -710,48 +721,45 @@ def load_byt5(args): font_ann_path=font_ann_path, multilingual=True, ) - - + # 3. Load Glyph-SDXL-v2 checkpoint print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}") - checkpoint = torch.load(byt5_checkpoint_path, map_location='cpu') - + checkpoint = torch.load(byt5_checkpoint_path, map_location="cpu") + # Handle different checkpoint formats - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: state_dict = checkpoint - - # add 'encoder.' prefix to the keys + + # add 'encoder.' prefix to the keys # Remove 'module.text_tower.encoder.' prefix if present cleaned_state_dict = {} for key, value in state_dict.items(): - if key.startswith('module.text_tower.encoder.'): - new_key = 'encoder.' + key[len('module.text_tower.encoder.'):] + if key.startswith("module.text_tower.encoder."): + new_key = "encoder." + key[len("module.text_tower.encoder.") :] cleaned_state_dict[new_key] = value else: - new_key = 'encoder.' + key + new_key = "encoder." + key cleaned_state_dict[new_key] = value - - + # 4. Load weights missing_keys, unexpected_keys = encoder.load_state_dict(cleaned_state_dict, strict=False) if unexpected_keys: raise ValueError(f"Unexpected keys: {unexpected_keys}") if "shared.weight" in missing_keys: - print(f" Missing shared.weight as expected") + print(" Missing shared.weight as expected") missing_keys.remove("shared.weight") if missing_keys: raise ValueError(f"Missing keys: {missing_keys}") - - + return encoder, tokenizer def load_siglip(): image_encoder = SiglipVisionModel.from_pretrained( "black-forest-labs/FLUX.1-Redux-dev", subfolder="image_encoder", torch_dtype=torch.bfloat16 - ) + ) feature_extractor = SiglipImageProcessor.from_pretrained( "black-forest-labs/FLUX.1-Redux-dev", subfolder="feature_extractor" ) @@ -763,11 +771,11 @@ def get_args(): parser.add_argument( "--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model" ) - parser.add_argument("--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict") - parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved") parser.add_argument( - "--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys()) + "--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict" ) + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved") + parser.add_argument("--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys())) parser.add_argument( "--byt5_path", type=str, @@ -826,7 +834,7 @@ def get_args(): feature_extractor=feature_extractor, ) elif task_type == "t2v": - pipeline = HunyuanVideo15Text2VideoPipeline( + pipeline = HunyuanVideo15Pipeline( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, @@ -840,6 +848,3 @@ def get_args(): raise ValueError(f"Task type {task_type} is not supported") pipeline.save_pretrained(args.output_path, safe_serialization=True) - - - diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e0c219a43f78..02dd42e4a580 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -483,11 +483,11 @@ "HunyuanImagePipeline", "HunyuanImageRefinerPipeline", "HunyuanSkyreelsImageToVideoPipeline", + "HunyuanVideo15ImageToVideoPipeline", + "HunyuanVideo15Pipeline", "HunyuanVideoFramepackPipeline", "HunyuanVideoImageToVideoPipeline", "HunyuanVideoPipeline", - "HunyuanVideo15Pipeline", - "HunyuanVideo15ImageToVideoPipeline", "I2VGenXLPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -949,9 +949,9 @@ HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, HunyuanImageTransformer2DModel, + HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, - HunyuanVideo15Transformer3DModel, I2VGenXLUNet, Kandinsky3UNet, Kandinsky5Transformer3DModel, @@ -1176,11 +1176,11 @@ HunyuanImagePipeline, HunyuanImageRefinerPipeline, HunyuanSkyreelsImageToVideoPipeline, + HunyuanVideo15ImageToVideoPipeline, + HunyuanVideo15Pipeline, HunyuanVideoFramepackPipeline, HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, - HunyuanVideo15Pipeline, - HunyuanVideo15ImageToVideoPipeline, I2VGenXLPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 0f335bec37bc..8b60b269324f 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -85,7 +85,6 @@ _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] - _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] @@ -98,6 +97,7 @@ _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] + _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] @@ -199,10 +199,10 @@ EasyAnimateTransformer3DModel, Flux2Transformer2DModel, FluxTransformer2DModel, - HunyuanVideo15Transformer3DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, + HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py index 2f05172a97d3..7d6a636a240b 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -141,7 +141,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() - attention_mask = self.prepare_causal_attention_mask(frames, height * width, query.dtype, query.device, batch_size=batch_size) + attention_mask = self.prepare_causal_attention_mask( + frames, height * width, query.dtype, query.device, batch_size=batch_size + ) x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) @@ -196,7 +198,7 @@ def forward(self, x: torch.Tensor): x_first = x[:, :, :1, :, :] x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2) x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1) - + x_next = x[:, :, 1:, :, :] x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2) x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 0294e2785cc4..67a555b676d5 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,6 +29,7 @@ from .transformer_flux2 import Flux2Transformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel + from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel @@ -45,5 +46,4 @@ from .transformer_wan import WanTransformer3DModel from .transformer_wan_animate import WanAnimateTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel - from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_z_image import ZImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 86ee2104475b..8f191e75009a 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -59,7 +59,6 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # 1. QKV projections query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) @@ -76,6 +75,7 @@ def __call__( # 3. Rotational positional embeddings applied to latent stream if image_rotary_emb is not None: from ..embeddings import apply_rotary_emb + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) @@ -175,14 +175,15 @@ def forward( class HunyuanVideo15TimeEmbedding(nn.Module): r""" Time embedding for HunyuanVideo 1.5. - - Supports standard timestep embedding and optional reference timestep embedding - for MeanFlow-based super-resolution models. - + + Supports standard timestep embedding and optional reference timestep embedding for MeanFlow-based super-resolution + models. + Args: embedding_dim (`int`): The dimension of the output embedding. """ + def __init__( self, embedding_dim: int, @@ -192,7 +193,6 @@ def __init__( self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) - def forward( self, timestep: torch.Tensor, @@ -200,7 +200,6 @@ def forward( timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype)) - return timesteps_emb @@ -469,7 +468,7 @@ def forward( norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - + # 2. Joint attention attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, @@ -478,7 +477,6 @@ def forward( image_rotary_emb=freqs_cis, ) - # 3. Modulation and residual connection hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) @@ -568,7 +566,7 @@ def __init__( rope_theta: float = 256.0, rope_axes_dim: Tuple[int, ...] = (16, 56, 56), # YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205 - target_size: int = 640, # did not name sample_size since it is in pixel spaces + target_size: int = 640, # did not name sample_size since it is in pixel spaces task_type: str = "i2v", ) -> None: super().__init__() @@ -579,7 +577,7 @@ def __init__( # 1. Latent and condition embedders self.x_embedder = HunyuanVideo15PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim) - + self.context_embedder = HunyuanVideo15TokenRefiner( text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers ) @@ -668,8 +666,7 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - - + def forward( self, hidden_states: torch.Tensor, @@ -733,10 +730,10 @@ def forward( if is_t2v: encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0 encoder_attention_mask_3 = torch.zeros( - (batch_size, encoder_hidden_states_3.shape[1]), - dtype=encoder_attention_mask.dtype, - device=encoder_attention_mask.device, - ) + (batch_size, encoder_hidden_states_3.shape[1]), + dtype=encoder_attention_mask.dtype, + device=encoder_attention_mask.device, + ) else: encoder_attention_mask_3 = torch.ones( (batch_size, encoder_hidden_states_3.shape[1]), @@ -744,7 +741,8 @@ def forward( device=encoder_attention_mask.device, ) encoder_hidden_states_3_cond_emb = self.cond_type_embed( - 2 * torch.ones_like( + 2 + * torch.ones_like( encoder_hidden_states_3[:, :, 0], dtype=torch.long, ) @@ -759,17 +757,17 @@ def forward( new_encoder_attention_mask = [] for text, text_mask, text_2, text_mask_2, image, image_mask in zip( - encoder_hidden_states, - encoder_attention_mask, - encoder_hidden_states_2, - encoder_attention_mask_2, - encoder_hidden_states_3, + encoder_hidden_states, + encoder_attention_mask, + encoder_hidden_states_2, + encoder_attention_mask_2, + encoder_hidden_states_3, encoder_attention_mask_3, ): # Concatenate: [valid_image, valid_byt5, valid_mllm, invalid_image, invalid_byt5, invalid_mllm] new_encoder_hidden_states.append( torch.cat( - [ + [ image[image_mask], # valid image text_2[text_mask_2], # valid byt5 text[text_mask], # valid mllm @@ -799,7 +797,6 @@ def forward( encoder_hidden_states = torch.stack(new_encoder_hidden_states) encoder_attention_mask = torch.stack(new_encoder_attention_mask) - # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.transformer_blocks: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ffa4cb4b7d03..cf86456642eb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -666,7 +666,7 @@ HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, ) - from .hunyuan_video1_5 import HunyuanVideo15Pipeline, HunyuanVideo15ImageToVideoPipeline + from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py index d6fe62c6ff41..b0c778235257 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -14,8 +14,9 @@ import numpy as np -from ...video_processor import VideoProcessor from ...configuration_utils import register_to_config +from ...video_processor import VideoProcessor + # copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L20 def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): @@ -32,6 +33,7 @@ def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): wp -= 1 return crop_size_list + # copied fromhttps://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38 def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): """ @@ -60,9 +62,11 @@ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): return closest_size, closest_ratio + class HunyuanVideo15ImageProcessor(VideoProcessor): r""" Image/video processor to preproces/postprocess the reference image/generatedvideo for the HunyuanVideo1.5 model. + Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept @@ -91,11 +95,9 @@ def __init__( do_convert_rgb=do_convert_rgb, ) - def calculate_default_height_width(self, height: int, width: int, target_size: int): - crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.config.vae_scale_factor) aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) height, width = get_closest_ratio(height, width, aspect_ratios, crop_size_list)[0] - - return height, width \ No newline at end of file + + return height, width diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py index d9d7fc5a37fc..00a703939004 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -13,21 +13,21 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union import re +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch -from transformers import Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel, ByT5Tokenizer +from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel +from ...guiders import ClassifierFreeGuidance from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring -from .image_processor import HunyuanVideo15ImageProcessor +from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .image_processor import HunyuanVideo15ImageProcessor from .pipeline_output import HunyuanVideo15PipelineOutput -from ...guiders import ClassifierFreeGuidance -from ...utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -61,8 +61,7 @@ """ -def format_text_input(prompt: List[str], system_message: str - ) -> List[Dict[str, Any]]: +def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]: """ Apply text to template. @@ -75,13 +74,8 @@ def format_text_input(prompt: List[str], system_message: str """ template = [ - [ - { - 'role': 'system', - 'content': system_message}, - {'role': 'user', 'content': p if p else " "} - ] - for p in prompt] + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] return template @@ -89,14 +83,14 @@ def format_text_input(prompt: List[str], system_message: str def extract_glyph_texts(prompt: str) -> List[str]: """ Extract glyph texts from prompt using regex pattern. - + Args: prompt: Input prompt string - + Returns: List of extracted glyph texts """ - pattern = r'\"(.*?)\"|“(.*?)”' + pattern = r"\"(.*?)\"|“(.*?)”" matches = re.findall(pattern, prompt) result = [match[0] or match[1] for match in matches] result = list(dict.fromkeys(result)) if len(result) > 1 else result @@ -225,7 +219,9 @@ def __init__( self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 - self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + self.vision_states_dim = ( + self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + ) self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 # fmt: off self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ @@ -239,8 +235,7 @@ def __init__( self.tokenizer_max_length = 1000 self.tokenizer_2_max_length = 256 self.vision_num_semantic_tokens = 729 - self.default_aspect_ratio = (16, 9) # (width: height) - + self.default_aspect_ratio = (16, 9) # (width: height) @staticmethod def _get_mllm_prompt_embeds( @@ -260,8 +255,6 @@ def _get_mllm_prompt_embeds( # fmt: on crop_start: int = 108, ) -> Tuple[torch.Tensor, torch.Tensor]: - - prompt = [prompt] if isinstance(prompt, str) else prompt prompt = format_text_input(prompt, system_message) @@ -292,7 +285,6 @@ def _get_mllm_prompt_embeds( return prompt_embeds, prompt_attention_mask - @staticmethod def _get_byt5_prompt_embeds( tokenizer: ByT5Tokenizer, @@ -301,7 +293,6 @@ def _get_byt5_prompt_embeds( device: torch.device, tokenizer_max_length: int = 256, ): - prompt = [prompt] if isinstance(prompt, str) else prompt glyph_texts = [extract_glyph_texts(p) for p in prompt] @@ -312,10 +303,9 @@ def _get_byt5_prompt_embeds( for glyph_text in glyph_texts: if glyph_text is None: glyph_text_embeds = torch.zeros( - (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype) - glyph_text_embeds_mask = torch.zeros( - (1, tokenizer_max_length), device=device, dtype=torch.int64 + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) else: txt_tokens = tokenizer( glyph_text, @@ -341,7 +331,6 @@ def _get_byt5_prompt_embeds( return prompt_embeds, prompt_embeds_mask - def encode_prompt( self, prompt: Union[str, List[str]], @@ -438,16 +427,11 @@ def check_inputs( prompt_embeds_mask_2=None, negative_prompt_embeds_2=None, negative_prompt_embeds_mask_2=None, - ): - + ): if height is None and width is not None: - raise ValueError( - "If `width` is provided, `height` also have to be provided." - ) + raise ValueError("If `width` is provided, `height` also have to be provided.") elif width is None and height is not None: - raise ValueError( - "If `height` is provided, `width` also have to be provided." - ) + raise ValueError("If `height` is provided, `width` also have to be provided.") if prompt is not None and prompt_embeds is not None: raise ValueError( @@ -521,33 +505,23 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents - def prepare_cond_latents_and_mask(self, latents, dtype: Optional[torch.dtype], device: Optional[torch.device]): """ Prepare conditional latents and mask for t2v generation. - + Args: latents: Main latents tensor (B, C, F, H, W) - + Returns: tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v """ batch, channels, frames, height, width = latents.shape - - cond_latents_concat = torch.zeros( - batch, channels, frames, height, width, - dtype=dtype, - device=device - ) - - mask_concat = torch.zeros( - batch, 1, frames, height, width, - dtype=dtype, - device=device - ) - - return cond_latents_concat, mask_concat + cond_latents_concat = torch.zeros(batch, channels, frames, height, width, dtype=dtype, device=device) + + mask_concat = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device) + + return cond_latents_concat, mask_concat @property def num_timesteps(self): @@ -655,8 +629,8 @@ def __call__( Returns: [`~HunyuanVideo15PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is returned - where the first element is a list with the generated videos. + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated videos. """ # 1. Check inputs. Raise error if not correct @@ -676,7 +650,9 @@ def __call__( ) if height is None and width is None: - height, width = self.video_processor.calculate_default_height_width(self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size) + height, width = self.video_processor.calculate_default_height_width( + self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size + ) self._attention_kwargs = attention_kwargs self._current_timestep = None @@ -705,8 +681,13 @@ def __call__( prompt_embeds_mask_2=prompt_embeds_mask_2, ) - if self.guider._enabled and self.guider.num_conditions >1 : - negative_prompt_embeds, negative_prompt_embeds_mask, negative_prompt_embeds_2, negative_prompt_embeds_mask_2 = self.encode_prompt( + if self.guider._enabled and self.guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( prompt=negative_prompt, device=device, dtype=self.transformer.dtype, @@ -736,11 +717,11 @@ def __call__( ) cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, self.transformer.dtype, device) image_embeds = torch.zeros( - batch_size, - self.vision_num_semantic_tokens, + batch_size, + self.vision_num_semantic_tokens, self.vision_states_dim, dtype=self.transformer.dtype, - device=device + device=device, ) # 7. Denoising loop @@ -838,7 +819,7 @@ def __call__( xm.mark_step() self._current_timestep = None - + # 8. decode the latents to video and postprocess if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py index 18bd590cfc83..1af717bc59af 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -13,22 +13,29 @@ # limitations under the License. import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union import re -import PIL -import numpy as np +from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL import torch -from transformers import Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel, ByT5Tokenizer, SiglipVisionModel, SiglipImageProcessor +from transformers import ( + ByT5Tokenizer, + Qwen2_5_VLTextModel, + Qwen2Tokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, +) +from ...guiders import ClassifierFreeGuidance from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring -from .image_processor import HunyuanVideo15ImageProcessor +from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +from .image_processor import HunyuanVideo15ImageProcessor from .pipeline_output import HunyuanVideo15PipelineOutput -from ...guiders import ClassifierFreeGuidance -from ...utils.torch_utils import randn_tensor if is_torch_xla_available(): @@ -53,9 +60,7 @@ >>> pipe.vae.enable_tiling() >>> pipe.to("cuda") - >>> image = load_image( - ... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG" - ... ) + >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG") >>> output = pipe( ... prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", @@ -66,9 +71,9 @@ ``` """ + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input -def format_text_input(prompt: List[str], system_message: str - ) -> List[Dict[str, Any]]: +def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]: """ Apply text to template. @@ -81,13 +86,8 @@ def format_text_input(prompt: List[str], system_message: str """ template = [ - [ - { - 'role': 'system', - 'content': system_message}, - {'role': 'user', 'content': p if p else " "} - ] - for p in prompt] + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] return template @@ -96,14 +96,14 @@ def format_text_input(prompt: List[str], system_message: str def extract_glyph_texts(prompt: str) -> List[str]: """ Extract glyph texts from prompt using regex pattern. - + Args: prompt: Input prompt string - + Returns: List of extracted glyph texts """ - pattern = r'\"(.*?)\"|“(.*?)”' + pattern = r"\"(.*?)\"|“(.*?)”" matches = re.findall(pattern, prompt) result = [match[0] or match[1] for match in matches] result = list(dict.fromkeys(result)) if len(result) > 1 else result @@ -129,6 +129,7 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -253,9 +254,13 @@ def __init__( self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 - self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True) + self.video_processor = HunyuanVideo15ImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True + ) self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 - self.vision_states_dim = self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + self.vision_states_dim = ( + self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + ) self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 # fmt: off self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ @@ -270,7 +275,6 @@ def __init__( self.tokenizer_2_max_length = 256 self.vision_num_semantic_tokens = 729 - @staticmethod # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds def _get_mllm_prompt_embeds( @@ -290,8 +294,6 @@ def _get_mllm_prompt_embeds( # fmt: on crop_start: int = 108, ) -> Tuple[torch.Tensor, torch.Tensor]: - - prompt = [prompt] if isinstance(prompt, str) else prompt prompt = format_text_input(prompt, system_message) @@ -320,10 +322,8 @@ def _get_mllm_prompt_embeds( prompt_embeds = prompt_embeds[:, crop_start:] prompt_attention_mask = prompt_attention_mask[:, crop_start:] - return prompt_embeds, prompt_attention_mask - @staticmethod # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds def _get_byt5_prompt_embeds( @@ -333,7 +333,6 @@ def _get_byt5_prompt_embeds( device: torch.device, tokenizer_max_length: int = 256, ): - prompt = [prompt] if isinstance(prompt, str) else prompt glyph_texts = [extract_glyph_texts(p) for p in prompt] @@ -346,9 +345,7 @@ def _get_byt5_prompt_embeds( glyph_text_embeds = torch.zeros( (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype ) - glyph_text_embeds_mask = torch.zeros( - (1, tokenizer_max_length), device=device, dtype=torch.int64 - ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) else: txt_tokens = tokenizer( glyph_text, @@ -374,17 +371,15 @@ def _get_byt5_prompt_embeds( return prompt_embeds, prompt_embeds_mask - @staticmethod def _get_image_latents( - vae: AutoencoderKLHunyuanVideo15, + vae: AutoencoderKLHunyuanVideo15, image_processor: HunyuanVideo15ImageProcessor, image: PIL.Image.Image, height: int, width: int, device: torch.device, ) -> torch.Tensor: - vae_dtype = vae.dtype image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype) image_tensor = image_tensor.unsqueeze(2) @@ -392,7 +387,6 @@ def _get_image_latents( image_latents = image_latents * vae.config.scaling_factor return image_latents - @staticmethod def _get_image_embeds( image_encoder: SiglipVisionModel, @@ -400,11 +394,8 @@ def _get_image_embeds( image: PIL.Image.Image, device: torch.device, ) -> torch.Tensor: - image_encoder_dtype = next(image_encoder.parameters()).dtype - image = feature_extractor.preprocess( - images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True - ) + image = feature_extractor.preprocess(images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True) image = image.to(device=device, dtype=image_encoder_dtype) image_enc_hidden_states = image_encoder(**image).last_hidden_state @@ -417,7 +408,6 @@ def encode_image( device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: - image_embeds = self._get_image_embeds( image_encoder=self.image_encoder, feature_extractor=self.feature_extractor, @@ -427,7 +417,7 @@ def encode_image( image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.to(device=device, dtype=dtype) return image_embeds - + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt def encode_prompt( self, @@ -524,7 +514,7 @@ def check_inputs( prompt_embeds_mask_2=None, negative_prompt_embeds_2=None, negative_prompt_embeds_mask_2=None, - ): + ): if not isinstance(image, PIL.Image.Image): raise ValueError(f"`image` has to be of type `PIL.Image.Image` but is {type(image)}") @@ -601,23 +591,22 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents - def prepare_cond_latents_and_mask( - self, - latents: torch.Tensor, + self, + latents: torch.Tensor, image: PIL.Image.Image, batch_size: int, height: int, width: int, - dtype: torch.dtype, + dtype: torch.dtype, device: torch.device, ): """ Prepare conditional latents and mask for t2v generation. - + Args: latents: Main latents tensor (B, C, F, H, W) - + Returns: tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v """ @@ -632,20 +621,15 @@ def prepare_cond_latents_and_mask( width=width, device=device, ) - + latent_condition = image_latents.repeat(batch_size, 1, frames, 1, 1) - latent_condition[:,:,1:, :, :] = 0 + latent_condition[:, :, 1:, :, :] = 0 latent_condition = latent_condition.to(device=device, dtype=dtype) - - latent_mask = torch.zeros( - batch, 1, frames, height, width, - dtype=dtype, - device=device - ) - latent_mask[:,:, 0, :, :] = 1.0 - - return latent_condition, latent_mask + latent_mask = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device) + latent_mask[:, :, 0, :, :] = 1.0 + + return latent_condition, latent_mask @property def num_timesteps(self): @@ -750,8 +734,8 @@ def __call__( Returns: [`~HunyuanVideo15PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is returned - where the first element is a list with the generated videos. + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated videos. """ # 1. Check inputs. Raise error if not correct @@ -769,8 +753,9 @@ def __call__( negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, ) - - height, width = self.video_processor.calculate_default_height_width(height=image.size[1], width=image.size[0], target_size=self.target_size) + height, width = self.video_processor.calculate_default_height_width( + height=image.size[1], width=image.size[0], target_size=self.target_size + ) image = self.video_processor.resize(image, height=height, width=width, resize_mode="crop") self._attention_kwargs = attention_kwargs @@ -808,8 +793,13 @@ def __call__( prompt_embeds_mask_2=prompt_embeds_mask_2, ) - if self.guider._enabled and self.guider.num_conditions >1 : - negative_prompt_embeds, negative_prompt_embeds_mask, negative_prompt_embeds_2, negative_prompt_embeds_mask_2 = self.encode_prompt( + if self.guider._enabled and self.guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( prompt=negative_prompt, device=device, dtype=self.transformer.dtype, @@ -837,15 +827,15 @@ def __call__( generator=generator, latents=latents, ) - + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask( - latents =latents, + latents=latents, image=image, batch_size=batch_size * num_videos_per_prompt, height=height, width=width, - dtype=self.transformer.dtype, - device=device + dtype=self.transformer.dtype, + device=device, ) # 7. Denoising loop @@ -944,7 +934,6 @@ def __call__( self._current_timestep = None - if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py index 3adb54e1fbed..441164db5a09 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py @@ -1,8 +1,5 @@ from dataclasses import dataclass -from typing import List, Union -import numpy as np -import PIL.Image import torch from diffusers.utils import BaseOutput From 8aa458ed4615c40535622be32a36eaa4cf9d279d Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 30 Nov 2025 20:16:59 +0000 Subject: [PATCH 20/34] copies --- .../pipeline_hunyuan_video1_5_image2video.py | 12 ++++---- src/diffusers/utils/dummy_pt_objects.py | 30 +++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 30 +++++++++++++++++++ 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py index 1af717bc59af..9e9f20c79eba 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -92,7 +92,7 @@ def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, return template -# Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.extract_glyph_text +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts def extract_glyph_texts(prompt: str) -> List[str]: """ Extract glyph texts from prompt using regex pattern. @@ -281,7 +281,7 @@ def _get_mllm_prompt_embeds( text_encoder: Qwen2_5_VLTextModel, tokenizer: Qwen2Tokenizer, prompt: Union[str, List[str]], - device: Optional[torch.device] = None, + device: torch.device, tokenizer_max_length: int = 1000, num_hidden_layers_to_skip: int = 2, # fmt: off @@ -494,10 +494,10 @@ def encode_prompt( prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) - prompt_embeds = prompt_embeds.to(device=device, dtype=dtype) - prompt_embeds_mask = prompt_embeds_mask.to(device=device, dtype=dtype) - prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dtype) - prompt_embeds_mask_2 = prompt_embeds_mask_2.to(device=device, dtype=dtype) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 73854b38190e..fe9a4b30f0c1 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -468,6 +468,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLHunyuanVideo15(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -993,6 +1008,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HunyuanVideo15Transformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e6cf26a12544..65306b839015 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1142,6 +1142,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanVideo15ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HunyuanVideo15Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanVideoFramepackPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 50abf504a1c4b8912113151dad2fab3e325626a2 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 30 Nov 2025 20:38:31 +0000 Subject: [PATCH 21/34] add docs --- .../models/autoencoder_kl_hunyuan_video15.md | 36 +++++++++ .../models/hunyuan_video15_transformer_3d.md | 30 ++++++++ .../en/api/pipelines/hunyuan_video15.md | 77 +++++++++++++++++++ 3 files changed, 143 insertions(+) create mode 100644 docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md create mode 100644 docs/source/en/api/models/hunyuan_video15_transformer_3d.md create mode 100644 docs/source/en/api/pipelines/hunyuan_video15.md diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md new file mode 100644 index 000000000000..e82fe31380a5 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md @@ -0,0 +1,36 @@ + + +# AutoencoderKLHunyuanVideo15 + +The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5) by Tencent. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLHunyuanVideo15 + +vae = AutoencoderKLHunyuanVideo15.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v", subfolder="vae", torch_dtype=torch.float32) + +# make sure to enable tiling to avoid OOM +vae.enable_tiling() +``` + +## AutoencoderKLHunyuanVideo15 + +[[autodoc]] AutoencoderKLHunyuanVideo15 + - decode + - encode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/hunyuan_video15_transformer_3d.md b/docs/source/en/api/models/hunyuan_video15_transformer_3d.md new file mode 100644 index 000000000000..5ad4c6f4643f --- /dev/null +++ b/docs/source/en/api/models/hunyuan_video15_transformer_3d.md @@ -0,0 +1,30 @@ + + +# HunyuanVideo15Transformer3DModel + +A Diffusion Transformer model for 3D video-like data used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5). + +The model can be loaded with the following code snippet. + +```python +from diffusers import HunyuanVideo15Transformer3DModel + +transformer = HunyuanVideo15Transformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v" subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## HunyuanVideo15Transformer3DModel + +[[autodoc]] HunyuanVideo15Transformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md new file mode 100644 index 000000000000..2d74088355c1 --- /dev/null +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -0,0 +1,77 @@ + + +
+
+ + LoRA + +
+
+ +# HunyuanVideo-1.5 + +HunyuanVideo-1.5 is a lightweight yet powerful video generation model that achieves state-of-the-art visual quality and motion coherence with only 8.3 billion parameters, enabling efficient inference on consumer-grade GPUs. This achievement is built upon several key components, including meticulous data curation, an advanced DiT architecture with selective and sliding tile attention (SSTA), enhanced bilingual understanding through glyph-aware text encoding, progressive pre-training and post-training, and an efficient video super-resolution network. Leveraging these designs, we developed a unified framework capable of high-quality text-to-video and image-to-video generation across multiple durations and resolutions. Extensive experiments demonstrate that this compact and proficient model establishes a new state-of-the-art among open-source models. + +You can find all the original HunyuanVideo checkpoints under the [Tencent](https://huggingface.co/tencent) organization. + +> [!TIP] +> Click on the HunyuanVideo models in the right sidebar for more examples of video generation tasks. +> +> The examples below use a checkpoint from [hunyuanvideo-community](https://huggingface.co/hunyuanvideo-community) because the weights are stored in a layout compatible with Diffusers. + +The example below demonstrates how to generate a video optimized for memory or inference speed. + + + + +Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. + + +```py +import torch +from diffusers import AutoModel, HunyuanVideo15Pipeline +from diffusers.utils import export_to_video + + +pipeline = HunyuanVideo15Pipeline.from_pretrained( + "HunyuanVideo-1.5-Diffusers-480p_t2v", + torch_dtype=torch.bfloat16, +) + +# model-offloading and tiling +pipeline.enable_model_cpu_offload() +pipeline.vae.enable_tiling() + +prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys." +video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0] +export_to_video(video, "output.mp4", fps=15) +``` + + +## HunyuanVideo15Pipeline + +[[autodoc]] HunyuanVideo15Pipeline + - all + - __call__ + +## HunyuanVideo15ImageToVideoPipeline + +[[autodoc]] HunyuanVideo15ImageToVideoPipeline + - all + - __call__ + +## HunyuanVideo15PipelineOutput + +[[autodoc]] pipelines.hunyuan_video1_5.pipeline_output.HunyuanVideo15PipelineOutput From 7aeab3f847393e646d502d859b85c12e76cc9c87 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 30 Nov 2025 20:39:49 +0000 Subject: [PATCH 22/34] add to toctree --- docs/source/en/_toctree.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a71bc7d864a1..eaa440066319 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -361,6 +361,8 @@ title: HunyuanImageTransformer2DModel - local: api/models/hunyuan_video_transformer_3d title: HunyuanVideoTransformer3DModel + - local: api/models/hunyuan_video15_transformer_3d + title: HunyuanVideo15Transformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel - local: api/models/ltx_video_transformer3d @@ -433,6 +435,8 @@ title: AutoencoderKLHunyuanImageRefiner - local: api/models/autoencoder_kl_hunyuan_video title: AutoencoderKLHunyuanVideo + - local: api/models/autoencoder_kl_hunyuan_video15 + title: AutoencoderKLHunyuanVideo15 - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_magvit @@ -652,6 +656,8 @@ title: Framepack - local: api/pipelines/hunyuan_video title: HunyuanVideo + - local: api/pipelines/hunyuan_video15 + title: HunyuanVideo1.5 - local: api/pipelines/i2vgenxl title: I2VGen-XL - local: api/pipelines/kandinsky5_video From c3f45982b6f93294a6ce6a307409c27d7947d5d0 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Sun, 30 Nov 2025 20:40:31 +0000 Subject: [PATCH 23/34] up --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index eaa440066319..d2b4a0de915b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -359,10 +359,10 @@ title: HunyuanDiT2DModel - local: api/models/hunyuanimage_transformer_2d title: HunyuanImageTransformer2DModel - - local: api/models/hunyuan_video_transformer_3d - title: HunyuanVideoTransformer3DModel - local: api/models/hunyuan_video15_transformer_3d title: HunyuanVideo15Transformer3DModel + - local: api/models/hunyuan_video_transformer_3d + title: HunyuanVideoTransformer3DModel - local: api/models/latte_transformer3d title: LatteTransformer3DModel - local: api/models/ltx_video_transformer3d From 54f008e30b57306b6e22477769f2635c21b12a7b Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 30 Nov 2025 13:17:31 -1000 Subject: [PATCH 24/34] Update docs/source/en/api/pipelines/hunyuan_video15.md --- docs/source/en/api/pipelines/hunyuan_video15.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index 2d74088355c1..f7c7147f80ce 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -12,13 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -
-
- - LoRA - -
-
# HunyuanVideo-1.5 From 237d318e85136c36dfc9551f6681f636f79c7c30 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 30 Nov 2025 13:22:43 -1000 Subject: [PATCH 25/34] Apply suggestions from code review Co-authored-by: Sayak Paul --- .../models/autoencoders/autoencoder_kl_hunyuanvideo15.py | 5 +---- .../models/transformers/transformer_hunyuan_video15.py | 7 ++----- .../pipelines/hunyuan_video1_5/image_processor.py | 2 +- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py index 7d6a636a240b..4b1beb74a3bc 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -215,8 +215,6 @@ class HunyuanVideo15Downsample(nn.Module): def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True): super().__init__() factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2 - assert out_channels % factor == 0 - # self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3) self.add_temporal_downsample = add_temporal_downsample @@ -531,7 +529,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.mid_block(hidden_states) - # short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2) batch_size, _, frame, height, width = hidden_states.shape short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2) @@ -546,7 +543,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class HunyuanVideo15Decoder3D(nn.Module): r""" - Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner. + Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner. """ def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 8f191e75009a..b870b15dad96 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -184,10 +184,7 @@ class HunyuanVideo15TimeEmbedding(nn.Module): The dimension of the output embedding. """ - def __init__( - self, - embedding_dim: int, - ): + def __init__(self, embedding_dim: int): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) @@ -362,7 +359,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] axes_grids = [] - for i in range(3): + for i in range(len(rope_sizes)): # Note: The following line diverges from original behaviour. We create the grid on the device, whereas # original implementation creates it on CPU and then moves it to device. This results in numerical # differences in layerwise debugging outputs, but visually it is the same. diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py index b0c778235257..82817365b6a5 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -34,7 +34,7 @@ def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): return crop_size_list -# copied fromhttps://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38 +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38 def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): """ Get the closest ratio in the buckets. From d7f399d1b261021effacb7587aa455da7a1fefe5 Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Mon, 1 Dec 2025 00:33:00 +0000 Subject: [PATCH 26/34] add a notes on the doc about attention backend --- docs/source/en/api/pipelines/hunyuan_video15.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index f7c7147f80ce..531741d92426 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -52,6 +52,18 @@ video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0] export_to_video(video, "output.mp4", fps=15) ``` +## Notes + +- HunyuanVideo1.5 use attention masks with avariable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently. + + - **H100/H800:** `_flash_3_hub` or `_flash_varlen_3` + - **A100/A800/RTX 4090:** `flash` or `flash_varlen` + - **Other GPUs:** `sage` + +```py +pipe.transformer.set_attention_backend("flash_varlen") # or your preferred backend +``` + ## HunyuanVideo15Pipeline From bdfab30766bf0bddf3aa5dd3487e18bfc263b78d Mon Sep 17 00:00:00 2001 From: "yiyi@huggingface.co" Date: Mon, 1 Dec 2025 00:35:30 +0000 Subject: [PATCH 27/34] up --- docs/source/en/api/pipelines/hunyuan_video15.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index 531741d92426..9a9bdcb352bd 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -60,6 +60,9 @@ export_to_video(video, "output.mp4", fps=15) - **A100/A800/RTX 4090:** `flash` or `flash_varlen` - **Other GPUs:** `sage` +Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend. + + ```py pipe.transformer.set_attention_backend("flash_varlen") # or your preferred backend ``` From 2c018f8be623994a853be5731b215a80e802d122 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 30 Nov 2025 14:43:51 -1000 Subject: [PATCH 28/34] Update docs/source/en/api/pipelines/hunyuan_video15.md --- docs/source/en/api/pipelines/hunyuan_video15.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index 9a9bdcb352bd..033c8490a024 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -54,7 +54,7 @@ export_to_video(video, "output.mp4", fps=15) ## Notes -- HunyuanVideo1.5 use attention masks with avariable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently. +- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently. - **H100/H800:** `_flash_3_hub` or `_flash_varlen_3` - **A100/A800/RTX 4090:** `flash` or `flash_varlen` From c7154707097f78e66a8cf7390950d856c8e4f5f2 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 1 Dec 2025 02:06:13 +0100 Subject: [PATCH 29/34] add a note on changing guidance_scale on doc --- .../en/api/pipelines/hunyuan_video15.md | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index 033c8490a024..fcd75b2aa974 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -67,6 +67,41 @@ Refer to the [Attention backends](../../optimization/attention_backends) guide f pipe.transformer.set_attention_backend("flash_varlen") # or your preferred backend ``` +- [`HunyuanVideo15Pipeline`] use guider and does not take `guidance_scale` parameter at runtime. + +You can check the default guider configuration using `pipe.guider`: + +```py +>>> pipe.guider +ClassifierFreeGuidance { + "_class_name": "ClassifierFreeGuidance", + "_diffusers_version": "0.36.0.dev0", + "enabled": true, + "guidance_rescale": 0.0, + "guidance_scale": 6.0, + "start": 0.0, + "stop": 1.0, + "use_original_formulation": false +} + +State: + step: None + num_inference_steps: None + timestep: None + count_prepared: 0 + enabled: True + num_conditions: 2 +``` + +To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)` + +```py +pipe.guider = pipe.guider.new(guidance_scale=5.0) +``` + +Read more on Guider [here](../../modular_diffusers/guider). + + ## HunyuanVideo15Pipeline From 0dae8f956de3c2d3c63a904f4f90342bc6e49769 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 30 Nov 2025 15:10:16 -1000 Subject: [PATCH 30/34] Apply suggestions from code review --- docs/source/en/api/pipelines/hunyuan_video15.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index fcd75b2aa974..b99114d792e6 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -57,8 +57,8 @@ export_to_video(video, "output.mp4", fps=15) - HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently. - **H100/H800:** `_flash_3_hub` or `_flash_varlen_3` - - **A100/A800/RTX 4090:** `flash` or `flash_varlen` - - **Other GPUs:** `sage` + - **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen` + - **Other GPUs:** `sage_hub` Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend. From 5989014cfe7d5010eb39d99f4d1597c3cda27ae9 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 30 Nov 2025 15:21:17 -1000 Subject: [PATCH 31/34] Update docs/source/en/api/pipelines/hunyuan_video15.md --- docs/source/en/api/pipelines/hunyuan_video15.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index b99114d792e6..4e35232211a8 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -64,7 +64,7 @@ Refer to the [Attention backends](../../optimization/attention_backends) guide f ```py -pipe.transformer.set_attention_backend("flash_varlen") # or your preferred backend +pipe.transformer.set_attention_backend("flash_hub") # or your preferred backend ``` - [`HunyuanVideo15Pipeline`] use guider and does not take `guidance_scale` parameter at runtime. From 404d3fa9a5b8b82d55ca45d7177fdefebefcfd26 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 30 Nov 2025 15:24:18 -1000 Subject: [PATCH 32/34] Update docs/source/en/api/pipelines/hunyuan_video15.md --- docs/source/en/api/pipelines/hunyuan_video15.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md index 4e35232211a8..d86b9f37b25a 100644 --- a/docs/source/en/api/pipelines/hunyuan_video15.md +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -99,7 +99,7 @@ To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)` pipe.guider = pipe.guider.new(guidance_scale=5.0) ``` -Read more on Guider [here](../../modular_diffusers/guider). +Read more on Guider [here](../../modular_diffusers/guiders). From 0869b22796d5ab8def4cb2447600e864a1913499 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 1 Dec 2025 11:30:54 +0530 Subject: [PATCH 33/34] tests for Hunyuan 1.5 (#12759) * start tests. * up * up * style. * up --- .../test_models_transformer_hunyuan_1_5.py | 100 +++++++++ tests/pipelines/hunyuan_video1_5/__init__.py | 1 + .../hunyuan_video1_5/test_hunyuan_1_5.py | 195 ++++++++++++++++++ 3 files changed, 296 insertions(+) create mode 100644 tests/models/transformers/test_models_transformer_hunyuan_1_5.py create mode 100644 tests/pipelines/hunyuan_video1_5/__init__.py create mode 100644 tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py diff --git a/tests/models/transformers/test_models_transformer_hunyuan_1_5.py b/tests/models/transformers/test_models_transformer_hunyuan_1_5.py new file mode 100644 index 000000000000..021fcdc9cfbf --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hunyuan_1_5.py @@ -0,0 +1,100 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HunyuanVideo15Transformer3DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideo15Transformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + text_embed_dim = 16 + text_embed_2_dim = 8 + image_embed_dim = 12 + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 1 + height = 8 + width = 8 + sequence_length = 6 + sequence_length_2 = 4 + image_sequence_length = 3 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device) + encoder_hidden_states_2 = torch.randn( + (batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device + ) + encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device) + encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device) + # All zeros for inducing T2V path in the model. + image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "encoder_hidden_states_2": encoder_hidden_states_2, + "encoder_attention_mask_2": encoder_attention_mask_2, + "image_embeds": image_embeds, + } + + @property + def input_shape(self): + return (4, 1, 8, 8) + + @property + def output_shape(self): + return (4, 1, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 8, + "num_layers": 2, + "num_refiner_layers": 1, + "mlp_ratio": 2.0, + "patch_size": 1, + "patch_size_t": 1, + "text_embed_dim": self.text_embed_dim, + "text_embed_2_dim": self.text_embed_2_dim, + "image_embed_dim": self.image_embed_dim, + "rope_axes_dim": (2, 2, 4), + "target_size": 16, + "task_type": "t2v", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideo15Transformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/hunyuan_video1_5/__init__.py b/tests/pipelines/hunyuan_video1_5/__init__.py new file mode 100644 index 000000000000..8fb044d9cf83 --- /dev/null +++ b/tests/pipelines/hunyuan_video1_5/__init__.py @@ -0,0 +1 @@ +# Copyright 2025 The HuggingFace Team. diff --git a/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py b/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py new file mode 100644 index 000000000000..2d8cc8f257f6 --- /dev/null +++ b/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py @@ -0,0 +1,195 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import ByT5Tokenizer, Qwen2_5_VLTextConfig, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLHunyuanVideo15, + FlowMatchEulerDiscreteScheduler, + HunyuanVideo15Pipeline, + HunyuanVideo15Transformer3DModel, +) +from diffusers.guiders import ClassifierFreeGuidance + +from ...testing_utils import enable_full_determinism +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanVideo15Pipeline + params = frozenset( + [ + "prompt", + "negative_prompt", + "height", + "width", + "prompt_embeds", + "prompt_embeds_mask", + "negative_prompt_embeds", + "negative_prompt_embeds_mask", + "prompt_embeds_2", + "prompt_embeds_mask_2", + "negative_prompt_embeds_2", + "negative_prompt_embeds_mask_2", + ] + ) + batch_params = ["prompt", "negative_prompt"] + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict" + ] + ) + test_attention_slicing = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = False + supports_dduf = False + + def get_dummy_components(self, num_layers: int = 1): + torch.manual_seed(0) + transformer = HunyuanVideo15Transformer3DModel( + in_channels=9, + out_channels=4, + num_attention_heads=2, + attention_head_dim=8, + num_layers=num_layers, + num_refiner_layers=1, + mlp_ratio=2.0, + patch_size=1, + patch_size_t=1, + text_embed_dim=16, + text_embed_2_dim=32, + image_embed_dim=12, + rope_axes_dim=(2, 2, 4), + target_size=16, + task_type="t2v", + ) + + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo15( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(16, 16), + layers_per_block=1, + spatial_compression_ratio=4, + temporal_compression_ratio=2, + downsample_match_channel=False, + upsample_match_channel=False, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + torch.manual_seed(0) + qwen_config = Qwen2_5_VLTextConfig( + **{ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + } + ) + text_encoder = Qwen2_5_VLTextModel(qwen_config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_2 = ByT5Tokenizer() + + guider = ClassifierFreeGuidance(guidance_scale=1.0) + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "text_encoder_2": text_encoder_2.eval(), + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "guider": guider, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "monkey", + "generator": generator, + "num_inference_steps": 2, + "height": 16, + "width": 16, + "num_frames": 9, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + result = pipe(**inputs) + video = result.frames + + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + + # fmt: off + expected_slice = torch.tensor([0.4296, 0.5549, 0.3088, 0.9115, 0.5049, 0.7926, 0.5549, 0.8618, 0.5091, 0.5075, 0.7117, 0.5292, 0.7053, 0.4864, 0.5206, 0.3878]) + # fmt: on + + self.assertTrue( + torch.abs(generated_slice - expected_slice).max() < 1e-3, + f"output_slice: {generated_slice}, expected_slice: {expected_slice}", + ) + + @unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip("Needs to be revisited.") + def test_inference_batch_consistent(self): + super().test_inference_batch_consistent() + + @unittest.skip("Needs to be revisited.") + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical() + From 6bfb75a2fe37d4be178545a25280526662c6baa7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 1 Dec 2025 06:03:48 +0000 Subject: [PATCH 34/34] Apply style fixes --- tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py b/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py index 2d8cc8f257f6..993c7ef6e4bb 100644 --- a/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py +++ b/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py @@ -51,14 +51,7 @@ class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) batch_params = ["prompt", "negative_prompt"] - required_optional_params = frozenset( - [ - "num_inference_steps", - "generator", - "latents", - "return_dict" - ] - ) + required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"]) test_attention_slicing = False test_xformers_attention = False test_layerwise_casting = True @@ -192,4 +185,3 @@ def test_inference_batch_consistent(self): @unittest.skip("Needs to be revisited.") def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical() -