Skip to content

fix: use state_dict pattern in ZImageControlNet.from_transformer to prevent weight sharing#13128

Closed
jscaldwell55 wants to merge 2 commits intohuggingface:mainfrom
jscaldwell55:fix/zimage-controlnet-from-transformer-shallow-copy
Closed

fix: use state_dict pattern in ZImageControlNet.from_transformer to prevent weight sharing#13128
jscaldwell55 wants to merge 2 commits intohuggingface:mainfrom
jscaldwell55:fix/zimage-controlnet-from-transformer-shallow-copy

Conversation

@jscaldwell55
Copy link

@jscaldwell55 jscaldwell55 commented Feb 11, 2026

Summary

ZImageControlNetModel.from_transformer() uses direct assignment to copy modules from the transformer, creating shared references. Modifying controlnet weights inadvertently mutates the original transformer. This PR replaces the shallow copies with fresh module instantiation + load_state_dict(), matching the established pattern used by FluxControlNetModel and QwenImageControlNetModel.

Problem

The current implementation:

controlnet.t_embedder = transformer.t_embedder
controlnet.noise_refiner = transformer.noise_refiner
# ... etc

This creates shared references; both objects point to the same nn.Module instances. Training the controlnet corrupts the transformer weights.

Solution

For each module category, use the approach appropriate to its type:

Type Old (broken) New (fixed)
nn.Module (t_embedder, cap_embedder, noise_refiner, etc.) controlnet.x = transformer.x Create fresh module, then load_state_dict(transformer.x.state_dict())
nn.Parameter (x_pad_token, cap_pad_token) controlnet.x = transformer.x nn.Parameter(transformer.x.data.clone())
RopeEmbedder (not nn.Module, no learnable weights) controlnet.x = transformer.x Fresh instance from transformer.config
Scalar (t_scale) controlnet.x = transformer.x Unchanged — immutable float

This follows the exact pattern used by:

Why not copy.deepcopy()?

deepcopy technically fixes the symptom but is the wrong tool here:

  • Copies internal PyTorch bookkeeping (hooks, grad_fn references, autograd metadata) that shouldn't be transferred
  • Higher peak memory - duplicates the entire object graph recursively
  • Not the established pattern anywhere in the diffusers controlnet codebase
  • The state_dict approach copies only the tensor data, which is exactly what we want

Testing

Added 10 regression tests in tests/models/controlnets/test_models_controlnet_z_image.py that verify:

  • Each shared module is independent (modifying controlnet doesn't affect transformer)
  • Weights are correctly copied
  • t_scale has correct value

The reproduction from #13077 now passes; modifying controlnet weights no longer affects the transformer:

import torch
from diffusers import ZImageControlNetModel, ZImageTransformer2DModel

transformer = ZImageTransformer2DModel.from_pretrained(
    "Tongyi-MAI/Z-Image", subfolder="transformer", torch_dtype=torch.bfloat16,
)
controlnet = ZImageControlNetModel(
    control_layers_places=[0, 15, 29],
    control_refiner_layers_places=[0, 1],
    add_control_noise_refiner="control_noise_refiner",
    control_in_dim=16,
)
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

# Modify controlnet weights
original_weight = transformer.t_embedder.mlp[0].weight.clone()
torch.nn.init.constant_(controlnet.t_embedder.mlp[0].weight, 42.0)

# Transformer should be unaffected
assert torch.equal(transformer.t_embedder.mlp[0].weight, original_weight), "Transformer weights were corrupted!"
print("✓ Transformer weights remain independent after controlnet modification")

Fixes #13077

cc: @yiyixuxu @JerryWu-code

@jscaldwell55 jscaldwell55 force-pushed the fix/zimage-controlnet-from-transformer-shallow-copy branch from 0eecfaa to 66ff936 Compare February 12, 2026 00:02
@hlky
Copy link
Contributor

hlky commented Feb 12, 2026

It is not intended to train these modules, see original training script https://github.com/aigc-apps/VideoX-Fun/blob/main/scripts/z_image_fun/train_turbo_control_2.1.sh#L35
With shallow copy we save on VRAM usage, and furthermore, if these modules were trained following this PR, how will they load afterwards? This is not handled
IMO the simplest solution here is to just not train these modules, don't use .train() and instead manually set requires_grad on the trainable control_* modules as in the original training script

@bghira
Copy link
Contributor

bghira commented Feb 12, 2026

literally can't train those modules or the controlnet will never converge

@jscaldwell55
Copy link
Author

jscaldwell55 commented Feb 12, 2026

@hlky @bghira I appreciate the feedback. I took a look, but I think there may be some confusion -- that script uses ZImageControlTransformer2DModel from videox_fun.models, which is a different class/codebase from diffusers. This PR addresses ZImageControlNetModel.from_transformer().

The diffusers from_transformer() method creates a controlnet from a transformer for inference/experimentation. The shallow copy causes silent weight corruption if the controlnet is modified (even accidentally). FluxControlNet and QwenImageControlNet both use load_state_dict for this same pattern.

Happy to add a docstring note that these modules are typically frozen during training workflows, but the API itself should be safe by default.

See:

…revent weight sharing

ZImageControlNetModel.from_transformer() used direct assignment to copy
modules from the transformer, creating shared references. Modifying
controlnet weights inadvertently mutated the original transformer.

This replaces the shallow copies with fresh module instantiation +
load_state_dict(), matching the established pattern used by
FluxControlNetModel and QwenImageControlNetModel.

For each module category, use the appropriate approach:
- nn.Module: Create fresh module, then load_state_dict()
- nn.Parameter: Clone the data tensor
- RopeEmbedder: Fresh instance from transformer.config
- Scalar (t_scale): Unchanged, immutable float is safe

Fixes huggingface#13077
…ence

Adds regression tests to verify that from_transformer creates independent
weight copies, preventing modifications to the controlnet from affecting
the original transformer.

Tests cover all shared modules:
- t_embedder, cap_embedder, all_x_embedder
- noise_refiner, context_refiner
- x_pad_token, cap_pad_token
- rope_embedder instance independence
- t_scale correct value
@jscaldwell55 jscaldwell55 force-pushed the fix/zimage-controlnet-from-transformer-shallow-copy branch from 66ff936 to 76db4a2 Compare February 12, 2026 18:40
@jscaldwell55
Copy link
Author

Closing this PR and opened #13136, which takes a different approach based on @hlky @bghira feedback and maintainer intent. Instead of copying modules, the new PR keeps the shared references but explicitly freezes them with `requires_grad_(False)`. This preserves the memory-efficient design while preventing accidental training of shared modules.

@jscaldwell55 jscaldwell55 deleted the fix/zimage-controlnet-from-transformer-shallow-copy branch February 13, 2026 12:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ZImageControlNet.from_transformer creates a shallow copy of the transformer weights

3 participants