Skip to content

flux2: dual-GPU model-parallel + transformers 5.8 compat + Mistral-on-CPU data_process#1434

Open
genno-whittlery wants to merge 3 commits into
modelscope:mainfrom
genno-whittlery:dual-gpu-flux2
Open

flux2: dual-GPU model-parallel + transformers 5.8 compat + Mistral-on-CPU data_process#1434
genno-whittlery wants to merge 3 commits into
modelscope:mainfrom
genno-whittlery:dual-gpu-flux2

Conversation

@genno-whittlery
Copy link
Copy Markdown

@genno-whittlery genno-whittlery commented May 11, 2026

Summary

Three independent improvements to examples/flux2/model_training/ and the underlying flux2 pipeline. All three are gated on environment variables / option flags so single-GPU users see no behavior change unless they opt in.

  1. Dual-GPU model-parallel for FLUX.2 LoRA training (DIFFSYNTH_DUAL_GPU=true).
  2. Mistral-on-CPU mode for sft:data_process (DIFFSYNTH_DATA_PROCESS_ON_CPU=true), so cards without 48+ GB VRAM can still produce the feature cache.
  3. transformers 5.8 compatibility for Flux2TextEncoder.forward (independent fix, single-GPU users benefit too).

Hardware target (why this matters most for 24 GB users)

The FLUX.2 transformer is ~60 GB in bf16. The bulk of the consumer-GPU population (24 GB cards: RTX 3090, RTX 4090, RTX A5000) cannot fit it on a single card even with fp8 weight-only quant + activation offload. RTX 5090 (32 GB) handles the single-card recipe, but supply is constrained and prices remain well above MSRP — most users who can train FLUX.2 LoRAs do so on the 24 GB stack.

GPU class Single-card FLUX.2 LoRA Dual-GPU (this PR)
1× 32 GB (5090) Fits with fp8 weight-only + gradient checkpointing; ~26-29 GB peak. Validated. Validated: 20.7 GB cuda:0 / 12.6 GB cuda:1 at 2.69 s/it.
1× 24 GB (3090 / 4090) Doesn't fit even with all the low-VRAM tricks — fp8 weights alone are ~30 GB. 2× 24 GB: ~21 GB cuda:0 + ~13 GB cuda:1 — comfortable headroom.
1× 16 GB or below Doesn't fit. 2× 16 GB: would need fp8 on both halves; architecturally fine, not validated yet.

So this isn't really about "faster than the single-card path" — it's about making FLUX.2 LoRA training reachable from the 2×-24-GB market at all, which is where the consumer 2-GPU configurations actually live (used 3090s, second-gen 4090 builds, A5000 workstations).

What changed

Dual-GPU FLUX.2 LoRA training

  • examples/flux2/model_training/flux2_dual_gpu_diffsynth.py (new, ~170 LOC): splits Flux2DiT.single_transformer_blocks at the midpoint across cuda:0/cuda:1. Registers forward_pre_hook on every cuda:1 single block — Flux2DiT.forward passes loop-level constants (temb_mod_params, image_rotary_emb, joint_attention_kwargs) to each iteration, so a boundary-only hook would leave subsequent blocks receiving cuda:0 tensors.

  • examples/flux2/model_training/train.py: forces CPU model load when DIFFSYNTH_DUAL_GPU=true (the ~60 GB bf16 transformer would otherwise pre-allocate on cuda:0 before split), runs torchao.quantize_ with Float8WeightOnlyConfig + a filter_fn that excludes lora_A / lora_B Linear submodules (otherwise their requires_grad is stripped and backward fails at the loss), then calls enable_flux2_dual_gpu(model.pipe.dit) after PEFT has injected LoRA so block.to(device) carries LoRA params with their base layers.

  • diffsynth/diffusion/runner.py (launch_training_task): skips model.to(accelerator.device) and uses device_placement=[False, True, True, True] in accelerator.prepare when DIFFSYNTH_DUAL_GPU=true. The env var name is intentionally model-neutral (per @gemini-code-assist review on wan: add dual-GPU model-parallel path for Wan 2.x LoRA training (depends on #1435) #1436) so the same gate covers both FLUX.2 and Wan dual-GPU paths.

Mistral-on-CPU sft:data_process

  • diffsynth/diffusion/runner.py (launch_data_process_task): env-var-gated CPU-offload branch (DIFFSYNTH_DATA_PROCESS_ON_CPU=true). Combined with --initialize_model_on_cpu, this lets users on cards without 48 GB VRAM run the feature-cache step on CPU (~15 s per 256-token caption at 512² — slow, but a one-time pre-process). Without this, Mistral-3-Small-24B's ~48 GB bf16 weights OOM on any consumer 32 GB card.

transformers 5.8 compat

transformers 5.8 trimmed Mistral3ForConditionalGeneration.forward's positional signature (removed output_attentions, output_hidden_states, return_dict, cache_position — they're TransformersKwargs-only now). The existing wrapper passed 15 positional args and dies with:

TypeError: Mistral3ForConditionalGeneration.forward() takes from 1 to 11
positional arguments but 15 were given

Rewrote Flux2TextEncoder.forward to pass everything by keyword and re-inject output_hidden_states / output_attentions via **kwargs so the existing get_mistral_3_small_prompt_embeds caller still receives a populated output.hidden_states.

Validation

End-to-end on 2× RTX 5090 with sumi v8 data + the unchanged recipe from examples/flux2/model_training/lora/FLUX.2-dev.sh:

  • sft:data_process (with --initialize_model_on_cpu + DIFFSYNTH_DATA_PROCESS_ON_CPU=true): 3 sumi-v8 images cached on CPU-offloaded Mistral-24B in 43 s.
  • sft:train (with DIFFSYNTH_DUAL_GPU=true): 15/15 steps at 2.69 s/it sustained, 40 s wall-clock. Distribution lands at 20.7 GB cuda:0 / 12.6 GB cuda:1 after fp8 weight-only quant. LoRA checkpoint (270 MB at rank 32) saved.

Cross-vendor (2× 3090 / 2× 4090) validation pending — happy to run if a maintainer has access to those configs and wants concrete numbers before merge. Architecture math suggests comfortable fit; the patch shape has been validated across four FLUX.2 LoRA trainers (ai-toolkit, musubi-tuner, OneTrainer, DiffSynth-Studio) with consistent ~21 / ~13 GB distribution.

Context

Cross-trainer write-up at https://github.com/genno-whittlery/flux2-dual-gpu-lora; DiffSynth-specific porting walkthrough at https://github.com/genno-whittlery/flux2-dual-gpu-lora/blob/main/docs/porting-diffsynth-studio.md.

The LoRA-skip filter_fn on quantize_ was a novel finding from this port — diffusers' reference script earlier hit TorchaoLoraLinear's constructor incompatibility because it was quantizing LoRA submodules too. Skipping lora_A / lora_B keeps requires_grad and routes PEFT through its normal LoraLinear instead of the broken torchao dispatcher.

Companion PR #1436 ports the same dual-GPU pattern to Wan video LoRA training and shares the DIFFSYNTH_DUAL_GPU env var with this PR.

Test plan

  • sft:data_process on 2× RTX 5090 with Mistral-on-CPU — produces valid cached .pth files (verified shapes: latents (1,1024,128) bf16, prompt_embeds (1,512,15360) bf16)
  • sft:train 15-step run, dual-GPU, fp8 weight-only — completes, LoRA saves
  • transformers 5.8 — Mistral3 forward succeeds, output.hidden_states populated for get_mistral_3_small_prompt_embeds
  • single-GPU path unchanged (env vars unset → train.py / runner.py take existing branches)
  • 2× 24 GB validation (3090 / 4090) — pending; happy to run if a reviewer has hardware
  • downstream tasks (Wan 2.2 etc.) regression-tested — not yet run; this PR only touches flux2 examples + flux2_text_encoder.py + runner.py (gated branches only)

Three changes, all gated on environment variables so single-GPU users
see no behavior difference unless they opt in.

1) examples/flux2/model_training/flux2_dual_gpu_diffsynth.py (new):
   ~170 LOC helper. Splits Flux2DiT.single_transformer_blocks at the
   midpoint across cuda:0/cuda:1. Registers forward_pre_hook on every
   cuda:1 single block (not just the boundary) because Flux2DiT.forward
   passes loop-level constants (temb_mod_params, image_rotary_emb,
   joint_attention_kwargs) into every iteration -- a boundary-only hook
   would leave subsequent blocks receiving cuda:0 tensors. The norm_out
   hook bridges activations back to cuda:0 for the final output layers.

2) examples/flux2/model_training/train.py:
   When FLUX2_DUAL_GPU=true, force CPU model load (so the ~60 GB bf16
   transformer doesn't pre-allocate on cuda:0 before split), run
   torchao Float8WeightOnlyConfig quantize_ with a filter that
   excludes lora_A/lora_B Linear submodules from the quant pass
   (otherwise their requires_grad is stripped and backward fails),
   then call enable_flux2_dual_gpu(model.pipe.dit) after PEFT has
   injected LoRA so layer .to(device) carries LoRA params with their
   base layers.

3) diffsynth/diffusion/runner.py:
   launch_training_task: skip model.to(accelerator.device) when
   FLUX2_DUAL_GPU=true (would undo the split), use
   device_placement=[False, True, True, True] so accelerate doesn't
   re-place the model.
   launch_data_process_task: separate DIFFSYNTH_DATA_PROCESS_ON_CPU
   env-var-gated branch for users without a 48 GB card -- Mistral-24B
   is too big for 32 GB; combined with --initialize_model_on_cpu this
   lets the TE/VAE feature caching step run on CPU.

Bonus: diffsynth/models/flux2_text_encoder.py:
   transformers 5.8 trimmed Mistral3ForConditionalGeneration's
   positional args (dropped output_attentions, output_hidden_states,
   return_dict, cache_position from the signature; they're now
   TransformersKwargs-only). The existing code passed 15 positional
   args and crashed with "takes from 1 to 11 positional arguments but
   15 were given". Rewrote to pass everything by keyword and re-inject
   output_hidden_states/output_attentions via **kwargs so the existing
   get_mistral_3_small_prompt_embeds caller still receives a populated
   output.hidden_states. Independent fix; single-GPU users benefit too.

Validated end-to-end on 2x RTX 5090 with sumi v8 data + the unchanged
recipe from examples/flux2/model_training/lora/FLUX.2-dev.sh:

- sft:data_process (with --initialize_model_on_cpu + the env var):
  3 sumi images cached on CPU-offloaded Mistral-24B in 43s.
- sft:train (FLUX2_DUAL_GPU=true): 15/15 steps at 2.69 s/it sustained,
  40s wall-clock. Distribution lands 20.7 GB cuda:0 / 12.6 GB cuda:1
  after fp8 weight-only quant. LoRA checkpoint epoch-0.safetensors
  (270 MB at rank 32) saved.

The same patch shape has been validated on three other FLUX.2 LoRA
trainers (ai-toolkit, musubi-tuner, OneTrainer) -- documented at
https://github.com/genno-whittlery/flux2-dual-gpu-lora -- with the same
20.7/12.6 GB distribution shape across all four.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces dual-GPU model parallelism for FLUX.2 training, enabling LoRA training on pairs of 24GB consumer GPUs. It includes a new helper module to split the transformer across devices using forward pre-hooks and integrates fp8 weight-only quantization via torchao to reduce memory usage. Additionally, the text encoder's forward method was updated for compatibility with transformers 5.8. Feedback indicates that the return_dict and cache_position arguments in the text encoder are currently captured but not passed to the superclass, which could lead to them being silently ignored.

Comment on lines +61 to +64
if output_hidden_states is not None:
kwargs.setdefault("output_hidden_states", output_hidden_states)
if output_attentions is not None:
kwargs.setdefault("output_attentions", output_attentions)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The arguments return_dict and cache_position are captured in the forward method signature but are not passed to the super().forward call, nor are they injected into kwargs. This means that if a caller explicitly provides these arguments, they will be silently ignored.

Since the PR description mentions that these parameters were trimmed from the positional signature in transformers 5.8 and should now be passed via keyword arguments, they should be handled in the same way as output_hidden_states and output_attentions to ensure compatibility and correctness.

        if output_hidden_states is not None:
            kwargs.setdefault("output_hidden_states", output_hidden_states)
        if output_attentions is not None:
            kwargs.setdefault("output_attentions", output_attentions)
        if return_dict is not None:
            kwargs.setdefault("return_dict", return_dict)
        if cache_position is not None:
            kwargs.setdefault("cache_position", cache_position)

…args

Address gemini-code-assist review comment on PR modelscope#1434
(modelscope#1434).

The previous patch captured return_dict and cache_position in the
forward signature (for backward compat with older transformers
versions where they were positional) but didn't forward them to
super().forward(). For 5.8+ callers explicitly passing these args,
they would be silently dropped.

Now forwarded the same way as output_hidden_states / output_attentions:
via kwargs.setdefault so explicit kwargs in the call still win, and
None values don't pollute **kwargs.

In transformers 5.8 these two are no-ops (return_dict always True;
cache_position computed internally), so forwarding them only matters
for older transformers versions -- but it's the correct behavior either
way, and a cheap fix.
genno-whittlery added a commit to genno-whittlery/DiffSynth-Studio that referenced this pull request May 11, 2026
…args

Address gemini-code-assist review comment on PR modelscope#1434
(modelscope#1434).

The previous patch captured return_dict and cache_position in the
forward signature (for backward compat with older transformers
versions where they were positional) but didn't forward them to
super().forward(). For 5.8+ callers explicitly passing these args,
they would be silently dropped.

Now forwarded the same way as output_hidden_states / output_attentions:
via kwargs.setdefault so explicit kwargs in the call still win, and
None values don't pollute **kwargs.

In transformers 5.8 these two are no-ops (return_dict always True;
cache_position computed internally), so forwarding them only matters
for older transformers versions -- but it's the correct behavior either
way, and a cheap fix.
@genno-whittlery
Copy link
Copy Markdown
Author

Thanks @gemini-code-assist — addressed in 00e85cc. return_dict and cache_position now flow through kwargs.setdefault alongside output_hidden_states / output_attentions. In transformers 5.8 these two are no-ops at runtime (return_dict is always True, cache_position is computed internally) but forwarding them is the correct behavior across versions.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Thanks for the update, @genno-whittlery. That approach for transformers 5.8 compatibility looks correct and robust for cross-version support. The use of kwargs.setdefault ensures that the necessary arguments are passed through while maintaining compatibility with the newer TransformersKwargs pattern. The rest of the dual-GPU and CPU-offload logic also appears well-implemented and appropriately gated. Great work on these improvements.

genno-whittlery added a commit to genno-whittlery/DiffSynth-Studio that referenced this pull request May 11, 2026
Off by default -- no behavior change unless WAN_DUAL_GPU=true is set.

Wan 2.2 14B variants (I2V-A14B, T2V-A14B, S2V-14B, etc.) are ~28 GB
in bf16 -- the weights fit on one 32 GB consumer card with fp8 quant,
but video training activations at 480x832x49 frames + gradient
checkpointing routinely push the actual step over 32 GB even on a
14B model. Splitting the transformer blocks across two GPUs gives
training-step headroom that single-GPU users can't otherwise reach
without dropping resolution or frame count.

What changed:

- examples/wanvideo/model_training/wan_dual_gpu_diffsynth.py (new):
  ~150 LOC helper. Splits WanModel.blocks at the midpoint across
  cuda:0/cuda:1. Registers forward_pre_hook on every cuda:1 block
  (not just the boundary -- Wan's forward passes loop-level constants
  context / t_mod / freqs positionally to each iteration, so a
  boundary-only hook would leave subsequent blocks receiving cuda:0
  tensors). Bridges activations back to cuda:0 at the head module.
  Also explicitly moves WanModel.freqs (a tuple of plain CPU tensors,
  not registered buffers) so .to(device) doesn't miss them.

- examples/wanvideo/model_training/train.py: forces CPU model load
  when WAN_DUAL_GPU=true (so the bf16 transformer doesn't
  pre-allocate on cuda:0 before split), runs torchao
  Float8WeightOnlyConfig quantize_ with the same LoRA-skip filter
  used by the FLUX.2 port (skips lora_A/lora_B Linear submodules --
  otherwise their requires_grad is stripped and backward fails),
  then calls enable_wan_dual_gpu(model.pipe.dit) after PEFT has
  injected LoRA so block.to(device) carries LoRA params with their
  base layers. Also sets FLUX2_DUAL_GPU=true after distribute so the
  existing runner.py branch from PR modelscope#1434 catches the
  device_placement=[False, ...] case in accelerator.prepare without
  needing a parallel WAN_DUAL_GPU branch there.

Depends on modelscope#1435 (patchify fix). The current main has a broken
WanModel.patchify that returns the wrong shape and arity; Wan
training fails immediately at the first forward call regardless of
dual-GPU. Once modelscope#1435 lands, both single-GPU and dual-GPU Wan training
paths work.

Validated locally on 2x RTX 5090 with a synthetic 8-layer WanModel
(same architecture shape as real Wan 2.2, miniaturized to fit a
quick smoke test): forward + backward complete across the cross-
device split, output round-trips to the original (B, C, T, H, W)
shape, LoRA gradients land on both cuda:0 and cuda:1 (proving cross-
device autograd).

Same patch shape as the validated FLUX.2 port in PR modelscope#1434 from this
account. Both share the runner.py model-parallel branch.
Per @gemini-code-assist review on modelscope#1436: model-neutral env var name so
a single signal controls dual-GPU mode for both FLUX.2 and Wan paths.
This change:

- FLUX2_DUAL_GPU            -> DIFFSYNTH_DUAL_GPU
- FLUX2_DUAL_GPU_SPLIT_AT   -> DIFFSYNTH_DUAL_GPU_SPLIT_AT

Function names (enable_flux2_dual_gpu) and the helper module name
(flux2_dual_gpu_diffsynth.py) keep the model-specific naming since
they're the entry points users explicitly import for FLUX.2 training.

Coordinated with PR modelscope#1436 which adopts the same DIFFSYNTH_DUAL_GPU
gate on the Wan side. Once both merge, one env var enables dual-GPU
for either model family; runner.py recognizes the gate regardless of
which family the training script targets.
genno-whittlery added a commit to genno-whittlery/DiffSynth-Studio that referenced this pull request May 11, 2026
Per @gemini-code-assist review on modelscope#1436: use a model-neutral env var
name. This change:

- WAN_DUAL_GPU            -> DIFFSYNTH_DUAL_GPU
- WAN_DUAL_GPU_SPLIT_AT   -> DIFFSYNTH_DUAL_GPU_SPLIT_AT

The same generic name will be adopted in modelscope#1434 (flux2 dual-GPU) so a
single env var controls dual-GPU mode for both Wan and FLUX.2 paths.
Helper function name (is_dual_gpu_enabled) and helper module file
name (wan_dual_gpu_diffsynth.py) kept Wan-specific to make the
import call sites at the entry-point clearer.

Also dropped the now-redundant os.environ["FLUX2_DUAL_GPU"]="true"
re-broadcast in train.py — the launcher-set DIFFSYNTH_DUAL_GPU is
inherited by the same process and read by both is_dual_gpu_enabled()
and runner.py's gate (after modelscope#1434 lands its matching rename).
Added a comment block making the implicit runtime dependency on
modelscope#1434's runner.py change explicit.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant