diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index 507acce1ebf6..4272066309a2 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -454,6 +454,9 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip block_state = self.get_block_state(state) block_state.device = components._execution_device + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + scheduler = components.scheduler transformer = components.transformer batch_size = block_state.batch_size * block_state.num_images_per_prompt @@ -659,8 +662,6 @@ def intermediate_outputs(self) -> List[OutputParam]: def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.height = block_state.height or components.default_height - block_state.width = block_state.width or components.default_width block_state.device = components._execution_device block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this? block_state.num_channels_latents = components.num_channels_latents diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index 04b439f026a4..37895bddbf07 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -148,8 +148,8 @@ def description(self): [ ("text_encoder", FluxTextEncoderStep), ("input", FluxInputStep), - ("set_timesteps", FluxSetTimestepsStep), ("prepare_latents", FluxPrepareLatentsStep), + ("set_timesteps", FluxSetTimestepsStep), ("denoise", FluxDenoiseStep), ("decode", FluxDecodeStep), ]