Skip to content

[training] fix: registration of out_channels in the control flux scripts. #10367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 24, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Additionally registers the out_channels otherwise, it will be overwritten as the newly set in_channels value.

@sayakpaul sayakpaul requested a review from a-r-r-o-w December 24, 2024 03:24
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The free memory changes look good, but don't think the change for out_channels is required. LMK if I'm missing some context

@@ -795,7 +795,7 @@ def main(args):
flux_transformer.x_embedder = new_linear

assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not completely sure if this is required.

import torch
from diffusers import FluxTransformer2DModel

model_id = "hf-internal-testing/tiny-flux-pipe"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
print(transformer.config)

in_channels = transformer.config.in_channels
new_in_channels = in_channels * 2

transformer.register_to_config(in_channels=new_in_channels)
print(transformer.config)

out_channels won't be updated here. Tested with the big flux transformer model as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think your reproduction is fully representative of what we're doing here. Let's understand in more details:

from diffusers import FluxTransformer2DModel
import torch 

flux_transformer = FluxTransformer2DModel.from_pretrained(
    "hf-internal-testing/tiny-flux-pipe", subfolder="transformer", torch_dtype=torch.bfloat16
)
with torch.no_grad():
    initial_input_channels = flux_transformer.config.in_channels
    new_linear = torch.nn.Linear(
        flux_transformer.x_embedder.in_features * 2,
        flux_transformer.x_embedder.out_features,
        bias=flux_transformer.x_embedder.bias is not None,
        dtype=flux_transformer.dtype,
        device=flux_transformer.device,
    )
    # Skip the layer weight copying for brevity #
    flux_transformer.x_embedder = new_linear

flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
print(flux_transformer.config.in_channels)
print(flux_transformer.config.out_channels)

It prints:

8
None

So, what happens is that when we call save_pretrained() on the flux_transformer, the config.json is:

{'_class_name': 'FluxTransformer2DModel', '_diffusers_version': '0.33.0.dev0', '_name_or_path': 'hf-internal-testing/tiny-flux-pipe', 'attention_head_dim': 16, 'axes_dims_rope': [4, 4, 8], 'guidance_embeds': False, 'in_channels': 8, 'joint_attention_dim': 32, 'num_attention_heads': 2, 'num_layers': 1, 'num_single_layers': 1, 'out_channels': None, 'patch_size': 1, 'pooled_projection_dim': 32}

And now when we call from_config() with the path where it was serialized, we get out_channels to be like so:

...

flux_transformer = FluxTransformer2DModel.from_config(path)
print(flux_transformer.out_channels)
# prints 8

out_channels should be 4 here as we don't want that doubled. So, when we try to call from_pretrained(), it errors out:

ValueError: Cannot load /tmp/tmp22cr7b_d  because proj_out.bias expected shape torch.Size([8]), but got torch.Size([4]). If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example.

So, we need to explicitly register the out_channels, too.

LMK if it's unclear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay, now it makes sense 👍 It's the out_channels = out_channels or in_channels in transformer init where the problem originates from, when using expanded linear config, so now it's set to correct value, got it

@sayakpaul sayakpaul merged commit 825979d into main Dec 24, 2024
12 checks passed
@sayakpaul sayakpaul deleted the fix-control-flux-registration branch December 24, 2024 16:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants