fix: use state_dict pattern in ZImageControlNet.from_transformer to prevent weight sharing#13128
Conversation
0eecfaa to
66ff936
Compare
|
It is not intended to train these modules, see original training script https://github.com/aigc-apps/VideoX-Fun/blob/main/scripts/z_image_fun/train_turbo_control_2.1.sh#L35 |
|
literally can't train those modules or the controlnet will never converge |
|
@hlky @bghira I appreciate the feedback. I took a look, but I think there may be some confusion -- that script uses ZImageControlTransformer2DModel from videox_fun.models, which is a different class/codebase from diffusers. This PR addresses ZImageControlNetModel.from_transformer(). The diffusers from_transformer() method creates a controlnet from a transformer for inference/experimentation. The shallow copy causes silent weight corruption if the controlnet is modified (even accidentally). FluxControlNet and QwenImageControlNet both use load_state_dict for this same pattern. Happy to add a docstring note that these modules are typically frozen during training workflows, but the API itself should be safe by default. See: |
…revent weight sharing ZImageControlNetModel.from_transformer() used direct assignment to copy modules from the transformer, creating shared references. Modifying controlnet weights inadvertently mutated the original transformer. This replaces the shallow copies with fresh module instantiation + load_state_dict(), matching the established pattern used by FluxControlNetModel and QwenImageControlNetModel. For each module category, use the appropriate approach: - nn.Module: Create fresh module, then load_state_dict() - nn.Parameter: Clone the data tensor - RopeEmbedder: Fresh instance from transformer.config - Scalar (t_scale): Unchanged, immutable float is safe Fixes huggingface#13077
…ence Adds regression tests to verify that from_transformer creates independent weight copies, preventing modifications to the controlnet from affecting the original transformer. Tests cover all shared modules: - t_embedder, cap_embedder, all_x_embedder - noise_refiner, context_refiner - x_pad_token, cap_pad_token - rope_embedder instance independence - t_scale correct value
66ff936 to
76db4a2
Compare
|
Closing this PR and opened #13136, which takes a different approach based on @hlky @bghira feedback and maintainer intent. Instead of copying modules, the new PR keeps the shared references but explicitly freezes them with `requires_grad_(False)`. This preserves the memory-efficient design while preventing accidental training of shared modules. |
Summary
ZImageControlNetModel.from_transformer()uses direct assignment to copy modules from the transformer, creating shared references. Modifying controlnet weights inadvertently mutates the original transformer. This PR replaces the shallow copies with fresh module instantiation +load_state_dict(), matching the established pattern used byFluxControlNetModelandQwenImageControlNetModel.Problem
The current implementation:
This creates shared references; both objects point to the same
nn.Moduleinstances. Training the controlnet corrupts the transformer weights.Solution
For each module category, use the approach appropriate to its type:
nn.Module(t_embedder, cap_embedder, noise_refiner, etc.)controlnet.x = transformer.xload_state_dict(transformer.x.state_dict())nn.Parameter(x_pad_token, cap_pad_token)controlnet.x = transformer.xnn.Parameter(transformer.x.data.clone())RopeEmbedder(not nn.Module, no learnable weights)controlnet.x = transformer.xtransformer.configcontrolnet.x = transformer.xThis follows the exact pattern used by:
FluxControlNetModel.from_transformerQwenImageControlNetModel.from_transformerWhy not
copy.deepcopy()?deepcopytechnically fixes the symptom but is the wrong tool here:state_dictapproach copies only the tensor data, which is exactly what we wantTesting
Added 10 regression tests in
tests/models/controlnets/test_models_controlnet_z_image.pythat verify:The reproduction from #13077 now passes; modifying controlnet weights no longer affects the transformer:
Fixes #13077
cc: @yiyixuxu @JerryWu-code