From 646e55bdbefc9713ea6a97190375d750c9a3e041 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 7 Dec 2025 11:38:58 +0100 Subject: [PATCH 1/2] support step-distilled --- .../convert_hunyuan_video1_5_to_diffusers.py | 26 +++++++++++++++++-- .../transformer_hunyuan_video15.py | 21 ++++++++++++--- .../pipeline_hunyuan_video1_5_image2video.py | 10 +++++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 38226f684a6d..713f1d620e4c 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -69,6 +69,11 @@ "target_size": 960, "task_type": "i2v", }, + "480p_i2v_step_distilled": { + "target_size": 640, + "task_type": "i2v", + "use_meanflow": True, + }, } SCHEDULER_CONFIGS = { @@ -93,6 +98,9 @@ "720p_i2v_distilled": { "shift": 7.0, }, + "480p_i2v_step_distilled": { + "shift": 7.0, + }, } GUIDANCE_CONFIGS = { @@ -117,6 +125,9 @@ "720p_i2v_distilled": { "guidance_scale": 1.0, }, + "480p_i2v_step_distilled": { + "guidance_scale": 1.0, + }, } @@ -126,7 +137,7 @@ def swap_scale_shift(weight): return new_weight -def convert_hyvideo15_transformer_to_diffusers(original_state_dict): +def convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=None): """ Convert HunyuanVideo 1.5 original checkpoint to Diffusers format. """ @@ -142,6 +153,17 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict): ) converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias") + if config.use_meanflow: + + converted_state_dict["time_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop( + "time_r_in.mlp.0.weight" + ) + converted_state_dict["time_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop("time_r_in.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop( + "time_r_in.mlp.2.weight" + ) + converted_state_dict["time_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop("time_r_in.mlp.2.bias") + # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = ( original_state_dict.pop("txt_in.t_embedder.mlp.0.weight") @@ -627,7 +649,7 @@ def convert_transformer(args): config = TRANSFORMER_CONFIGS[args.transformer_type] with init_empty_weights(): transformer = HunyuanVideo15Transformer3DModel(**config) - state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict) + state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=transformer.config) transformer.load_state_dict(state_dict, strict=True, assign=True) return transformer diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 76a02cb1a886..293ba996ea98 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -184,19 +184,32 @@ class HunyuanVideo15TimeEmbedding(nn.Module): The dimension of the output embedding. """ - def __init__(self, embedding_dim: int): + def __init__(self, embedding_dim: int, use_meanflow: bool = False): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.use_meanflow = use_meanflow + self.time_proj_r = None + self.timestep_embedder_r = None + if use_meanflow: + self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + def forward( self, timestep: torch.Tensor, + timestep_r: Optional[torch.Tensor] = None, ) -> torch.Tensor: timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype)) + if timestep_r is not None: + timesteps_proj_r = self.time_proj_r(timestep_r) + timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype)) + timesteps_emb = timesteps_emb + timesteps_emb_r + return timesteps_emb @@ -567,6 +580,7 @@ def __init__( # YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205 target_size: int = 640, # did not name sample_size since it is in pixel spaces task_type: str = "i2v", + use_meanflow: bool = False, ) -> None: super().__init__() @@ -582,7 +596,7 @@ def __init__( ) self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim) - self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim) + self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow=use_meanflow) self.cond_type_embed = nn.Embedding(3, inner_dim) @@ -612,6 +626,7 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, + timestep_r: Optional[torch.LongTensor] = None, encoder_hidden_states_2: Optional[torch.Tensor] = None, encoder_attention_mask_2: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, @@ -643,7 +658,7 @@ def forward( image_rotary_emb = self.rope(hidden_states) # 2. Conditional embeddings - temb = self.time_embed(timestep) + temb = self.time_embed(timestep, timestep_r=timestep_r) hidden_states = self.x_embedder(hidden_states) diff --git a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py index 9e9f20c79eba..8c555eabba11 100644 --- a/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py @@ -852,6 +852,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + if self.transformer.config.use_meanflow: + if i == len(timesteps) - 1: + timestep_r = torch.tensor([0.0], device=device) + else: + timestep_r = timesteps[i + 1] + timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype) + else: + timestep_r = None + # Step 1: Collect model inputs needed for the guidance method # conditional inputs should always be first element in the tuple guider_inputs = { @@ -893,6 +902,7 @@ def __call__( hidden_states=latent_model_input, image_embeds=image_embeds, timestep=timestep, + timestep_r=timestep_r, attention_kwargs=self.attention_kwargs, return_dict=False, **cond_kwargs, From 9cc3601a3acdfdd0911ead713cb9e1eafc3a17de Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 7 Dec 2025 11:40:39 +0100 Subject: [PATCH 2/2] style --- scripts/convert_hunyuan_video1_5_to_diffusers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/convert_hunyuan_video1_5_to_diffusers.py b/scripts/convert_hunyuan_video1_5_to_diffusers.py index 713f1d620e4c..89e5cdb16956 100644 --- a/scripts/convert_hunyuan_video1_5_to_diffusers.py +++ b/scripts/convert_hunyuan_video1_5_to_diffusers.py @@ -154,15 +154,18 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=None) converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias") if config.use_meanflow: - converted_state_dict["time_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop( "time_r_in.mlp.0.weight" ) - converted_state_dict["time_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop("time_r_in.mlp.0.bias") + converted_state_dict["time_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop( + "time_r_in.mlp.0.bias" + ) converted_state_dict["time_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop( "time_r_in.mlp.2.weight" ) - converted_state_dict["time_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop("time_r_in.mlp.2.bias") + converted_state_dict["time_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop( + "time_r_in.mlp.2.bias" + ) # 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (