From 7ba78d5ce5142feb277c23f7df50af4215a7507e Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 25 Jul 2025 13:29:11 +0200 Subject: [PATCH 01/12] support wan 2.2 i2v --- scripts/convert_wan_to_diffusers.py | 68 ++++++++++++++++- .../pipelines/wan/pipeline_wan_i2v.py | 76 +++++++++++++++---- 2 files changed, 128 insertions(+), 16 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 6d25cde071b1..a69a1a140a78 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -278,16 +278,62 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-I2V-14B-720p": + config = { + "model_id": "Wan-AI/Wan2.2-I2V-A14B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-T2V-A14B": + config = { + "model_id": "Wan-AI/Wan2.2-T2V-A14B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + return config, RENAME_DICT, SPECIAL_KEYS_REMAP return config, RENAME_DICT, SPECIAL_KEYS_REMAP -def convert_transformer(model_type: str): +def convert_transformer(model_type: str, stage: str=None): config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type) diffusers_config = config["diffusers_config"] model_id = config["model_id"] model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model")) + if stage is not None: + model_dir = model_dir / stage + original_state_dict = load_sharded_safetensors(model_dir) with init_empty_weights(): @@ -533,7 +579,13 @@ def get_args(): if __name__ == "__main__": args = get_args() - transformer = convert_transformer(args.model_type) + if "Wan2.2" in args.model_type: + transformer = convert_transformer(args.model_type, stage="high_noise_model") + transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") + else: + transformer = convert_transformer(args.model_type) + transformer_2 = None + vae = convert_vae() text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") @@ -547,7 +599,17 @@ def get_args(): dtype = DTYPE_MAPPING[args.dtype] transformer.to(dtype) - if "I2V" in args.model_type or "FLF2V" in args.model_type: + if "Wan2.2" and "I2V" in args.model_type: + pipe = WanImageToVideoPipeline( + transformer=transformer, + transformer_2=transformer_2, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + boundary_ratio=0.9, + ) + elif "I2V" in args.model_type or "FLF2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 ) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index c71138a97dd9..c2299a2e4650 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -149,20 +149,32 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. + In two-stage denoising, `transformer` handles high-noise stages + and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. + When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep. + If `None`, only `transformer` is used for the entire denoising process. """ - model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer_2", "image_encoder", "image_processor"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModel, - image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + image_processor: CLIPImageProcessor=None, + image_encoder: CLIPVisionModel=None, + transformer_2: WanTransformer3DModel=None, + boundary_ratio: Optional[float] = None, ): super().__init__() @@ -174,7 +186,9 @@ def __init__( transformer=transformer, scheduler=scheduler, image_processor=image_processor, + transformer_2=transformer_2, ) + self.register_to_config(boundary_ratio=boundary_ratio) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -325,6 +339,7 @@ def check_inputs( negative_prompt_embeds=None, image_embeds=None, callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, ): if image is not None and image_embeds is not None: raise ValueError( @@ -368,6 +383,12 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if self.config.boundary_ratio is not None and image_embeds is not None: + raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") + def prepare_latents( self, image: PipelineImageInput, @@ -483,6 +504,7 @@ def __call__( num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -527,6 +549,9 @@ def __call__( of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, + uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -589,6 +614,7 @@ def __call__( negative_prompt_embeds, image_embeds, callback_on_step_end_tensor_inputs, + guidance_scale_2, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -598,7 +624,12 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -631,13 +662,15 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - if image_embeds is None: - if last_image is None: - image_embeds = self.encode_image(image, device) - else: - image_embeds = self.encode_image([image, last_image], device) - image_embeds = image_embeds.repeat(batch_size, 1, 1) - image_embeds = image_embeds.to(transformer_dtype) + + if self.config.boundary_ratio is None: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -668,16 +701,33 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( + noise_pred = current_model( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, @@ -687,7 +737,7 @@ def __call__( )[0] if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + noise_uncond = current_model( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, @@ -695,7 +745,7 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] From f5da83c4a3f174edec084f9dc5897e7a870ceb1f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 27 Jul 2025 08:42:56 +0200 Subject: [PATCH 02/12] add t2v + vae2.2 --- .../models/autoencoders/autoencoder_kl_wan.py | 386 ++++++++++++++++-- src/diffusers/pipelines/wan/pipeline_wan.py | 74 +++- 2 files changed, 404 insertions(+), 56 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 49cefcd8a142..8ecde415dc25 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -34,6 +34,104 @@ CACHE_T = 2 +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1:, :, :] + return x + class WanCausalConv3d(nn.Conv3d): r""" A custom 3D causal convolution layer with feature caching support. @@ -134,19 +232,23 @@ class WanResample(nn.Module): - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. """ - def __init__(self, dim: int, mode: str) -> None: + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: super().__init__() self.dim = dim self.mode = mode + + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 # layers if mode == "upsample2d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) ) elif mode == "upsample3d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) ) self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -363,6 +465,48 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): return x +class WanResidualDownBlock(nn.Module): + + def __init__(self, + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=False, + down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for resnet in self.resnets: + x = resnet(x, feat_cache, feat_idx) + if self.downsampler is not None: + x = self.downsampler(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + class WanEncoder3d(nn.Module): r""" A 3D encoder module. @@ -380,6 +524,7 @@ class WanEncoder3d(nn.Module): def __init__( self, + in_channels: int = 3, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], @@ -388,6 +533,7 @@ def __init__( temperal_downsample=[True, True, False], dropout=0.0, non_linearity: str = "silu", + is_residual: bool = False, # wan 2.2 vae use a residual downblock ): super().__init__() self.dim = dim @@ -403,23 +549,35 @@ def __init__( scale = 1.0 # init block - self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1) + self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) # downsample blocks self.down_blocks = nn.ModuleList([]) for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks - for _ in range(num_res_blocks): - self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - self.down_blocks.append(WanAttentionBlock(out_dim)) - in_dim = out_dim - - # downsample block - if i != len(dim_mult) - 1: - mode = "downsample3d" if temperal_downsample[i] else "downsample2d" - self.down_blocks.append(WanResample(out_dim, mode=mode)) - scale /= 2.0 + if is_residual: + self.down_blocks.append( + WanResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 # middle blocks self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) @@ -469,6 +627,92 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): x = self.conv_out(x) return x +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim) + else: + self.upsampler = None + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + x_copy = x.clone() + + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsampler is not None: + if feat_cache is not None: + x = self.upsampler(x, feat_cache, feat_idx) + else: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk) + + return x class WanUpBlock(nn.Module): """ @@ -513,7 +757,7 @@ def __init__( self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): """ Forward pass through the upsampling block. @@ -564,6 +808,8 @@ def __init__( temperal_upsample=[False, True, True], dropout=0.0, non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, ): super().__init__() self.dim = dim @@ -577,7 +823,6 @@ def __init__( # dimensions dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2 ** (len(dim_mult) - 2) # init block self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) @@ -589,36 +834,47 @@ def __init__( self.up_blocks = nn.ModuleList([]) for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): # residual (+attention) blocks - if i > 0: + if i > 0 and not is_residual: + # wan vae 2.1 in_dim = in_dim // 2 - # Determine if we need upsampling + # determine if we need upsampling + up_flag = i != len(dim_mult) - 1 + # determine upsampling mode, if not upsampling, set to None upsample_mode = None - if i != len(dim_mult) - 1: - upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" - + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" # Create and add the upsampling block - up_block = WanUpBlock( - in_dim=in_dim, - out_dim=out_dim, - num_res_blocks=num_res_blocks, - dropout=dropout, - upsample_mode=upsample_mode, - non_linearity=non_linearity, - ) + if is_residual: + up_block = WanResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag= up_flag, + non_linearity=non_linearity, + ) + else: + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) self.up_blocks.append(up_block) - # Update scale for next iteration - if upsample_mode is not None: - scale *= 2.0 - # output blocks self.norm_out = WanRMS_norm(out_dim, images=False) - self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1) + self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) self.gradient_checkpointing = False - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): ## conv1 if feat_cache is not None: idx = feat_idx[0] @@ -637,7 +893,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): ## upsamples for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx) + x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk) ## head x = self.norm_out(x) @@ -656,6 +912,44 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): return x +# YiYi TODO: refactor this +from einops import rearrange + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. @@ -671,6 +965,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): def __init__( self, base_dim: int = 96, + decoder_base_dim: Optional[int] = None, z_dim: int = 16, dim_mult: Tuple[int] = [1, 2, 4, 4], num_res_blocks: int = 2, @@ -713,6 +1008,10 @@ def __init__( 2.8251, 1.9160, ], + is_residual: bool = False, + in_channels: int = 3, + out_channels: int = 3, + patch_size: Optional[int] = None, ) -> None: super().__init__() @@ -720,14 +1019,17 @@ def __init__( self.temperal_downsample = temperal_downsample self.temperal_upsample = temperal_downsample[::-1] + if decoder_base_dim is None: + decoder_base_dim = base_dim + self.encoder = WanEncoder3d( - base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual ) self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) self.decoder = WanDecoder3d( - base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual ) self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) @@ -827,6 +1129,8 @@ def _encode(self, x: torch.Tensor): return self.tiled_encode(x) self.clear_cache() + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) iter_ = 1 + (num_frame - 1) // 4 for i in range(iter_): self._enc_conv_idx = [0] @@ -884,12 +1188,14 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): for i in range(num_frame): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True) else: out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) out = torch.clamp(out, min=-1.0, max=1.0) + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) self.clear_cache() if not return_dict: return (out,) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index d14dac91f14a..748b20c11238 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -112,10 +112,21 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + transformer_2 ([`WanTransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. + If provided, enables two-stage denoising where `transformer` handles high-noise stages + and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. + When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep. + If `None`, only `transformer` is used for the entire denoising process. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _optional_components = ["transformer_2"] + def __init__( self, @@ -124,6 +135,8 @@ def __init__( transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, ): super().__init__() @@ -133,8 +146,9 @@ def __init__( tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, + transformer_2=transformer_2, ) - + self.register_to_config(boundary_ratio=boundary_ratio) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -270,6 +284,7 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -301,6 +316,9 @@ def check_inputs( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.config.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") def prepare_latents( self, @@ -369,6 +387,7 @@ def __call__( num_frames: int = 81, num_inference_steps: int = 50, guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -407,6 +426,9 @@ def __call__( of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, + uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -461,6 +483,7 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + guidance_scale_2, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -470,7 +493,11 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + if self.config.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 self._attention_kwargs = attention_kwargs self._current_timestep = None self._interrupt = False @@ -524,34 +551,49 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) + if self.config.boundary_ratio is not None: + boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( + #with current_model.cache_context("cond"): + noise_pred = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + #with current_model.cache_context("uncond"): + noise_uncond = current_model( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - - if self.do_classifier_free_guidance: - with self.transformer.cache_context("uncond"): - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] From 95a55f9dc0c14db8b4d10798e3ff51b3d53f32a4 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 27 Jul 2025 08:50:32 +0200 Subject: [PATCH 03/12] add conversion script for vae 2.2 --- scripts/convert_wan_to_diffusers.py | 315 +++++++++++++++++++++++++++- 1 file changed, 314 insertions(+), 1 deletion(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index a69a1a140a78..6964cd09246a 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -560,6 +560,305 @@ def convert_vae(): vae.load_state_dict(new_state_dict, strict=True, assign=True) return vae +vae22_diffusers_config = { + "base_dim": 160, + "z_dim": 48, + "is_residual": True, + "in_channels": 12, + "out_channels": 12, + "decoder_base_dim": 256, + "latents_mean":[ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + -0.0174, + -0.1838, + -0.1557, + -0.1382, + -0.0542, + -0.2813, + -0.0891, + -0.1570, + -0.0098, + -0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + -0.5109, + -0.2665, + -0.2108, + -0.2158, + -0.2502, + -0.2055, + -0.0322, + -0.1109, + -0.1567, + -0.0729, + -0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + -0.0117, + -0.0723, + -0.2839, + -0.2083, + -0.0520, + -0.3748, + -0.0152, + -0.1957, + -0.1433, + -0.2944, + -0.3573, + -0.0548, + -0.1681, + -0.0667, + ], + "latents_std":[ + -0.4765, + -1.0364, + -0.4514, + -1.1677, + -0.5313, + -0.4990, + -0.4818, + -0.5013, + -0.8158, + -1.0344, + -0.5894, + -1.0901, + -0.6885, + -0.6165, + -0.8454, + -0.4978, + -0.5759, + -0.3523, + -0.7135, + -0.6804, + -0.5833, + -1.4146, + -0.8986, + -0.5659, + -0.7069, + -0.5338, + -0.4889, + -0.4917, + -0.4069, + -0.4999, + -0.6866, + -0.4093, + -0.5709, + -0.6065, + -0.6415, + -0.4944, + -0.5726, + -1.2042, + -0.5458, + -1.6887, + -0.3971, + -1.0600, + -0.3943, + -0.5537, + -0.5444, + -0.4089, + -0.7468, + -0.7744, + ], +} + + +def convert_vae_22(): + vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth") + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in old_state_dict.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + new_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + new_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + new_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + new_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Change encoder.downsamples to encoder.down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Handle residual blocks - change downsamples to resnets and rename components + if "residual" in new_key or "shortcut" in new_key: + # Change the second downsamples to resnets + new_key = new_key.replace(".downsamples.", ".resnets.") + + # Rename residual components + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + + # Handle resample blocks - change downsamples to downsampler and remove index + elif "resample" in new_key or "time_conv" in new_key: + # Change the second downsamples to downsampler and remove the index + parts = new_key.split(".") + # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample... + # We want to change it to: encoder.down_blocks.X.downsampler.resample... + if len(parts) >= 4 and parts[3] == "downsamples": + # Remove the index (parts[4]) and change downsamples to downsampler + new_parts = parts[:3] + ["downsampler"] + parts[5:] + new_key = ".".join(new_parts) + + new_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Change decoder.upsamples to decoder.up_blocks + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + # Handle residual blocks - change upsamples to resnets and rename components + if "residual" in new_key or "shortcut" in new_key: + # Change the second upsamples to resnets + new_key = new_key.replace(".upsamples.", ".resnets.") + + # Rename residual components + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + + # Handle resample blocks - change upsamples to upsampler and remove index + elif "resample" in new_key or "time_conv" in new_key: + # Change the second upsamples to upsampler and remove the index + parts = new_key.split(".") + # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample... + # We want to change it to: encoder.down_blocks.X.downsampler.resample... + if len(parts) >= 4 and parts[3] == "upsamples": + # Remove the index (parts[4]) and change upsamples to upsampler + new_parts = parts[:3] + ["upsampler"] + parts[5:] + new_key = ".".join(new_parts) + + new_state_dict[new_key] = value + else: + # Keep other keys unchanged + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan(**vae22_config) + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + def get_args(): parser = argparse.ArgumentParser() @@ -586,7 +885,11 @@ def get_args(): transformer = convert_transformer(args.model_type) transformer_2 = None - vae = convert_vae() + if "Wan2.2" in args.model_type and "TI2V" in args.model_type: + vae = convert_vae_22() + else: + vae = convert_vae() + text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0 @@ -609,6 +912,16 @@ def get_args(): scheduler=scheduler, boundary_ratio=0.9, ) + elif "Wan2.2" and "T2V" in args.model_type: + pipe = WanPipeline( + transformer=transformer, + transformer_2=transformer_2, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + boundary_ratio=0.875, + ) elif "I2V" in args.model_type or "FLF2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 From bf2c6e0a0846b7e7334015e5923d30745e5a7b3b Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Jul 2025 12:22:09 +0200 Subject: [PATCH 04/12] add --- scripts/convert_wan_to_diffusers.py | 196 +++++++++++------- .../models/autoencoders/autoencoder_kl_wan.py | 6 +- .../models/transformers/transformer_wan.py | 51 ++++- src/diffusers/pipelines/wan/pipeline_wan.py | 43 ++-- 4 files changed, 193 insertions(+), 103 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 6964cd09246a..0a46ae80f097 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -320,7 +320,27 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: } RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP - return config, RENAME_DICT, SPECIAL_KEYS_REMAP + elif model_type == "Wan2.2-TI2V-5B": + config = { + "model_id": "Wan-AI/Wan2.2-TI2V-5B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_channels": 48, + "num_attention_heads": 24, + "num_layers": 30, + "out_channels": 48, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP return config, RENAME_DICT, SPECIAL_KEYS_REMAP @@ -567,106 +587,110 @@ def convert_vae(): "in_channels": 12, "out_channels": 12, "decoder_base_dim": 256, + "scale_factor_temporal": 4, + "scale_factor_spatial": 16, + "patch_size": 2, "latents_mean":[ - -0.2289, - -0.0052, - -0.1323, - -0.2339, + -0.2289, + -0.0052, + -0.1323, + -0.2339, -0.2799, - -0.0174, - -0.1838, - -0.1557, + 0.0174, + 0.1838, + 0.1557, -0.1382, - -0.0542, - -0.2813, - -0.0891, - -0.1570, + 0.0542, + 0.2813, + 0.0891, + 0.1570, -0.0098, - -0.0375, + 0.0375, -0.1825, -0.2246, -0.1207, -0.0698, - -0.5109, - -0.2665, + 0.5109, + 0.2665, -0.2108, -0.2158, - -0.2502, + 0.2502, -0.2055, -0.0322, - -0.1109, - -0.1567, + 0.1109, + 0.1567, -0.0729, - -0.0899, + 0.0899, -0.2799, -0.1230, -0.0313, -0.1649, - -0.0117, - -0.0723, + 0.0117, + 0.0723, -0.2839, -0.2083, -0.0520, - -0.3748, - -0.0152, - -0.1957, - -0.1433, + 0.3748, + 0.0152, + 0.1957, + 0.1433, -0.2944, - -0.3573, + 0.3573, -0.0548, -0.1681, -0.0667, ], - "latents_std":[ - -0.4765, - -1.0364, - -0.4514, - -1.1677, - -0.5313, - -0.4990, - -0.4818, - -0.5013, - -0.8158, - -1.0344, - -0.5894, - -1.0901, - -0.6885, - -0.6165, - -0.8454, - -0.4978, - -0.5759, - -0.3523, - -0.7135, - -0.6804, - -0.5833, - -1.4146, - -0.8986, - -0.5659, - -0.7069, - -0.5338, - -0.4889, - -0.4917, - -0.4069, - -0.4999, - -0.6866, - -0.4093, - -0.5709, - -0.6065, - -0.6415, - -0.4944, - -0.5726, - -1.2042, - -0.5458, - -1.6887, - -0.3971, - -1.0600, - -0.3943, - -0.5537, - -0.5444, - -0.4089, - -0.7468, - -0.7744, + "latents_std": [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, ], + "clip_output": False, } @@ -855,7 +879,7 @@ def convert_vae_22(): new_state_dict[key] = value with init_empty_weights(): - vae = AutoencoderKLWan(**vae22_config) + vae = AutoencoderKLWan(**vae22_diffusers_config) vae.load_state_dict(new_state_dict, strict=True, assign=True) return vae @@ -878,7 +902,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - if "Wan2.2" in args.model_type: + if "Wan2.2" in args.model_type and "TI2V" not in args.model_type: transformer = convert_transformer(args.model_type, stage="high_noise_model") transformer_2 = convert_transformer(args.model_type, stage="low_noise_model") else: @@ -892,7 +916,12 @@ def get_args(): text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") - flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0 + if "FLF2V" in args.model_type: + flow_shift = 16.0 + elif "TI2V" in args.model_type: + flow_shift = 5.0 + else: + flow_shift = 3.0 scheduler = UniPCMultistepScheduler( prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift ) @@ -902,7 +931,7 @@ def get_args(): dtype = DTYPE_MAPPING[args.dtype] transformer.to(dtype) - if "Wan2.2" and "I2V" in args.model_type: + if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type: pipe = WanImageToVideoPipeline( transformer=transformer, transformer_2=transformer_2, @@ -922,6 +951,15 @@ def get_args(): scheduler=scheduler, boundary_ratio=0.875, ) + elif "Wan2.2" and "TI2V" in args.model_type: + pipe = WanPipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + expand_timesteps=True, + ) elif "I2V" in args.model_type or "FLF2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 8ecde415dc25..0b1acadf9961 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1012,6 +1012,9 @@ def __init__( in_channels: int = 3, out_channels: int = 3, patch_size: Optional[int] = None, + scale_factor_temporal: Optional[int] = 4, + scale_factor_spatial: Optional[int] = 8, + clip_output: bool = True, ) -> None: super().__init__() @@ -1193,7 +1196,8 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) - out = torch.clamp(out, min=-1.0, max=1.0) + if self.config.clip_output: + out = torch.clamp(out, min=-1.0, max=1.0) if self.config.patch_size is not None: out = unpatchify(out, patch_size=self.config.patch_size) self.clear_cache() diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index bdb9201e62cf..621c8e31cd93 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -170,8 +170,11 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, - ): + timestep_seq_len: Optional[int] = None, + ): timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (1, timestep_seq_len)) time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: @@ -309,9 +312,24 @@ def forward( temb: torch.Tensor, rotary_emb: torch.Tensor, ) -> torch.Tensor: - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table + temb.float() - ).chunk(6, dim=1) + + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) @@ -469,10 +487,22 @@ def forward( hidden_states = self.patch_embedding(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.ndim == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( - timestep, encoder_hidden_states, encoder_hidden_states_image + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len ) - timestep_proj = timestep_proj.unflatten(1, (6, -1)) + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) @@ -488,7 +518,14 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + if temb.ndim ==3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) # Move the shift and scale tensors to the same device as hidden_states. # When using multi-GPU inference via accelerate these will be on the diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 748b20c11238..28bf37e39462 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -137,6 +137,7 @@ def __init__( scheduler: FlowMatchEulerDiscreteScheduler, transformer_2: Optional[WanTransformer3DModel] = None, boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, # Wan2.2 ti2v ): super().__init__() @@ -149,8 +150,9 @@ def __init__( transformer_2=transformer_2, ) self.register_to_config(boundary_ratio=boundary_ratio) - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.register_to_config(expand_timesteps=expand_timesteps) + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_t5_prompt_embeds( @@ -547,6 +549,9 @@ def __call__( latents, ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -573,26 +578,32 @@ def __call__( current_guidance_scale = guidance_scale_2 latent_model_input = latents.to(transformer_dtype) - timestep = t.expand(latents.shape[0]) - - #with current_model.cache_context("cond"): - noise_pred = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + if self.config.expand_timesteps: + # seq_len: num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + # batch_size, seq_len + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) - if self.do_classifier_free_guidance: - #with current_model.cache_context("uncond"): - noise_uncond = current_model( + with current_model.cache_context("cond"): + noise_pred = current_model( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 From 27ce75b9847d75b4febba15242daf9eae476b049 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Jul 2025 12:28:23 +0200 Subject: [PATCH 05/12] add 5b t2v --- .../models/autoencoders/autoencoder_kl_wan.py | 23 ++++++++++--------- .../models/transformers/transformer_wan.py | 4 ++-- src/diffusers/pipelines/wan/pipeline_wan.py | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 0b1acadf9961..5ff969c5ee42 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -236,7 +236,7 @@ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: super().__init__() self.dim = dim self.mode = mode - + # default to dim //2 if upsample_out_dim is None: upsample_out_dim = dim // 2 @@ -524,7 +524,7 @@ class WanEncoder3d(nn.Module): def __init__( self, - in_channels: int = 3, + in_channels: int = 3, dim=128, z_dim=4, dim_mult=[1, 2, 4, 4], @@ -558,10 +558,10 @@ def __init__( if is_residual: self.down_blocks.append( WanResidualDownBlock( - in_dim, - out_dim, - dropout, - num_res_blocks, + in_dim, + out_dim, + dropout, + num_res_blocks, temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, down_flag=i != len(dim_mult) - 1, ) @@ -708,10 +708,10 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): x = self.upsampler(x, feat_cache, feat_idx) else: x = self.upsampler(x) - + if self.avg_shortcut is not None: x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk) - + return x class WanUpBlock(nn.Module): @@ -912,10 +912,9 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): return x -# YiYi TODO: refactor this -from einops import rearrange - def patchify(x, patch_size): + # YiYi TODO: refactor this + from einops import rearrange if patch_size == 1: return x if x.dim() == 4: @@ -935,6 +934,8 @@ def patchify(x, patch_size): def unpatchify(x, patch_size): + # YiYi TODO: refactor this + from einops import rearrange if patch_size == 1: return x diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 621c8e31cd93..7352a8a21a98 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -171,7 +171,7 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, timestep_seq_len: Optional[int] = None, - ): + ): timestep = self.timesteps_proj(timestep) if timestep_seq_len is not None: timestep = timestep.unflatten(0, (1, timestep_seq_len)) @@ -518,7 +518,7 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify - if temb.ndim ==3: + if temb.ndim ==3: # batch_size, seq_len, inner_dim (wan 2.2 ti2v) shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) shift = shift.squeeze(2) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 28bf37e39462..65e4c1c344f9 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -318,7 +318,7 @@ def check_inputs( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - + if self.config.boundary_ratio is None and guidance_scale_2 is not None: raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") From 5709f7e04d29569f5ee19f7aa60d185699c7e207 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Jul 2025 12:28:42 +0200 Subject: [PATCH 06/12] conversion script --- scripts/convert_wan_to_diffusers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 0a46ae80f097..e6a09e0d980f 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -764,7 +764,7 @@ def convert_vae_22(): "conv2.weight": "post_quant_conv.weight", "conv2.bias": "post_quant_conv.bias", } - + # Process each key in the state dict for key, value in old_state_dict.items(): # Handle middle block keys using the mapping @@ -797,12 +797,12 @@ def convert_vae_22(): elif key.startswith("encoder.downsamples."): # Change encoder.downsamples to encoder.down_blocks new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") - + # Handle residual blocks - change downsamples to resnets and rename components if "residual" in new_key or "shortcut" in new_key: # Change the second downsamples to resnets new_key = new_key.replace(".downsamples.", ".resnets.") - + # Rename residual components if ".residual.0.gamma" in new_key: new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") @@ -820,7 +820,7 @@ def convert_vae_22(): new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") elif ".shortcut.bias" in new_key: new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") - + # Handle resample blocks - change downsamples to downsampler and remove index elif "resample" in new_key or "time_conv" in new_key: # Change the second downsamples to downsampler and remove the index @@ -831,19 +831,19 @@ def convert_vae_22(): # Remove the index (parts[4]) and change downsamples to downsampler new_parts = parts[:3] + ["downsampler"] + parts[5:] new_key = ".".join(new_parts) - + new_state_dict[new_key] = value # Handle decoder upsamples elif key.startswith("decoder.upsamples."): # Change decoder.upsamples to decoder.up_blocks new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - + # Handle residual blocks - change upsamples to resnets and rename components if "residual" in new_key or "shortcut" in new_key: # Change the second upsamples to resnets new_key = new_key.replace(".upsamples.", ".resnets.") - + # Rename residual components if ".residual.0.gamma" in new_key: new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") @@ -861,7 +861,7 @@ def convert_vae_22(): new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") elif ".shortcut.bias" in new_key: new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") - + # Handle resample blocks - change upsamples to upsampler and remove index elif "resample" in new_key or "time_conv" in new_key: # Change the second upsamples to upsampler and remove the index @@ -872,7 +872,7 @@ def convert_vae_22(): # Remove the index (parts[4]) and change upsamples to upsampler new_parts = parts[:3] + ["upsampler"] + parts[5:] new_key = ".".join(new_parts) - + new_state_dict[new_key] = value else: # Keep other keys unchanged From 0ab2e4fd9156652083e445231df3eb0cdd90410d Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Jul 2025 20:07:31 +0200 Subject: [PATCH 07/12] refactor out reearrange --- scripts/convert_wan_to_diffusers.py | 217 +++++++++--------- .../models/autoencoders/autoencoder_kl_wan.py | 129 +++++++---- .../models/transformers/transformer_wan.py | 5 +- src/diffusers/pipelines/wan/pipeline_wan.py | 21 +- .../pipelines/wan/pipeline_wan_i2v.py | 27 +-- 5 files changed, 221 insertions(+), 178 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index e6a09e0d980f..599c90be57ce 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -344,7 +344,7 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: return config, RENAME_DICT, SPECIAL_KEYS_REMAP -def convert_transformer(model_type: str, stage: str=None): +def convert_transformer(model_type: str, stage: str = None): config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type) diffusers_config = config["diffusers_config"] @@ -580,115 +580,116 @@ def convert_vae(): vae.load_state_dict(new_state_dict, strict=True, assign=True) return vae + vae22_diffusers_config = { - "base_dim": 160, - "z_dim": 48, - "is_residual": True, - "in_channels": 12, - "out_channels": 12, - "decoder_base_dim": 256, - "scale_factor_temporal": 4, - "scale_factor_spatial": 16, - "patch_size": 2, - "latents_mean":[ - -0.2289, - -0.0052, - -0.1323, - -0.2339, - -0.2799, - 0.0174, - 0.1838, - 0.1557, - -0.1382, - 0.0542, - 0.2813, - 0.0891, - 0.1570, - -0.0098, - 0.0375, - -0.1825, - -0.2246, - -0.1207, - -0.0698, - 0.5109, - 0.2665, - -0.2108, - -0.2158, - 0.2502, - -0.2055, - -0.0322, - 0.1109, - 0.1567, - -0.0729, - 0.0899, - -0.2799, - -0.1230, - -0.0313, - -0.1649, - 0.0117, - 0.0723, - -0.2839, - -0.2083, - -0.0520, - 0.3748, - 0.0152, - 0.1957, - 0.1433, - -0.2944, - 0.3573, - -0.0548, - -0.1681, - -0.0667, + "base_dim": 160, + "z_dim": 48, + "is_residual": True, + "in_channels": 12, + "out_channels": 12, + "decoder_base_dim": 256, + "scale_factor_temporal": 4, + "scale_factor_spatial": 16, + "patch_size": 2, + "latents_mean": [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, ], - "latents_std": [ - 0.4765, - 1.0364, - 0.4514, - 1.1677, - 0.5313, - 0.4990, - 0.4818, - 0.5013, - 0.8158, - 1.0344, - 0.5894, - 1.0901, - 0.6885, - 0.6165, - 0.8454, - 0.4978, - 0.5759, - 0.3523, - 0.7135, - 0.6804, - 0.5833, - 1.4146, - 0.8986, - 0.5659, - 0.7069, - 0.5338, - 0.4889, - 0.4917, - 0.4069, - 0.4999, - 0.6866, - 0.4093, - 0.5709, - 0.6065, - 0.6415, - 0.4944, - 0.5726, - 1.2042, - 0.5458, - 1.6887, - 0.3971, - 1.0600, - 0.3943, - 0.5537, - 0.5444, - 0.4089, - 0.7468, - 0.7744, + "latents_std": [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, ], "clip_output": False, } diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index 5ff969c5ee42..608de25da598 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -35,7 +35,6 @@ class AvgDown3D(nn.Module): - def __init__( self, in_channels, @@ -89,7 +88,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DupUp3D(nn.Module): - def __init__( self, in_channels: int, @@ -129,9 +127,10 @@ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: x.size(6) * self.factor_s, ) if first_chunk: - x = x[:, :, self.factor_t - 1:, :, :] + x = x[:, :, self.factor_t - 1 :, :, :] return x + class WanCausalConv3d(nn.Conv3d): r""" A custom 3D causal convolution layer with feature caching support. @@ -244,11 +243,13 @@ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: # layers if mode == "upsample2d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), ) elif mode == "upsample3d": self.resample = nn.Sequential( - WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, upsample_out_dim, 3, padding=1) + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), ) self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) @@ -466,14 +467,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): class WanResidualDownBlock(nn.Module): - - def __init__(self, - in_dim, - out_dim, - dropout, - num_res_blocks, - temperal_downsample=False, - down_flag=False): + def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False): super().__init__() # Shortcut path with downsample @@ -507,6 +501,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): return x + self.avg_shortcut(x_copy) + class WanEncoder3d(nn.Module): r""" A 3D encoder module. @@ -533,7 +528,7 @@ def __init__( temperal_downsample=[True, True, False], dropout=0.0, non_linearity: str = "silu", - is_residual: bool = False, # wan 2.2 vae use a residual downblock + is_residual: bool = False, # wan 2.2 vae use a residual downblock ): super().__init__() self.dim = dim @@ -564,8 +559,8 @@ def __init__( num_res_blocks, temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, down_flag=i != len(dim_mult) - 1, - ) - ) + ) + ) else: for _ in range(num_res_blocks): self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) @@ -627,6 +622,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): x = self.conv_out(x) return x + class WanResidualUpBlock(nn.Module): """ A block that handles upsampling for the WanVAE decoder. @@ -714,6 +710,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): return x + class WanUpBlock(nn.Module): """ A block that handles upsampling for the WanVAE decoder. @@ -854,7 +851,7 @@ def __init__( num_res_blocks=num_res_blocks, dropout=dropout, temperal_upsample=temperal_upsample[i] if up_flag else False, - up_flag= up_flag, + up_flag=up_flag, non_linearity=non_linearity, ) else: @@ -893,7 +890,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): ## upsamples for up_block in self.up_blocks: - x = up_block(x, feat_cache, feat_idx, first_chunk = first_chunk) + x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk) ## head x = self.norm_out(x) @@ -913,20 +910,39 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): def patchify(x, patch_size): - # YiYi TODO: refactor this - from einops import rearrange if patch_size == 1: return x + if x.dim() == 4: - x = rearrange( - x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + # x shape: [batch_size, channels, height, width] + batch_size, channels, height, width = x.shape + + # Ensure height and width are divisible by patch_size + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") + + # Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size] + x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size) + + # Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size] + x = x.permute(0, 1, 3, 5, 2, 4).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size) + elif x.dim() == 5: - x = rearrange( - x, - "b c f (h q) (w r) -> b (c r q) f h w", - q=patch_size, - r=patch_size, - ) + # x shape: [batch_size, channels, frames, height, width] + batch_size, channels, frames, height, width = x.shape + + # Ensure height and width are divisible by patch_size + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") + + # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size] + x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size) + + # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size] + x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size) + else: raise ValueError(f"Invalid input shape: {x.shape}") @@ -934,23 +950,36 @@ def patchify(x, patch_size): def unpatchify(x, patch_size): - # YiYi TODO: refactor this - from einops import rearrange if patch_size == 1: return x if x.dim() == 4: - x = rearrange( - x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + # x shape: [b, (c * patch_size * patch_size), h, w] + batch_size, c_patches, height, width = x.shape + channels = c_patches // (patch_size * patch_size) + + # Reshape to [b, c, patch_size, patch_size, h, w] + x = x.view(batch_size, channels, patch_size, patch_size, height, width) + + # Rearrange to [b, c, h * patch_size, w * patch_size] + x = x.permute(0, 1, 4, 2, 5, 3).contiguous() + x = x.view(batch_size, channels, height * patch_size, width * patch_size) + elif x.dim() == 5: - x = rearrange( - x, - "b (c r q) f h w -> b c f (h q) (w r)", - q=patch_size, - r=patch_size, - ) + # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width] + batch_size, c_patches, frames, height, width = x.shape + channels = c_patches // (patch_size * patch_size) + + # Reshape to [b, c, patch_size, patch_size, f, h, w] + x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width) + + # Rearrange to [b, c, f, h * patch_size, w * patch_size] + x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous() + x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size) + return x + class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. @@ -1027,13 +1056,29 @@ def __init__( decoder_base_dim = base_dim self.encoder = WanEncoder3d( - in_channels=in_channels, dim=base_dim, z_dim=z_dim * 2, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_downsample=temperal_downsample, dropout=dropout, is_residual=is_residual + in_channels=in_channels, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, + is_residual=is_residual, ) self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) self.decoder = WanDecoder3d( - dim=decoder_base_dim, z_dim=z_dim, dim_mult=dim_mult, num_res_blocks=num_res_blocks, attn_scales=attn_scales, temperal_upsample=self.temperal_upsample, dropout=dropout, out_channels=out_channels, is_residual=is_residual + dim=decoder_base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temperal_upsample, + dropout=dropout, + out_channels=out_channels, + is_residual=is_residual, ) self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) @@ -1192,7 +1237,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True): for i in range(num_frame): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True) + out = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True + ) else: out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 7352a8a21a98..eddf196718f8 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -312,7 +312,6 @@ def forward( temb: torch.Tensor, rotary_emb: torch.Tensor, ) -> torch.Tensor: - if temb.ndim == 4: # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( @@ -490,7 +489,7 @@ def forward( # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) if timestep.ndim == 2: ts_seq_len = timestep.shape[1] - timestep = timestep.flatten() # batch_size * seq_len + timestep = timestep.flatten() # batch_size * seq_len else: ts_seq_len = None @@ -518,7 +517,7 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) # 5. Output norm, projection & unpatchify - if temb.ndim ==3: + if temb.ndim == 3: # batch_size, seq_len, inner_dim (wan 2.2 ti2v) shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) shift = shift.squeeze(2) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 65e4c1c344f9..f52bf33d810b 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -113,21 +113,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer_2 ([`WanTransformer3DModel`], *optional*): - Conditional Transformer to denoise the input latents during the low-noise stage. - If provided, enables two-stage denoising where `transformer` handles high-noise stages - and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used. + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. boundary_ratio (`float`, *optional*, defaults to `None`): Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. - The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. - When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep. - If `None`, only `transformer` is used for the entire denoising process. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. """ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] _optional_components = ["transformer_2"] - def __init__( self, tokenizer: AutoTokenizer, @@ -137,7 +136,7 @@ def __init__( scheduler: FlowMatchEulerDiscreteScheduler, transformer_2: Optional[WanTransformer3DModel] = None, boundary_ratio: Optional[float] = None, - expand_timesteps: bool = False, # Wan2.2 ti2v + expand_timesteps: bool = False, # Wan2.2 ti2v ): super().__init__() @@ -429,8 +428,9 @@ def __call__( `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. guidance_scale_2 (`float`, *optional*, defaults to `None`): - Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, - uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -549,7 +549,6 @@ def __call__( latents, ) - mask = torch.ones(latents.shape, dtype=torch.float32, device=device) # 6. Denoising loop diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index c2299a2e4650..b075cf5ba014 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -150,14 +150,14 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. transformer_2 ([`WanTransformer3DModel`], *optional*): - Conditional Transformer to denoise the input latents during the low-noise stage. - In two-stage denoising, `transformer` handles high-noise stages - and `transformer_2` handles low-noise stages. If not provided, only `transformer` is used. + Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only + `transformer` is used. boundary_ratio (`float`, *optional*, defaults to `None`): Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. - The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. - When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < boundary_timestep. - If `None`, only `transformer` is used for the entire denoising process. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. """ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae" @@ -171,9 +171,9 @@ def __init__( transformer: WanTransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, - image_processor: CLIPImageProcessor=None, - image_encoder: CLIPVisionModel=None, - transformer_2: WanTransformer3DModel=None, + image_processor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModel = None, + transformer_2: WanTransformer3DModel = None, boundary_ratio: Optional[float] = None, ): super().__init__() @@ -550,8 +550,9 @@ def __call__( `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. guidance_scale_2 (`float`, *optional*, defaults to `None`): - Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's `boundary_ratio` is not None, - uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. + Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's + `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` + and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -624,7 +625,6 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) - if self.config.boundary_ratio is not None and guidance_scale_2 is None: guidance_scale_2 = guidance_scale @@ -662,7 +662,6 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - if self.config.boundary_ratio is None: if image_embeds is None: if last_image is None: @@ -706,7 +705,6 @@ def __call__( else: boundary_timestep = None - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: @@ -723,7 +721,6 @@ def __call__( current_model = self.transformer_2 current_guidance_scale = guidance_scale_2 - latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) From e78a4fa30ab9df012949fc390816b0bed5045fac Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Jul 2025 20:09:10 +0200 Subject: [PATCH 08/12] remove a copied from in skyreels --- src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py | 1 - src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index e742f4419893..8562a5eaf0e6 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -275,7 +275,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs def check_inputs( self, prompt, diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index 12bf727cae63..12be5efeccb2 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -316,7 +316,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs def check_inputs( self, prompt, From 97675c703679584635d08dc29b4ecdfddf57c0a7 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 28 Jul 2025 08:13:16 -1000 Subject: [PATCH 09/12] Apply suggestions from code review Co-authored-by: bagheera <59658056+bghira@users.noreply.github.com> --- src/diffusers/models/transformers/transformer_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index eddf196718f8..2c8b76912b34 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -325,7 +325,7 @@ def forward( c_scale_msa = c_scale_msa.squeeze(2) c_gate_msa = c_gate_msa.squeeze(2) else: - # temb: batch_size, 6, inner_dim + # temb: batch_size, 6, inner_dim (wan2.1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( self.scale_shift_table + temb.float() ).chunk(6, dim=1) From 6bb1677bd46a45d05583c6cda9050dd67dae4adf Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 28 Jul 2025 08:13:59 -1000 Subject: [PATCH 10/12] Update src/diffusers/models/transformers/transformer_wan.py --- src/diffusers/models/transformers/transformer_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 2c8b76912b34..b6c01c13c12f 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -325,7 +325,7 @@ def forward( c_scale_msa = c_scale_msa.squeeze(2) c_gate_msa = c_gate_msa.squeeze(2) else: - # temb: batch_size, 6, inner_dim (wan2.1) + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( self.scale_shift_table + temb.float() ).chunk(6, dim=1) From 1ff7c99596599cfc2df9824c5428074e59dff114 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Jul 2025 21:48:19 +0200 Subject: [PATCH 11/12] fix fast tests --- tests/pipelines/wan/test_wan.py | 17 +++++++ .../pipelines/wan/test_wan_image_to_video.py | 45 +++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py index fdb2d298356e..a7e4e27813b3 100644 --- a/tests/pipelines/wan/test_wan.py +++ b/tests/pipelines/wan/test_wan.py @@ -85,12 +85,29 @@ def get_dummy_components(self): rope_max_seq_len=32, ) + torch.manual_seed(0) + transformer_2 = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + components = { "transformer": transformer, "vae": vae, "scheduler": scheduler, "text_encoder": text_encoder, "tokenizer": tokenizer, + "transformer_2": transformer_2, } return components diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index 6edc0cc882f7..5fb913c2da2e 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -86,6 +86,23 @@ def get_dummy_components(self): image_dim=4, ) + torch.manual_seed(0) + transformer_2 = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + image_dim=4, + ) + torch.manual_seed(0) image_encoder_config = CLIPVisionConfig( hidden_size=4, @@ -109,6 +126,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "image_encoder": image_encoder, "image_processor": image_processor, + "transformer_2": transformer_2, } return components @@ -164,6 +182,10 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): pass + @unittest.skip("TODO: refactor this test: one component can be optional for certain checkpoints but not for others") + def test_save_load_optional_components(self): + pass + class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = WanImageToVideoPipeline @@ -218,6 +240,24 @@ def get_dummy_components(self): pos_embed_seq_len=2 * (4 * 4 + 1), ) + torch.manual_seed(0) + transformer_2 = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + image_dim=4, + pos_embed_seq_len=2 * (4 * 4 + 1), + ) + torch.manual_seed(0) image_encoder_config = CLIPVisionConfig( hidden_size=4, @@ -241,6 +281,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "image_encoder": image_encoder, "image_processor": image_processor, + "transformer_2": transformer_2, } return components @@ -297,3 +338,7 @@ def test_attention_slicing_forward_pass(self): @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass") def test_inference_batch_single_identical(self): pass + + @unittest.skip("TODO: refactor this test: one component can be optional for certain checkpoints but not for others") + def test_save_load_optional_components(self): + pass From 4fd0333b17b35d37b96d88320bca36db23685283 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Jul 2025 21:57:52 +0200 Subject: [PATCH 12/12] style --- tests/pipelines/wan/test_wan_image_to_video.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index 5fb913c2da2e..c693f4fcb247 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -182,7 +182,9 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): pass - @unittest.skip("TODO: refactor this test: one component can be optional for certain checkpoints but not for others") + @unittest.skip( + "TODO: refactor this test: one component can be optional for certain checkpoints but not for others" + ) def test_save_load_optional_components(self): pass @@ -339,6 +341,8 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): pass - @unittest.skip("TODO: refactor this test: one component can be optional for certain checkpoints but not for others") + @unittest.skip( + "TODO: refactor this test: one component can be optional for certain checkpoints but not for others" + ) def test_save_load_optional_components(self): pass