-
Notifications
You must be signed in to change notification settings - Fork 6.4k
make flux ready for mellon #12419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make flux ready for mellon #12419
Conversation
return [ | ||
InputParam("prompt"), | ||
InputParam("prompt_2"), | ||
InputParam("max_sequence_length", type_hint=int, default=512, required=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took the liberty of supporting this input from the user so that Schnell can also work. In another PR, I will harmonize the steps in Flux Modular that include repetition along the batch size dimension (similar to Qwen).
Cc: @yiyixuxu
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
thanks, I think that we need to move the FluxAutoBeforeDenoiseStep inside a denoise step wrapper here because usually when we use modular diffusers with separate steps, we only use the Right now, if I use Flux I get this error because of this:
Code to test it: import torch
from diffusers import ModularPipeline, ComponentsManager
repo_id = "black-forest-labs/FLUX.1-dev"
device = "cuda"
components = ComponentsManager()
components.enable_auto_cpu_offload(device=device)
text_blocks = ModularPipeline.from_pretrained(repo_id, components_manager=components).blocks.sub_blocks.pop("text_encoder")
text_encoder_node = text_blocks.init_pipeline(repo_id, components_manager=components)
text_encoder_node.load_components(torch_dtype=torch.bfloat16)
prompt = "a dog sitting by the see waiting for its companion to come"
text_state = text_encoder_node(prompt=prompt)
text_embeddings = text_state.get_by_kwargs("denoiser_input_fields")
denoise_blocks = ModularPipeline.from_pretrained(repo_id).blocks.sub_blocks.pop("denoise")
denoise_node = denoise_blocks.init_pipeline(repo_id, components_manager=components)
denoise_node.load_components(torch_dtype=torch.bfloat16)
denoise_state = denoise_node(
**text_embeddings,
guidance_scale=4.5,
num_inference_steps=28,
max_sequence_length=512,
generator=torch.Generator(device=device).manual_seed(0),
) |
@asomoza we should now be good to go! Test code (with a small update):import torch
from diffusers import ModularPipeline, ComponentsManager
repo_id = "black-forest-labs/FLUX.1-dev"
device = "cuda"
components = ComponentsManager()
components.enable_auto_cpu_offload(device=device)
text_blocks = ModularPipeline.from_pretrained(repo_id, components_manager=components).blocks.sub_blocks.pop("text_encoder")
text_encoder_node = text_blocks.init_pipeline(repo_id, components_manager=components)
text_encoder_node.load_components(torch_dtype=torch.bfloat16)
prompt = "a dog sitting by the see waiting for its companion to come"
# We should provide text embedding related inputs (`max_sequence_length`, for example)
# to the following step as it's utilized during the text embedding preparation
# step and NOT in the denoising step.
text_state = text_encoder_node(prompt=prompt, max_sequence_length=512)
text_embeddings = text_state.get_by_kwargs("denoiser_input_fields")
denoise_blocks = ModularPipeline.from_pretrained(repo_id).blocks.sub_blocks.pop("denoise")
denoise_node = denoise_blocks.init_pipeline(repo_id, components_manager=components)
denoise_node.load_components(torch_dtype=torch.bfloat16)
print(f"{text_embeddings.keys()=}")
denoise_state = denoise_node(
**text_embeddings,
guidance_scale=4.5,
num_inference_steps=28,
generator=torch.Generator(device=device).manual_seed(0),
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, LGTM after the minor fix, it works with the nodes
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
Merging after confirming with @asomoza! |
What does this PR do?
Code to test:
Result: