Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions scripts/convert_hunyuan_video1_5_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@
"target_size": 960,
"task_type": "i2v",
},
"480p_i2v_step_distilled": {
"target_size": 640,
"task_type": "i2v",
"use_meanflow": True,
},
}

SCHEDULER_CONFIGS = {
Expand All @@ -93,6 +98,9 @@
"720p_i2v_distilled": {
"shift": 7.0,
},
"480p_i2v_step_distilled": {
"shift": 7.0,
},
}

GUIDANCE_CONFIGS = {
Expand All @@ -117,6 +125,9 @@
"720p_i2v_distilled": {
"guidance_scale": 1.0,
},
"480p_i2v_step_distilled": {
"guidance_scale": 1.0,
},
}


Expand All @@ -126,7 +137,7 @@ def swap_scale_shift(weight):
return new_weight


def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
def convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=None):
"""
Convert HunyuanVideo 1.5 original checkpoint to Diffusers format.
"""
Expand All @@ -142,6 +153,20 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
)
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias")

if config.use_meanflow:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe slightly better?

Suggested change
if config.use_meanflow:
if config is not None and getattr(config, "use_meanflow", False):

converted_state_dict["time_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop(
"time_r_in.mlp.0.weight"
)
converted_state_dict["time_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop(
"time_r_in.mlp.0.bias"
)
converted_state_dict["time_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop(
"time_r_in.mlp.2.weight"
)
converted_state_dict["time_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop(
"time_r_in.mlp.2.bias"
)

# 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
Expand Down Expand Up @@ -627,7 +652,7 @@ def convert_transformer(args):
config = TRANSFORMER_CONFIGS[args.transformer_type]
with init_empty_weights():
transformer = HunyuanVideo15Transformer3DModel(**config)
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict)
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=transformer.config)
transformer.load_state_dict(state_dict, strict=True, assign=True)

return transformer
Expand Down
21 changes: 18 additions & 3 deletions src/diffusers/models/transformers/transformer_hunyuan_video15.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,32 @@ class HunyuanVideo15TimeEmbedding(nn.Module):
The dimension of the output embedding.
"""

def __init__(self, embedding_dim: int):
def __init__(self, embedding_dim: int, use_meanflow: bool = False):
super().__init__()

self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)

self.use_meanflow = use_meanflow
self.time_proj_r = None
self.timestep_embedder_r = None
if use_meanflow:
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): Maybe time_proj_meanflow and timestep_embedder_meanflow are better names?

self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)

def forward(
self,
timestep: torch.Tensor,
timestep_r: Optional[torch.Tensor] = None,
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))

if timestep_r is not None:
timesteps_proj_r = self.time_proj_r(timestep_r)
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
timesteps_emb = timesteps_emb + timesteps_emb_r

return timesteps_emb


Expand Down Expand Up @@ -567,6 +580,7 @@ def __init__(
# YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205
target_size: int = 640, # did not name sample_size since it is in pixel spaces
task_type: str = "i2v",
use_meanflow: bool = False,
) -> None:
super().__init__()

Expand All @@ -582,7 +596,7 @@ def __init__(
)
self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim)

self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim)
self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow=use_meanflow)

self.cond_type_embed = nn.Embedding(3, inner_dim)

Expand Down Expand Up @@ -612,6 +626,7 @@ def forward(
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
timestep_r: Optional[torch.LongTensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
encoder_attention_mask_2: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -643,7 +658,7 @@ def forward(
image_rotary_emb = self.rope(hidden_states)

# 2. Conditional embeddings
temb = self.time_embed(timestep)
temb = self.time_embed(timestep, timestep_r=timestep_r)

hidden_states = self.x_embedder(hidden_states)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,15 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)

if self.transformer.config.use_meanflow:
if i == len(timesteps) - 1:
timestep_r = torch.tensor([0.0], device=device)
else:
timestep_r = timesteps[i + 1]
timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
else:
timestep_r = None

# Step 1: Collect model inputs needed for the guidance method
# conditional inputs should always be first element in the tuple
guider_inputs = {
Expand Down Expand Up @@ -893,6 +902,7 @@ def __call__(
hidden_states=latent_model_input,
image_embeds=image_embeds,
timestep=timestep,
timestep_r=timestep_r,
attention_kwargs=self.attention_kwargs,
return_dict=False,
**cond_kwargs,
Expand Down