Skip to content

Conversation

@DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Nov 27, 2025

The expected dims for img_ids and txt_ids in FLUX.2 should be 3, not 2.

  • before this PR

Expected input tensor to have 2 dimensions, but got 3 dimensions, split will not be applied.

Expected input tensor to have 2 dimensions, but got 3 dimensions, split will not be applied.
hidden_states 0, mean: -0.0005601883167400956, std: 0.9917193651199341, shape: torch.Size([1, 2048, 128])
encoder_hidden_states 0, mean: -8.521552808815613e-05, std: 0.1560058295726776, shape: torch.Size([1, 256, 15360])
img_ids, mean: 15.75, std: 20.462167739868164, shape: torch.Size([1, 4096, 4])
txt_ids, mean: 63.875, std: 133.0626983642578, shape: torch.Size([1, 512, 4])
timestep, mean: 1.0, std: 0.0, shape: torch.Size([1])
guidance, mean: 4.0, std: 0.0, shape: torch.Size([1])
hidden_states 1, mean: -2.4117840439430438e-05, std: 0.24530985951423645, shape: torch.Size([1, 2048, 6144])
encoder_hidden_states 1, mean: -0.0007469747215509415, std: 0.8318553566932678, shape: torch.Size([1, 256, 6144])
concat_rotary_emb[0], mean: 0.7477867603302002, std: 0.5412695407867432, shape: torch.Size([4608, 128])
concat_rotary_emb[1], mean: 0.09602554887533188, std: 0.37231945991516113, shape: torch.Size([4608, 128])
double_stream_mod_img[0][0], mean: -0.006415137555450201, std: 1.1594423055648804, shape: torch.Size([1, 1, 6144])
double_stream_mod_txt[0][0], mean: 0.018749061971902847, std: 0.2851669490337372, shape: torch.Size([1, 1, 6144])
single_stream_mod[0], mean: 0.008495260030031204, std: 0.19870036840438843, shape: torch.Size([1, 1, 6144])

[rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/dev/vipshop/cache-dit/examples/parallelism/run_flux2_cp.py", line 133, in <module>
[rank1]:     _ = run_pipe(warmup=True)
[rank1]:         ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/cache-dit/examples/parallelism/run_flux2_cp.py", line 118, in run_pipe
[rank1]:     image = pipe(
[rank1]:             ^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py", line 821, in __call__
[rank1]:     noise_pred = self.transformer(
[rank1]:                  ^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 189, in new_forward
[rank1]:     output = function_reference.forward(*args, **kwargs)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_flux2.py", line 879, in forward
[rank1]:     encoder_hidden_states, hidden_states = block(
[rank1]:                                            ^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_flux2.py", line 521, in forward
[rank1]:     attention_outputs = self.attn(
[rank1]:                         ^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_flux2.py", line 256, in forward
[rank1]:     return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/transformers/transformer_flux2.py", line 155, in __call__
[rank1]:     query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/models/embeddings.py", line 1232, in apply_rotary_emb
[rank1]:     out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
[rank1]:            ~~~~~~~~~~^~~~~
[rank1]: RuntimeError: The size of tensor a (2304) must match the size of tensor b (4608) at non-singleton dimension 1
  • with this PR
hidden_states 0, mean: -0.0005601883167400956, std: 0.9917193651199341, shape: torch.Size([1, 2048, 128])
encoder_hidden_states 0, mean: -8.521552808815613e-05, std: 0.1560058295726776, shape: torch.Size([1, 256, 15360])
img_ids, mean: 15.75, std: 20.462478637695312, shape: torch.Size([1, 2048, 4])
txt_ids, mean: 63.875, std: 133.07894897460938, shape: torch.Size([1, 256, 4])
timestep, mean: 1.0, std: 0.0, shape: torch.Size([1])
guidance, mean: 4.0, std: 0.0, shape: torch.Size([1])
hidden_states 1, mean: -2.4117840439430438e-05, std: 0.24530985951423645, shape: torch.Size([1, 2048, 6144])
encoder_hidden_states 1, mean: -0.0007469747215509415, std: 0.8318553566932678, shape: torch.Size([1, 256, 6144])
concat_rotary_emb[0], mean: 0.7477867603302002, std: 0.5412697196006775, shape: torch.Size([2304, 128])
concat_rotary_emb[1], mean: 0.09602555632591248, std: 0.37231960892677307, shape: torch.Size([2304, 128])
double_stream_mod_img[0][0], mean: -0.006415137555450201, std: 1.1594423055648804, shape: torch.Size([1, 1, 6144])
double_stream_mod_txt[0][0], mean: 0.018749061971902847, std: 0.2851669490337372, shape: torch.Size([1, 1, 6144])
single_stream_mod[0], mean: 0.008495260030031204, std: 0.19870036840438843, shape: torch.Size([1, 1, 6144])
  • test script
import os
import time
import torch
import torch.distributed as dist
from diffusers import Flux2Pipeline, Flux2Transformer2DModel
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import ContextParallelConfig

dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
world_size = dist.get_world_size()
torch.cuda.set_device(device)

model_id = "black-forest-labs/FLUX.2-dev"
model_id = os.environ.get("FLUX_2_DIR", model_id)

pipe: Flux2Pipeline = Flux2Pipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    quantization_config=(
        (
            PipelineQuantizationConfig(
                quant_backend="bitsandbytes_4bit",
                quant_kwargs={
                    "load_in_4bit": True,
                    "bnb_4bit_quant_type": "nf4",
                    "bnb_4bit_compute_dtype": torch.bfloat16,
                },
                # 112/4 ~= 28GB total for text_encoder + transformer in 4-bit
                components_to_quantize=["text_encoder", "transformer"],
            )
        )
    ),
).to(device)

assert isinstance(pipe.transformer, Flux2Transformer2DModel)
pipe.transformer.set_attention_backend("native")
if world_size > 1:
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(ulysses_degree=world_size)
    )

pipe.set_progress_bar_config(disable=rank != 0)

prompt = (
    "Realistic macro photograph of a hermit crab using a soda can as its shell, "
    "partially emerging from the can, captured with sharp detail and natural colors, "
    "on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean "
    "waves in the background. The can has the text `BFL Diffusers` on it and it has a color "
    "gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
)

def run_pipe(warmup: bool = False):
    generator = torch.Generator("cpu").manual_seed(0)
    image = pipe(
        prompt=prompt,
        # 28 steps can be a good trade-off
        num_inference_steps=5 if warmup else 28,
        guidance_scale=4,
        generator=generator,
    ).images[0]
    return image

# warmup
_ = run_pipe(warmup=True)

start = time.time()
image = run_pipe()
end = time.time()

if rank == 0:

    time_cost = end - start
    save_path = f"flux2.ulysses{world_size}.png"
    print(f"Time cost: {time_cost:.2f}s")
    print(f"Saving image to {save_path}")
    image.save(save_path)

if dist.is_initialized():
    dist.destroy_process_group()
Baseline w/ bnb_4bit Ulysses 2 w/ bnb_4bit
flux2 C0_Q1_bitsandbytes_4bit_NONE flux2 C0_Q1_bitsandbytes_4bit_NONE_Ulysses2

@DN6 @sayakpaul

@sayakpaul sayakpaul requested a review from DN6 November 27, 2025 15:05
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants