diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a71bc7d864a1..d2b4a0de915b 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -359,6 +359,8 @@ title: HunyuanDiT2DModel - local: api/models/hunyuanimage_transformer_2d title: HunyuanImageTransformer2DModel + - local: api/models/hunyuan_video15_transformer_3d + title: HunyuanVideo15Transformer3DModel - local: api/models/hunyuan_video_transformer_3d title: HunyuanVideoTransformer3DModel - local: api/models/latte_transformer3d @@ -433,6 +435,8 @@ title: AutoencoderKLHunyuanImageRefiner - local: api/models/autoencoder_kl_hunyuan_video title: AutoencoderKLHunyuanVideo + - local: api/models/autoencoder_kl_hunyuan_video15 + title: AutoencoderKLHunyuanVideo15 - local: api/models/autoencoderkl_ltx_video title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_magvit @@ -652,6 +656,8 @@ title: Framepack - local: api/pipelines/hunyuan_video title: HunyuanVideo + - local: api/pipelines/hunyuan_video15 + title: HunyuanVideo1.5 - local: api/pipelines/i2vgenxl title: I2VGen-XL - local: api/pipelines/kandinsky5_video diff --git a/docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md b/docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md new file mode 100644 index 000000000000..e82fe31380a5 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_kl_hunyuan_video15.md @@ -0,0 +1,36 @@ + + +# AutoencoderKLHunyuanVideo15 + +The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5) by Tencent. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLHunyuanVideo15 + +vae = AutoencoderKLHunyuanVideo15.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v", subfolder="vae", torch_dtype=torch.float32) + +# make sure to enable tiling to avoid OOM +vae.enable_tiling() +``` + +## AutoencoderKLHunyuanVideo15 + +[[autodoc]] AutoencoderKLHunyuanVideo15 + - decode + - encode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/hunyuan_video15_transformer_3d.md b/docs/source/en/api/models/hunyuan_video15_transformer_3d.md new file mode 100644 index 000000000000..5ad4c6f4643f --- /dev/null +++ b/docs/source/en/api/models/hunyuan_video15_transformer_3d.md @@ -0,0 +1,30 @@ + + +# HunyuanVideo15Transformer3DModel + +A Diffusion Transformer model for 3D video-like data used in [HunyuanVideo1.5](https://github.com/Tencent/HunyuanVideo1-1.5). + +The model can be loaded with the following code snippet. + +```python +from diffusers import HunyuanVideo15Transformer3DModel + +transformer = HunyuanVideo15Transformer3DModel.from_pretrained("hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v" subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## HunyuanVideo15Transformer3DModel + +[[autodoc]] HunyuanVideo15Transformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/hunyuan_video15.md b/docs/source/en/api/pipelines/hunyuan_video15.md new file mode 100644 index 000000000000..d86b9f37b25a --- /dev/null +++ b/docs/source/en/api/pipelines/hunyuan_video15.md @@ -0,0 +1,120 @@ + + + +# HunyuanVideo-1.5 + +HunyuanVideo-1.5 is a lightweight yet powerful video generation model that achieves state-of-the-art visual quality and motion coherence with only 8.3 billion parameters, enabling efficient inference on consumer-grade GPUs. This achievement is built upon several key components, including meticulous data curation, an advanced DiT architecture with selective and sliding tile attention (SSTA), enhanced bilingual understanding through glyph-aware text encoding, progressive pre-training and post-training, and an efficient video super-resolution network. Leveraging these designs, we developed a unified framework capable of high-quality text-to-video and image-to-video generation across multiple durations and resolutions. Extensive experiments demonstrate that this compact and proficient model establishes a new state-of-the-art among open-source models. + +You can find all the original HunyuanVideo checkpoints under the [Tencent](https://huggingface.co/tencent) organization. + +> [!TIP] +> Click on the HunyuanVideo models in the right sidebar for more examples of video generation tasks. +> +> The examples below use a checkpoint from [hunyuanvideo-community](https://huggingface.co/hunyuanvideo-community) because the weights are stored in a layout compatible with Diffusers. + +The example below demonstrates how to generate a video optimized for memory or inference speed. + + + + +Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. + + +```py +import torch +from diffusers import AutoModel, HunyuanVideo15Pipeline +from diffusers.utils import export_to_video + + +pipeline = HunyuanVideo15Pipeline.from_pretrained( + "HunyuanVideo-1.5-Diffusers-480p_t2v", + torch_dtype=torch.bfloat16, +) + +# model-offloading and tiling +pipeline.enable_model_cpu_offload() +pipeline.vae.enable_tiling() + +prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys." +video = pipeline(prompt=prompt, num_frames=61, num_inference_steps=30).frames[0] +export_to_video(video, "output.mp4", fps=15) +``` + +## Notes + +- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently. + + - **H100/H800:** `_flash_3_hub` or `_flash_varlen_3` + - **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen` + - **Other GPUs:** `sage_hub` + +Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend. + + +```py +pipe.transformer.set_attention_backend("flash_hub") # or your preferred backend +``` + +- [`HunyuanVideo15Pipeline`] use guider and does not take `guidance_scale` parameter at runtime. + +You can check the default guider configuration using `pipe.guider`: + +```py +>>> pipe.guider +ClassifierFreeGuidance { + "_class_name": "ClassifierFreeGuidance", + "_diffusers_version": "0.36.0.dev0", + "enabled": true, + "guidance_rescale": 0.0, + "guidance_scale": 6.0, + "start": 0.0, + "stop": 1.0, + "use_original_formulation": false +} + +State: + step: None + num_inference_steps: None + timestep: None + count_prepared: 0 + enabled: True + num_conditions: 2 +``` + +To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)` + +```py +pipe.guider = pipe.guider.new(guidance_scale=5.0) +``` + +Read more on Guider [here](../../modular_diffusers/guiders). + + + +## HunyuanVideo15Pipeline + +[[autodoc]] HunyuanVideo15Pipeline + - all + - __call__ + +## HunyuanVideo15ImageToVideoPipeline + +[[autodoc]] HunyuanVideo15ImageToVideoPipeline + - all + - __call__ + +## HunyuanVideo15PipelineOutput + +[[autodoc]] pipelines.hunyuan_video1_5.pipeline_output.HunyuanVideo15PipelineOutput diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py new file mode 100644 index 000000000000..38226f684a6d --- /dev/null +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -0,0 +1,850 @@ +import argparse +import json +import os +import pathlib + +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors.torch import load_file +from transformers import ( + AutoModel, + AutoTokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, +) + +from diffusers import ( + AutoencoderKLHunyuanVideo15, + ClassifierFreeGuidance, + FlowMatchEulerDiscreteScheduler, + HunyuanVideo15ImageToVideoPipeline, + HunyuanVideo15Pipeline, + HunyuanVideo15Transformer3DModel, +) + + +# to convert only transformer +""" +python scripts/convert_hunyuan_video1_5_to_diffusers.py \ + --original_state_dict_repo_id tencent/HunyuanVideo-1.5\ + --output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers/transformer\ + --transformer_type 480p_t2v +""" + +# to convert full pipeline +""" +python scripts/convert_hunyuan_video1_5_to_diffusers.py \ + --original_state_dict_repo_id tencent/HunyuanVideo-1.5\ + --output_path /fsx/yiyi/HunyuanVideo-1.5-Diffusers \ + --save_pipeline \ + --byt5_path /fsx/yiyi/hy15/text_encoder/Glyph-SDXL-v2\ + --transformer_type 480p_t2v +""" + + +TRANSFORMER_CONFIGS = { + "480p_t2v": { + "target_size": 640, + "task_type": "i2v", + }, + "720p_t2v": { + "target_size": 960, + "task_type": "t2v", + }, + "720p_i2v": { + "target_size": 960, + "task_type": "i2v", + }, + "480p_t2v_distilled": { + "target_size": 640, + "task_type": "t2v", + }, + "480p_i2v_distilled": { + "target_size": 640, + "task_type": "i2v", + }, + "720p_i2v_distilled": { + "target_size": 960, + "task_type": "i2v", + }, +} + +SCHEDULER_CONFIGS = { + "480p_t2v": { + "shift": 5.0, + }, + "480p_i2v": { + "shift": 5.0, + }, + "720p_t2v": { + "shift": 9.0, + }, + "720p_i2v": { + "shift": 7.0, + }, + "480p_t2v_distilled": { + "shift": 5.0, + }, + "480p_i2v_distilled": { + "shift": 5.0, + }, + "720p_i2v_distilled": { + "shift": 7.0, + }, +} + +GUIDANCE_CONFIGS = { + "480p_t2v": { + "guidance_scale": 6.0, + }, + "480p_i2v": { + "guidance_scale": 6.0, + }, + "720p_t2v": { + "guidance_scale": 6.0, + }, + "720p_i2v": { + "guidance_scale": 6.0, + }, + "480p_t2v_distilled": { + "guidance_scale": 1.0, + }, + "480p_i2v_distilled": { + "guidance_scale": 1.0, + }, + "720p_i2v_distilled": { + "guidance_scale": 1.0, + }, +} + + +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_hyvideo15_transformer_to_diffusers(original_state_dict): + """ + Convert HunyuanVideo 1.5 original checkpoint to Diffusers format. + """ + converted_state_dict = {} + + # 1. time_embed.timestep_embedder <- time_in + converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop( + "time_in.mlp.0.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop( + "time_in.mlp.2.weight" + ) + converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias") + + # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.0.weight") + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "txt_in.t_embedder.mlp.0.bias" + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = ( + original_state_dict.pop("txt_in.t_embedder.mlp.2.weight") + ) + converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "txt_in.t_embedder.mlp.2.bias" + ) + + # 3. context_embedder.time_text_embed.text_embedder <- txt_in.c_embedder + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop( + "txt_in.c_embedder.linear_1.weight" + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop( + "txt_in.c_embedder.linear_1.bias" + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop( + "txt_in.c_embedder.linear_2.weight" + ) + converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop( + "txt_in.c_embedder.linear_2.bias" + ) + + # 4. context_embedder.proj_in <- txt_in.input_embedder + converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight") + converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias") + + # 5. context_embedder.token_refiner <- txt_in.individual_token_refiner + num_refiner_blocks = 2 + for i in range(num_refiner_blocks): + block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}." + orig_prefix = f"txt_in.individual_token_refiner.blocks.{i}." + + # norm1 + converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(f"{orig_prefix}norm1.weight") + converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(f"{orig_prefix}norm1.bias") + + # Split self_attn_qkv into to_q, to_k, to_v + qkv_weight = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.weight") + qkv_bias = original_state_dict.pop(f"{orig_prefix}self_attn_qkv.bias") + q, k, v = torch.chunk(qkv_weight, 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0) + + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias + + # self_attn_proj -> attn.to_out.0 + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"{orig_prefix}self_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"{orig_prefix}self_attn_proj.bias" + ) + + # norm2 + converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(f"{orig_prefix}norm2.weight") + converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(f"{orig_prefix}norm2.bias") + + # mlp -> ff + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(f"{orig_prefix}mlp.fc2.bias") + + # adaLN_modulation -> norm_out + converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}adaLN_modulation.1.weight" + ) + converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}adaLN_modulation.1.bias" + ) + + # 6. context_embedder_2 <- byt5_in + converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight") + converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias") + converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight") + converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias") + converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight") + converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias") + converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight") + converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias") + + # 7. image_embedder <- vision_in + converted_state_dict["image_embedder.norm_in.weight"] = original_state_dict.pop("vision_in.proj.0.weight") + converted_state_dict["image_embedder.norm_in.bias"] = original_state_dict.pop("vision_in.proj.0.bias") + converted_state_dict["image_embedder.linear_1.weight"] = original_state_dict.pop("vision_in.proj.1.weight") + converted_state_dict["image_embedder.linear_1.bias"] = original_state_dict.pop("vision_in.proj.1.bias") + converted_state_dict["image_embedder.linear_2.weight"] = original_state_dict.pop("vision_in.proj.3.weight") + converted_state_dict["image_embedder.linear_2.bias"] = original_state_dict.pop("vision_in.proj.3.bias") + converted_state_dict["image_embedder.norm_out.weight"] = original_state_dict.pop("vision_in.proj.4.weight") + converted_state_dict["image_embedder.norm_out.bias"] = original_state_dict.pop("vision_in.proj.4.bias") + + # 8. x_embedder <- img_in + converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight") + converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias") + + # 9. cond_type_embed <- cond_type_embedding + converted_state_dict["cond_type_embed.weight"] = original_state_dict.pop("cond_type_embedding.weight") + + # 10. transformer_blocks <- double_blocks + num_layers = 54 + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + orig_prefix = f"double_blocks.{i}." + + # norm1 (img_mod) + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mod.linear.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mod.linear.bias" + ) + + # norm1_context (txt_mod) + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mod.linear.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mod.linear.bias" + ) + + # img attention (to_q, to_k, to_v) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_v.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_v.bias" + ) + + # img attention qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_q_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_k_norm.weight" + ) + + # img attention output projection + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"{orig_prefix}img_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"{orig_prefix}img_attn_proj.bias" + ) + + # txt attention (add_q_proj, add_k_proj, add_v_proj) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q.bias" + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k.bias" + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_v.weight" + ) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_v.bias" + ) + + # txt attention qk norm + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_q_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_k_norm.weight" + ) + + # txt attention output projection + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_attn_proj.bias" + ) + + # norm2 and norm2_context (these don't have weights in the original, they're LayerNorm with elementwise_affine=False) + # So we skip them + + # img_mlp -> ff + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( + f"{orig_prefix}img_mlp.fc2.bias" + ) + + # txt_mlp -> ff_context + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc1.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc1.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( + f"{orig_prefix}txt_mlp.fc2.bias" + ) + + # 11. norm_out and proj_out <- final_layer + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.bias") + ) + converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") + + return converted_state_dict + + +def convert_hunyuan_video_15_vae_checkpoint_to_diffusers( + original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2 +): + converted = {} + + # 1. Encoder + # 1.1 conv_in + converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight") + converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias") + + # 1.2 Down blocks + for down_block_index in range(len(block_out_channels)): # 0 to 4 + # ResNet blocks + for resnet_block_index in range(layers_per_block): # 0 to 1 + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma") + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = ( + original_state_dict.pop( + f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight" + ) + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias") + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma") + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = ( + original_state_dict.pop( + f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight" + ) + ) + converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias") + ) + + # Downsample (if exists) + if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict: + converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight") + ) + converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = ( + original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias") + ) + + # 1.3 Mid block + converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma") + converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop( + "encoder.mid.block_1.conv1.conv.weight" + ) + converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop( + "encoder.mid.block_1.conv1.conv.bias" + ) + converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma") + converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop( + "encoder.mid.block_1.conv2.conv.weight" + ) + converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop( + "encoder.mid.block_1.conv2.conv.bias" + ) + + converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma") + converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop( + "encoder.mid.block_2.conv1.conv.weight" + ) + converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop( + "encoder.mid.block_2.conv1.conv.bias" + ) + converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma") + converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop( + "encoder.mid.block_2.conv2.conv.weight" + ) + converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop( + "encoder.mid.block_2.conv2.conv.bias" + ) + + # Attention block + converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma") + converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight") + converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias") + converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight") + converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias") + converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight") + converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias") + converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop( + "encoder.mid.attn_1.proj_out.weight" + ) + converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop( + "encoder.mid.attn_1.proj_out.bias" + ) + + # 1.4 Encoder output + converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma") + converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight") + converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias") + + # 2. Decoder + # 2.1 conv_in + converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight") + converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias") + + # 2.2 Mid block + converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma") + converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop( + "decoder.mid.block_1.conv1.conv.weight" + ) + converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop( + "decoder.mid.block_1.conv1.conv.bias" + ) + converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma") + converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop( + "decoder.mid.block_1.conv2.conv.weight" + ) + converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop( + "decoder.mid.block_1.conv2.conv.bias" + ) + + converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma") + converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop( + "decoder.mid.block_2.conv1.conv.weight" + ) + converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop( + "decoder.mid.block_2.conv1.conv.bias" + ) + converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma") + converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop( + "decoder.mid.block_2.conv2.conv.weight" + ) + converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop( + "decoder.mid.block_2.conv2.conv.bias" + ) + + # Decoder attention block + converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma") + converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight") + converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias") + converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight") + converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias") + converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight") + converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias") + converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop( + "decoder.mid.attn_1.proj_out.weight" + ) + converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop( + "decoder.mid.attn_1.proj_out.bias" + ) + + # 2.3 Up blocks + for up_block_index in range(len(block_out_channels)): # 0 to 5 + # ResNet blocks + for resnet_block_index in range(layers_per_block + 1): # 0 to 2 (decoder has 3 resnets per level) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight") + ) + converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = ( + original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias") + ) + + # Upsample (if exists) + if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict: + converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop( + f"decoder.up.{up_block_index}.upsample.conv.conv.weight" + ) + converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop( + f"decoder.up.{up_block_index}.upsample.conv.conv.bias" + ) + + # 2.4 Decoder output + converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma") + converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight") + converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias") + + return converted + + +def load_sharded_safetensors(dir: pathlib.Path): + file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) + state_dict = {} + for path in file_paths: + state_dict.update(load_file(path)) + return state_dict + + +def load_original_transformer_state_dict(args): + if args.original_state_dict_repo_id is not None: + model_dir = snapshot_download( + args.original_state_dict_repo_id, + repo_type="model", + allow_patterns="transformer/" + args.transformer_type + "/*", + ) + elif args.original_state_dict_folder is not None: + model_dir = pathlib.Path(args.original_state_dict_folder) + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`") + model_dir = pathlib.Path(model_dir) + model_dir = model_dir / "transformer" / args.transformer_type + return load_sharded_safetensors(model_dir) + + +def load_original_vae_state_dict(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download( + repo_id=args.original_state_dict_repo_id, filename="vae/diffusion_pytorch_model.safetensors" + ) + elif args.original_state_dict_folder is not None: + model_dir = pathlib.Path(args.original_state_dict_folder) + ckpt_path = model_dir / "vae/diffusion_pytorch_model.safetensors" + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or `original_state_dict_folder`") + + original_state_dict = load_file(ckpt_path) + return original_state_dict + + +def convert_transformer(args): + original_state_dict = load_original_transformer_state_dict(args) + + config = TRANSFORMER_CONFIGS[args.transformer_type] + with init_empty_weights(): + transformer = HunyuanVideo15Transformer3DModel(**config) + state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict) + transformer.load_state_dict(state_dict, strict=True, assign=True) + + return transformer + + +def convert_vae(args): + original_state_dict = load_original_vae_state_dict(args) + with init_empty_weights(): + vae = AutoencoderKLHunyuanVideo15() + state_dict = convert_hunyuan_video_15_vae_checkpoint_to_diffusers(original_state_dict) + vae.load_state_dict(state_dict, strict=True, assign=True) + return vae + + +def load_mllm(): + print(" loading from Qwen/Qwen2.5-VL-7B-Instruct") + text_encoder = AutoModel.from_pretrained( + "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True + ) + if hasattr(text_encoder, "language_model"): + text_encoder = text_encoder.language_model + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", padding_side="right") + return text_encoder, tokenizer + + +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/hyvideo/models/text_encoders/byT5/__init__.py#L89 +def add_special_token( + tokenizer, + text_encoder, + add_color=True, + add_font=True, + multilingual=True, + color_ann_path="assets/color_idx.json", + font_ann_path="assets/multilingual_10-lang_idx.json", +): + """ + Add special tokens for color and font to tokenizer and text encoder. + + Args: + tokenizer: Huggingface tokenizer. + text_encoder: Huggingface T5 encoder. + add_color (bool): Whether to add color tokens. + add_font (bool): Whether to add font tokens. + color_ann_path (str): Path to color annotation JSON. + font_ann_path (str): Path to font annotation JSON. + multilingual (bool): Whether to use multilingual font tokens. + """ + with open(font_ann_path, "r") as f: + idx_font_dict = json.load(f) + with open(color_ann_path, "r") as f: + idx_color_dict = json.load(f) + + if multilingual: + font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict] + else: + font_token = [f"" for i in range(len(idx_font_dict))] + color_token = [f"" for i in range(len(idx_color_dict))] + additional_special_tokens = [] + if add_color: + additional_special_tokens += color_token + if add_font: + additional_special_tokens += font_token + + tokenizer.add_tokens(additional_special_tokens, special_tokens=True) + # Set mean_resizing=False to avoid PyTorch LAPACK dependency + text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) + + +def load_byt5(args): + """ + Load ByT5 encoder with Glyph-SDXL-v2 weights and save in HuggingFace format. + """ + + # 1. Load base tokenizer and encoder + tokenizer = AutoTokenizer.from_pretrained("google/byt5-small") + + # Load as T5EncoderModel + encoder = T5EncoderModel.from_pretrained("google/byt5-small") + + byt5_checkpoint_path = os.path.join(args.byt5_path, "checkpoints/byt5_model.pt") + color_ann_path = os.path.join(args.byt5_path, "assets/color_idx.json") + font_ann_path = os.path.join(args.byt5_path, "assets/multilingual_10-lang_idx.json") + + # 2. Add special tokens + add_special_token( + tokenizer=tokenizer, + text_encoder=encoder, + add_color=True, + add_font=True, + color_ann_path=color_ann_path, + font_ann_path=font_ann_path, + multilingual=True, + ) + + # 3. Load Glyph-SDXL-v2 checkpoint + print(f"\n3. Loading Glyph-SDXL-v2 checkpoint: {byt5_checkpoint_path}") + checkpoint = torch.load(byt5_checkpoint_path, map_location="cpu") + + # Handle different checkpoint formats + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + + # add 'encoder.' prefix to the keys + # Remove 'module.text_tower.encoder.' prefix if present + cleaned_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("module.text_tower.encoder."): + new_key = "encoder." + key[len("module.text_tower.encoder.") :] + cleaned_state_dict[new_key] = value + else: + new_key = "encoder." + key + cleaned_state_dict[new_key] = value + + # 4. Load weights + missing_keys, unexpected_keys = encoder.load_state_dict(cleaned_state_dict, strict=False) + if unexpected_keys: + raise ValueError(f"Unexpected keys: {unexpected_keys}") + if "shared.weight" in missing_keys: + print(" Missing shared.weight as expected") + missing_keys.remove("shared.weight") + if missing_keys: + raise ValueError(f"Missing keys: {missing_keys}") + + return encoder, tokenizer + + +def load_siglip(): + image_encoder = SiglipVisionModel.from_pretrained( + "black-forest-labs/FLUX.1-Redux-dev", subfolder="image_encoder", torch_dtype=torch.bfloat16 + ) + feature_extractor = SiglipImageProcessor.from_pretrained( + "black-forest-labs/FLUX.1-Redux-dev", subfolder="feature_extractor" + ) + return image_encoder, feature_extractor + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--original_state_dict_repo_id", type=str, default=None, help="Path to original hub_id for the model" + ) + parser.add_argument( + "--original_state_dict_folder", type=str, default=None, help="Local folder name of the original state dict" + ) + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model(s) should be saved") + parser.add_argument("--transformer_type", type=str, default="480p_i2v", choices=list(TRANSFORMER_CONFIGS.keys())) + parser.add_argument( + "--byt5_path", + type=str, + default=None, + help=( + "path to the downloaded byt5 checkpoint & assets. " + "Note: They use Glyph-SDXL-v2 as byt5 encoder. You can download from modelscope like: " + "`modelscope download --model AI-ModelScope/Glyph-SDXL-v2 --local_dir ./ckpts/text_encoder/Glyph-SDXL-v2` " + "or manually download following the instructions on " + "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/910da2a829c484ea28982e8cff3bbc2cacdf1681/checkpoints-download.md. " + "The path should point to the Glyph-SDXL-v2 folder which should contain an `assets` folder and a `checkpoints` folder, " + "like: Glyph-SDXL-v2/assets/... and Glyph-SDXL-v2/checkpoints/byt5_model.pt" + ), + ) + parser.add_argument("--save_pipeline", action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + if args.save_pipeline and args.byt5_path is None: + raise ValueError("Please provide --byt5_path when saving pipeline") + + transformer = None + + transformer = convert_transformer(args) + if not args.save_pipeline: + transformer.save_pretrained(args.output_path, safe_serialization=True) + else: + task_type = transformer.config.task_type + + vae = convert_vae(args) + + text_encoder, tokenizer = load_mllm() + text_encoder_2, tokenizer_2 = load_byt5(args) + + flow_shift = SCHEDULER_CONFIGS[args.transformer_type]["shift"] + scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) + + guidance_scale = GUIDANCE_CONFIGS[args.transformer_type]["guidance_scale"] + guider = ClassifierFreeGuidance(guidance_scale=guidance_scale) + + if task_type == "i2v": + image_encoder, feature_extractor = load_siglip() + pipeline = HunyuanVideo15ImageToVideoPipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + guider=guider, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + elif task_type == "t2v": + pipeline = HunyuanVideo15Pipeline( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + guider=guider, + scheduler=scheduler, + ) + else: + raise ValueError(f"Task type {task_type} is not supported") + + pipeline.save_pretrained(args.output_path, safe_serialization=True) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f2d1840da222..02dd42e4a580 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -190,6 +190,7 @@ "AutoencoderKLHunyuanImage", "AutoencoderKLHunyuanImageRefiner", "AutoencoderKLHunyuanVideo", + "AutoencoderKLHunyuanVideo15", "AutoencoderKLLTXVideo", "AutoencoderKLMagvit", "AutoencoderKLMochi", @@ -225,6 +226,7 @@ "HunyuanDiT2DModel", "HunyuanDiT2DMultiControlNetModel", "HunyuanImageTransformer2DModel", + "HunyuanVideo15Transformer3DModel", "HunyuanVideoFramepackTransformer3DModel", "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", @@ -481,6 +483,8 @@ "HunyuanImagePipeline", "HunyuanImageRefinerPipeline", "HunyuanSkyreelsImageToVideoPipeline", + "HunyuanVideo15ImageToVideoPipeline", + "HunyuanVideo15Pipeline", "HunyuanVideoFramepackPipeline", "HunyuanVideoImageToVideoPipeline", "HunyuanVideoPipeline", @@ -909,6 +913,7 @@ AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, + AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -944,6 +949,7 @@ HunyuanDiT2DModel, HunyuanDiT2DMultiControlNetModel, HunyuanImageTransformer2DModel, + HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, I2VGenXLUNet, @@ -1170,6 +1176,8 @@ HunyuanImagePipeline, HunyuanImageRefinerPipeline, HunyuanSkyreelsImageToVideoPipeline, + HunyuanVideo15ImageToVideoPipeline, + HunyuanVideo15Pipeline, HunyuanVideoFramepackPipeline, HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 09b2b731b5c4..8b60b269324f 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -39,6 +39,7 @@ _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"] _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"] + _import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] @@ -96,6 +97,7 @@ _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] + _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] @@ -147,6 +149,7 @@ AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, AutoencoderKLHunyuanVideo, + AutoencoderKLHunyuanVideo15, AutoencoderKLLTXVideo, AutoencoderKLMagvit, AutoencoderKLMochi, @@ -199,6 +202,7 @@ HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, + HunyuanVideo15Transformer3DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, Kandinsky5Transformer3DModel, diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0c247b76d039..3660e8d1d3ac 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -282,6 +282,7 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke backend = AttentionBackendName(backend) _check_attention_backend_requirements(backend) + _maybe_download_kernel_for_backend(backend) old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry._active_backend = backend diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 470979ad33a7..56df27f93cd7 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -8,6 +8,7 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner +from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15 from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_mochi import AutoencoderKLMochi diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py new file mode 100644 index 000000000000..4b1beb74a3bc --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanvideo15.py @@ -0,0 +1,967 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanVideo15CausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + return self.conv(hidden_states) + + +class HunyuanVideo15RMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class HunyuanVideo15AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = HunyuanVideo15RMS_norm(in_channels, images=False) + + self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1) + + @staticmethod + def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None): + """Prepare a causal attention mask for 3D videos. + + Args: + n_frame (int): Number of frames (temporal length). + n_hw (int): Product of height and width. + dtype: Desired mask dtype. + device: Device for the mask. + batch_size (int, optional): If set, expands for batch. + + Returns: + torch.Tensor: Causal attention mask. + """ + seq_len = n_frame * n_hw + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // n_hw + mask[i, : (i_frame + 1) * n_hw] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + def forward(self, x: torch.Tensor) -> torch.Tensor: + identity = x + + x = self.norm(x) + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + batch_size, channels, frames, height, width = query.shape + + query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous() + + attention_mask = self.prepare_causal_attention_mask( + frames, height * width, query.dtype, query.device, batch_size=batch_size + ) + + x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + # batch_size, 1, frames * height * width, channels + + x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3) + x = self.proj_out(x) + + return x + identity + + +class HunyuanVideo15Upsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2 + self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels * factor, kernel_size=3) + + self.add_temporal_upsample = add_temporal_upsample + self.repeats = factor * out_channels // in_channels + + @staticmethod + def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w) + + Args: + tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w) + r1: temporal upsampling factor + r2: height upsampling factor + r3: width upsampling factor + """ + b, packed_c, f, h, w = tensor.shape + factor = r1 * r2 * r3 + c = packed_c // factor + + tensor = tensor.view(b, r1, r2, r3, c, f, h, w) + tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3) + return tensor.reshape(b, c, f * r1, h * r2, w * r3) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_upsample else 1 + h = self.conv(x) + if self.add_temporal_upsample: + h_first = h[:, :, :1, :, :] + h_first = self._dcae_upsample_rearrange(h_first, r1=1, r2=2, r3=2) + h_first = h_first[:, : h_first.shape[1] // 2] + h_next = h[:, :, 1:, :, :] + h_next = self._dcae_upsample_rearrange(h_next, r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = self._dcae_upsample_rearrange(x_first, r1=1, r2=2, r3=2) + x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1) + + x_next = x[:, :, 1:, :, :] + x_next = self._dcae_upsample_rearrange(x_next, r1=r1, r2=2, r3=2) + x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = torch.cat([x_first, x_next], dim=2) + + else: + h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2) + return h + shortcut + + +class HunyuanVideo15Downsample(nn.Module): + def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True): + super().__init__() + factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2 + self.conv = HunyuanVideo15CausalConv3d(in_channels, out_channels // factor, kernel_size=3) + + self.add_temporal_downsample = add_temporal_downsample + self.group_size = factor * in_channels // out_channels + + @staticmethod + def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2): + """ + Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w) + + This packs spatial/temporal dimensions into channels (opposite of upsample) + """ + b, c, packed_f, packed_h, packed_w = tensor.shape + f, h, w = packed_f // r1, packed_h // r2, packed_w // r3 + + tensor = tensor.view(b, c, f, r1, h, r2, w, r3) + tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6) + return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w) + + def forward(self, x: torch.Tensor): + r1 = 2 if self.add_temporal_downsample else 1 + h = self.conv(x) + if self.add_temporal_downsample: + h_first = h[:, :, :1, :, :] + h_first = self._dcae_downsample_rearrange(h_first, r1=1, r2=2, r3=2) + h_first = torch.cat([h_first, h_first], dim=1) + h_next = h[:, :, 1:, :, :] + h_next = self._dcae_downsample_rearrange(h_next, r1=r1, r2=2, r3=2) + h = torch.cat([h_first, h_next], dim=2) + + # shortcut computation + x_first = x[:, :, :1, :, :] + x_first = self._dcae_downsample_rearrange(x_first, r1=1, r2=2, r3=2) + B, C, T, H, W = x_first.shape + x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2) + x_next = x[:, :, 1:, :, :] + x_next = self._dcae_downsample_rearrange(x_next, r1=r1, r2=2, r3=2) + B, C, T, H, W = x_next.shape + x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + shortcut = torch.cat([x_first, x_next], dim=2) + else: + h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2) + shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2) + B, C, T, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2) + + return h + shortcut + + +class HunyuanVideo15ResnetBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + non_linearity: str = "swish", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = HunyuanVideo15RMS_norm(in_channels, images=False) + self.conv1 = HunyuanVideo15CausalConv3d(in_channels, out_channels, kernel_size=3) + + self.norm2 = HunyuanVideo15RMS_norm(out_channels, images=False) + self.conv2 = HunyuanVideo15CausalConv3d(out_channels, out_channels, kernel_size=3) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return hidden_states + residual + + +class HunyuanVideo15MidBlock(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int = 1, + add_attention: bool = True, + ) -> None: + super().__init__() + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ] + attentions = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append(HunyuanVideo15AttnBlock(in_channels)) + else: + attentions.append(None) + + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=in_channels, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideo15DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + downsample_out_channels: Optional[int] = None, + add_temporal_downsample: int = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=in_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if downsample_out_channels is not None: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideo15Downsample( + out_channels, + out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideo15UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + upsample_out_channels: Optional[int] = None, + add_temporal_upsample: bool = True, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideo15ResnetBlock( + in_channels=input_channels, + out_channels=out_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if upsample_out_channels is not None: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideo15Upsample( + out_channels, + out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideo15Encoder3D(nn.Module): + r""" + 3D vae encoder for HunyuanImageRefiner. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 64, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 16, + downsample_match_channel: bool = True, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.group_size = block_out_channels[-1] // self.out_channels + + self.conv_in = HunyuanVideo15CausalConv3d(in_channels, block_out_channels[0], kernel_size=3) + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + add_spatial_downsample = i < np.log2(spatial_compression_ratio) + output_channel = block_out_channels[i] + if not add_spatial_downsample: + down_block = HunyuanVideo15DownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=None, + add_temporal_downsample=False, + ) + input_channel = output_channel + else: + add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio) + downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel + down_block = HunyuanVideo15DownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + downsample_out_channels=downsample_out_channels, + add_temporal_downsample=add_temporal_downsample, + ) + input_channel = downsample_out_channels + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[-1]) + + self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + batch_size, _, frame, height, width = hidden_states.shape + short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2) + + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + hidden_states += short_cut + + return hidden_states + + +class HunyuanVideo15Decoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data used for HunyuanImage-1.5 Refiner. + """ + + def __init__( + self, + in_channels: int = 32, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + upsample_match_channel: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.in_channels = in_channels + self.out_channels = out_channels + self.repeat = block_out_channels[0] // self.in_channels + + self.conv_in = HunyuanVideo15CausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideo15MidBlock(in_channels=block_out_channels[0]) + + # up + input_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + output_channel = block_out_channels[i] + + add_spatial_upsample = i < np.log2(spatial_compression_ratio) + add_temporal_upsample = i < np.log2(temporal_compression_ratio) + if add_spatial_upsample or add_temporal_upsample: + upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel + up_block = HunyuanVideo15UpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=upsample_out_channels, + add_temporal_upsample=add_temporal_upsample, + ) + input_channel = upsample_out_channels + else: + up_block = HunyuanVideo15UpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + upsample_out_channels=None, + add_temporal_upsample=False, + ) + input_channel = output_channel + + self.up_blocks.append(up_block) + + # out + self.norm_out = HunyuanVideo15RMS_norm(block_out_channels[-1], images=False) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideo15CausalConv3d(block_out_channels[-1], out_channels, kernel_size=3) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for + HunyuanVideo-1.5. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 32, + block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024), + layers_per_block: int = 2, + spatial_compression_ratio: int = 16, + temporal_compression_ratio: int = 4, + downsample_match_channel: bool = True, + upsample_match_channel: bool = True, + scaling_factor: float = 1.03682, + ) -> None: + super().__init__() + + self.encoder = HunyuanVideo15Encoder3D( + in_channels=in_channels, + out_channels=latent_channels * 2, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + downsample_match_channel=downsample_match_channel, + ) + + self.decoder = HunyuanVideo15Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + layers_per_block=layers_per_block, + temporal_compression_ratio=temporal_compression_ratio, + spatial_compression_ratio=spatial_compression_ratio, + upsample_match_channel=upsample_match_channel, + ) + + self.spatial_compression_ratio = spatial_compression_ratio + self.temporal_compression_ratio = temporal_compression_ratio + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal tile height and width in latent space + self.tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio + self.tile_overlap_factor = 0.25 + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_latent_min_height: Optional[int] = None, + tile_latent_min_width: Optional[int] = None, + tile_overlap_factor: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_latent_min_height (`int`, *optional*): + The minimum height required for a latent to be separated into tiles across the height dimension. + tile_latent_min_width (`int`, *optional*): + The minimum width required for a latent to be separated into tiles across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_latent_min_height = tile_latent_min_height or self.tile_latent_min_height + self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width + self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + x = self.encoder(x) + return x + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + _, _, _, height, width = z.shape + + if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + return self.tiled_decode(z) + + dec = self.decoder(z) + + return dec + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, _, height, width = x.shape + + overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192 + blend_height = int(self.tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2 + blend_width = int(self.tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2 + row_limit_height = self.tile_latent_min_height - blend_height # 8 - 2 = 6 + row_limit_width = self.tile_latent_min_width - blend_width # 8 - 2 = 6 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + moments = torch.cat(result_rows, dim=-2) + + return moments + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + _, _, _, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6 + blend_height = int(self.tile_sample_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64 + blend_width = int(self.tile_sample_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64 + row_limit_height = self.tile_sample_min_height - blend_height # 256 - 64 = 192 + row_limit_width = self.tile_sample_min_width - blend_width # 256 - 64 = 192 + + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=-2) + + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d0794dc321a8..67a555b676d5 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,6 +29,7 @@ from .transformer_flux2 import Flux2Transformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel + from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_hunyuanimage import HunyuanImageTransformer2DModel from .transformer_kandinsky import Kandinsky5Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py new file mode 100644 index 000000000000..b870b15dad96 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -0,0 +1,836 @@ +# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.loaders import FromOriginalModelMixin + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..attention_processor import Attention, AttentionProcessor +from ..cache_utils import CacheMixin +from ..embeddings import ( + CombinedTimestepTextProjEmbeddings, + TimestepEmbedding, + Timesteps, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class HunyuanVideo15AttnProcessor2_0: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HunyuanVideo15AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + # 2. QK normalization + query = attn.norm_q(query) + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + from ..embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + # 4. Encoder condition QKV projection and normalization + if encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=1) + key = torch.cat([key, encoder_key], dim=1) + value = torch.cat([value, encoder_value], dim=1) + + batch_size, seq_len, heads, dim = query.shape + attention_mask = F.pad(attention_mask, (seq_len - attention_mask.shape[1], 0), value=True) + attention_mask = attention_mask.bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + attention_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + + # 5. Attention + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + + +class HunyuanVideo15PatchEmbed(nn.Module): + def __init__( + self, + patch_size: Union[int, Tuple[int, int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + super().__init__() + + patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.proj(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC + return hidden_states + + +class HunyuanVideo15AdaNorm(nn.Module): + def __init__(self, in_features: int, out_features: Optional[int] = None) -> None: + super().__init__() + + out_features = out_features or 2 * in_features + self.linear = nn.Linear(in_features, out_features) + self.nonlinearity = nn.SiLU() + + def forward( + self, temb: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + temb = self.linear(self.nonlinearity(temb)) + gate_msa, gate_mlp = temb.chunk(2, dim=1) + gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1) + return gate_msa, gate_mlp + + +class HunyuanVideo15TimeEmbedding(nn.Module): + r""" + Time embedding for HunyuanVideo 1.5. + + Supports standard timestep embedding and optional reference timestep embedding for MeanFlow-based super-resolution + models. + + Args: + embedding_dim (`int`): + The dimension of the output embedding. + """ + + def __init__(self, embedding_dim: int): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + + def forward( + self, + timestep: torch.Tensor, + ) -> torch.Tensor: + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype)) + + return timesteps_emb + + +class HunyuanVideo15IndividualTokenRefinerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_width_ratio: str = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate) + + self.norm_out = HunyuanVideo15AdaNorm(hidden_size, 2 * hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + ) + + gate_msa, gate_mlp = self.norm_out(temb) + hidden_states = hidden_states + attn_output * gate_msa + + ff_output = self.ff(self.norm2(hidden_states)) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +class HunyuanVideo15IndividualTokenRefiner(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + self.refiner_blocks = nn.ModuleList( + [ + HunyuanVideo15IndividualTokenRefinerBlock( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> None: + self_attn_mask = None + if attention_mask is not None: + batch_size = attention_mask.shape[0] + seq_len = attention_mask.shape[1] + attention_mask = attention_mask.to(hidden_states.device).bool() + self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1) + self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) + self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() + + for block in self.refiner_blocks: + hidden_states = block(hidden_states, temb, self_attn_mask) + + return hidden_states + + +class HunyuanVideo15TokenRefiner(nn.Module): + def __init__( + self, + in_channels: int, + num_attention_heads: int, + attention_head_dim: int, + num_layers: int, + mlp_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + attention_bias: bool = True, + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.time_text_embed = CombinedTimestepTextProjEmbeddings( + embedding_dim=hidden_size, pooled_projection_dim=in_channels + ) + self.proj_in = nn.Linear(in_channels, hidden_size, bias=True) + self.token_refiner = HunyuanVideo15IndividualTokenRefiner( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_layers=num_layers, + mlp_width_ratio=mlp_ratio, + mlp_drop_rate=mlp_drop_rate, + attention_bias=attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + if attention_mask is None: + pooled_projections = hidden_states.mean(dim=1) + else: + original_dtype = hidden_states.dtype + mask_float = attention_mask.float().unsqueeze(-1) + pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) + pooled_projections = pooled_projections.to(original_dtype) + + temb = self.time_text_embed(timestep, pooled_projections) + hidden_states = self.proj_in(hidden_states) + hidden_states = self.token_refiner(hidden_states, temb, attention_mask) + + return hidden_states + + +class HunyuanVideo15RotaryPosEmbed(nn.Module): + def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.rope_dim = rope_dim + self.theta = theta + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size] + + axes_grids = [] + for i in range(len(rope_sizes)): + # Note: The following line diverges from original behaviour. We create the grid on the device, whereas + # original implementation creates it on CPU and then moves it to device. This results in numerical + # differences in layerwise debugging outputs, but visually it is the same. + grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32) + axes_grids.append(grid) + grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T] + grid = torch.stack(grid, dim=0) # [3, W, H, T] + + freqs = [] + for i in range(3): + freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True) + freqs.append(freq) + + freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2) + freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2) + return freqs_cos, freqs_sin + + +class HunyuanVideo15ByT5TextProjection(nn.Module): + def __init__(self, in_features: int, hidden_size: int, out_features: int): + super().__init__() + self.norm = nn.LayerNorm(in_features) + self.linear_1 = nn.Linear(in_features, hidden_size) + self.linear_2 = nn.Linear(hidden_size, hidden_size) + self.linear_3 = nn.Linear(hidden_size, out_features) + self.act_fn = nn.GELU() + + def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm(encoder_hidden_states) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_3(hidden_states) + return hidden_states + + +class HunyuanVideo15ImageProjection(nn.Module): + def __init__(self, in_channels: int, hidden_size: int): + super().__init__() + self.norm_in = nn.LayerNorm(in_channels) + self.linear_1 = nn.Linear(in_channels, in_channels) + self.act_fn = nn.GELU() + self.linear_2 = nn.Linear(in_channels, hidden_size) + self.norm_out = nn.LayerNorm(hidden_size) + + def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: + hidden_states = self.norm_in(image_embeds) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.norm_out(hidden_states) + return hidden_states + + +class HunyuanVideo15TransformerBlock(nn.Module): + def __init__( + self, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float, + qk_norm: str = "rms_norm", + ) -> None: + super().__init__() + + hidden_size = num_attention_heads * attention_head_dim + + self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm") + + self.attn = Attention( + query_dim=hidden_size, + cross_attention_dim=None, + added_kv_proj_dim=hidden_size, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=hidden_size, + context_pre_only=False, + bias=True, + processor=HunyuanVideo15AttnProcessor2_0(), + qk_norm=qk_norm, + eps=1e-6, + ) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Input normalization + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + + # 2. Joint attention + attn_output, context_attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=freqs_cis, + ) + + # 3. Modulation and residual connection + hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1) + + norm_hidden_states = self.norm2(hidden_states) + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + context_ff_output = self.ff_context(norm_encoder_hidden_states) + + hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + + return hidden_states, encoder_hidden_states + + +class HunyuanVideo15Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): + r""" + A Transformer model for video-like data used in [HunyuanVideo1.5](https://huggingface.co/tencent/HunyuanVideo1.5). + + Args: + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `20`): + The number of layers of dual-stream blocks to use. + num_refiner_layers (`int`, defaults to `2`): + The number of layers of refiner blocks to use. + mlp_ratio (`float`, defaults to `4.0`): + The ratio of the hidden layer size to the input size in the feedforward network. + patch_size (`int`, defaults to `2`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + qk_norm (`str`, defaults to `rms_norm`): + The normalization to use for the query and key projections in the attention layers. + guidance_embeds (`bool`, defaults to `True`): + Whether to use guidance embeddings in the model. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + pooled_projection_dim (`int`, defaults to `768`): + The dimension of the pooled projection of the text embeddings. + rope_theta (`float`, defaults to `256.0`): + The value of theta to use in the RoPE layer. + rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions of the axes to use in the RoPE layer. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"] + _no_split_modules = [ + "HunyuanVideo15TransformerBlock", + "HunyuanVideo15PatchEmbed", + "HunyuanVideo15TokenRefiner", + ] + _repeated_blocks = [ + "HunyuanVideo15TransformerBlock", + "HunyuanVideo15PatchEmbed", + "HunyuanVideo15TokenRefiner", + ] + + @register_to_config + def __init__( + self, + in_channels: int = 65, + out_channels: int = 32, + num_attention_heads: int = 16, + attention_head_dim: int = 128, + num_layers: int = 54, + num_refiner_layers: int = 2, + mlp_ratio: float = 4.0, + patch_size: int = 1, + patch_size_t: int = 1, + qk_norm: str = "rms_norm", + text_embed_dim: int = 3584, + text_embed_2_dim: int = 1472, + image_embed_dim: int = 1152, + rope_theta: float = 256.0, + rope_axes_dim: Tuple[int, ...] = (16, 56, 56), + # YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205 + target_size: int = 640, # did not name sample_size since it is in pixel spaces + task_type: str = "i2v", + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Latent and condition embedders + self.x_embedder = HunyuanVideo15PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim) + self.image_embedder = HunyuanVideo15ImageProjection(image_embed_dim, inner_dim) + + self.context_embedder = HunyuanVideo15TokenRefiner( + text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers + ) + self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim) + + self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim) + + self.cond_type_embed = nn.Embedding(3, inner_dim) + + # 2. RoPE + self.rope = HunyuanVideo15RotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta) + + # 3. Dual stream transformer blocks + + self.transformer_blocks = nn.ModuleList( + [ + HunyuanVideo15TransformerBlock( + num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + ) + for _ in range(num_layers) + ] + ) + + # 5. Output projection + self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + encoder_hidden_states_2: Optional[torch.Tensor] = None, + encoder_attention_mask_2: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_embed(timestep) + + hidden_states = self.x_embedder(hidden_states) + + # qwen text embedding + encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask) + + encoder_hidden_states_cond_emb = self.cond_type_embed( + torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long) + ) + encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb + + # byt5 text embedding + encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2) + + encoder_hidden_states_2_cond_emb = self.cond_type_embed( + torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long) + ) + encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb + + # image embed + encoder_hidden_states_3 = self.image_embedder(image_embeds) + is_t2v = torch.all(image_embeds == 0) + if is_t2v: + encoder_hidden_states_3 = encoder_hidden_states_3 * 0.0 + encoder_attention_mask_3 = torch.zeros( + (batch_size, encoder_hidden_states_3.shape[1]), + dtype=encoder_attention_mask.dtype, + device=encoder_attention_mask.device, + ) + else: + encoder_attention_mask_3 = torch.ones( + (batch_size, encoder_hidden_states_3.shape[1]), + dtype=encoder_attention_mask.dtype, + device=encoder_attention_mask.device, + ) + encoder_hidden_states_3_cond_emb = self.cond_type_embed( + 2 + * torch.ones_like( + encoder_hidden_states_3[:, :, 0], + dtype=torch.long, + ) + ) + encoder_hidden_states_3 = encoder_hidden_states_3 + encoder_hidden_states_3_cond_emb + + # reorder and combine text tokens: combine valid tokens first, then padding + encoder_attention_mask = encoder_attention_mask.bool() + encoder_attention_mask_2 = encoder_attention_mask_2.bool() + encoder_attention_mask_3 = encoder_attention_mask_3.bool() + new_encoder_hidden_states = [] + new_encoder_attention_mask = [] + + for text, text_mask, text_2, text_mask_2, image, image_mask in zip( + encoder_hidden_states, + encoder_attention_mask, + encoder_hidden_states_2, + encoder_attention_mask_2, + encoder_hidden_states_3, + encoder_attention_mask_3, + ): + # Concatenate: [valid_image, valid_byt5, valid_mllm, invalid_image, invalid_byt5, invalid_mllm] + new_encoder_hidden_states.append( + torch.cat( + [ + image[image_mask], # valid image + text_2[text_mask_2], # valid byt5 + text[text_mask], # valid mllm + image[~image_mask], # invalid image + torch.zeros_like(text_2[~text_mask_2]), # invalid byt5 (zeroed) + torch.zeros_like(text[~text_mask]), # invalid mllm (zeroed) + ], + dim=0, + ) + ) + + # Apply same reordering to attention masks + new_encoder_attention_mask.append( + torch.cat( + [ + image_mask[image_mask], + text_mask_2[text_mask_2], + text_mask[text_mask], + image_mask[~image_mask], + text_mask_2[~text_mask_2], + text_mask[~text_mask], + ], + dim=0, + ) + ) + + encoder_hidden_states = torch.stack(new_encoder_hidden_states) + encoder_attention_mask = torch.stack(new_encoder_attention_mask) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, + encoder_hidden_states, + temb, + encoder_attention_mask, + image_rotary_emb, + ) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p_h, p_w + ) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b4043cd146b4..cf86456642eb 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -243,6 +243,7 @@ "HunyuanVideoImageToVideoPipeline", "HunyuanVideoFramepackPipeline", ] + _import_structure["hunyuan_video1_5"] = ["HunyuanVideo15Pipeline", "HunyuanVideo15ImageToVideoPipeline"] _import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"] _import_structure["kandinsky"] = [ "KandinskyCombinedPipeline", @@ -665,6 +666,7 @@ HunyuanVideoImageToVideoPipeline, HunyuanVideoPipeline, ) + from .hunyuan_video1_5 import HunyuanVideo15ImageToVideoPipeline, HunyuanVideo15Pipeline from .hunyuandit import HunyuanDiTPipeline from .i2vgen_xl import I2VGenXLPipeline from .kandinsky import ( diff --git a/src/diffusers/pipelines/hunyuan_video1_5/__init__.py b/src/diffusers/pipelines/hunyuan_video1_5/__init__.py new file mode 100644 index 000000000000..846320f4ace0 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_hunyuan_video1_5"] = ["HunyuanVideo15Pipeline"] + _import_structure["pipeline_hunyuan_video1_5_image2video"] = ["HunyuanVideo15ImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline + from .pipeline_hunyuan_video1_5_image2video import HunyuanVideo15ImageToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py new file mode 100644 index 000000000000..82817365b6a5 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/image_processor.py @@ -0,0 +1,103 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...configuration_utils import register_to_config +from ...video_processor import VideoProcessor + + +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L20 +def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0): + num_patches = round((base_size / patch_size) ** 2) + assert max_ratio >= 1.0 + crop_size_list = [] + wp, hp = num_patches, 1 + while wp > 0: + if max(wp, hp) / min(wp, hp) <= max_ratio: + crop_size_list.append((wp * patch_size, hp * patch_size)) + if (hp + 1) * wp <= num_patches: + hp += 1 + else: + wp -= 1 + return crop_size_list + + +# copied from https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5/blob/main/hyvideo/utils/data_utils.py#L38 +def get_closest_ratio(height: float, width: float, ratios: list, buckets: list): + """ + Get the closest ratio in the buckets. + + Args: + height (float): video height + width (float): video width + ratios (list): video aspect ratio + buckets (list): buckets generated by `generate_crop_size_list` + + Returns: + the closest size in the buckets and the corresponding ratio + """ + aspect_ratio = float(height) / float(width) + diff_ratios = ratios - aspect_ratio + + if aspect_ratio >= 1: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0] + else: + indices = [(index, x) for index, x in enumerate(diff_ratios) if x >= 0] + + closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0] + closest_size = buckets[closest_ratio_id] + closest_ratio = ratios[closest_ratio_id] + + return closest_size, closest_ratio + + +class HunyuanVideo15ImageProcessor(VideoProcessor): + r""" + Image/video processor to preproces/postprocess the reference image/generatedvideo for the HunyuanVideo1.5 model. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. + vae_scale_factor (`int`, *optional*, defaults to `16`): + VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of + this factor. + vae_latent_channels (`int`, *optional*, defaults to `32`): + VAE latent channels. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + """ + + @register_to_config + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 16, + vae_latent_channels: int = 32, + do_convert_rgb: bool = True, + ): + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + do_convert_rgb=do_convert_rgb, + ) + + def calculate_default_height_width(self, height: int, width: int, target_size: int): + crop_size_list = generate_crop_size_list(base_size=target_size, patch_size=self.config.vae_scale_factor) + aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list]) + height, width = get_closest_ratio(height, width, aspect_ratios, crop_size_list)[0] + + return height, width diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py new file mode 100644 index 000000000000..00a703939004 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py @@ -0,0 +1,837 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel + +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import HunyuanVideo15ImageProcessor +from .pipeline_output import HunyuanVideo15PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideo15Pipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_t2v" + >>> pipe = HunyuanVideo15Pipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (List[str]): Input text. + system_message (str): System message. + + Returns: + List[Dict[str, Any]]: List of chat conversation. + """ + + template = [ + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] + + return template + + +def extract_glyph_texts(prompt: str) -> List[str]: + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r"\"(.*?)\"|“(.*?)”" + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + + return formatted_result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo1.5. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`HunyuanVideo15Transformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo15`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`ClassifierFreeGuidance`]): + [ClassifierFreeGuidance]for classifier free guidance. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanVideo15Transformer3DModel, + vae: AutoencoderKLHunyuanVideo15, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + guider: ClassifierFreeGuidance, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + guider=guider, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 + self.vision_states_dim = ( + self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + ) + self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 + # fmt: off + self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on + self.prompt_template_encode_start_idx = 108 + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 256 + self.vision_num_semantic_tokens = 729 + self.default_aspect_ratio = (16, 9) # (width: height) + + @staticmethod + def _get_mllm_prompt_embeds( + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + prompt: Union[str, List[str]], + device: torch.device, + tokenizer_max_length: int = 1000, + num_hidden_layers_to_skip: int = 2, + # fmt: off + system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start: int = 108, + ) -> Tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + @staticmethod + def _get_byt5_prompt_embeds( + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + device: torch.device, + tokenizer_max_length: int = 256, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0) + + return prompt_embeds, prompt_embeds_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + batch_size: int = 1, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_embeds_mask_2: Optional[torch.Tensor] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + system_message=self.system_message, + crop_start=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + ): + if height is None and width is not None: + raise ValueError("If `width` is provided, `height` also have to be provided.") + elif width is None and height is not None: + raise ValueError("If `height` is provided, `width` also have to be provided.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_cond_latents_and_mask(self, latents, dtype: Optional[torch.dtype], device: Optional[torch.device]): + """ + Prepare conditional latents and mask for t2v generation. + + Args: + latents: Main latents tensor (B, C, F, H, W) + + Returns: + tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v + """ + batch, channels, frames, height, width = latents.shape + + cond_latents_concat = torch.zeros(batch, channels, frames, height, width, dtype=dtype, device=device) + + mask_concat = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device) + + return cond_latents_concat, mask_concat + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 121, + num_inference_steps: int = 50, + sigmas: List[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_embeds_mask_2: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + height (`int`, *optional*): + The height in pixels of the generated video. + width (`int`, *optional*): + The width in pixels of the generated video. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings from the second text encoder. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the second text encoder. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings from the second text encoder. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between "np", "pt", or "latent". + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~HunyuanVideo15PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated videos. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + if height is None and width is None: + height, width = self.video_processor.calculate_default_height_width( + self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size + ) + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + if self.guider._enabled and self.guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + prompt=negative_prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + self.num_channels_latents, + height, + width, + num_frames, + self.transformer.dtype, + device, + generator, + latents, + ) + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask(latents, self.transformer.dtype, device) + image_embeds = torch.zeros( + batch_size, + self.vision_num_semantic_tokens, + self.vision_states_dim, + dtype=self.transformer.dtype, + device=device, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = self.guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + image_embeds=image_embeds, + timestep=timestep, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + self.guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + # 8. decode the latents to video and postprocess + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideo15PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py new file mode 100644 index 000000000000..9e9f20c79eba --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -0,0 +1,950 @@ +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import ( + ByT5Tokenizer, + Qwen2_5_VLTextModel, + Qwen2Tokenizer, + SiglipImageProcessor, + SiglipVisionModel, + T5EncoderModel, +) + +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLHunyuanVideo15, HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import HunyuanVideo15ImageProcessor +from .pipeline_output import HunyuanVideo15PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideo15ImageToVideoPipeline + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo-1.5-480p_i2v" + >>> pipe = HunyuanVideo15ImageToVideoPipeline.from_pretrained(model_id, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG") + + >>> output = pipe( + ... prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + ... image=image, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input +def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (List[str]): Input text. + system_message (str): System message. + + Returns: + List[Dict[str, Any]]: List of chat conversation. + """ + + template = [ + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] + + return template + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts +def extract_glyph_texts(prompt: str) -> List[str]: + """ + Extract glyph texts from prompt using regex pattern. + + Args: + prompt: Input prompt string + + Returns: + List of extracted glyph texts + """ + pattern = r"\"(.*?)\"|“(.*?)”" + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + + return formatted_result + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using HunyuanVideo1.5. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`HunyuanVideo15Transformer3DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo15`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer]. + text_encoder_2 ([`T5EncoderModel`]): + [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) + variant. + tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer] + guider ([`ClassifierFreeGuidance`]): + [ClassifierFreeGuidance]for classifier free guidance. + image_encoder ([`SiglipVisionModel`]): + [SiglipVisionModel](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipVisionModel) + variant. + feature_extractor ([`SiglipImageProcessor`]): + [SiglipImageProcessor](https://huggingface.co/docs/transformers/en/model_doc/siglip#transformers.SiglipImageProcessor) + variant. + """ + + model_cpu_offload_seq = "image_encoder->text_encoder->transformer->vae" + + def __init__( + self, + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + transformer: HunyuanVideo15Transformer3DModel, + vae: AutoencoderKLHunyuanVideo15, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: T5EncoderModel, + tokenizer_2: ByT5Tokenizer, + guider: ClassifierFreeGuidance, + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + guider=guider, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 + self.video_processor = HunyuanVideo15ImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_resize=False, do_convert_rgb=True + ) + self.target_size = self.transformer.config.target_size if getattr(self, "transformer", None) else 640 + self.vision_states_dim = ( + self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + ) + self.num_channels_latents = self.vae.config.latent_channels if hasattr(self, "vae") else 32 + # fmt: off + self.system_message = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on + self.prompt_template_encode_start_idx = 108 + self.tokenizer_max_length = 1000 + self.tokenizer_2_max_length = 256 + self.vision_num_semantic_tokens = 729 + + @staticmethod + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds + def _get_mllm_prompt_embeds( + text_encoder: Qwen2_5_VLTextModel, + tokenizer: Qwen2Tokenizer, + prompt: Union[str, List[str]], + device: torch.device, + tokenizer_max_length: int = 1000, + num_hidden_layers_to_skip: int = 2, + # fmt: off + system_message: str = "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start: int = 108, + ) -> Tuple[torch.Tensor, torch.Tensor]: + prompt = [prompt] if isinstance(prompt, str) else prompt + + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + @staticmethod + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds + def _get_byt5_prompt_embeds( + tokenizer: ByT5Tokenizer, + text_encoder: T5EncoderModel, + prompt: Union[str, List[str]], + device: torch.device, + tokenizer_max_length: int = 256, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + prompt_embeds = torch.cat(prompt_embeds_list, dim=0) + prompt_embeds_mask = torch.cat(prompt_embeds_mask_list, dim=0) + + return prompt_embeds, prompt_embeds_mask + + @staticmethod + def _get_image_latents( + vae: AutoencoderKLHunyuanVideo15, + image_processor: HunyuanVideo15ImageProcessor, + image: PIL.Image.Image, + height: int, + width: int, + device: torch.device, + ) -> torch.Tensor: + vae_dtype = vae.dtype + image_tensor = image_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype) + image_tensor = image_tensor.unsqueeze(2) + image_latents = retrieve_latents(vae.encode(image_tensor), sample_mode="argmax") + image_latents = image_latents * vae.config.scaling_factor + return image_latents + + @staticmethod + def _get_image_embeds( + image_encoder: SiglipVisionModel, + feature_extractor: SiglipImageProcessor, + image: PIL.Image.Image, + device: torch.device, + ) -> torch.Tensor: + image_encoder_dtype = next(image_encoder.parameters()).dtype + image = feature_extractor.preprocess(images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True) + image = image.to(device=device, dtype=image_encoder_dtype) + image_enc_hidden_states = image_encoder(**image).last_hidden_state + + return image_enc_hidden_states + + def encode_image( + self, + image: PIL.Image.Image, + batch_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + image_embeds = self._get_image_embeds( + image_encoder=self.image_encoder, + feature_extractor=self.feature_extractor, + image=image, + device=device, + ) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(device=device, dtype=dtype) + return image_embeds + + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + batch_size: int = 1, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_embeds_mask_2: Optional[torch.Tensor] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + batch_size (`int`): + batch size of prompts, defaults to 1 + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input + argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input + argument using self.tokenizer_2 and self.text_encoder_2. + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_mllm_prompt_embeds( + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_max_length, + system_message=self.system_message, + crop_start=self.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = self._get_byt5_prompt_embeds( + tokenizer=self.tokenizer_2, + text_encoder=self.text_encoder_2, + prompt=prompt, + device=device, + tokenizer_max_length=self.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_videos_per_prompt, seq_len_2) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + def check_inputs( + self, + prompt, + image: PIL.Image.Image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + negative_prompt_embeds_2=None, + negative_prompt_embeds_mask_2=None, + ): + if not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `PIL.Image.Image` but is {type(image)}") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) + + if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None: + raise ValueError( + "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`." + ) + if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None: + raise ValueError( + "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`." + ) + + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def prepare_cond_latents_and_mask( + self, + latents: torch.Tensor, + image: PIL.Image.Image, + batch_size: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + ): + """ + Prepare conditional latents and mask for t2v generation. + + Args: + latents: Main latents tensor (B, C, F, H, W) + + Returns: + tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v + """ + + batch, channels, frames, height, width = latents.shape + + image_latents = self._get_image_latents( + vae=self.vae, + image_processor=self.video_processor, + image=image, + height=height, + width=width, + device=device, + ) + + latent_condition = image_latents.repeat(batch_size, 1, frames, 1, 1) + latent_condition[:, :, 1:, :, :] = 0 + latent_condition = latent_condition.to(device=device, dtype=dtype) + + latent_mask = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device) + latent_mask[:, :, 0, :, :] = 1.0 + + return latent_condition, latent_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PIL.Image.Image, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + num_frames: int = 121, + num_inference_steps: int = 50, + sigmas: List[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_embeds_mask_2: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PIL.Image.Image`): + The input image to condition video generation on. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. + num_frames (`int`, defaults to `121`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_prompt_embeds_mask (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated text embeddings from the second text encoder. Can be used to easily tweak text inputs. + prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for prompt embeddings from the second text encoder. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings from the second text encoder. + negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*): + Pre-generated mask for negative prompt embeddings from the second text encoder. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between "np", "pt", or "latent". + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideo15PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~HunyuanVideo15PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideo15PipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated videos. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + negative_prompt_embeds_2=negative_prompt_embeds_2, + negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + height, width = self.video_processor.calculate_default_height_width( + height=image.size[1], width=image.size[0], target_size=self.target_size + ) + image = self.video_processor.resize(image, height=height, width=width, resize_mode="crop") + + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode image + image_embeds = self.encode_image( + image=image, + batch_size=batch_size * num_videos_per_prompt, + device=device, + dtype=self.transformer.dtype, + ) + + # 4. Encode input prompt + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt( + prompt=prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + prompt_embeds_2=prompt_embeds_2, + prompt_embeds_mask_2=prompt_embeds_mask_2, + ) + + if self.guider._enabled and self.guider.num_conditions > 1: + ( + negative_prompt_embeds, + negative_prompt_embeds_mask, + negative_prompt_embeds_2, + negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + prompt=negative_prompt, + device=device, + dtype=self.transformer.dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + prompt_embeds_2=negative_prompt_embeds_2, + prompt_embeds_mask_2=negative_prompt_embeds_mask_2, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 6. Prepare latent variables + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + num_frames=num_frames, + dtype=self.transformer.dtype, + device=device, + generator=generator, + latents=latents, + ) + + cond_latents_concat, mask_concat = self.prepare_cond_latents_and_mask( + latents=latents, + image=image, + batch_size=batch_size * num_videos_per_prompt, + height=height, + width=width, + dtype=self.transformer.dtype, + device=device, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, cond_latents_concat, mask_concat], dim=1) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + # Step 1: Collect model inputs needed for the guidance method + # conditional inputs should always be first element in the tuple + guider_inputs = { + "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds), + "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask), + "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2), + "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2), + } + + # Step 2: Update guider's internal state for this denoising step + self.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) + + # Step 3: Prepare batched model inputs based on the guidance method + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = self.guider.prepare_inputs(guider_inputs) + # Step 4: Run the denoiser for each batch + # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.). + # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred. + for guider_state_batch in guider_state: + self.guider.prepare_models(self.transformer) + + # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states) + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + # e.g. "pred_cond"/"pred_uncond" + context_name = getattr(guider_state_batch, self.guider._identifier_key) + with self.transformer.cache_context(context_name): + # Run denoiser and store noise prediction in this batch + guider_state_batch.noise_pred = self.transformer( + hidden_states=latent_model_input, + image_embeds=image_embeds, + timestep=timestep, + attention_kwargs=self.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + # Cleanup model (e.g., remove hooks) + self.guider.cleanup_models(self.transformer) + + # Step 5: Combine predictions using the guidance method + # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm. + # Continuing the CFG example, the guider receives: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0 + # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1 + # ] + # And extracts predictions using the __guidance_identifier__: + # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond + # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond + # Then applies CFG formula: + # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond) + noise_pred = self.guider(guider_state)[0] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return HunyuanVideo15PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py new file mode 100644 index 000000000000..441164db5a09 --- /dev/null +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class HunyuanVideo15PipelineOutput(BaseOutput): + r""" + Output class for HunyuanVideo1.5 pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 73854b38190e..fe9a4b30f0c1 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -468,6 +468,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLHunyuanVideo15(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLLTXVideo(metaclass=DummyObject): _backends = ["torch"] @@ -993,6 +1008,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class HunyuanVideo15Transformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e6cf26a12544..65306b839015 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1142,6 +1142,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class HunyuanVideo15ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class HunyuanVideo15Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class HunyuanVideoFramepackPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_hunyuan_1_5.py b/tests/models/transformers/test_models_transformer_hunyuan_1_5.py new file mode 100644 index 000000000000..021fcdc9cfbf --- /dev/null +++ b/tests/models/transformers/test_models_transformer_hunyuan_1_5.py @@ -0,0 +1,100 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HunyuanVideo15Transformer3DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase): + model_class = HunyuanVideo15Transformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + text_embed_dim = 16 + text_embed_2_dim = 8 + image_embed_dim = 12 + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 1 + height = 8 + width = 8 + sequence_length = 6 + sequence_length_2 = 4 + image_sequence_length = 3 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device) + encoder_hidden_states_2 = torch.randn( + (batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device + ) + encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device) + encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device) + # All zeros for inducing T2V path in the model. + image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "encoder_attention_mask": encoder_attention_mask, + "encoder_hidden_states_2": encoder_hidden_states_2, + "encoder_attention_mask_2": encoder_attention_mask_2, + "image_embeds": image_embeds, + } + + @property + def input_shape(self): + return (4, 1, 8, 8) + + @property + def output_shape(self): + return (4, 1, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 8, + "num_layers": 2, + "num_refiner_layers": 1, + "mlp_ratio": 2.0, + "patch_size": 1, + "patch_size_t": 1, + "text_embed_dim": self.text_embed_dim, + "text_embed_2_dim": self.text_embed_2_dim, + "image_embed_dim": self.image_embed_dim, + "rope_axes_dim": (2, 2, 4), + "target_size": 16, + "task_type": "t2v", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"HunyuanVideo15Transformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/hunyuan_video1_5/__init__.py b/tests/pipelines/hunyuan_video1_5/__init__.py new file mode 100644 index 000000000000..8fb044d9cf83 --- /dev/null +++ b/tests/pipelines/hunyuan_video1_5/__init__.py @@ -0,0 +1 @@ +# Copyright 2025 The HuggingFace Team. diff --git a/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py b/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py new file mode 100644 index 000000000000..993c7ef6e4bb --- /dev/null +++ b/tests/pipelines/hunyuan_video1_5/test_hunyuan_1_5.py @@ -0,0 +1,187 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import ByT5Tokenizer, Qwen2_5_VLTextConfig, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLHunyuanVideo15, + FlowMatchEulerDiscreteScheduler, + HunyuanVideo15Pipeline, + HunyuanVideo15Transformer3DModel, +) +from diffusers.guiders import ClassifierFreeGuidance + +from ...testing_utils import enable_full_determinism +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class HunyuanVideo15PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = HunyuanVideo15Pipeline + params = frozenset( + [ + "prompt", + "negative_prompt", + "height", + "width", + "prompt_embeds", + "prompt_embeds_mask", + "negative_prompt_embeds", + "negative_prompt_embeds_mask", + "prompt_embeds_2", + "prompt_embeds_mask_2", + "negative_prompt_embeds_2", + "negative_prompt_embeds_mask_2", + ] + ) + batch_params = ["prompt", "negative_prompt"] + required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"]) + test_attention_slicing = False + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = False + supports_dduf = False + + def get_dummy_components(self, num_layers: int = 1): + torch.manual_seed(0) + transformer = HunyuanVideo15Transformer3DModel( + in_channels=9, + out_channels=4, + num_attention_heads=2, + attention_head_dim=8, + num_layers=num_layers, + num_refiner_layers=1, + mlp_ratio=2.0, + patch_size=1, + patch_size_t=1, + text_embed_dim=16, + text_embed_2_dim=32, + image_embed_dim=12, + rope_axes_dim=(2, 2, 4), + target_size=16, + task_type="t2v", + ) + + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo15( + in_channels=3, + out_channels=3, + latent_channels=4, + block_out_channels=(16, 16), + layers_per_block=1, + spatial_compression_ratio=4, + temporal_compression_ratio=2, + downsample_match_channel=False, + upsample_match_channel=False, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + torch.manual_seed(0) + qwen_config = Qwen2_5_VLTextConfig( + **{ + "hidden_size": 16, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "rope_scaling": { + "mrope_section": [1, 1, 2], + "rope_type": "default", + "type": "default", + }, + "rope_theta": 1000000.0, + } + ) + text_encoder = Qwen2_5_VLTextModel(qwen_config) + tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration") + + torch.manual_seed(0) + text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer_2 = ByT5Tokenizer() + + guider = ClassifierFreeGuidance(guidance_scale=1.0) + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "text_encoder_2": text_encoder_2.eval(), + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "guider": guider, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "monkey", + "generator": generator, + "num_inference_steps": 2, + "height": 16, + "width": 16, + "num_frames": 9, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + result = pipe(**inputs) + video = result.frames + + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + + # fmt: off + expected_slice = torch.tensor([0.4296, 0.5549, 0.3088, 0.9115, 0.5049, 0.7926, 0.5549, 0.8618, 0.5091, 0.5075, 0.7117, 0.5292, 0.7053, 0.4864, 0.5206, 0.3878]) + # fmt: on + + self.assertTrue( + torch.abs(generated_slice - expected_slice).max() < 1e-3, + f"output_slice: {generated_slice}, expected_slice: {expected_slice}", + ) + + @unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.") + def test_encode_prompt_works_in_isolation(self): + pass + + @unittest.skip("Needs to be revisited.") + def test_inference_batch_consistent(self): + super().test_inference_batch_consistent() + + @unittest.skip("Needs to be revisited.") + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical()