Skip to content

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Code to test:

import torch 
from diffusers import ModularPipeline

repo_id = "black-forest-labs/FLUX.1-dev"

pipe = ModularPipeline.from_pretrained(repo_id)
pipe.load_components(torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
# print(pipe)

output = pipe(
    prompt="a dog sitting by the see waiting for its companion to come",
    guidance_scale=4.5,
    num_inference_steps=28,
    max_sequence_length=512,
    generator=torch.manual_seed(0)
)
output.values["images"][0].save("modular_flux_image.png")

Result:

image

@sayakpaul sayakpaul requested review from asomoza and yiyixuxu October 2, 2025 12:49
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took the liberty of supporting this input from the user so that Schnell can also work. In another PR, I will harmonize the steps in Flux Modular that include repetition along the batch size dimension (similar to Qwen).

Cc: @yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@asomoza
Copy link
Member

asomoza commented Oct 5, 2025

thanks, I think that we need to move the FluxAutoBeforeDenoiseStep inside a denoise step wrapper here because usually when we use modular diffusers with separate steps, we only use the denoise step . This is done in the SDXL here so the auto blocks here don't have a before_denoise step.

Right now, if I use Flux I get this error because of this:

ValueError: Required input 'timesteps' is missing

Code to test it:

import torch 
from diffusers import ModularPipeline, ComponentsManager

repo_id = "black-forest-labs/FLUX.1-dev"
device = "cuda"

components = ComponentsManager()
components.enable_auto_cpu_offload(device=device)

text_blocks = ModularPipeline.from_pretrained(repo_id, components_manager=components).blocks.sub_blocks.pop("text_encoder")
text_encoder_node = text_blocks.init_pipeline(repo_id, components_manager=components)
text_encoder_node.load_components(torch_dtype=torch.bfloat16)

prompt = "a dog sitting by the see waiting for its companion to come"

text_state = text_encoder_node(prompt=prompt)
text_embeddings = text_state.get_by_kwargs("denoiser_input_fields")

denoise_blocks = ModularPipeline.from_pretrained(repo_id).blocks.sub_blocks.pop("denoise")
denoise_node = denoise_blocks.init_pipeline(repo_id, components_manager=components)
denoise_node.load_components(torch_dtype=torch.bfloat16)

denoise_state = denoise_node(
    **text_embeddings,
    guidance_scale=4.5,
    num_inference_steps=28,
    max_sequence_length=512,
    generator=torch.Generator(device=device).manual_seed(0),
)

@sayakpaul
Copy link
Member Author

@asomoza we should now be good to go!

Test code (with a small update):
import torch 
from diffusers import ModularPipeline, ComponentsManager

repo_id = "black-forest-labs/FLUX.1-dev"
device = "cuda"

components = ComponentsManager()
components.enable_auto_cpu_offload(device=device)

text_blocks = ModularPipeline.from_pretrained(repo_id, components_manager=components).blocks.sub_blocks.pop("text_encoder")
text_encoder_node = text_blocks.init_pipeline(repo_id, components_manager=components)
text_encoder_node.load_components(torch_dtype=torch.bfloat16)

prompt = "a dog sitting by the see waiting for its companion to come"

# We should provide text embedding related inputs (`max_sequence_length`, for example)
# to the following step as it's utilized during the text embedding preparation
# step and NOT in the denoising step.
text_state = text_encoder_node(prompt=prompt, max_sequence_length=512)
text_embeddings = text_state.get_by_kwargs("denoiser_input_fields")

denoise_blocks = ModularPipeline.from_pretrained(repo_id).blocks.sub_blocks.pop("denoise")
denoise_node = denoise_blocks.init_pipeline(repo_id, components_manager=components)
denoise_node.load_components(torch_dtype=torch.bfloat16)

print(f"{text_embeddings.keys()=}")

denoise_state = denoise_node(
    **text_embeddings,
    guidance_scale=4.5,
    num_inference_steps=28,
    generator=torch.Generator(device=device).manual_seed(0),
)

Copy link
Member

@asomoza asomoza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, LGTM after the minor fix, it works with the nodes

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
@sayakpaul
Copy link
Member Author

Merging after confirming with @asomoza!

@sayakpaul sayakpaul merged commit 7f3e9b8 into main Oct 6, 2025
15 of 17 checks passed
@sayakpaul sayakpaul deleted the fix-flux-mellon branch October 6, 2025 07:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants