flux2: dual-GPU model-parallel + transformers 5.8 compat + Mistral-on-CPU data_process#1434
flux2: dual-GPU model-parallel + transformers 5.8 compat + Mistral-on-CPU data_process#1434genno-whittlery wants to merge 3 commits into
Conversation
Three changes, all gated on environment variables so single-GPU users see no behavior difference unless they opt in. 1) examples/flux2/model_training/flux2_dual_gpu_diffsynth.py (new): ~170 LOC helper. Splits Flux2DiT.single_transformer_blocks at the midpoint across cuda:0/cuda:1. Registers forward_pre_hook on every cuda:1 single block (not just the boundary) because Flux2DiT.forward passes loop-level constants (temb_mod_params, image_rotary_emb, joint_attention_kwargs) into every iteration -- a boundary-only hook would leave subsequent blocks receiving cuda:0 tensors. The norm_out hook bridges activations back to cuda:0 for the final output layers. 2) examples/flux2/model_training/train.py: When FLUX2_DUAL_GPU=true, force CPU model load (so the ~60 GB bf16 transformer doesn't pre-allocate on cuda:0 before split), run torchao Float8WeightOnlyConfig quantize_ with a filter that excludes lora_A/lora_B Linear submodules from the quant pass (otherwise their requires_grad is stripped and backward fails), then call enable_flux2_dual_gpu(model.pipe.dit) after PEFT has injected LoRA so layer .to(device) carries LoRA params with their base layers. 3) diffsynth/diffusion/runner.py: launch_training_task: skip model.to(accelerator.device) when FLUX2_DUAL_GPU=true (would undo the split), use device_placement=[False, True, True, True] so accelerate doesn't re-place the model. launch_data_process_task: separate DIFFSYNTH_DATA_PROCESS_ON_CPU env-var-gated branch for users without a 48 GB card -- Mistral-24B is too big for 32 GB; combined with --initialize_model_on_cpu this lets the TE/VAE feature caching step run on CPU. Bonus: diffsynth/models/flux2_text_encoder.py: transformers 5.8 trimmed Mistral3ForConditionalGeneration's positional args (dropped output_attentions, output_hidden_states, return_dict, cache_position from the signature; they're now TransformersKwargs-only). The existing code passed 15 positional args and crashed with "takes from 1 to 11 positional arguments but 15 were given". Rewrote to pass everything by keyword and re-inject output_hidden_states/output_attentions via **kwargs so the existing get_mistral_3_small_prompt_embeds caller still receives a populated output.hidden_states. Independent fix; single-GPU users benefit too. Validated end-to-end on 2x RTX 5090 with sumi v8 data + the unchanged recipe from examples/flux2/model_training/lora/FLUX.2-dev.sh: - sft:data_process (with --initialize_model_on_cpu + the env var): 3 sumi images cached on CPU-offloaded Mistral-24B in 43s. - sft:train (FLUX2_DUAL_GPU=true): 15/15 steps at 2.69 s/it sustained, 40s wall-clock. Distribution lands 20.7 GB cuda:0 / 12.6 GB cuda:1 after fp8 weight-only quant. LoRA checkpoint epoch-0.safetensors (270 MB at rank 32) saved. The same patch shape has been validated on three other FLUX.2 LoRA trainers (ai-toolkit, musubi-tuner, OneTrainer) -- documented at https://github.com/genno-whittlery/flux2-dual-gpu-lora -- with the same 20.7/12.6 GB distribution shape across all four.
There was a problem hiding this comment.
Code Review
This pull request introduces dual-GPU model parallelism for FLUX.2 training, enabling LoRA training on pairs of 24GB consumer GPUs. It includes a new helper module to split the transformer across devices using forward pre-hooks and integrates fp8 weight-only quantization via torchao to reduce memory usage. Additionally, the text encoder's forward method was updated for compatibility with transformers 5.8. Feedback indicates that the return_dict and cache_position arguments in the text encoder are currently captured but not passed to the superclass, which could lead to them being silently ignored.
| if output_hidden_states is not None: | ||
| kwargs.setdefault("output_hidden_states", output_hidden_states) | ||
| if output_attentions is not None: | ||
| kwargs.setdefault("output_attentions", output_attentions) |
There was a problem hiding this comment.
The arguments return_dict and cache_position are captured in the forward method signature but are not passed to the super().forward call, nor are they injected into kwargs. This means that if a caller explicitly provides these arguments, they will be silently ignored.
Since the PR description mentions that these parameters were trimmed from the positional signature in transformers 5.8 and should now be passed via keyword arguments, they should be handled in the same way as output_hidden_states and output_attentions to ensure compatibility and correctness.
if output_hidden_states is not None:
kwargs.setdefault("output_hidden_states", output_hidden_states)
if output_attentions is not None:
kwargs.setdefault("output_attentions", output_attentions)
if return_dict is not None:
kwargs.setdefault("return_dict", return_dict)
if cache_position is not None:
kwargs.setdefault("cache_position", cache_position)…args Address gemini-code-assist review comment on PR modelscope#1434 (modelscope#1434). The previous patch captured return_dict and cache_position in the forward signature (for backward compat with older transformers versions where they were positional) but didn't forward them to super().forward(). For 5.8+ callers explicitly passing these args, they would be silently dropped. Now forwarded the same way as output_hidden_states / output_attentions: via kwargs.setdefault so explicit kwargs in the call still win, and None values don't pollute **kwargs. In transformers 5.8 these two are no-ops (return_dict always True; cache_position computed internally), so forwarding them only matters for older transformers versions -- but it's the correct behavior either way, and a cheap fix.
…args Address gemini-code-assist review comment on PR modelscope#1434 (modelscope#1434). The previous patch captured return_dict and cache_position in the forward signature (for backward compat with older transformers versions where they were positional) but didn't forward them to super().forward(). For 5.8+ callers explicitly passing these args, they would be silently dropped. Now forwarded the same way as output_hidden_states / output_attentions: via kwargs.setdefault so explicit kwargs in the call still win, and None values don't pollute **kwargs. In transformers 5.8 these two are no-ops (return_dict always True; cache_position computed internally), so forwarding them only matters for older transformers versions -- but it's the correct behavior either way, and a cheap fix.
|
Thanks @gemini-code-assist — addressed in 00e85cc. |
|
Thanks for the update, @genno-whittlery. That approach for |
Off by default -- no behavior change unless WAN_DUAL_GPU=true is set. Wan 2.2 14B variants (I2V-A14B, T2V-A14B, S2V-14B, etc.) are ~28 GB in bf16 -- the weights fit on one 32 GB consumer card with fp8 quant, but video training activations at 480x832x49 frames + gradient checkpointing routinely push the actual step over 32 GB even on a 14B model. Splitting the transformer blocks across two GPUs gives training-step headroom that single-GPU users can't otherwise reach without dropping resolution or frame count. What changed: - examples/wanvideo/model_training/wan_dual_gpu_diffsynth.py (new): ~150 LOC helper. Splits WanModel.blocks at the midpoint across cuda:0/cuda:1. Registers forward_pre_hook on every cuda:1 block (not just the boundary -- Wan's forward passes loop-level constants context / t_mod / freqs positionally to each iteration, so a boundary-only hook would leave subsequent blocks receiving cuda:0 tensors). Bridges activations back to cuda:0 at the head module. Also explicitly moves WanModel.freqs (a tuple of plain CPU tensors, not registered buffers) so .to(device) doesn't miss them. - examples/wanvideo/model_training/train.py: forces CPU model load when WAN_DUAL_GPU=true (so the bf16 transformer doesn't pre-allocate on cuda:0 before split), runs torchao Float8WeightOnlyConfig quantize_ with the same LoRA-skip filter used by the FLUX.2 port (skips lora_A/lora_B Linear submodules -- otherwise their requires_grad is stripped and backward fails), then calls enable_wan_dual_gpu(model.pipe.dit) after PEFT has injected LoRA so block.to(device) carries LoRA params with their base layers. Also sets FLUX2_DUAL_GPU=true after distribute so the existing runner.py branch from PR modelscope#1434 catches the device_placement=[False, ...] case in accelerator.prepare without needing a parallel WAN_DUAL_GPU branch there. Depends on modelscope#1435 (patchify fix). The current main has a broken WanModel.patchify that returns the wrong shape and arity; Wan training fails immediately at the first forward call regardless of dual-GPU. Once modelscope#1435 lands, both single-GPU and dual-GPU Wan training paths work. Validated locally on 2x RTX 5090 with a synthetic 8-layer WanModel (same architecture shape as real Wan 2.2, miniaturized to fit a quick smoke test): forward + backward complete across the cross- device split, output round-trips to the original (B, C, T, H, W) shape, LoRA gradients land on both cuda:0 and cuda:1 (proving cross- device autograd). Same patch shape as the validated FLUX.2 port in PR modelscope#1434 from this account. Both share the runner.py model-parallel branch.
00e85cc to
df7b01c
Compare
Per @gemini-code-assist review on modelscope#1436: model-neutral env var name so a single signal controls dual-GPU mode for both FLUX.2 and Wan paths. This change: - FLUX2_DUAL_GPU -> DIFFSYNTH_DUAL_GPU - FLUX2_DUAL_GPU_SPLIT_AT -> DIFFSYNTH_DUAL_GPU_SPLIT_AT Function names (enable_flux2_dual_gpu) and the helper module name (flux2_dual_gpu_diffsynth.py) keep the model-specific naming since they're the entry points users explicitly import for FLUX.2 training. Coordinated with PR modelscope#1436 which adopts the same DIFFSYNTH_DUAL_GPU gate on the Wan side. Once both merge, one env var enables dual-GPU for either model family; runner.py recognizes the gate regardless of which family the training script targets.
Per @gemini-code-assist review on modelscope#1436: use a model-neutral env var name. This change: - WAN_DUAL_GPU -> DIFFSYNTH_DUAL_GPU - WAN_DUAL_GPU_SPLIT_AT -> DIFFSYNTH_DUAL_GPU_SPLIT_AT The same generic name will be adopted in modelscope#1434 (flux2 dual-GPU) so a single env var controls dual-GPU mode for both Wan and FLUX.2 paths. Helper function name (is_dual_gpu_enabled) and helper module file name (wan_dual_gpu_diffsynth.py) kept Wan-specific to make the import call sites at the entry-point clearer. Also dropped the now-redundant os.environ["FLUX2_DUAL_GPU"]="true" re-broadcast in train.py — the launcher-set DIFFSYNTH_DUAL_GPU is inherited by the same process and read by both is_dual_gpu_enabled() and runner.py's gate (after modelscope#1434 lands its matching rename). Added a comment block making the implicit runtime dependency on modelscope#1434's runner.py change explicit.
Summary
Three independent improvements to
examples/flux2/model_training/and the underlying flux2 pipeline. All three are gated on environment variables / option flags so single-GPU users see no behavior change unless they opt in.DIFFSYNTH_DUAL_GPU=true).sft:data_process(DIFFSYNTH_DATA_PROCESS_ON_CPU=true), so cards without 48+ GB VRAM can still produce the feature cache.transformers5.8 compatibility forFlux2TextEncoder.forward(independent fix, single-GPU users benefit too).Hardware target (why this matters most for 24 GB users)
The FLUX.2 transformer is ~60 GB in bf16. The bulk of the consumer-GPU population (24 GB cards: RTX 3090, RTX 4090, RTX A5000) cannot fit it on a single card even with fp8 weight-only quant + activation offload. RTX 5090 (32 GB) handles the single-card recipe, but supply is constrained and prices remain well above MSRP — most users who can train FLUX.2 LoRAs do so on the 24 GB stack.
So this isn't really about "faster than the single-card path" — it's about making FLUX.2 LoRA training reachable from the 2×-24-GB market at all, which is where the consumer 2-GPU configurations actually live (used 3090s, second-gen 4090 builds, A5000 workstations).
What changed
Dual-GPU FLUX.2 LoRA training
examples/flux2/model_training/flux2_dual_gpu_diffsynth.py(new, ~170 LOC): splitsFlux2DiT.single_transformer_blocksat the midpoint acrosscuda:0/cuda:1. Registersforward_pre_hookon every cuda:1 single block —Flux2DiT.forwardpasses loop-level constants (temb_mod_params,image_rotary_emb,joint_attention_kwargs) to each iteration, so a boundary-only hook would leave subsequent blocks receiving cuda:0 tensors.examples/flux2/model_training/train.py: forces CPU model load whenDIFFSYNTH_DUAL_GPU=true(the ~60 GB bf16 transformer would otherwise pre-allocate on cuda:0 before split), runstorchao.quantize_withFloat8WeightOnlyConfig+ afilter_fnthat excludeslora_A/lora_BLinear submodules (otherwise theirrequires_gradis stripped and backward fails at the loss), then callsenable_flux2_dual_gpu(model.pipe.dit)after PEFT has injected LoRA soblock.to(device)carries LoRA params with their base layers.diffsynth/diffusion/runner.py(launch_training_task): skipsmodel.to(accelerator.device)and usesdevice_placement=[False, True, True, True]inaccelerator.preparewhenDIFFSYNTH_DUAL_GPU=true. The env var name is intentionally model-neutral (per @gemini-code-assist review on wan: add dual-GPU model-parallel path for Wan 2.x LoRA training (depends on #1435) #1436) so the same gate covers both FLUX.2 and Wan dual-GPU paths.Mistral-on-CPU
sft:data_processdiffsynth/diffusion/runner.py(launch_data_process_task): env-var-gated CPU-offload branch (DIFFSYNTH_DATA_PROCESS_ON_CPU=true). Combined with--initialize_model_on_cpu, this lets users on cards without 48 GB VRAM run the feature-cache step on CPU (~15 s per 256-token caption at 512² — slow, but a one-time pre-process). Without this, Mistral-3-Small-24B's ~48 GB bf16 weights OOM on any consumer 32 GB card.transformers 5.8 compat
transformers5.8 trimmedMistral3ForConditionalGeneration.forward's positional signature (removedoutput_attentions,output_hidden_states,return_dict,cache_position— they'reTransformersKwargs-only now). The existing wrapper passed 15 positional args and dies with:Rewrote
Flux2TextEncoder.forwardto pass everything by keyword and re-injectoutput_hidden_states/output_attentionsvia**kwargsso the existingget_mistral_3_small_prompt_embedscaller still receives a populatedoutput.hidden_states.Validation
End-to-end on 2× RTX 5090 with sumi v8 data + the unchanged recipe from
examples/flux2/model_training/lora/FLUX.2-dev.sh:sft:data_process(with--initialize_model_on_cpu+DIFFSYNTH_DATA_PROCESS_ON_CPU=true): 3 sumi-v8 images cached on CPU-offloaded Mistral-24B in 43 s.sft:train(withDIFFSYNTH_DUAL_GPU=true): 15/15 steps at 2.69 s/it sustained, 40 s wall-clock. Distribution lands at 20.7 GB cuda:0 / 12.6 GB cuda:1 after fp8 weight-only quant. LoRA checkpoint (270 MB at rank 32) saved.Cross-vendor (2× 3090 / 2× 4090) validation pending — happy to run if a maintainer has access to those configs and wants concrete numbers before merge. Architecture math suggests comfortable fit; the patch shape has been validated across four FLUX.2 LoRA trainers (ai-toolkit, musubi-tuner, OneTrainer, DiffSynth-Studio) with consistent ~21 / ~13 GB distribution.
Context
Cross-trainer write-up at https://github.com/genno-whittlery/flux2-dual-gpu-lora; DiffSynth-specific porting walkthrough at https://github.com/genno-whittlery/flux2-dual-gpu-lora/blob/main/docs/porting-diffsynth-studio.md.
The LoRA-skip
filter_fnonquantize_was a novel finding from this port — diffusers' reference script earlier hitTorchaoLoraLinear's constructor incompatibility because it was quantizing LoRA submodules too. Skippinglora_A/lora_Bkeepsrequires_gradand routes PEFT through its normalLoraLinearinstead of the broken torchao dispatcher.Companion PR #1436 ports the same dual-GPU pattern to Wan video LoRA training and shares the
DIFFSYNTH_DUAL_GPUenv var with this PR.Test plan
sft:data_processon 2× RTX 5090 with Mistral-on-CPU — produces valid cached .pth files (verified shapes: latents (1,1024,128) bf16, prompt_embeds (1,512,15360) bf16)sft:train15-step run, dual-GPU, fp8 weight-only — completes, LoRA savestransformers5.8 — Mistral3 forward succeeds,output.hidden_statespopulated forget_mistral_3_small_prompt_embedstrain.py/runner.pytake existing branches)flux2examples +flux2_text_encoder.py+runner.py(gated branches only)