From 84e030cbf342bad4d7acfc4cb4d6313d4f35a25c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 7 Nov 2025 09:20:13 +0100 Subject: [PATCH 01/12] update, remove intermediaate_inputs --- .../modular_pipelines/wan/before_denoise.py | 67 +++------ .../modular_pipelines/wan/decoders.py | 35 ++--- .../modular_pipelines/wan/denoise.py | 127 ++++++++++++------ .../modular_pipelines/wan/encoders.py | 124 ++++++++--------- .../modular_pipelines/wan/modular_pipeline.py | 10 ++ src/diffusers/pipelines/auto_pipeline.py | 22 +++ 6 files changed, 198 insertions(+), 187 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index d48f678edd59..cd46522fc098 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -112,11 +112,6 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [ InputParam("num_videos_per_prompt", default=1), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "prompt_embeds", required=True, @@ -143,18 +138,6 @@ def intermediate_outputs(self) -> List[str]: type_hint=torch.dtype, description="Data type of model tensor inputs (determined by `prompt_embeds`)", ), - OutputParam( - "prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields - description="text embeddings used to guide the image generation", - ), - OutputParam( - "negative_prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields - description="negative text embeddings used to guide the image generation", - ), ] def check_inputs(self, components, block_state): @@ -215,26 +198,16 @@ def inputs(self) -> List[InputParam]: InputParam("sigmas"), ] - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam( - "num_inference_steps", - type_hint=int, - description="The number of denoising steps to perform at inference time", - ), - ] @torch.no_grad() def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( components.scheduler, block_state.num_inference_steps, - block_state.device, + device, block_state.timesteps, block_state.sigmas, ) @@ -246,10 +219,6 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe class WanPrepareLatentsStep(ModularPipelineBlocks): model_name = "wan" - @property - def expected_components(self) -> List[ComponentSpec]: - return [] - @property def description(self) -> str: return "Prepare latents step that prepares the latents for the text-to-video generation process" @@ -262,11 +231,6 @@ def inputs(self) -> List[InputParam]: InputParam("num_frames", type_hint=int), InputParam("latents", type_hint=Optional[torch.Tensor]), InputParam("num_videos_per_prompt", type_hint=int, default=1), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam("generator"), InputParam( "batch_size", @@ -337,27 +301,26 @@ def prepare_latents( @torch.no_grad() def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 # Wan latents should be torch.float32 for best quality block_state.height = block_state.height or components.default_height block_state.width = block_state.width or components.default_width block_state.num_frames = block_state.num_frames or components.default_num_frames - block_state.device = components._execution_device - block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality - block_state.num_channels_latents = components.num_channels_latents - - self.check_inputs(components, block_state) block_state.latents = self.prepare_latents( components, - block_state.batch_size * block_state.num_videos_per_prompt, - block_state.num_channels_latents, - block_state.height, - block_state.width, - block_state.num_frames, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, + batch_size=block_state.batch_size * block_state.num_videos_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + num_frames=block_state.num_frames, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, ) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py index 8c751172d858..bb6f9c147cbd 100644 --- a/src/diffusers/modular_pipelines/wan/decoders.py +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -50,12 +50,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("output_type", default="pil"), - ] - - @property - def intermediate_inputs(self) -> List[str]: return [ InputParam( "latents", @@ -80,24 +74,21 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) vae_dtype = components.vae.dtype - if not block_state.output_type == "latent": - latents = block_state.latents - latents_mean = ( - torch.tensor(components.vae.config.latents_mean) - .view(1, components.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( - 1, components.vae.config.z_dim, 1, 1, 1 - ).to(latents.device, latents.dtype) - latents = latents / latents_std + latents_mean - latents = latents.to(vae_dtype) - block_state.videos = components.vae.decode(latents, return_dict=False)[0] - else: - block_state.videos = block_state.latents + latents = block_state.latents + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + latents = latents.to(vae_dtype) + block_state.videos = components.vae.decode(latents, return_dict=False)[0] block_state.videos = components.video_processor.postprocess_video( - block_state.videos, output_type=block_state.output_type + block_state.videos, output_type="np" ) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 4f3ca80acc70..cdcbae72a97e 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -27,13 +27,14 @@ ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec from .modular_pipeline import WanModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name + class WanLoopDenoiser(ModularPipelineBlocks): model_name = "wan" @@ -61,11 +62,6 @@ def description(self) -> str: def inputs(self) -> List[Tuple[str, Any]]: return [ InputParam("attention_kwargs"), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ InputParam( "latents", required=True, @@ -78,14 +74,8 @@ def intermediate_inputs(self) -> List[str]: type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam( - kwargs_type="denoiser_input_fields", - description=( - "All conditional model inputs that need to be prepared with guider. " - "It should contain prompt_embeds/negative_prompt_embeds. " - "Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ), - ), + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", required=True, type_hint=torch.Tensor), ] @torch.no_grad() @@ -95,10 +85,7 @@ def __call__( # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) guider_inputs = { - "prompt_embeds": ( - getattr(block_state, "prompt_embeds", None), - getattr(block_state, "negative_prompt_embeds", None), - ), + "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), } transformer_dtype = components.transformer.dtype @@ -118,16 +105,15 @@ def __call__( for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} - prompt_embeds = cond_kwargs.pop("prompt_embeds") # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latents.to(transformer_dtype), - timestep=t.flatten(), - encoder_hidden_states=prompt_embeds, + timestep=t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype), attention_kwargs=block_state.attention_kwargs, return_dict=False, + **cond_kwargs, )[0] components.guider.cleanup_models(components.transformer) @@ -154,19 +140,6 @@ def description(self) -> str: "object (e.g. `WanDenoiseLoopWrapper`)" ) - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [] - - @property - def intermediate_inputs(self) -> List[str]: - return [ - InputParam("generator"), - ] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): @@ -198,18 +171,11 @@ def description(self) -> str: @property def loop_expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 5.0}), - default_creation_method="from_config", - ), ComponentSpec("scheduler", UniPCMultistepScheduler), - ComponentSpec("transformer", WanTransformer3DModel), ] @property - def loop_intermediate_inputs(self) -> List[InputParam]: + def loop_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", @@ -246,6 +212,81 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state +# class Wan22DenoiseLoopWrapper(LoopSequentialPipelineBlocks): +# model_name = "wan" + +# @property +# def description(self) -> str: +# return ( +# "Pipeline block that iteratively denoise the latents over `timesteps`. " +# "The specific steps with each iteration can be customized with `sub_blocks` attributes" +# ) + +# @property +# def loop_expected_configs(self) -> List[ConfigSpec]: +# return [ +# ConfigSpec( +# "boundary_ratio", +# type_hint=float, +# description="The ratio of the total timesteps to use as the boundary for switching between transformers in two-stage denoising.", +# ), +# ] + +# @property +# def loop_expected_components(self) -> List[ComponentSpec]: +# return [ +# ComponentSpec("scheduler", UniPCMultistepScheduler), +# ] + +# @property +# def loop_inputs(self) -> List[InputParam]: +# return [ +# InputParam( +# "timesteps", +# required=True, +# type_hint=torch.Tensor, +# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", +# ), +# InputParam( +# "num_inference_steps", +# required=True, +# type_hint=int, +# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", +# ), +# ] + +# @torch.no_grad() +# def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: +# block_state = self.get_block_state(state) + +# block_state.num_warmup_steps = max( +# len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 +# ) + +# block_state.boundary_timestep = components.config.boundary_ratio * components.scheduler.config.num_train_timesteps + +# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: +# for i, t in enumerate(block_state.timesteps): + +# if t > block_state.boundary_timestep: +# # hieh-noise stage +# block_state.current_model = components.transformer +# block_state.current_guider = components.guider +# else: +# # low-noise stage +# block_state.current_model = components.transformer_2 +# block_state.current_guider = components.guider_2 +# components, block_state = self.loop_step(components, block_state, i=i, t=t) +# if i == len(block_state.timesteps) - 1 or ( +# (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 +# ): +# progress_bar.update() + +# self.set_block_state(state, block_state) + +# return components, state + + class WanDenoiseStep(WanDenoiseLoopWrapper): block_classes = [ WanLoopDenoiser, @@ -261,5 +302,5 @@ def description(self) -> str: "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" - "This block supports both text2vid tasks." + "This block supports text-to-video tasks." ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index cb2fc242383c..4db79d17704e 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -51,6 +51,39 @@ def prompt_clean(text): return text + +def get_t5_prompt_embeds( + text_encoder: UMT5EncoderModel, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + max_sequence_length: int, + device: torch.device, +): + dtype = text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + return prompt_embeds + + class WanTextEncoderStep(ModularPipelineBlocks): model_name = "wan" @@ -71,16 +104,12 @@ def expected_components(self) -> List[ComponentSpec]: ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [] - @property def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), InputParam("negative_prompt"), - InputParam("attention_kwargs"), + InputParam("max_sequence_length", default=512), ] @property @@ -107,47 +136,13 @@ def check_inputs(block_state): ): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") - @staticmethod - def _get_t5_prompt_embeds( - components, - prompt: Union[str, List[str]], - max_sequence_length: int, - device: torch.device, - ): - dtype = components.text_encoder.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - - text_inputs = components.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - return prompt_embeds - @staticmethod def encode_prompt( components, prompt: str, device: Optional[torch.device] = None, - num_videos_per_prompt: int = 1, prepare_unconditional_embeds: bool = True, negative_prompt: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, ): r""" @@ -158,21 +153,12 @@ def encode_prompt( prompt to be encoded device: (`torch.device`): torch device - num_videos_per_prompt (`int`): - number of videos that should be generated per prompt prepare_unconditional_embeds (`bool`): whether to use prepare unconditional embeddings or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. max_sequence_length (`int`, defaults to `512`): The maximum number of text tokens to be used for the generation process. """ @@ -180,10 +166,15 @@ def encode_prompt( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] - if prompt_embeds is None: - prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device) + prompt_embeds = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) - if prepare_unconditional_embeds and negative_prompt_embeds is None: + if prepare_unconditional_embeds: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt @@ -199,18 +190,14 @@ def encode_prompt( " the batch size of `prompt`." ) - negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds( - components, negative_prompt, max_sequence_length, device + negative_prompt_embeds = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, ) - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - - if prepare_unconditional_embeds: - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return prompt_embeds, negative_prompt_embeds @torch.no_grad() @@ -219,7 +206,6 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe block_state = self.get_block_state(state) self.check_inputs(block_state) - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 block_state.device = components._execution_device # Encode input prompt @@ -227,14 +213,12 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe block_state.prompt_embeds, block_state.negative_prompt_embeds, ) = self.encode_prompt( - components, - block_state.prompt, - block_state.device, - 1, - block_state.prepare_unconditional_embeds, - block_state.negative_prompt, - prompt_embeds=None, - negative_prompt_embeds=None, + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, ) # Add outputs diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py index e4adf3d151d6..cc7c99840085 100644 --- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -86,3 +86,13 @@ def num_channels_latents(self): if hasattr(self, "vae") and self.vae is not None: num_channels_latents = self.vae.config.z_dim return num_channels_latents + + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds \ No newline at end of file diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 8a32d4c367a3..2d7560cd6ad6 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -118,6 +118,7 @@ StableDiffusionXLPipeline, ) from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline +from .wan import WanPipeline, WanImageToVideoPipeline, WanVideoToVideoPipeline, WanVACEPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( @@ -214,6 +215,24 @@ ] ) +AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict( + [ + ("wan", WanPipeline), + ] +) + +AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict( + [ + ("wan", WanImageToVideoPipeline), + ] +) + +AUTO_VIDEO2VIDEO_PIPELINES_MAPPING = OrderedDict( + [ + ("wan", WanVideoToVideoPipeline), + ] +) + _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( [ ("kandinsky", KandinskyPipeline), @@ -247,6 +266,9 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, AUTO_INPAINT_PIPELINES_MAPPING, + AUTO_TEXT2VIDEO_PIPELINES_MAPPING, + AUTO_IMAGE2VIDEO_PIPELINES_MAPPING, + AUTO_VIDEO2VIDEO_PIPELINES_MAPPING, _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING, _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING, _AUTO_INPAINT_DECODER_PIPELINES_MAPPING, From 921185c98de905269a58c77ed319ddb62a37610f Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Nov 2025 02:53:43 +0100 Subject: [PATCH 02/12] support image2video --- .../guiders/adaptive_projected_guidance.py | 13 +- .../adaptive_projected_guidance_mix.py | 17 +- src/diffusers/guiders/auto_guidance.py | 9 + .../guiders/classifier_free_guidance.py | 10 +- .../classifier_free_zero_star_guidance.py | 10 +- .../guiders/frequency_decoupled_guidance.py | 8 + src/diffusers/guiders/guider_utils.py | 50 +++ .../guiders/perturbed_attention_guidance.py | 18 ++ src/diffusers/guiders/skip_layer_guidance.py | 18 ++ .../guiders/smoothed_energy_guidance.py | 18 ++ .../tangential_classifier_free_guidance.py | 10 +- .../modular_pipelines/wan/before_denoise.py | 260 +++++++++++++++- .../modular_pipelines/wan/decoders.py | 2 +- .../modular_pipelines/wan/denoise.py | 216 +++++++------ .../modular_pipelines/wan/encoders.py | 294 +++++++++++++++++- .../modular_pipelines/wan/modular_blocks.py | 162 +++++++--- .../modular_pipelines/wan/modular_pipeline.py | 8 + 17 files changed, 973 insertions(+), 150 deletions(-) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 492d10d2f108..48f1fd448351 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -88,6 +88,17 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/adaptive_projected_guidance_mix.py b/src/diffusers/guiders/adaptive_projected_guidance_mix.py index 732741fc927f..95511500a8bf 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance_mix.py +++ b/src/diffusers/guiders/adaptive_projected_guidance_mix.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -99,6 +99,21 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches + + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 4374f45aff7c..97156e3d220f 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -141,6 +141,15 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index d475b302263d..e54ba0dc4ac6 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -99,6 +99,14 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 1ea6bbb1c830..4d7ff12e304f 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -85,6 +85,14 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index cd542a43a429..6668b1adf2cb 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -225,6 +225,14 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batch = self._prepare_batch(data, tuple_idx, input_prediction) data_batches.append(data_batch) return data_batches + + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 71e4becfcdf3..b718956412eb 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -166,6 +166,9 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.") + def __call__(self, data: List["BlockState"]) -> Any: if not all(hasattr(d, "noise_pred") for d in data): raise ValueError("Expected all data to have `noise_pred` attribute.") @@ -234,6 +237,53 @@ def _prepare_batch( data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) + + @classmethod + def _prepare_batch_from_block_state( + cls, + input_fields: Dict[str, Union[str, Tuple[str, str]]], + data: "BlockState", + tuple_index: int, + identifier: str, + ) -> "BlockState": + """ + Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the + `BaseGuidance` class. It prepares the batch based on the provided tuple index. + + Args: + input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once it is + prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used + to look up the required data provided for preparation. If a string is provided, it will be used as the + conditional data (or unconditional if used with a guidance method that requires it). If a tuple of + length 2 is provided, the first element must be the conditional data identifier and the second element + must be the unconditional data identifier or None. + data (`BlockState`): + The input data to be prepared. + tuple_index (`int`): + The index to use when accessing input fields that are tuples. + + Returns: + `BlockState`: The prepared batch of data. + """ + from ..modular_pipelines.modular_pipeline import BlockState + + + data_batch = {} + for key, value in input_fields.items(): + try: + if isinstance(value, str): + data_batch[key] = getattr(data, value) + elif isinstance(value, tuple): + data_batch[key] = getattr(data, value[tuple_index]) + else: + # We've already checked that value is a string or a tuple of strings with length 2 + pass + except AttributeError: + logger.debug(f"`data` does not have attribute(s) {value}, skipping.") + data_batch[cls._identifier_key] = identifier + return BlockState(**data_batch) + @classmethod @validate_hf_hub_args def from_pretrained( diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index 29341736e8d9..61e29aa350f1 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -186,6 +186,24 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batch = self._prepare_batch(data, tuple_idx, input_prediction) data_batches.append(data_batch) return data_batches + + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward def forward( diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index fa5b93b68009..493630ef2011 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -183,6 +183,24 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward( self, pred_cond: torch.Tensor, diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 7446b33f1250..ab69669d62c8 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -172,6 +172,24 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward( self, pred_cond: torch.Tensor, diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index cfa3c4a61619..bd36cc24ddc2 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -74,6 +74,14 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index cd46522fc098..95c009e23f43 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple import torch @@ -34,6 +34,93 @@ # configuration of guider is. +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_videos_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times + - If batch size equals batch_size: repeat each element num_videos_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_videos_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_videos_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_videos_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + +def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int) -> Tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by multiplying the latent num_frames/height/width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension. + Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension) + vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + Tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + if latents.ndim != 5: + raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}") + + _, _, num_latent_frames, latent_height, latent_width = latents.shape + + num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1 + height = latent_height * vae_scale_factor_spatial + width = latent_width * vae_scale_factor_spatial + + return num_frames, height, width + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -94,7 +181,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class WanInputStep(ModularPipelineBlocks): +class WanTextInputStep(ModularPipelineBlocks): model_name = "wan" @property @@ -177,6 +264,140 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state +class WanInputsDynamicStep(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + image_latent_inputs: List[str] = ["first_frame_latents"], + additional_batch_inputs: List[str] = ["image_embeds"], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (List[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be a single string or + list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"] + additional_batch_inputs (List[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. Examples: ["processed_mask_image"] + + Examples: + # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep() + + # Configure to process multiple image latent inputs + QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"]) + + # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep( + image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_videos_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="num_frames"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + num_frames,height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial) + block_state.num_frames = block_state.num_frames or num_frames + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + class WanSetTimestepsStep(ModularPipelineBlocks): model_name = "wan" @@ -326,3 +547,38 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe self.set_block_state(state, block_state) return components, state + + +class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the last frame mask latents and add it to the latent condition" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("num_frames", type_hint=int), + ] + + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) + block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py index bb6f9c147cbd..254595d131ee 100644 --- a/src/diffusers/modular_pipelines/wan/decoders.py +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -29,7 +29,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class WanDecodeStep(ModularPipelineBlocks): +class WanImageVaeDecoderStep(ModularPipelineBlocks): model_name = "wan" @property diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index cdcbae72a97e..c7a8511808e4 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Dict import torch @@ -34,10 +34,79 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class WanLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + ] + + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = block_state.latents + return components, block_state + -class WanLoopDenoiser(ModularPipelineBlocks): +class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks): model_name = "wan" + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "first_frame_latents", + required=True, + type_hint=torch.Tensor, + description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1) + return components, block_state + +class WanLoopDenoiserDynamic(ModularPipelineBlocks): + model_name = "wan" + + # guider_input_fields maps the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.encoder_hidden_states) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + def __init__(self, guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}): + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -60,33 +129,33 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: - return [ + + inputs = [ InputParam("attention_kwargs"), - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", - ), InputParam( "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), - InputParam("negative_prompt_embeds", required=True, type_hint=torch.Tensor), ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + @torch.no_grad() def __call__( self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: - # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_inputs = { - "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), - } + transformer_dtype = components.transformer.dtype components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) @@ -99,18 +168,19 @@ def __call__( # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch # ] # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). - guider_state = components.guider.prepare_inputs(guider_inputs) + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) # run the denoiser for each guidance batch for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) - cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys()} # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.transformer( - hidden_states=block_state.latents.to(transformer_dtype), - timestep=t.expand(block_state.latents.shape[0]).to(block_state.latents.dtype), + hidden_states=block_state.latent_model_input.to(transformer_dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype), attention_kwargs=block_state.attention_kwargs, return_dict=False, **cond_kwargs, @@ -212,84 +282,14 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -# class Wan22DenoiseLoopWrapper(LoopSequentialPipelineBlocks): -# model_name = "wan" - -# @property -# def description(self) -> str: -# return ( -# "Pipeline block that iteratively denoise the latents over `timesteps`. " -# "The specific steps with each iteration can be customized with `sub_blocks` attributes" -# ) - -# @property -# def loop_expected_configs(self) -> List[ConfigSpec]: -# return [ -# ConfigSpec( -# "boundary_ratio", -# type_hint=float, -# description="The ratio of the total timesteps to use as the boundary for switching between transformers in two-stage denoising.", -# ), -# ] - -# @property -# def loop_expected_components(self) -> List[ComponentSpec]: -# return [ -# ComponentSpec("scheduler", UniPCMultistepScheduler), -# ] - -# @property -# def loop_inputs(self) -> List[InputParam]: -# return [ -# InputParam( -# "timesteps", -# required=True, -# type_hint=torch.Tensor, -# description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", -# ), -# InputParam( -# "num_inference_steps", -# required=True, -# type_hint=int, -# description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", -# ), -# ] - -# @torch.no_grad() -# def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: -# block_state = self.get_block_state(state) - -# block_state.num_warmup_steps = max( -# len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 -# ) - -# block_state.boundary_timestep = components.config.boundary_ratio * components.scheduler.config.num_train_timesteps - -# with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: -# for i, t in enumerate(block_state.timesteps): - -# if t > block_state.boundary_timestep: -# # hieh-noise stage -# block_state.current_model = components.transformer -# block_state.current_guider = components.guider -# else: -# # low-noise stage -# block_state.current_model = components.transformer_2 -# block_state.current_guider = components.guider_2 -# components, block_state = self.loop_step(components, block_state, i=i, t=t) -# if i == len(block_state.timesteps) - 1 or ( -# (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 -# ): -# progress_bar.update() - -# self.set_block_state(state, block_state) - -# return components, state - - class WanDenoiseStep(WanDenoiseLoopWrapper): block_classes = [ - WanLoopDenoiser, + WanLoopBeforeDenoiser, + WanLoopDenoiserDynamic( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), WanLoopAfterDenoiser, ] block_names = ["before_denoiser", "denoiser", "after_denoiser"] @@ -304,3 +304,27 @@ def description(self) -> str: " - `WanLoopAfterDenoiser`\n" "This block supports text-to-video tasks." ) + +class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanImage2VideoLoopBeforeDenoiser, + WanLoopDenoiserDynamic( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_hidden_states_image": "image_embeds", + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports image-to-video tasks." + ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index 4db79d17704e..8730112ed889 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -17,7 +17,7 @@ import regex as re import torch -from transformers import AutoTokenizer, UMT5EncoderModel +from transformers import AutoTokenizer, UMT5EncoderModel, CLIPImageProcessor, CLIPVisionModel from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance @@ -25,7 +25,11 @@ from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline - +from ...image_processor import PipelineImageInput +from ...video_processor import VideoProcessor +from ...models import AutoencoderKLWan +import PIL +import numpy as np if is_ftfy_available(): import ftfy @@ -51,7 +55,6 @@ def prompt_clean(text): return text - def get_t5_prompt_embeds( text_encoder: UMT5EncoderModel, tokenizer: AutoTokenizer, @@ -84,6 +87,84 @@ def get_t5_prompt_embeds( return prompt_embeds +def encode_image( + image: PipelineImageInput, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + device: Optional[torch.device] = None, +): + image = image_processor(images=image, return_tensors="pt").to(device) + image_embeds = image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + image: torch.Tensor, + vae: AutoencoderKLWan, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + num_frames: int = 81, + height: int = 480, + width: int = 832, + latent_channels: int = 16, +): + if not isinstance(image, torch.Tensor): + raise ValueError(f"Expected image to be a tensor, got {type(image)}.") + + if isinstance(generator, list) and len(generator) != image.shape[0]: + raise ValueError(f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image.shape[0]}.") + + # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width + if image.dim() == 4: + image = image.unsqueeze(2) + elif image.dim() != 5: + raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") + + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + + video_condition = video_condition.to(device=device, dtype=dtype) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(vae.encode(video_condition[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(image.shape[0]) + ] + latent_condition = torch.cat(latent_condition, dim=0) + else: + latent_condition = retrieve_latents(vae.encode(video_condition), sample_mode="argmax") + + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, latent_channels, 1, 1, 1) + .to(latent_condition.device, latent_condition.dtype) + ) + latents_std = ( + 1.0 / torch.tensor(vae.config.latents_std) + .view(1, latent_channels, 1, 1, 1) + .to(latent_condition.device, latent_condition.dtype) + ) + latent_condition = (latent_condition - latents_mean) * latents_std + + return latent_condition + + + class WanTextEncoderStep(ModularPipelineBlocks): model_name = "wan" @@ -224,3 +305,210 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe # Add outputs self.set_block_state(state, block_state) return components, state + + +class WanImageResizeDynamicStep(ModularPipelineBlocks): + model_name = "wan" + + def __init__(self, input_name: str = "image", output_name: str = "resized_image"): + """Create a configurable step for resizing images to the target area (height * width) while maintaining the aspect ratio. + + This block resizes an input image and exposes the resized result under configurable + input and output names. Use this when you need to wire the resize step to different image fields (e.g., + "image", "last_image") + + Args: + input_name (str, optional): Name of the image field to read from the + pipeline state. Defaults to "image". + output_name (str, optional): Name of the resized image field to write + back to the pipeline state. Defaults to "resized_image". + """ + if not isinstance(input_name, str) or not isinstance(output_name, str): + raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}") + self._image_input_name = input_name + self._resized_image_output_name = output_name + super().__init__() + + @property + def description(self) -> str: + return f"Image Resize step that resize the {self._image_input_name} to the target area (height * width) while maintaining the aspect ratio." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(self._image_input_name, type_hint=PIL.Image.Image), + InputParam("height", type_hint=int, default=480), + InputParam("width", type_hint=int, default=832), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(self._resized_image_output_name, type_hint=PIL.Image.Image, description="The resized image"), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + max_area = block_state.height * block_state.width + + image = getattr(block_state, self._image_input_name) + + aspect_ratio = image.height / image.width + mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + resized_image = image.resize((width, height)) + setattr(block_state, self._resized_image_output_name, resized_image) + + self.set_block_state(state, block_state) + return components, state + + +class WanImageEncoderDynamicStep(ModularPipelineBlocks): + model_name = "wan" + + def __init__(self, input_name: str = "resized_image", output_name: str = "image_embeds"): + """Create a configurable step for encoding images to generate image embeddings. + + This block encodes an input image and exposes the generated embeddings under configurable + input and output names. Use this when you need to wire the encoder step to different image fields (e.g., + "resized_image") + """ + if not isinstance(input_name, str) or not isinstance(output_name, str): + raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}") + self._image_input_name = input_name + self._image_embeds_output_name = output_name + super().__init__() + + @property + def description(self) -> str: + return f"Image Encoder step that generate {self._image_embeds_output_name} to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_processor", CLIPImageProcessor), + ComponentSpec("image_encoder", CLIPVisionModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(self._image_input_name, type_hint=PIL.Image.Image), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(self._image_embeds_output_name, type_hint=torch.Tensor, description="The image embeddings"), + ] + + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + image = getattr(block_state, self._image_input_name) + + image_embeds = encode_image( + image_processor=components.image_processor, + image_encoder=components.image_encoder, + image=image, + device=device, + ) + setattr(block_state, self._image_embeds_output_name, image_embeds) + self.set_block_state(state, block_state) + return components, state + + +class WanVaeImageEncoderDynamicStep(ModularPipelineBlocks): + model_name = "wan" + + def __init__(self, input_name: str = "resized_image", output_name: str = "first_frame_latents"): + """Create a configurable step for encoding images to generate image latents. + + This block encodes an input image and exposes the generated latents under configurable + input and output names. Use this when you need to wire the encoder step to different image fields (e.g., + "resized_image") + """ + if not isinstance(input_name, str) or not isinstance(output_name, str): + raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}") + self._image_input_name = input_name + self._image_latents_output_name = output_name + super().__init__() + + @property + def description(self) -> str: + return f"Vae Image Encoder step that generate {self._image_latents_output_name} to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec("video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(self._image_input_name, type_hint=PIL.Image.Image), + InputParam("height"), + InputParam("width"), + InputParam("num_frames"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(self._image_latents_output_name, type_hint=torch.Tensor, description="The latent condition"), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = getattr(block_state, self._image_input_name) + + device = components._execution_device + dtype = torch.float32 + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames or components.default_num_frames + + image_tensor = components.video_processor.preprocess( + image, height=height, width=width).to(device=device, dtype=dtype) + + latent_condition = encode_vae_image( + image=image_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + num_frames=num_frames, + height=height, + width=width, + latent_channels=components.num_channels_latents, + ) + + setattr(block_state, self._image_latents_output_name, latent_condition) + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 5f4c1a983566..a618970be939 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -16,96 +16,156 @@ from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( - WanInputStep, + WanTextInputStep, WanPrepareLatentsStep, WanSetTimestepsStep, + WanInputsDynamicStep, + WanPrepareFirstFrameLatentsStep, ) -from .decoders import WanDecodeStep -from .denoise import WanDenoiseStep -from .encoders import WanTextEncoderStep +from .decoders import WanImageVaeDecoderStep +from .denoise import WanDenoiseStep, WanImage2VideoDenoiseStep +from .encoders import WanTextEncoderStep, WanImageResizeDynamicStep, WanImageEncoderDynamicStep, WanVaeImageEncoderDynamicStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# before_denoise: text2vid -class WanBeforeDenoiseStep(SequentialPipelineBlocks): +# text2vid +class WanCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ - WanInputStep, + WanTextInputStep, WanSetTimestepsStep, WanPrepareLatentsStep, + WanDenoiseStep, ] - block_names = ["input", "set_timesteps", "prepare_latents"] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] @property def description(self): return ( - "Before denoise step that prepare the inputs for the denoise step.\n" + "denoise block that takes encoded conditions and runs the denoising process.\n" + "This is a sequential pipeline blocks:\n" - + " - `WanInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + " - `WanSetTimestepsStep` is used to set the timesteps\n" + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanDenoiseStep` is used to denoise the latents\n" ) -# before_denoise: all task (text2vid,) -class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): +# image2video + +## iamge encoder +class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [WanImageResizeDynamicStep(input_name="image", output_name="resized_image"), WanImageEncoderDynamicStep(input_name="resized_image", output_name="image_embeds")] + block_names = ["image_resize", "image_encoder"] + + @property + def description(self): + return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" + + + +# vae encoder +class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [WanImageResizeDynamicStep(input_name="image", output_name="resized_image"), WanVaeImageEncoderDynamicStep(input_name="resized_image", output_name="first_frame_latents")] + block_names = ["image_resize", "vae_image_encoder"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + + +class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ - WanBeforeDenoiseStep, + WanTextInputStep, + WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanPrepareFirstFrameLatentsStep, + WanImage2VideoDenoiseStep, ] - block_names = ["text2vid"] - block_trigger_inputs = [None] + block_names = ["input", "additional_inputs", "set_timesteps", "prepare_latents", "prepare_first_frame_latents", "denoise"] @property def description(self): return ( - "Before denoise step that prepare the inputs for the denoise step.\n" - + "This is an auto pipeline block that works for text2vid.\n" - + " - `WanBeforeDenoiseStep` (text2vid) is used.\n" + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanInputsDynamicStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanPrepareConditionLatentsStep` is used to prepare the latent conditions\n" + + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" ) -# denoise: text2vid + +# auto blocks + +class WanAutoImageEncoderStep(AutoPipelineBlocks): + block_classes = [WanImage2VideoImageEncoderStep] + block_names = ["image_encoder"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ("Image Encoder step that encode the image to generate the image embeddings" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `image` is not provided, step will be skipped.") + +class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): + block_classes = [WanImage2VideoVaeImageEncoderStep] + block_names = ["vae_image_encoder"] + block_trigger_inputs = ["image"] + + @property + def description(self): + return ("Vae Image Encoder step that encode the image to generate the image latents" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `image` is not provided, step will be skipped.") + + class WanAutoDenoiseStep(AutoPipelineBlocks): block_classes = [ - WanDenoiseStep, + WanImage2VideoCoreDenoiseStep, + WanCoreDenoiseStep, ] - block_names = ["denoise"] - block_trigger_inputs = [None] + block_names = ["image2video", "text2video"] + block_trigger_inputs = ["first_frame_latents", None] @property def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2vid tasks.." - " - `WanDenoiseStep` (denoise) for text2vid tasks." + "This is a auto pipeline block that works for text2video and image2video tasks." + " - `WanCoreDenoiseStep` (text2video) for text2vid tasks." + " - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks." + + " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n" + + " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n" ) -# decode: all task (text2img, img2img, inpainting) -class WanAutoDecodeStep(AutoPipelineBlocks): - block_classes = [WanDecodeStep] - block_names = ["non-inpaint"] - block_trigger_inputs = [None] - - @property - def description(self): - return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`" - - # text2vid class WanAutoBlocks(SequentialPipelineBlocks): block_classes = [ WanTextEncoderStep, - WanAutoBeforeDenoiseStep, + WanAutoImageEncoderStep, + WanAutoVaeImageEncoderStep, WanAutoDenoiseStep, - WanAutoDecodeStep, + WanImageVaeDecoderStep, ] block_names = [ "text_encoder", - "before_denoise", + "image_encoder", + "vae_image_encoder", "denoise", - "decoder", + "decode", ] @property @@ -119,26 +179,42 @@ def description(self): TEXT2VIDEO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), - ("input", WanInputStep), + ("input", WanTextInputStep), ("set_timesteps", WanSetTimestepsStep), ("prepare_latents", WanPrepareLatentsStep), ("denoise", WanDenoiseStep), - ("decode", WanDecodeStep), + ("decode", WanImageVaeDecoderStep), ] ) +IMAGE2VIDEO_BLOCKS = InsertableDict( + [ + ("image_resize", WanImageResizeDynamicStep()), + ("image_encoder", WanImage2VideoImageEncoderStep()), + ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), + ("input", WanTextInputStep), + ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"])), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("denoise", WanImage2VideoCoreDenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] + +) AUTO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), - ("before_denoise", WanAutoBeforeDenoiseStep), + ("image_encoder", WanAutoImageEncoderStep), + ("vae_image_encoder", WanAutoVaeImageEncoderStep), ("denoise", WanAutoDenoiseStep), - ("decode", WanAutoDecodeStep), + ("decode", WanImageVaeDecoderStep), ] ) ALL_BLOCKS = { "text2video": TEXT2VIDEO_BLOCKS, + "image2video": IMAGE2VIDEO_BLOCKS, "auto": AUTO_BLOCKS, } diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py index cc7c99840085..dae03608b9c3 100644 --- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -58,6 +58,14 @@ def default_sample_width(self): @property def default_sample_num_frames(self): return 21 + + @property + def patch_size_spatial(self): + patch_size_spatial = 2 + if hasattr(self, "transformer") and self.transformer is not None: + patch_size_spatial = self.transformer.config.patch_size[1] + return patch_size_spatial + @property def vae_scale_factor_spatial(self): From 846b5f98a61542b6dea7dbfd4caf75e7ded4b74a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Nov 2025 08:13:20 +0100 Subject: [PATCH 03/12] revert dynamic steps to simplify --- .../modular_pipelines/wan/encoders.py | 134 +++++++++--------- .../modular_pipelines/wan/modular_blocks.py | 10 +- 2 files changed, 72 insertions(+), 72 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index 8730112ed889..ebdd7c23c57d 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -21,7 +21,7 @@ from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance -from ...utils import is_ftfy_available, logging +from ...utils import is_ftfy_available, is_torchvision_available, logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline @@ -31,9 +31,13 @@ import PIL import numpy as np + if is_ftfy_available(): import ftfy +if is_torchvision_available(): + from torchvision import transforms + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -307,36 +311,17 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -class WanImageResizeDynamicStep(ModularPipelineBlocks): +class WanImageResizeStep(ModularPipelineBlocks): model_name = "wan" - def __init__(self, input_name: str = "image", output_name: str = "resized_image"): - """Create a configurable step for resizing images to the target area (height * width) while maintaining the aspect ratio. - - This block resizes an input image and exposes the resized result under configurable - input and output names. Use this when you need to wire the resize step to different image fields (e.g., - "image", "last_image") - - Args: - input_name (str, optional): Name of the image field to read from the - pipeline state. Defaults to "image". - output_name (str, optional): Name of the resized image field to write - back to the pipeline state. Defaults to "resized_image". - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}") - self._image_input_name = input_name - self._resized_image_output_name = output_name - super().__init__() - @property def description(self) -> str: - return f"Image Resize step that resize the {self._image_input_name} to the target area (height * width) while maintaining the aspect ratio." + return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio." @property def inputs(self) -> List[InputParam]: return [ - InputParam(self._image_input_name, type_hint=PIL.Image.Image), + InputParam("image", type_hint=PIL.Image.Image, required=True), InputParam("height", type_hint=int, default=480), InputParam("width", type_hint=int, default=832), ] @@ -344,7 +329,7 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(self._resized_image_output_name, type_hint=PIL.Image.Image, description="The resized image"), + OutputParam("resized_image", type_hint=PIL.Image.Image), ] def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: @@ -352,38 +337,66 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe block_state = self.get_block_state(state) max_area = block_state.height * block_state.width - image = getattr(block_state, self._image_input_name) - + image = block_state.image aspect_ratio = image.height / image.width mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial - height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value - resized_image = image.resize((width, height)) - setattr(block_state, self._resized_image_output_name, resized_image) + block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + block_state.resized_image = image.resize((block_state.width, block_state.height)) self.set_block_state(state, block_state) return components, state -class WanImageEncoderDynamicStep(ModularPipelineBlocks): +class WanImageCropResizeStep(ModularPipelineBlocks): model_name = "wan" - def __init__(self, input_name: str = "resized_image", output_name: str = "image_embeds"): - """Create a configurable step for encoding images to generate image embeddings. - This block encodes an input image and exposes the generated embeddings under configurable - input and output names. Use this when you need to wire the encoder step to different image fields (e.g., - "resized_image") - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}") - self._image_input_name = input_name - self._image_embeds_output_name = output_name - super().__init__() + @property + def description(self) -> str: + return "Image Resize step that resize the last_image to the same size of first frame image with center crop." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image"), + InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("resized_last_image", type_hint=PIL.Image.Image), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + + block_state = self.get_block_state(state) + + height = block_state.resized_image.height + width = block_state.resized_image.width + image = block_state.last_image + + # Calculate resize ratio to match first frame dimensions + resize_ratio = max(width / image.width, height / image.height) + + # Resize the image + width = round(image.width * resize_ratio) + height = round(image.height * resize_ratio) + size = [width, height] + resized_image = transforms.functional.center_crop(image, size) + block_state.resized_last_image = resized_image + + self.set_block_state(state, block_state) + return components, state + + +class WanImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" @property def description(self) -> str: - return f"Image Encoder step that generate {self._image_embeds_output_name} to guide the video generation" + return "Image Encoder step that generate image_embeds to guide the video generation" @property def expected_components(self) -> List[ComponentSpec]: @@ -395,13 +408,13 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(self._image_input_name, type_hint=PIL.Image.Image), + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(self._image_embeds_output_name, type_hint=torch.Tensor, description="The image embeddings"), + OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), ] @@ -409,8 +422,8 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe block_state = self.get_block_state(state) device = components._execution_device - - image = getattr(block_state, self._image_input_name) + + image = block_state.resized_image image_embeds = encode_image( image_processor=components.image_processor, @@ -418,30 +431,17 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe image=image, device=device, ) - setattr(block_state, self._image_embeds_output_name, image_embeds) + block_state.image_embeds = image_embeds self.set_block_state(state, block_state) return components, state -class WanVaeImageEncoderDynamicStep(ModularPipelineBlocks): +class WanVaeImageEncoderStep(ModularPipelineBlocks): model_name = "wan" - def __init__(self, input_name: str = "resized_image", output_name: str = "first_frame_latents"): - """Create a configurable step for encoding images to generate image latents. - - This block encodes an input image and exposes the generated latents under configurable - input and output names. Use this when you need to wire the encoder step to different image fields (e.g., - "resized_image") - """ - if not isinstance(input_name, str) or not isinstance(output_name, str): - raise ValueError(f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}") - self._image_input_name = input_name - self._image_latents_output_name = output_name - super().__init__() - @property def description(self) -> str: - return f"Vae Image Encoder step that generate {self._image_latents_output_name} to guide the video generation" + return "Vae Image Encoder step that generate first_frame_latents to guide the video generation" @property def expected_components(self) -> List[ComponentSpec]: @@ -453,7 +453,7 @@ def expected_components(self) -> List[ComponentSpec]: @property def inputs(self) -> List[InputParam]: return [ - InputParam(self._image_input_name, type_hint=PIL.Image.Image), + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), InputParam("height"), InputParam("width"), InputParam("num_frames"), @@ -463,7 +463,7 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam(self._image_latents_output_name, type_hint=torch.Tensor, description="The latent condition"), + OutputParam("first_frame_latents", type_hint=torch.Tensor, description="The latent condition"), ] @staticmethod @@ -485,7 +485,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe block_state = self.get_block_state(state) self.check_inputs(components, block_state) - image = getattr(block_state, self._image_input_name) + image = block_state.resized_image device = components._execution_device dtype = torch.float32 @@ -509,6 +509,6 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe latent_channels=components.num_channels_latents, ) - setattr(block_state, self._image_latents_output_name, latent_condition) + block_state.first_frame_latents = latent_condition self.set_block_state(state, block_state) return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index a618970be939..83f94bc00c81 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -24,7 +24,7 @@ ) from .decoders import WanImageVaeDecoderStep from .denoise import WanDenoiseStep, WanImage2VideoDenoiseStep -from .encoders import WanTextEncoderStep, WanImageResizeDynamicStep, WanImageEncoderDynamicStep, WanVaeImageEncoderDynamicStep +from .encoders import WanTextEncoderStep, WanImageResizeStep, WanImageEncoderStep, WanVaeImageEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -57,7 +57,7 @@ def description(self): ## iamge encoder class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" - block_classes = [WanImageResizeDynamicStep(input_name="image", output_name="resized_image"), WanImageEncoderDynamicStep(input_name="resized_image", output_name="image_embeds")] + block_classes = [WanImageResizeStep, WanImageEncoderStep] block_names = ["image_resize", "image_encoder"] @property @@ -69,7 +69,7 @@ def description(self): # vae encoder class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" - block_classes = [WanImageResizeDynamicStep(input_name="image", output_name="resized_image"), WanVaeImageEncoderDynamicStep(input_name="resized_image", output_name="first_frame_latents")] + block_classes = [WanImageResizeStep, WanVaeImageEncoderStep] block_names = ["image_resize", "vae_image_encoder"] @property @@ -189,8 +189,8 @@ def description(self): IMAGE2VIDEO_BLOCKS = InsertableDict( [ - ("image_resize", WanImageResizeDynamicStep()), - ("image_encoder", WanImage2VideoImageEncoderStep()), + ("image_resize", WanImageResizeStep), + ("image_encoder", WanImage2VideoImageEncoderStep), ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), ("input", WanTextInputStep), ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"])), From 1589e75a00ed02f27759b1388b44ca7470af580c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Nov 2025 08:42:33 +0100 Subject: [PATCH 04/12] refactor vae encoder block --- .../modular_pipelines/wan/before_denoise.py | 10 ++-- .../modular_pipelines/wan/denoise.py | 6 +- .../modular_pipelines/wan/encoders.py | 60 ++++++++----------- .../modular_pipelines/wan/modular_blocks.py | 6 +- 4 files changed, 36 insertions(+), 46 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index 95c009e23f43..2a29c3d31934 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -269,7 +269,7 @@ class WanInputsDynamicStep(ModularPipelineBlocks): def __init__( self, - image_latent_inputs: List[str] = ["first_frame_latents"], + image_latent_inputs: List[str] = ["condition_latents"], additional_batch_inputs: List[str] = ["image_embeds"], ): """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" @@ -559,7 +559,7 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("condition_latents", type_hint=Optional[torch.Tensor]), InputParam("num_frames", type_hint=int), ] @@ -567,7 +567,7 @@ def inputs(self) -> List[InputParam]: def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape + batch_size, _, _, latent_height, latent_width = block_state.condition_latents.shape mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 @@ -577,8 +577,8 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width) mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) - block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) + mask_lat_size = mask_lat_size.to(block_state.condition_latents.device) + block_state.condition_latents = torch.concat([mask_lat_size, block_state.condition_latents], dim=1) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index c7a8511808e4..bfafa8c92b1c 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -84,16 +84,16 @@ def inputs(self) -> List[InputParam]: description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( - "first_frame_latents", + "condition_latents", required=True, type_hint=torch.Tensor, - description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.", + description="The condition latents to use for the denoising process. Can be generated in prepare_condition_latents step.", ), ] @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1) + block_state.latent_model_input = torch.cat([block_state.latents, block_state.condition_latents], dim=1) return components, block_state class WanLoopDenoiserDynamic(ModularPipelineBlocks): diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index ebdd7c23c57d..32917db3a400 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -117,55 +117,42 @@ def retrieve_latents( def encode_vae_image( - image: torch.Tensor, + video_tensor: torch.Tensor, vae: AutoencoderKLWan, generator: torch.Generator, device: torch.device, dtype: torch.dtype, - num_frames: int = 81, - height: int = 480, - width: int = 832, latent_channels: int = 16, ): - if not isinstance(image, torch.Tensor): - raise ValueError(f"Expected image to be a tensor, got {type(image)}.") + if not isinstance(video_tensor, torch.Tensor): + raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.") - if isinstance(generator, list) and len(generator) != image.shape[0]: - raise ValueError(f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image.shape[0]}.") + if isinstance(generator, list) and len(generator) != video_tensor.shape[0]: + raise ValueError(f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}.") - # preprocessed image should be a 4D tensor: batch_size, num_channels, height, width - if image.dim() == 4: - image = image.unsqueeze(2) - elif image.dim() != 5: - raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") - - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 - ) - - video_condition = video_condition.to(device=device, dtype=dtype) + video_tensor = video_tensor.to(device=device, dtype=dtype) if isinstance(generator, list): - latent_condition = [ - retrieve_latents(vae.encode(video_condition[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(image.shape[0]) + video_latents = [ + retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(video_tensor.shape[0]) ] - latent_condition = torch.cat(latent_condition, dim=0) + video_latents = torch.cat(video_latents, dim=0) else: - latent_condition = retrieve_latents(vae.encode(video_condition), sample_mode="argmax") + video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax") latents_mean = ( torch.tensor(vae.config.latents_mean) .view(1, latent_channels, 1, 1, 1) - .to(latent_condition.device, latent_condition.dtype) + .to(video_latents.device, video_latents.dtype) ) latents_std = ( 1.0 / torch.tensor(vae.config.latents_std) .view(1, latent_channels, 1, 1, 1) - .to(latent_condition.device, latent_condition.dtype) + .to(video_latents.device, video_latents.dtype) ) - latent_condition = (latent_condition - latents_mean) * latents_std + video_latents = (video_latents - latents_mean) * latents_std - return latent_condition + return video_latents @@ -441,7 +428,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Vae Image Encoder step that generate first_frame_latents to guide the video generation" + return "Vae Image Encoder step that generate condition_latents to guide the video generation" @property def expected_components(self) -> List[ComponentSpec]: @@ -463,7 +450,7 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("first_frame_latents", type_hint=torch.Tensor, description="The latent condition"), + OutputParam("condition_latents", type_hint=torch.Tensor, description="The condition latents"), ] @staticmethod @@ -497,18 +484,21 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe image_tensor = components.video_processor.preprocess( image, height=height, width=width).to(device=device, dtype=dtype) - latent_condition = encode_vae_image( - image=image_tensor, + if image_tensor.dim() == 4: + image_tensor = image_tensor.unsqueeze(2) + + video_tensor = torch.cat( + [image_tensor, image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width)], dim=2 + ).to(device=device, dtype=dtype) + + block_state.condition_latents = encode_vae_image( + video_tensor=video_tensor, vae=components.vae, generator=block_state.generator, device=device, dtype=dtype, - num_frames=num_frames, - height=height, - width=width, latent_channels=components.num_channels_latents, ) - block_state.first_frame_latents = latent_condition self.set_block_state(state, block_state) return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 83f94bc00c81..5a05018ebacc 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -81,7 +81,7 @@ def description(self): class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanTextInputStep, - WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"]), + WanInputsDynamicStep(image_latent_inputs=["condition_latents"]), WanSetTimestepsStep, WanPrepareLatentsStep, WanPrepareFirstFrameLatentsStep, @@ -137,7 +137,7 @@ class WanAutoDenoiseStep(AutoPipelineBlocks): WanCoreDenoiseStep, ] block_names = ["image2video", "text2video"] - block_trigger_inputs = ["first_frame_latents", None] + block_trigger_inputs = ["condition_latents", None] @property def description(self) -> str: @@ -193,7 +193,7 @@ def description(self): ("image_encoder", WanImage2VideoImageEncoderStep), ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), ("input", WanTextInputStep), - ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"])), + ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["condition_latents"])), ("set_timesteps", WanSetTimestepsStep), ("prepare_latents", WanPrepareLatentsStep), ("denoise", WanImage2VideoCoreDenoiseStep), From cb2d3b98fde9629a1ba4f150d0a52a4ade7b3790 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 8 Nov 2025 10:50:55 +0100 Subject: [PATCH 05/12] support flf2video! --- .../modular_pipelines/wan/before_denoise.py | 58 ++++++- .../modular_pipelines/wan/denoise.py | 98 ++++++++++-- .../modular_pipelines/wan/encoders.py | 141 +++++++++++++++++- .../modular_pipelines/wan/modular_blocks.py | 108 ++++++++++++-- 4 files changed, 368 insertions(+), 37 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index 2a29c3d31934..222bdd82590f 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -23,6 +23,7 @@ from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline +from ...models import WanTransformer3DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -194,6 +195,12 @@ def description(self) -> str: "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" "have a final batch_size of batch_size * num_videos_per_prompt." ) + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("transformer", WanTransformer3DModel), + ] @property def inputs(self) -> List[InputParam]: @@ -223,7 +230,7 @@ def intermediate_outputs(self) -> List[str]: OutputParam( "dtype", type_hint=torch.dtype, - description="Data type of model tensor inputs (determined by `prompt_embeds`)", + description="Data type of model tensor inputs (determined by `transformer.dtype`)", ), ] @@ -242,7 +249,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe self.check_inputs(components, block_state) block_state.batch_size = block_state.prompt_embeds.shape[0] - block_state.dtype = block_state.prompt_embeds.dtype + block_state.dtype = components.transformer.dtype _, seq_len, _ = block_state.prompt_embeds.shape block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1) @@ -269,8 +276,8 @@ class WanInputsDynamicStep(ModularPipelineBlocks): def __init__( self, - image_latent_inputs: List[str] = ["condition_latents"], - additional_batch_inputs: List[str] = ["image_embeds"], + image_latent_inputs: List[str] = ["first_frame_latents"], + additional_batch_inputs: List[str] = [], ): """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" @@ -559,7 +566,7 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam("condition_latents", type_hint=Optional[torch.Tensor]), + InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), InputParam("num_frames", type_hint=int), ] @@ -567,7 +574,7 @@ def inputs(self) -> List[InputParam]: def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - batch_size, _, _, latent_height, latent_width = block_state.condition_latents.shape + batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 @@ -577,8 +584,43 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width) mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(block_state.condition_latents.device) - block_state.condition_latents = torch.concat([mask_lat_size, block_state.condition_latents], dim=1) + mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) + block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) self.set_block_state(state, block_state) return components, state + + +class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the last frame mask latents and add it to the latent condition" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("num_frames", type_hint=int), + ] + + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames-1))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) + block_state.first_last_frame_latents = torch.concat([mask_lat_size, block_state.first_last_frame_latents], dim=1) + + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index bfafa8c92b1c..d99cafb4e94a 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -54,12 +54,18 @@ def inputs(self) -> List[InputParam]: type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), ] @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = block_state.latents + block_state.latent_model_input = block_state.latents.to(block_state.dtype) return components, block_state @@ -84,18 +90,67 @@ def inputs(self) -> List[InputParam]: description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), InputParam( - "condition_latents", + "first_frame_latents", required=True, type_hint=torch.Tensor, - description="The condition latents to use for the denoising process. Can be generated in prepare_condition_latents step.", + description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", ), ] @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = torch.cat([block_state.latents, block_state.condition_latents], dim=1) + block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(block_state.dtype) + block_state.image_embeds = block_state.image_embeds.to(block_state.dtype) return components, block_state + +class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "first_last_frame_latents", + required=True, + type_hint=torch.Tensor, + description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_last_frame_latents], dim=1).to(block_state.dtype) + block_state.image_embeds = block_state.image_embeds.to(block_state.dtype) + return components, block_state + + class WanLoopDenoiserDynamic(ModularPipelineBlocks): model_name = "wan" @@ -155,9 +210,6 @@ def inputs(self) -> List[Tuple[str, Any]]: def __call__( self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: - - transformer_dtype = components.transformer.dtype - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) # The guider splits model inputs into separate batches for conditional/unconditional predictions. @@ -179,8 +231,8 @@ def __call__( # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.transformer( - hidden_states=block_state.latent_model_input.to(transformer_dtype), - timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype), + hidden_states=block_state.latent_model_input.to(block_state.dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), attention_kwargs=block_state.attention_kwargs, return_dict=False, **cond_kwargs, @@ -300,6 +352,7 @@ def description(self) -> str: "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopBeforeDenoiser`\n" " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" "This block supports text-to-video tasks." @@ -324,7 +377,34 @@ def description(self) -> str: "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanImage2VideoLoopBeforeDenoiser`\n" " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" "This block supports image-to-video tasks." ) + + +class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanFLF2VLoopBeforeDenoiser, + WanLoopDenoiserDynamic( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_hidden_states_image": "image_embeds", + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanFLF2VLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports FLF2V tasks." + ) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index 32917db3a400..832531428557 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -383,7 +383,7 @@ class WanImageEncoderStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Image Encoder step that generate image_embeds to guide the video generation" + return "Image Encoder step that generate image_embeds based on first frame image to guide the video generation" @property def expected_components(self) -> List[ComponentSpec]: @@ -423,12 +423,59 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state +class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step that generate image_embeds based on first and last frame images to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_processor", CLIPImageProcessor), + ComponentSpec("image_encoder", CLIPVisionModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), + ] + + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + first_frame_image = block_state.resized_image + last_frame_image = block_state.resized_last_image + + image_embeds = encode_image( + image_processor=components.image_processor, + image_encoder=components.image_encoder, + image=[first_frame_image, last_frame_image], + device=device, + ) + block_state.image_embeds = image_embeds + self.set_block_state(state, block_state) + return components, state + + class WanVaeImageEncoderStep(ModularPipelineBlocks): model_name = "wan" @property def description(self) -> str: - return "Vae Image Encoder step that generate condition_latents to guide the video generation" + return "Vae Image Encoder step that generate condition_latents based on first frame image to guide the video generation" @property def expected_components(self) -> List[ComponentSpec]: @@ -450,7 +497,7 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("condition_latents", type_hint=torch.Tensor, description="The condition latents"), + OutputParam("first_frame_latents", type_hint=torch.Tensor, description="video latent representation with the first frame image condition"), ] @staticmethod @@ -491,7 +538,93 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe [image_tensor, image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width)], dim=2 ).to(device=device, dtype=dtype) - block_state.condition_latents = encode_vae_image( + block_state.first_frame_latents = encode_vae_image( + video_tensor=video_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on first and last frame images to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec("video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("num_frames"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("first_last_frame_latents", type_hint=torch.Tensor, description="video latent representation with the first and last frame images condition"), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + first_frame_image = block_state.resized_image + last_frame_image = block_state.resized_last_image + + device = components._execution_device + dtype = torch.float32 + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames or components.default_num_frames + + first_image_tensor = components.video_processor.preprocess( + first_frame_image, height=height, width=width).to(device=device, dtype=dtype) + first_image_tensor = first_image_tensor.unsqueeze(2) + + last_image_tensor = components.video_processor.preprocess( + last_frame_image, height=height, width=width).to(device=device, dtype=dtype) + + last_image_tensor = last_image_tensor.unsqueeze(2) + + video_tensor = torch.cat( + [first_image_tensor, first_image_tensor.new_zeros(first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width), last_image_tensor], dim=2 + ).to(device=device, dtype=dtype) + + block_state.first_last_frame_latents = encode_vae_image( video_tensor=video_tensor, vae=components.vae, generator=block_state.generator, diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 5a05018ebacc..41902bdd468a 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -21,10 +21,11 @@ WanSetTimestepsStep, WanInputsDynamicStep, WanPrepareFirstFrameLatentsStep, + WanPrepareFirstLastFrameLatentsStep, ) from .decoders import WanImageVaeDecoderStep -from .denoise import WanDenoiseStep, WanImage2VideoDenoiseStep -from .encoders import WanTextEncoderStep, WanImageResizeStep, WanImageEncoderStep, WanVaeImageEncoderStep +from .denoise import WanDenoiseStep, WanImage2VideoDenoiseStep, WanFLF2VDenoiseStep +from .encoders import WanTextEncoderStep, WanImageResizeStep, WanImageCropResizeStep, WanImageEncoderStep, WanVaeImageEncoderStep, WanFirstLastFrameImageEncoderStep, WanFirstLastFrameVaeImageEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -81,7 +82,7 @@ def description(self): class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanTextInputStep, - WanInputsDynamicStep(image_latent_inputs=["condition_latents"]), + WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"]), WanSetTimestepsStep, WanPrepareLatentsStep, WanPrepareFirstFrameLatentsStep, @@ -103,41 +104,93 @@ def description(self): ) +# FLF2v + +## iamge encoder +class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] + block_names = ["image_resize", "last_image_resize", "image_encoder"] + + @property + def description(self): + return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" + + + +# vae encoder +class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep] + block_names = ["image_resize", "last_image_resize", "vae_image_encoder"] + + @property + def description(self): + return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" + + + +class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanTextInputStep, + WanInputsDynamicStep(image_latent_inputs=["first_last_frame_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanPrepareFirstLastFrameLatentsStep, + WanFLF2VDenoiseStep, + ] + block_names = ["input", "additional_inputs", "set_timesteps", "prepare_latents", "prepare_first_last_frame_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanInputsDynamicStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n" + + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" + ) # auto blocks class WanAutoImageEncoderStep(AutoPipelineBlocks): - block_classes = [WanImage2VideoImageEncoderStep] - block_names = ["image_encoder"] - block_trigger_inputs = ["image"] + block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] + block_names = ["flf2v_image_encoder", "image2video_image_encoder"] + block_trigger_inputs = ["last_image", "image"] @property def description(self): return ("Image Encoder step that encode the image to generate the image embeddings" + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." - + " - if `image` is not provided, step will be skipped.") + + " - if `last_image` or `image` is not provided, step will be skipped.") class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): - block_classes = [WanImage2VideoVaeImageEncoderStep] - block_names = ["vae_image_encoder"] - block_trigger_inputs = ["image"] + block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep] + block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"] + block_trigger_inputs = ["last_image", "image"] @property def description(self): return ("Vae Image Encoder step that encode the image to generate the image latents" + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided." + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided." - + " - if `image` is not provided, step will be skipped.") + + " - if `last_image` or `image` is not provided, step will be skipped.") class WanAutoDenoiseStep(AutoPipelineBlocks): block_classes = [ + WanFLF2VCoreDenoiseStep, WanImage2VideoCoreDenoiseStep, WanCoreDenoiseStep, ] - block_names = ["image2video", "text2video"] - block_trigger_inputs = ["condition_latents", None] + block_names = ["flf2v", "image2video", "text2video"] + block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None] @property def description(self) -> str: @@ -151,7 +204,7 @@ def description(self) -> str: ) -# text2vid +# auto blocks class WanAutoBlocks(SequentialPipelineBlocks): block_classes = [ WanTextEncoderStep, @@ -176,6 +229,9 @@ def description(self): ) +# presets + +# text2video TEXT2VIDEO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), @@ -193,10 +249,29 @@ def description(self): ("image_encoder", WanImage2VideoImageEncoderStep), ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), ("input", WanTextInputStep), - ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["condition_latents"])), + ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"])), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep), + ("denoise", WanImage2VideoDenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] + +) + + +FLF2V_BLOCKS = InsertableDict( + [ + ("image_resize", WanImageResizeStep), + ("last_image_resize", WanImageCropResizeStep), + ("image_encoder", WanFLF2VImageEncoderStep), + ("vae_image_encoder", WanFLF2VVaeImageEncoderStep), + ("input", WanTextInputStep), + ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_last_frame_latents"])), ("set_timesteps", WanSetTimestepsStep), ("prepare_latents", WanPrepareLatentsStep), - ("denoise", WanImage2VideoCoreDenoiseStep), + ("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep), + ("denoise", WanFLF2VDenoiseStep), ("decode", WanImageVaeDecoderStep), ] @@ -216,5 +291,6 @@ def description(self): ALL_BLOCKS = { "text2video": TEXT2VIDEO_BLOCKS, "image2video": IMAGE2VIDEO_BLOCKS, + "flf2v": FLF2V_BLOCKS, "auto": AUTO_BLOCKS, } From 63f5521b52d52661e29d0fd17dc4125e65ba85e9 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 9 Nov 2025 21:28:39 +0100 Subject: [PATCH 06/12] add support for wan2.2 14B --- src/diffusers/__init__.py | 2 + src/diffusers/modular_pipelines/__init__.py | 4 +- .../modular_pipelines/modular_pipeline.py | 198 ++++++++-------- .../modular_pipelines/wan/__init__.py | 3 +- .../modular_pipelines/wan/before_denoise.py | 2 +- .../modular_pipelines/wan/denoise.py | 218 +++++++++++++++++- .../modular_pipelines/wan/modular_blocks.py | 179 ++++++++++++-- .../modular_pipelines/wan/modular_pipeline.py | 17 +- 8 files changed, 496 insertions(+), 127 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 572aad4bd3f1..33127e67468d 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -407,6 +407,7 @@ "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", "WanAutoBlocks", + "Wan22AutoBlocks", "WanModularPipeline", ] ) @@ -1088,6 +1089,7 @@ StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, WanAutoBlocks, + Wan22AutoBlocks, WanModularPipeline, ) from .pipelines import ( diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 86ed735134ff..1d695ad8f800 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -45,7 +45,7 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] - _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] + _import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"] _import_structure["flux"] = [ "FluxAutoBlocks", "FluxModularPipeline", @@ -90,7 +90,7 @@ QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline - from .wan import WanAutoBlocks, WanModularPipeline + from .wan import WanAutoBlocks, Wan22AutoBlocks, WanModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 307698245e5b..d61e3ff4e960 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1441,6 +1441,8 @@ def __init__( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, + modular_config_dict: Optional[Dict[str, Any]] = None, + config_dict: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -1492,23 +1494,8 @@ def __init__( - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as `_blocks_class_name` in the config dict """ - if blocks is None: - blocks_class_name = self.default_blocks_name - if blocks_class_name is not None: - diffusers_module = importlib.import_module("diffusers") - blocks_class = getattr(diffusers_module, blocks_class_name) - blocks = blocks_class() - else: - logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") - self.blocks = blocks - self._components_manager = components_manager - self._collection = collection - self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} - self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} - - # update component_specs and config_specs from modular_repo - if pretrained_model_name_or_path is not None: + if modular_config_dict is None and config_dict is None and pretrained_model_name_or_path is not None: cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1524,52 +1511,60 @@ def __init__( "local_files_only": local_files_only, "revision": revision, } - # try to load modular_model_index.json - try: - config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs) - except EnvironmentError as e: - logger.debug(f"modular_model_index.json not found: {e}") - config_dict = None - - # update component_specs and config_specs based on modular_model_index.json - if config_dict is not None: - for name, value in config_dict.items(): - # all the components in modular_model_index.json are from_pretrained components - if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - component_spec = self._dict_to_component_spec(name, component_spec_dict) - component_spec.default_creation_method = "from_pretrained" - self._component_specs[name] = component_spec - - elif name in self._config_specs: - self._config_specs[name].default = value - - # if modular_model_index.json is not found, try to load model_index.json + + modular_config_dict, config_dict = self._load_pipeline_config(pretrained_model_name_or_path, **load_config_kwargs) + + if blocks is None: + if modular_config_dict is not None: + blocks_class_name = modular_config_dict.get("_blocks_class_name") + elif config_dict is not None: + blocks_class_name = self.get_default_blocks_name(config_dict) + else: + blocks_class_name = None + if blocks_class_name is not None: + diffusers_module = importlib.import_module("diffusers") + blocks_class = getattr(diffusers_module, blocks_class_name) + blocks = blocks_class() else: - logger.debug(" loading config from model_index.json") - try: - from diffusers import DiffusionPipeline - - config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs) - except EnvironmentError as e: - logger.debug(f" model_index.json not found in the repo: {e}") - config_dict = None - - # update component_specs and config_specs based on model_index.json - if config_dict is not None: - for name, value in config_dict.items(): - if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2: - library, class_name = value - component_spec_dict = { - "repo": pretrained_model_name_or_path, - "subfolder": name, - "type_hint": (library, class_name), - } - component_spec = self._dict_to_component_spec(name, component_spec_dict) - component_spec.default_creation_method = "from_pretrained" - self._component_specs[name] = component_spec - elif name in self._config_specs: - self._config_specs[name].default = value + logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") + + self.blocks = blocks + self._components_manager = components_manager + self._collection = collection + self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} + self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} + + # update component_specs and config_specs from modular_repo + + + # update component_specs and config_specs based on modular_model_index.json + if modular_config_dict is not None: + for name, value in modular_config_dict.items(): + # all the components in modular_model_index.json are from_pretrained components + if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + component_spec.default_creation_method = "from_pretrained" + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + # if modular_model_index.json is not found, try to load model_index.json + elif config_dict is not None: + for name, value in config_dict.items(): + if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2: + library, class_name = value + component_spec_dict = { + "repo": pretrained_model_name_or_path, + "subfolder": name, + "type_hint": (library, class_name), + } + component_spec = self._dict_to_component_spec(name, component_spec_dict) + component_spec.default_creation_method = "from_pretrained" + self._component_specs[name] = component_spec + elif name in self._config_specs: + self._config_specs[name].default = value if len(kwargs) > 0: logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.") @@ -1601,6 +1596,37 @@ def default_call_parameters(self) -> Dict[str, Any]: params[input_param.name] = input_param.default return params + def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: + return self.default_blocks_name + + @classmethod + def _load_pipeline_config( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + **load_config_kwargs, + ): + + try: + # try to load modular_model_index.json + modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) + return modular_config_dict, None + + except EnvironmentError as e: + logger.debug(f" modular_model_index.json not found in the repo: {e}") + + try: + logger.debug(" try to load model_index.json") + from diffusers import DiffusionPipeline + from diffusers.pipelines.auto_pipeline import _get_model + + config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs) + return None, config_dict + + except EnvironmentError as e: + logger.debug(f" model_index.json not found in the repo: {e}") + + return None, None + @classmethod @validate_hf_hub_args def from_pretrained( @@ -1655,42 +1681,30 @@ def from_pretrained( "revision": revision, } - try: - # try to load modular_model_index.json - config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) - except EnvironmentError as e: - logger.debug(f" modular_model_index.json not found in the repo: {e}") - config_dict = None - - if config_dict is not None: - pipeline_class = _get_pipeline_class(cls, config=config_dict) + modular_config_dict, config_dict = cls._load_pipeline_config(pretrained_model_name_or_path, **load_config_kwargs) + + if modular_config_dict is not None: + pipeline_class = _get_pipeline_class(cls, config=modular_config_dict) + elif config_dict is not None: + from diffusers.pipelines.auto_pipeline import _get_model + logger.debug(" try to determine the modular pipeline class from model_index.json") + standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) + model_name = _get_model(standard_pipeline_class.__name__) + pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__) + diffusers_module = importlib.import_module("diffusers") + pipeline_class = getattr(diffusers_module, pipeline_class_name) else: - try: - logger.debug(" try to load model_index.json") - from diffusers import DiffusionPipeline - from diffusers.pipelines.auto_pipeline import _get_model - - config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs) - except EnvironmentError as e: - logger.debug(f" model_index.json not found in the repo: {e}") - - if config_dict is not None: - logger.debug(" try to determine the modular pipeline class from model_index.json") - standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) - model_name = _get_model(standard_pipeline_class.__name__) - pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__) - diffusers_module = importlib.import_module("diffusers") - pipeline_class = getattr(diffusers_module, pipeline_class_name) - else: - # there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components - pipeline_class = cls - pretrained_model_name_or_path = None + # there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components + pipeline_class = cls + pretrained_model_name_or_path = None pipeline = pipeline_class( blocks=blocks, pretrained_model_name_or_path=pretrained_model_name_or_path, components_manager=components_manager, collection=collection, + modular_config_dict=modular_config_dict, + config_dict=config_dict, **kwargs, ) return pipeline @@ -2134,7 +2148,9 @@ def load_components(self, names: Optional[Union[List[str], str]] = None, **kwarg logger.warning( f"\nFailed to create component {name}:\n" f"- Component spec: {spec}\n" - f"- load() called with kwargs: {component_load_kwargs}\n\n" + f"- load() called with kwargs: {component_load_kwargs}\n" + "If this component is not required for your workflow you can safely ignore this message.\n\n" + "Traceback:\n" f"{traceback.format_exc()}" ) diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py index 7b548e003c63..8926d5a27c41 100644 --- a/src/diffusers/modular_pipelines/wan/__init__.py +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -28,7 +28,7 @@ "TEXT2VIDEO_BLOCKS", "WanAutoBeforeDenoiseStep", "WanAutoBlocks", - "WanAutoBlocks", + "Wan22AutoBlocks", "WanAutoDecodeStep", "WanAutoDenoiseStep", ] @@ -48,6 +48,7 @@ TEXT2VIDEO_BLOCKS, WanAutoBeforeDenoiseStep, WanAutoBlocks, + Wan22AutoBlocks, WanAutoDecodeStep, WanAutoDenoiseStep, ) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index 222bdd82590f..d5ddaea5447e 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -249,7 +249,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe self.check_inputs(components, block_state) block_state.batch_size = block_state.prompt_embeds.shape[0] - block_state.dtype = components.transformer.dtype + block_state.dtype = block_state.prompt_embeds.dtype _, seq_len, _ = block_state.prompt_embeds.shape block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index d99cafb4e94a..67fc69b5549b 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -106,7 +106,6 @@ def inputs(self) -> List[InputParam]: @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(block_state.dtype) - block_state.image_embeds = block_state.image_embeds.to(block_state.dtype) return components, block_state @@ -147,16 +146,31 @@ def inputs(self) -> List[InputParam]: @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_last_frame_latents], dim=1).to(block_state.dtype) - block_state.image_embeds = block_state.image_embeds.to(block_state.dtype) return components, block_state -class WanLoopDenoiserDynamic(ModularPipelineBlocks): +class WanLoopDenoiser(ModularPipelineBlocks): model_name = "wan" - # guider_input_fields maps the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.encoder_hidden_states) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - def __init__(self, guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}): + + def __init__( + self, + guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")} + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. + The value can be either: + + - A tuple of strings. For instance, + {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")} tells + the guider to read `block_state.prompt_embeds` and `block_state.negative_prompt_embeds` and + pass them as the conditional and unconditional batches of 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider + forward `block_state.image_embeds` for both conditional and unconditional batches. + """ if not isinstance(guider_input_fields, dict): raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") self._guider_input_fields = guider_input_fields @@ -227,6 +241,8 @@ def __call__( components.guider.prepare_models(components.transformer) cond_kwargs = guider_state_batch.as_dict() cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys()} + for cond in cond_kwargs.values(): + cond = cond.to(block_state.dtype) # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches @@ -245,6 +261,141 @@ def __call__( return components, block_state +class Wan22LoopDenoiser(ModularPipelineBlocks): + model_name = "wan" + + + def __init__(self, guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}): + """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on `block_state`. + The value can be either: + + - A tuple of strings. For instance, + `{"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}` tells + the guider to read `block_state.prompt_embeds` and `block_state.negative_prompt_embeds` and + pass them as the conditional and unconditional batches of `encoder_hidden_states`. + - A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider + forward `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec( + "guider_2", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 3.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ComponentSpec("transformer_2", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec( + name="boundary_ratio", + default=0.875, + description="The boundary ratio to divide the denoising loop into high noise and low noise stages.", + ), + ] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + + inputs = [ + InputParam("attention_kwargs"), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + + boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps + if t >= boundary_timestep: + block_state.current_model = components.transformer + block_state.guider = components.guider + else: + block_state.current_model = components.transformer_2 + block_state.guider = components.guider_2 + + block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + block_state.guider.prepare_models(block_state.current_model) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys()} + for cond in cond_kwargs.values(): + cond = cond.to(block_state.dtype) + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = block_state.current_model( + hidden_states=block_state.latent_model_input.to(block_state.dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + block_state.guider.cleanup_models(block_state.current_model) + + # Perform guidance + block_state.noise_pred = block_state.guider(guider_state)[0] + + return components, block_state + + class WanLoopAfterDenoiser(ModularPipelineBlocks): model_name = "wan" @@ -337,7 +488,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe class WanDenoiseStep(WanDenoiseLoopWrapper): block_classes = [ WanLoopBeforeDenoiser, - WanLoopDenoiserDynamic( + WanLoopDenoiser( guider_input_fields={ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), } @@ -358,10 +509,34 @@ def description(self) -> str: "This block supports text-to-video tasks." ) +class Wan22DenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanLoopBeforeDenoiser, + Wan22LoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopBeforeDenoiser`\n" + " - `Wan22LoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports text-to-video tasks for Wan2.2." + ) + class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper): block_classes = [ WanImage2VideoLoopBeforeDenoiser, - WanLoopDenoiserDynamic( + WanLoopDenoiser( guider_input_fields={ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), "encoder_hidden_states_image": "image_embeds", @@ -384,10 +559,35 @@ def description(self) -> str: ) +class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanImage2VideoLoopBeforeDenoiser, + Wan22LoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanImage2VideoLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports image-to-video tasks for Wan2.2." + ) + + class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper): block_classes = [ WanFLF2VLoopBeforeDenoiser, - WanLoopDenoiserDynamic( + WanLoopDenoiser( guider_input_fields={ "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), "encoder_hidden_states_image": "image_embeds", diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 41902bdd468a..7a597168e2b5 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -24,14 +24,14 @@ WanPrepareFirstLastFrameLatentsStep, ) from .decoders import WanImageVaeDecoderStep -from .denoise import WanDenoiseStep, WanImage2VideoDenoiseStep, WanFLF2VDenoiseStep +from .denoise import WanDenoiseStep, WanImage2VideoDenoiseStep, WanFLF2VDenoiseStep, Wan22DenoiseStep, Wan22Image2VideoDenoiseStep from .encoders import WanTextEncoderStep, WanImageResizeStep, WanImageCropResizeStep, WanImageEncoderStep, WanVaeImageEncoderStep, WanFirstLastFrameImageEncoderStep, WanFirstLastFrameVaeImageEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -# text2vid +# wan2.1 +# wan2.1: text2vid class WanCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanTextInputStep, @@ -53,9 +53,8 @@ def description(self): ) -# image2video - -## iamge encoder +# wan2.1: image2video +## image encoder class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" block_classes = [WanImageResizeStep, WanImageEncoderStep] @@ -67,7 +66,7 @@ def description(self): -# vae encoder +## vae encoder class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" block_classes = [WanImageResizeStep, WanVaeImageEncoderStep] @@ -78,7 +77,7 @@ def description(self): return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" - +## denoise class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanTextInputStep, @@ -104,9 +103,9 @@ def description(self): ) -# FLF2v +# wan2.1: FLF2v -## iamge encoder +## image encoder class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] @@ -118,7 +117,7 @@ def description(self): -# vae encoder +## vae encoder class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep] @@ -129,7 +128,7 @@ def description(self): return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" - +## denoise class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanTextInputStep, @@ -154,8 +153,8 @@ def description(self): + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" ) -# auto blocks - +# wan2.1: auto blocks +## image encoder class WanAutoImageEncoderStep(AutoPipelineBlocks): block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] block_names = ["flf2v_image_encoder", "image2video_image_encoder"] @@ -169,6 +168,7 @@ def description(self): + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." + " - if `last_image` or `image` is not provided, step will be skipped.") +## vae encoder class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep] block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"] @@ -182,7 +182,7 @@ def description(self): + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided." + " - if `last_image` or `image` is not provided, step will be skipped.") - +## denoise class WanAutoDenoiseStep(AutoPipelineBlocks): block_classes = [ WanFLF2VCoreDenoiseStep, @@ -204,7 +204,7 @@ def description(self) -> str: ) -# auto blocks +# auto pipeline blocks class WanAutoBlocks(SequentialPipelineBlocks): block_classes = [ WanTextEncoderStep, @@ -229,9 +229,100 @@ def description(self): ) -# presets -# text2video +# wan22 +# wan2.2: text2vid + +## denoise +class Wan22CoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22DenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n" + ) + + +# wan2.2: image2video +## denoise +class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanTextInputStep, + WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanPrepareFirstFrameLatentsStep, + Wan22Image2VideoDenoiseStep, + ] + block_names = ["input", "additional_inputs", "set_timesteps", "prepare_latents", "prepare_first_frame_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanInputsDynamicStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanPrepareConditionLatentsStep` is used to prepare the latent conditions\n" + + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n" + ) + +class Wan22AutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + Wan22Image2VideoCoreDenoiseStep, Wan22CoreDenoiseStep, + ] + block_names = ["image2video", "text2video"] + block_trigger_inputs = ["first_frame_latents", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2video and image2video tasks." + " - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks." + " - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks." + + " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n" + + " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n" + ) + +class Wan22AutoBlocks(SequentialPipelineBlocks): + block_classes = [ + WanTextEncoderStep, + WanAutoVaeImageEncoderStep, + Wan22AutoDenoiseStep, + WanImageVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "vae_image_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-video using Wan2.2.\n" + + "- for text-to-video generation, all you need to provide is `prompt`" + ) + +# presets for wan2.1 and wan2.2 +# YiYi Notes: should we move these to doc? +# wan2.1 TEXT2VIDEO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), @@ -287,10 +378,54 @@ def description(self): ] ) +# wan2.2 presets + +TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("input", WanTextInputStep), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("denoise", Wan22DenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] +) + +IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict( + [ + ("image_resize", WanImageResizeStep), + ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), + ("input", WanTextInputStep), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("denoise", Wan22DenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] +) + +AUTO_BLOCKS_WAN22 = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("vae_image_encoder", WanAutoVaeImageEncoderStep), + ("denoise", Wan22AutoDenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] +) + +# presets all blocks (wan and wan22) + + ALL_BLOCKS = { - "text2video": TEXT2VIDEO_BLOCKS, - "image2video": IMAGE2VIDEO_BLOCKS, - "flf2v": FLF2V_BLOCKS, - "auto": AUTO_BLOCKS, + "wan2.1":{ + "text2video": TEXT2VIDEO_BLOCKS, + "image2video": IMAGE2VIDEO_BLOCKS, + "flf2v": FLF2V_BLOCKS, + "auto": AUTO_BLOCKS, + }, + "wan2.2":{ + "text2video": TEXT2VIDEO_BLOCKS_WAN22, + "image2video": IMAGE2VIDEO_BLOCKS_WAN22, + "auto": AUTO_BLOCKS_WAN22, + } } diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py index dae03608b9c3..3dfbed185730 100644 --- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -17,6 +17,7 @@ from ...pipelines.pipeline_utils import StableDiffusionMixin from ...utils import logging from ..modular_pipeline import ModularPipeline +from typing import Optional, Dict, Any logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -35,6 +36,13 @@ class WanModularPipeline( default_blocks_name = "WanAutoBlocks" + # override the default_blocks_name in base class, which is just return self.default_blocks_name + def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: + if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: + return "Wan22AutoBlocks" + else: + return "WanAutoBlocks" + @property def default_height(self): return self.default_sample_height * self.vae_scale_factor_spatial @@ -103,4 +111,11 @@ def requires_unconditional_embeds(self): if hasattr(self, "guider") and self.guider is not None: requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 - return requires_unconditional_embeds \ No newline at end of file + return requires_unconditional_embeds + + @property + def num_train_timesteps(self): + num_train_timesteps = 1000 + if hasattr(self, "scheduler") and self.scheduler is not None: + num_train_timesteps = self.scheduler.config.num_train_timesteps + return num_train_timesteps \ No newline at end of file From e4abfdbeded6ffa853a48defa95274c90edf9e7a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 9 Nov 2025 22:02:35 +0100 Subject: [PATCH 07/12] style --- src/diffusers/__init__.py | 4 +- .../guiders/adaptive_projected_guidance.py | 4 +- .../adaptive_projected_guidance_mix.py | 2 - src/diffusers/guiders/auto_guidance.py | 5 +- .../guiders/classifier_free_guidance.py | 4 +- .../classifier_free_zero_star_guidance.py | 4 +- .../guiders/frequency_decoupled_guidance.py | 6 +- src/diffusers/guiders/guider_utils.py | 6 +- .../guiders/perturbed_attention_guidance.py | 6 +- src/diffusers/guiders/skip_layer_guidance.py | 4 +- .../guiders/smoothed_energy_guidance.py | 4 +- .../tangential_classifier_free_guidance.py | 4 +- src/diffusers/modular_pipelines/__init__.py | 2 +- .../modular_pipelines/modular_pipeline.py | 22 +-- .../modular_pipelines/wan/__init__.py | 4 +- .../modular_pipelines/wan/before_denoise.py | 58 ++++--- .../modular_pipelines/wan/decoders.py | 4 +- .../modular_pipelines/wan/denoise.py | 70 ++++----- .../modular_pipelines/wan/encoders.py | 146 +++++++++++------- .../modular_pipelines/wan/modular_blocks.py | 99 ++++++++---- .../modular_pipelines/wan/modular_pipeline.py | 9 +- src/diffusers/pipelines/auto_pipeline.py | 2 +- .../dummy_torch_and_transformers_objects.py | 15 ++ 23 files changed, 298 insertions(+), 186 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 33127e67468d..d60a1a870ff0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -406,8 +406,8 @@ "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", - "WanAutoBlocks", "Wan22AutoBlocks", + "WanAutoBlocks", "WanModularPipeline", ] ) @@ -1088,8 +1088,8 @@ QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, - WanAutoBlocks, Wan22AutoBlocks, + WanAutoBlocks, WanModularPipeline, ) from .pipelines import ( diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 48f1fd448351..8ec30d02d758 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -88,7 +88,9 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) diff --git a/src/diffusers/guiders/adaptive_projected_guidance_mix.py b/src/diffusers/guiders/adaptive_projected_guidance_mix.py index 95511500a8bf..bdc97bcf6269 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance_mix.py +++ b/src/diffusers/guiders/adaptive_projected_guidance_mix.py @@ -99,11 +99,9 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches - def prepare_inputs_from_block_state( self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] ) -> List["BlockState"]: - if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 97156e3d220f..b7f62e2f4a6e 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -141,8 +141,9 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: - + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index e54ba0dc4ac6..5e55d4d869c1 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -99,7 +99,9 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 4d7ff12e304f..23b492e51b02 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -85,7 +85,9 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index 6668b1adf2cb..4ec6e2d36da9 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -225,8 +225,10 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batch = self._prepare_batch(data, tuple_idx, input_prediction) data_batches.append(data_batch) return data_batches - - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index b718956412eb..52cb0ce34980 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -166,7 +166,9 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.") def __call__(self, data: List["BlockState"]) -> Any: @@ -237,7 +239,6 @@ def _prepare_batch( data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) - @classmethod def _prepare_batch_from_block_state( cls, @@ -268,7 +269,6 @@ def _prepare_batch_from_block_state( """ from ..modular_pipelines.modular_pipeline import BlockState - data_batch = {} for key, value in input_fields.items(): try: diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index 61e29aa350f1..f233e90ca410 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -186,8 +186,10 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batch = self._prepare_batch(data, tuple_idx, input_prediction) data_batches.append(data_batch) return data_batches - - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index 493630ef2011..e6109300d99c 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -183,7 +183,9 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index ab69669d62c8..6c3906e820e0 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -172,7 +172,9 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index bd36cc24ddc2..76899c6e8494 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -74,7 +74,9 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> data_batches.append(data_batch) return data_batches - def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]: + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 1d695ad8f800..252b9f33dfe8 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -90,7 +90,7 @@ QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline - from .wan import WanAutoBlocks, Wan22AutoBlocks, WanModularPipeline + from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index d61e3ff4e960..c664e4a33366 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1512,8 +1512,10 @@ def __init__( "revision": revision, } - modular_config_dict, config_dict = self._load_pipeline_config(pretrained_model_name_or_path, **load_config_kwargs) - + modular_config_dict, config_dict = self._load_pipeline_config( + pretrained_model_name_or_path, **load_config_kwargs + ) + if blocks is None: if modular_config_dict is not None: blocks_class_name = modular_config_dict.get("_blocks_class_name") @@ -1536,7 +1538,6 @@ def __init__( # update component_specs and config_specs from modular_repo - # update component_specs and config_specs based on modular_model_index.json if modular_config_dict is not None: for name, value in modular_config_dict.items(): @@ -1605,28 +1606,26 @@ def _load_pipeline_config( pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **load_config_kwargs, ): - try: # try to load modular_model_index.json modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) return modular_config_dict, None - + except EnvironmentError as e: logger.debug(f" modular_model_index.json not found in the repo: {e}") try: logger.debug(" try to load model_index.json") from diffusers import DiffusionPipeline - from diffusers.pipelines.auto_pipeline import _get_model config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs) return None, config_dict - + except EnvironmentError as e: logger.debug(f" model_index.json not found in the repo: {e}") - + return None, None - + @classmethod @validate_hf_hub_args def from_pretrained( @@ -1681,12 +1680,15 @@ def from_pretrained( "revision": revision, } - modular_config_dict, config_dict = cls._load_pipeline_config(pretrained_model_name_or_path, **load_config_kwargs) + modular_config_dict, config_dict = cls._load_pipeline_config( + pretrained_model_name_or_path, **load_config_kwargs + ) if modular_config_dict is not None: pipeline_class = _get_pipeline_class(cls, config=modular_config_dict) elif config_dict is not None: from diffusers.pipelines.auto_pipeline import _get_model + logger.debug(" try to determine the modular pipeline class from model_index.json") standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) model_name = _get_model(standard_pipeline_class.__name__) diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py index 8926d5a27c41..d6d83866d184 100644 --- a/src/diffusers/modular_pipelines/wan/__init__.py +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -26,9 +26,9 @@ "ALL_BLOCKS", "AUTO_BLOCKS", "TEXT2VIDEO_BLOCKS", + "Wan22AutoBlocks", "WanAutoBeforeDenoiseStep", "WanAutoBlocks", - "Wan22AutoBlocks", "WanAutoDecodeStep", "WanAutoDenoiseStep", ] @@ -46,9 +46,9 @@ ALL_BLOCKS, AUTO_BLOCKS, TEXT2VIDEO_BLOCKS, + Wan22AutoBlocks, WanAutoBeforeDenoiseStep, WanAutoBlocks, - Wan22AutoBlocks, WanAutoDecodeStep, WanAutoDenoiseStep, ) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index d5ddaea5447e..71a01b8b9943 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -13,17 +13,17 @@ # limitations under the License. import inspect -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import torch +from ...models import WanTransformer3DModel from ...schedulers import UniPCMultistepScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline -from ...models import WanTransformer3DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -90,11 +90,14 @@ def repeat_tensor_to_batch_size( return input_tensor -def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int) -> Tuple[int, int]: + +def calculate_dimension_from_latents( + latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int +) -> Tuple[int, int]: """Calculate image dimensions from latent tensor dimensions. - This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by multiplying the latent num_frames/height/width - by the VAE scale factor. + This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by + multiplying the latent num_frames/height/width by the VAE scale factor. Args: latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. @@ -122,6 +125,7 @@ def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_tem return num_frames, height, width + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -195,7 +199,7 @@ def description(self) -> str: "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" "have a final batch_size of batch_size * num_videos_per_prompt." ) - + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -289,8 +293,9 @@ def __init__( Args: image_latent_inputs (List[str], optional): Names of image latent tensors to process. - In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be a single string or - list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"] + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], + ["control_image_latents"] additional_batch_inputs (List[str], optional): Names of additional conditional input tensors to expand batch size. These tensors will only have their batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. @@ -358,7 +363,6 @@ def inputs(self) -> List[InputParam]: return inputs - def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -369,12 +373,13 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe continue # 1. Calculate num_frames, height/width from latents - num_frames,height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial) + num_frames, height, width = calculate_dimension_from_latents( + image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial + ) block_state.num_frames = block_state.num_frames or num_frames block_state.height = block_state.height or height block_state.width = block_state.width or width - # 3. Expand batch size image_latent_tensor = repeat_tensor_to_batch_size( input_name=image_latent_input_name, @@ -426,7 +431,6 @@ def inputs(self) -> List[InputParam]: InputParam("sigmas"), ] - @torch.no_grad() def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -569,8 +573,7 @@ def inputs(self) -> List[InputParam]: InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), InputParam("num_frames", type_hint=int), ] - - + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -580,9 +583,13 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal) + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) mask_lat_size = mask_lat_size.transpose(1, 2) mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) @@ -604,23 +611,28 @@ def inputs(self) -> List[InputParam]: InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]), InputParam("num_frames", type_hint=int), ] - - + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) - mask_lat_size[:, :, list(range(1, block_state.num_frames-1))] = 0 + mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal) + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) mask_lat_size = mask_lat_size.transpose(1, 2) mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) - block_state.first_last_frame_latents = torch.concat([mask_lat_size, block_state.first_last_frame_latents], dim=1) + block_state.first_last_frame_latents = torch.concat( + [mask_lat_size, block_state.first_last_frame_latents], dim=1 + ) self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py index 254595d131ee..7cec318c1706 100644 --- a/src/diffusers/modular_pipelines/wan/decoders.py +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -87,9 +87,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState: latents = latents.to(vae_dtype) block_state.videos = components.vae.decode(latents, return_dict=False)[0] - block_state.videos = components.video_processor.postprocess_video( - block_state.videos, output_type="np" - ) + block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np") self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 67fc69b5549b..ab73b8ffab5e 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple, Dict +from typing import Any, Dict, List, Tuple import torch @@ -27,7 +27,7 @@ ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam from .modular_pipeline import WanModularPipeline @@ -61,7 +61,6 @@ def inputs(self) -> List[InputParam]: description="The dtype of the model inputs. Can be generated in input step.", ), ] - @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): @@ -105,7 +104,9 @@ def inputs(self) -> List[InputParam]: @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(block_state.dtype) + block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to( + block_state.dtype + ) return components, block_state @@ -145,31 +146,31 @@ def inputs(self) -> List[InputParam]: @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_last_frame_latents], dim=1).to(block_state.dtype) + block_state.latent_model_input = torch.cat( + [block_state.latents, block_state.first_last_frame_latents], dim=1 + ).to(block_state.dtype) return components, block_state class WanLoopDenoiser(ModularPipelineBlocks): model_name = "wan" - def __init__( - self, - guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")} - ): + self, + guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}, + ): """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1. Args: guider_input_fields: A dictionary that maps each argument expected by the denoiser model - (for example, "encoder_hidden_states") to data stored on 'block_state'. - The value can be either: - - - A tuple of strings. For instance, - {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")} tells - the guider to read `block_state.prompt_embeds` and `block_state.negative_prompt_embeds` and - pass them as the conditional and unconditional batches of 'encoder_hidden_states'. - - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider - forward `block_state.image_embeds` for both conditional and unconditional batches. + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. """ if not isinstance(guider_input_fields, dict): raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") @@ -198,7 +199,6 @@ def description(self) -> str: @property def inputs(self) -> List[Tuple[str, Any]]: - inputs = [ InputParam("attention_kwargs"), InputParam( @@ -219,7 +219,6 @@ def inputs(self) -> List[Tuple[str, Any]]: inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) return inputs - @torch.no_grad() def __call__( self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor @@ -264,21 +263,22 @@ def __call__( class Wan22LoopDenoiser(ModularPipelineBlocks): model_name = "wan" - - def __init__(self, guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}): + def __init__( + self, + guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}, + ): """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2. Args: guider_input_fields: A dictionary that maps each argument expected by the denoiser model - (for example, "encoder_hidden_states") to data stored on `block_state`. - The value can be either: - - - A tuple of strings. For instance, - `{"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}` tells - the guider to read `block_state.prompt_embeds` and `block_state.negative_prompt_embeds` and - pass them as the conditional and unconditional batches of `encoder_hidden_states`. - - A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider - forward `block_state.image_embeds` for both conditional and unconditional batches. + (for example, "encoder_hidden_states") to data stored on `block_state`. The value can be either: + + - A tuple of strings. For instance, `{"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")}` tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + `encoder_hidden_states`. + - A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. """ if not isinstance(guider_input_fields, dict): raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") @@ -324,7 +324,6 @@ def expected_configs(self) -> List[ConfigSpec]: @property def inputs(self) -> List[Tuple[str, Any]]: - inputs = [ InputParam("attention_kwargs"), InputParam( @@ -345,12 +344,10 @@ def inputs(self) -> List[Tuple[str, Any]]: inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) return inputs - @torch.no_grad() def __call__( self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: - boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps if t >= boundary_timestep: block_state.current_model = components.transformer @@ -413,7 +410,6 @@ def description(self) -> str: "object (e.g. `WanDenoiseLoopWrapper`)" ) - @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): # Perform scheduler step using the predicted output @@ -509,6 +505,7 @@ def description(self) -> str: "This block supports text-to-video tasks." ) + class Wan22DenoiseStep(WanDenoiseLoopWrapper): block_classes = [ WanLoopBeforeDenoiser, @@ -533,6 +530,7 @@ def description(self) -> str: "This block supports text-to-video tasks for Wan2.2." ) + class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper): block_classes = [ WanImage2VideoLoopBeforeDenoiser, @@ -607,4 +605,4 @@ def description(self) -> str: " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" "This block supports FLF2V tasks." - ) \ No newline at end of file + ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index 832531428557..dc49df8eab8c 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -15,21 +15,21 @@ import html from typing import List, Optional, Union +import numpy as np +import PIL import regex as re import torch -from transformers import AutoTokenizer, UMT5EncoderModel, CLIPImageProcessor, CLIPVisionModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan from ...utils import is_ftfy_available, is_torchvision_available, logging +from ...video_processor import VideoProcessor from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline -from ...image_processor import PipelineImageInput -from ...video_processor import VideoProcessor -from ...models import AutoencoderKLWan -import PIL -import numpy as np if is_ftfy_available(): @@ -128,13 +128,16 @@ def encode_vae_image( raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.") if isinstance(generator, list) and len(generator) != video_tensor.shape[0]: - raise ValueError(f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}.") + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}." + ) video_tensor = video_tensor.to(device=device, dtype=dtype) if isinstance(generator, list): video_latents = [ - retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax") for i in range(video_tensor.shape[0]) + retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(video_tensor.shape[0]) ] video_latents = torch.cat(video_latents, dim=0) else: @@ -145,17 +148,14 @@ def encode_vae_image( .view(1, latent_channels, 1, 1, 1) .to(video_latents.device, video_latents.dtype) ) - latents_std = ( - 1.0 / torch.tensor(vae.config.latents_std) - .view(1, latent_channels, 1, 1, 1) - .to(video_latents.device, video_latents.dtype) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, latent_channels, 1, 1, 1).to( + video_latents.device, video_latents.dtype ) video_latents = (video_latents - latents_mean) * latents_std return video_latents - class WanTextEncoderStep(ModularPipelineBlocks): model_name = "wan" @@ -235,13 +235,14 @@ def encode_prompt( The maximum number of text tokens to be used for the generation process. """ device = device or components._execution_device - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) prompt_embeds = get_t5_prompt_embeds( - text_encoder=components.text_encoder, - tokenizer=components.tokenizer, - prompt=prompt, + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, max_sequence_length=max_sequence_length, device=device, ) @@ -263,10 +264,10 @@ def encode_prompt( ) negative_prompt_embeds = get_t5_prompt_embeds( - text_encoder=components.text_encoder, - tokenizer=components.tokenizer, - prompt=negative_prompt, - max_sequence_length=max_sequence_length, + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, device=device, ) @@ -320,7 +321,6 @@ def intermediate_outputs(self) -> List[OutputParam]: ] def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) max_area = block_state.height * block_state.width @@ -338,7 +338,6 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe class WanImageCropResizeStep(ModularPipelineBlocks): model_name = "wan" - @property def description(self) -> str: return "Image Resize step that resize the last_image to the same size of first frame image with center crop." @@ -346,7 +345,9 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam("resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image"), + InputParam( + "resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image" + ), InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"), ] @@ -357,16 +358,15 @@ def intermediate_outputs(self) -> List[OutputParam]: ] def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - + height = block_state.resized_image.height width = block_state.resized_image.width image = block_state.last_image # Calculate resize ratio to match first frame dimensions resize_ratio = max(width / image.width, height / image.height) - + # Resize the image width = round(image.width * resize_ratio) height = round(image.height * resize_ratio) @@ -404,12 +404,11 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), ] - def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) device = components._execution_device - + image = block_state.resized_image image_embeds = encode_image( @@ -450,12 +449,11 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), ] - def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) device = components._execution_device - + first_frame_image = block_state.resized_image last_frame_image = block_state.resized_last_image @@ -481,9 +479,14 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKLWan), - ComponentSpec("video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), ] - + @property def inputs(self) -> List[InputParam]: return [ @@ -493,13 +496,17 @@ def inputs(self) -> List[InputParam]: InputParam("num_frames"), InputParam("generator"), ] - + @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("first_frame_latents", type_hint=torch.Tensor, description="video latent representation with the first frame image condition"), + OutputParam( + "first_frame_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), ] - + @staticmethod def check_inputs(components, block_state): if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( @@ -513,8 +520,8 @@ def check_inputs(components, block_state): ): raise ValueError( f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." - ) - + ) + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) self.check_inputs(components, block_state) @@ -528,14 +535,19 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe width = block_state.width or components.default_width num_frames = block_state.num_frames or components.default_num_frames - image_tensor = components.video_processor.preprocess( - image, height=height, width=width).to(device=device, dtype=dtype) + image_tensor = components.video_processor.preprocess(image, height=height, width=width).to( + device=device, dtype=dtype + ) if image_tensor.dim() == 4: image_tensor = image_tensor.unsqueeze(2) video_tensor = torch.cat( - [image_tensor, image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width)], dim=2 + [ + image_tensor, + image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width), + ], + dim=2, ).to(device=device, dtype=dtype) block_state.first_frame_latents = encode_vae_image( @@ -562,9 +574,14 @@ def description(self) -> str: def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec("vae", AutoencoderKLWan), - ComponentSpec("video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), ] - + @property def inputs(self) -> List[InputParam]: return [ @@ -575,13 +592,17 @@ def inputs(self) -> List[InputParam]: InputParam("num_frames"), InputParam("generator"), ] - + @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam("first_last_frame_latents", type_hint=torch.Tensor, description="video latent representation with the first and last frame images condition"), + OutputParam( + "first_last_frame_latents", + type_hint=torch.Tensor, + description="video latent representation with the first and last frame images condition", + ), ] - + @staticmethod def check_inputs(components, block_state): if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( @@ -595,8 +616,8 @@ def check_inputs(components, block_state): ): raise ValueError( f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." - ) - + ) + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) self.check_inputs(components, block_state) @@ -611,17 +632,26 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe width = block_state.width or components.default_width num_frames = block_state.num_frames or components.default_num_frames - first_image_tensor = components.video_processor.preprocess( - first_frame_image, height=height, width=width).to(device=device, dtype=dtype) + first_image_tensor = components.video_processor.preprocess(first_frame_image, height=height, width=width).to( + device=device, dtype=dtype + ) first_image_tensor = first_image_tensor.unsqueeze(2) - - last_image_tensor = components.video_processor.preprocess( - last_frame_image, height=height, width=width).to(device=device, dtype=dtype) - + + last_image_tensor = components.video_processor.preprocess(last_frame_image, height=height, width=width).to( + device=device, dtype=dtype + ) + last_image_tensor = last_image_tensor.unsqueeze(2) video_tensor = torch.cat( - [first_image_tensor, first_image_tensor.new_zeros(first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width), last_image_tensor], dim=2 + [ + first_image_tensor, + first_image_tensor.new_zeros( + first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width + ), + last_image_tensor, + ], + dim=2, ).to(device=device, dtype=dtype) block_state.first_last_frame_latents = encode_vae_image( @@ -634,4 +664,4 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe ) self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 7a597168e2b5..94d8fc6999b7 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -16,20 +16,35 @@ from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( - WanTextInputStep, - WanPrepareLatentsStep, - WanSetTimestepsStep, WanInputsDynamicStep, WanPrepareFirstFrameLatentsStep, WanPrepareFirstLastFrameLatentsStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, + WanTextInputStep, ) from .decoders import WanImageVaeDecoderStep -from .denoise import WanDenoiseStep, WanImage2VideoDenoiseStep, WanFLF2VDenoiseStep, Wan22DenoiseStep, Wan22Image2VideoDenoiseStep -from .encoders import WanTextEncoderStep, WanImageResizeStep, WanImageCropResizeStep, WanImageEncoderStep, WanVaeImageEncoderStep, WanFirstLastFrameImageEncoderStep, WanFirstLastFrameVaeImageEncoderStep +from .denoise import ( + Wan22DenoiseStep, + Wan22Image2VideoDenoiseStep, + WanDenoiseStep, + WanFLF2VDenoiseStep, + WanImage2VideoDenoiseStep, +) +from .encoders import ( + WanFirstLastFrameImageEncoderStep, + WanFirstLastFrameVaeImageEncoderStep, + WanImageCropResizeStep, + WanImageEncoderStep, + WanImageResizeStep, + WanTextEncoderStep, + WanVaeImageEncoderStep, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name + # wan2.1 # wan2.1: text2vid class WanCoreDenoiseStep(SequentialPipelineBlocks): @@ -65,7 +80,6 @@ def description(self): return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" - ## vae encoder class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" @@ -87,7 +101,14 @@ class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): WanPrepareFirstFrameLatentsStep, WanImage2VideoDenoiseStep, ] - block_names = ["input", "additional_inputs", "set_timesteps", "prepare_latents", "prepare_first_frame_latents", "denoise"] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "prepare_first_frame_latents", + "denoise", + ] @property def description(self): @@ -105,6 +126,7 @@ def description(self): # wan2.1: FLF2v + ## image encoder class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" @@ -116,7 +138,6 @@ def description(self): return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" - ## vae encoder class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks): model_name = "wan" @@ -138,7 +159,14 @@ class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks): WanPrepareFirstLastFrameLatentsStep, WanFLF2VDenoiseStep, ] - block_names = ["input", "additional_inputs", "set_timesteps", "prepare_latents", "prepare_first_last_frame_latents", "denoise"] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "prepare_first_last_frame_latents", + "denoise", + ] @property def description(self): @@ -153,6 +181,7 @@ def description(self): + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" ) + # wan2.1: auto blocks ## image encoder class WanAutoImageEncoderStep(AutoPipelineBlocks): @@ -162,11 +191,14 @@ class WanAutoImageEncoderStep(AutoPipelineBlocks): @property def description(self): - return ("Image Encoder step that encode the image to generate the image embeddings" - + "This is an auto pipeline block that works for image2video tasks." - + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." - + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." - + " - if `last_image` or `image` is not provided, step will be skipped.") + return ( + "Image Encoder step that encode the image to generate the image embeddings" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + ## vae encoder class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): @@ -176,11 +208,14 @@ class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): @property def description(self): - return ("Vae Image Encoder step that encode the image to generate the image latents" - + "This is an auto pipeline block that works for image2video tasks." - + " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided." - + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided." - + " - if `last_image` or `image` is not provided, step will be skipped.") + return ( + "Vae Image Encoder step that encode the image to generate the image latents" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + ## denoise class WanAutoDenoiseStep(AutoPipelineBlocks): @@ -229,10 +264,10 @@ def description(self): ) - # wan22 # wan2.2: text2vid + ## denoise class Wan22CoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ @@ -266,7 +301,14 @@ class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): WanPrepareFirstFrameLatentsStep, Wan22Image2VideoDenoiseStep, ] - block_names = ["input", "additional_inputs", "set_timesteps", "prepare_latents", "prepare_first_frame_latents", "denoise"] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "prepare_first_frame_latents", + "denoise", + ] @property def description(self): @@ -281,9 +323,11 @@ def description(self): + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n" ) + class Wan22AutoDenoiseStep(AutoPipelineBlocks): block_classes = [ - Wan22Image2VideoCoreDenoiseStep, Wan22CoreDenoiseStep, + Wan22Image2VideoCoreDenoiseStep, + Wan22CoreDenoiseStep, ] block_names = ["image2video", "text2video"] block_trigger_inputs = ["first_frame_latents", None] @@ -299,6 +343,7 @@ def description(self) -> str: + " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n" ) + class Wan22AutoBlocks(SequentialPipelineBlocks): block_classes = [ WanTextEncoderStep, @@ -320,6 +365,7 @@ def description(self): + "- for text-to-video generation, all you need to provide is `prompt`" ) + # presets for wan2.1 and wan2.2 # YiYi Notes: should we move these to doc? # wan2.1 @@ -347,7 +393,6 @@ def description(self): ("denoise", WanImage2VideoDenoiseStep), ("decode", WanImageVaeDecoderStep), ] - ) @@ -365,7 +410,6 @@ def description(self): ("denoise", WanFLF2VDenoiseStep), ("decode", WanImageVaeDecoderStep), ] - ) AUTO_BLOCKS = InsertableDict( @@ -415,17 +459,16 @@ def description(self): # presets all blocks (wan and wan22) - ALL_BLOCKS = { - "wan2.1":{ + "wan2.1": { "text2video": TEXT2VIDEO_BLOCKS, "image2video": IMAGE2VIDEO_BLOCKS, "flf2v": FLF2V_BLOCKS, "auto": AUTO_BLOCKS, }, - "wan2.2":{ + "wan2.2": { "text2video": TEXT2VIDEO_BLOCKS_WAN22, "image2video": IMAGE2VIDEO_BLOCKS_WAN22, "auto": AUTO_BLOCKS_WAN22, - } + }, } diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py index 3dfbed185730..930b25e4b905 100644 --- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -13,11 +13,12 @@ # limitations under the License. +from typing import Any, Dict, Optional + from ...loaders import WanLoraLoaderMixin from ...pipelines.pipeline_utils import StableDiffusionMixin from ...utils import logging from ..modular_pipeline import ModularPipeline -from typing import Optional, Dict, Any logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -66,14 +67,13 @@ def default_sample_width(self): @property def default_sample_num_frames(self): return 21 - + @property def patch_size_spatial(self): patch_size_spatial = 2 if hasattr(self, "transformer") and self.transformer is not None: patch_size_spatial = self.transformer.config.patch_size[1] return patch_size_spatial - @property def vae_scale_factor_spatial(self): @@ -103,7 +103,6 @@ def num_channels_latents(self): num_channels_latents = self.vae.config.z_dim return num_channels_latents - @property def requires_unconditional_embeds(self): requires_unconditional_embeds = False @@ -118,4 +117,4 @@ def num_train_timesteps(self): num_train_timesteps = 1000 if hasattr(self, "scheduler") and self.scheduler is not None: num_train_timesteps = self.scheduler.config.num_train_timesteps - return num_train_timesteps \ No newline at end of file + return num_train_timesteps diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 2d7560cd6ad6..044d854390e4 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -117,8 +117,8 @@ StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) +from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline -from .wan import WanPipeline, WanImageToVideoPipeline, WanVideoToVideoPipeline, WanVACEPipeline AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e8209403de75..ed163e87a0c3 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -182,6 +182,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Wan22AutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class WanAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] From 0823d6728e28730fef53478596a000fa99e32c68 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Sun, 9 Nov 2025 11:27:02 -1000 Subject: [PATCH 08/12] Apply suggestions from code review --- .../modular_pipelines/modular_pipeline.py | 4 +--- .../modular_pipelines/wan/before_denoise.py | 17 ++++++++--------- src/diffusers/modular_pipelines/wan/denoise.py | 6 +++--- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index c664e4a33366..151adbbc0320 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1536,8 +1536,6 @@ def __init__( self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} - # update component_specs and config_specs from modular_repo - # update component_specs and config_specs based on modular_model_index.json if modular_config_dict is not None: for name, value in modular_config_dict.items(): @@ -1551,7 +1549,7 @@ def __init__( elif name in self._config_specs: self._config_specs[name].default = value - # if modular_model_index.json is not found, try to load model_index.json + # if `modular_config_dict` is None (i.e. `modular_model_index.json` is not found), update based on `config_dict` (i.e. `model_index.json`) elif config_dict is not None: for name, value in config_dict.items(): if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2: diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index 71a01b8b9943..f4e32063cafe 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -294,21 +294,20 @@ def __init__( Args: image_latent_inputs (List[str], optional): Names of image latent tensors to process. In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be - a single string or list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], - ["control_image_latents"] + a single string or list of strings. Defaults to ["first_frame_latents"]. additional_batch_inputs (List[str], optional): Names of additional conditional input tensors to expand batch size. These tensors will only have their batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. - Defaults to []. Examples: ["processed_mask_image"] + Defaults to []. Examples: - # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep() + # Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep() # Configure to process multiple image latent inputs - QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"]) + WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"]) - # Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep( - image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] + # Configure to process image latents and additional batch inputs WanAdditionalInputsStep( + image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"] ) """ if not isinstance(image_latent_inputs, list): @@ -565,7 +564,7 @@ class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): @property def description(self) -> str: - return "step that prepares the last frame mask latents and add it to the latent condition" + return "step that prepares the masked first frame latents and add it to the latent condition" @property def inputs(self) -> List[InputParam]: @@ -603,7 +602,7 @@ class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): @property def description(self) -> str: - return "step that prepares the last frame mask latents and add it to the latent condition" + return "step that prepares the masked latents with first and last frames and add it to the latent condition" @property def inputs(self) -> List[InputParam]: diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index ab73b8ffab5e..f154e5cac4d0 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -502,7 +502,7 @@ def description(self) -> str: " - `WanLoopBeforeDenoiser`\n" " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" - "This block supports text-to-video tasks." + "This block supports text-to-video tasks for wan2.1." ) @@ -553,7 +553,7 @@ def description(self) -> str: " - `WanImage2VideoLoopBeforeDenoiser`\n" " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" - "This block supports image-to-video tasks." + "This block supports image-to-video tasks for wan2.1." ) @@ -604,5 +604,5 @@ def description(self) -> str: " - `WanFLF2VLoopBeforeDenoiser`\n" " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" - "This block supports FLF2V tasks." + "This block supports FLF2V tasks for wan2.1." ) From 3341e25a9278682c61fe0426d0f6eab029072892 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 9 Nov 2025 22:30:05 +0100 Subject: [PATCH 09/12] input dynamic step -> additiional input step --- .../modular_pipelines/wan/before_denoise.py | 2 +- .../modular_pipelines/wan/modular_blocks.py | 22 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index f4e32063cafe..e2f8d3e7d88b 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -275,7 +275,7 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe return components, state -class WanInputsDynamicStep(ModularPipelineBlocks): +class WanAdditionalInputsStep(ModularPipelineBlocks): model_name = "wan" def __init__( diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 94d8fc6999b7..b3b70b2f9be1 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -16,7 +16,7 @@ from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( - WanInputsDynamicStep, + WanAdditionalInputsStep, WanPrepareFirstFrameLatentsStep, WanPrepareFirstLastFrameLatentsStep, WanPrepareLatentsStep, @@ -95,7 +95,7 @@ def description(self): class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanTextInputStep, - WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"]), + WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]), WanSetTimestepsStep, WanPrepareLatentsStep, WanPrepareFirstFrameLatentsStep, @@ -116,10 +116,10 @@ def description(self): "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + "This is a sequential pipeline blocks:\n" + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanInputsDynamicStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + " - `WanSetTimestepsStep` is used to set the timesteps\n" + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanPrepareConditionLatentsStep` is used to prepare the latent conditions\n" + + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n" + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" ) @@ -153,7 +153,7 @@ def description(self): class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanTextInputStep, - WanInputsDynamicStep(image_latent_inputs=["first_last_frame_latents"]), + WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]), WanSetTimestepsStep, WanPrepareLatentsStep, WanPrepareFirstLastFrameLatentsStep, @@ -174,7 +174,7 @@ def description(self): "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + "This is a sequential pipeline blocks:\n" + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanInputsDynamicStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + " - `WanSetTimestepsStep` is used to set the timesteps\n" + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n" @@ -295,7 +295,7 @@ def description(self): class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ WanTextInputStep, - WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"]), + WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]), WanSetTimestepsStep, WanPrepareLatentsStep, WanPrepareFirstFrameLatentsStep, @@ -316,10 +316,10 @@ def description(self): "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + "This is a sequential pipeline blocks:\n" + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanInputsDynamicStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + " - `WanSetTimestepsStep` is used to set the timesteps\n" + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - + " - `WanPrepareConditionLatentsStep` is used to prepare the latent conditions\n" + + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n" + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n" ) @@ -386,7 +386,7 @@ def description(self): ("image_encoder", WanImage2VideoImageEncoderStep), ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), ("input", WanTextInputStep), - ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_frame_latents"])), + ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])), ("set_timesteps", WanSetTimestepsStep), ("prepare_latents", WanPrepareLatentsStep), ("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep), @@ -403,7 +403,7 @@ def description(self): ("image_encoder", WanFLF2VImageEncoderStep), ("vae_image_encoder", WanFLF2VVaeImageEncoderStep), ("input", WanTextInputStep), - ("additional_inputs", WanInputsDynamicStep(image_latent_inputs=["first_last_frame_latents"])), + ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])), ("set_timesteps", WanSetTimestepsStep), ("prepare_latents", WanPrepareLatentsStep), ("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep), From 5a376436e171bc672cb4862ea7fed36b551a3a74 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 9 Nov 2025 22:39:07 +0100 Subject: [PATCH 10/12] up --- src/diffusers/modular_pipelines/wan/encoders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index dc49df8eab8c..c7491071a436 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -190,6 +190,7 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam( "prompt_embeds", type_hint=torch.Tensor, + # YiYi TODO: we should change it to text_embeddings kwargs_type="denoiser_input_fields", description="text embeddings used to guide the image generation", ), From 84d96f4d51af12e1820f31459efed82c54123c03 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sun, 9 Nov 2025 23:00:26 +0100 Subject: [PATCH 11/12] fix init --- src/diffusers/modular_pipelines/wan/__init__.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py index d6d83866d184..73f67c9afed2 100644 --- a/src/diffusers/modular_pipelines/wan/__init__.py +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -21,16 +21,14 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["decoders"] = ["WanImageVaeDecoderStep"] _import_structure["encoders"] = ["WanTextEncoderStep"] _import_structure["modular_blocks"] = [ "ALL_BLOCKS", - "AUTO_BLOCKS", - "TEXT2VIDEO_BLOCKS", "Wan22AutoBlocks", - "WanAutoBeforeDenoiseStep", "WanAutoBlocks", - "WanAutoDecodeStep", - "WanAutoDenoiseStep", + "WanAutoImageEncoderStep", + "WanAutoVaeImageEncoderStep", ] _import_structure["modular_pipeline"] = ["WanModularPipeline"] @@ -41,16 +39,14 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .decoders import WanImageVaeDecoderStep from .encoders import WanTextEncoderStep from .modular_blocks import ( ALL_BLOCKS, - AUTO_BLOCKS, - TEXT2VIDEO_BLOCKS, Wan22AutoBlocks, - WanAutoBeforeDenoiseStep, WanAutoBlocks, - WanAutoDecodeStep, - WanAutoDenoiseStep, + WanAutoImageEncoderStep, + WanAutoVaeImageEncoderStep, ) from .modular_pipeline import WanModularPipeline else: From 495354f554a333032b4d56399a155f3609d2e6e8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 10 Nov 2025 05:09:17 +0100 Subject: [PATCH 12/12] update dtype --- src/diffusers/modular_pipelines/wan/denoise.py | 16 ++++++++++------ src/diffusers/modular_pipelines/wan/encoders.py | 1 - 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index f154e5cac4d0..2da36f52da87 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -239,9 +239,11 @@ def __call__( for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = guider_state_batch.as_dict() - cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys()} - for cond in cond_kwargs.values(): - cond = cond.to(block_state.dtype) + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches @@ -372,9 +374,11 @@ def __call__( for guider_state_batch in guider_state: block_state.guider.prepare_models(block_state.current_model) cond_kwargs = guider_state_batch.as_dict() - cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys()} - for cond in cond_kwargs.values(): - cond = cond.to(block_state.dtype) + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index c7491071a436..dc49df8eab8c 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -190,7 +190,6 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam( "prompt_embeds", type_hint=torch.Tensor, - # YiYi TODO: we should change it to text_embeddings kwargs_type="denoiser_input_fields", description="text embeddings used to guide the image generation", ),