Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/flux-control/train_control_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def main(args):
flux_transformer.x_embedder = new_linear

assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely sure if this is required.

import torch
from diffusers import FluxTransformer2DModel

model_id = "hf-internal-testing/tiny-flux-pipe"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
print(transformer.config)

in_channels = transformer.config.in_channels
new_in_channels = in_channels * 2

transformer.register_to_config(in_channels=new_in_channels)
print(transformer.config)

out_channels won't be updated here. Tested with the big flux transformer model as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think your reproduction is fully representative of what we're doing here. Let's understand in more details:

from diffusers import FluxTransformer2DModel
import torch 

flux_transformer = FluxTransformer2DModel.from_pretrained(
    "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", torch_dtype=torch.bfloat16
)
with torch.no_grad():
    initial_input_channels = flux_transformer.config.in_channels
    new_linear = torch.nn.Linear(
        flux_transformer.x_embedder.in_features * 2,
        flux_transformer.x_embedder.out_features,
        bias=flux_transformer.x_embedder.bias is not None,
        dtype=flux_transformer.dtype,
        device=flux_transformer.device,
    )
    # Skip the layer weight copying for brevity #
    flux_transformer.x_embedder = new_linear

flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
print(flux_transformer.config.in_channels)
print(flux_transformer.config.out_channels)

It prints:

8
None

So, what happens is that when we call save_pretrained() on the flux_transformer, the config.json is:

{'_class_name': 'FluxTransformer2DModel', '_diffusers_version': '0.33.0.dev0', '_name_or_path': 'hf-internal-testing/tiny-flux-pipe', 'attention_head_dim': 16, 'axes_dims_rope': [4, 4, 8], 'guidance_embeds': False, 'in_channels': 8, 'joint_attention_dim': 32, 'num_attention_heads': 2, 'num_layers': 1, 'num_single_layers': 1, 'out_channels': None, 'patch_size': 1, 'pooled_projection_dim': 32}

And now when we call from_config() with the path where it was serialized, we get out_channels to be like so:

...

flux_transformer = FluxTransformer2DModel.from_config(path)
print(flux_transformer.out_channels)
# prints 8

out_channels should be 4 here as we don't want that doubled. So, when we try to call from_pretrained(), it errors out:

ValueError: Cannot load /tmp/tmp22cr7b_d  because proj_out.bias expected shape torch.Size([8]), but got torch.Size([4]). If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example.

So, we need to explicitly register the out_channels, too.

LMK if it's unclear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, now it makes sense 👍 It's the out_channels = out_channels or in_channels in transformer init where the problem originates from, when using expanded linear config, so now it's set to correct value, got it


def unwrap_model(model):
model = accelerator.unwrap_model(model)
Expand Down Expand Up @@ -1166,6 +1166,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
flux_transformer.to(torch.float32)
flux_transformer.save_pretrained(args.output_dir)

del flux_transformer
del text_encoding_pipeline
del vae
free_memory()

# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
Expand Down
7 changes: 6 additions & 1 deletion examples/flux-control/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def main(args):
flux_transformer.x_embedder = new_linear

assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)

if args.train_norm_layers:
for name, param in flux_transformer.named_parameters():
Expand Down Expand Up @@ -1319,6 +1319,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
transformer_lora_layers=transformer_lora_layers,
)

del flux_transformer
del text_encoding_pipeline
del vae
free_memory()

# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
Expand Down
Loading