-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[training] fix: registration of out_channels in the control flux scripts. #10367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The free memory changes look good, but don't think the change for out_channels
is required. LMK if I'm missing some context
@@ -795,7 +795,7 @@ def main(args): | |||
flux_transformer.x_embedder = new_linear | |||
|
|||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) | |||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2) | |||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not completely sure if this is required.
import torch
from diffusers import FluxTransformer2DModel
model_id = "hf-internal-testing/tiny-flux-pipe"
transformer = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float16)
print(transformer.config)
in_channels = transformer.config.in_channels
new_in_channels = in_channels * 2
transformer.register_to_config(in_channels=new_in_channels)
print(transformer.config)
out_channels won't be updated here. Tested with the big flux transformer model as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think your reproduction is fully representative of what we're doing here. Let's understand in more details:
from diffusers import FluxTransformer2DModel
import torch
flux_transformer = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe", subfolder="transformer", torch_dtype=torch.bfloat16
)
with torch.no_grad():
initial_input_channels = flux_transformer.config.in_channels
new_linear = torch.nn.Linear(
flux_transformer.x_embedder.in_features * 2,
flux_transformer.x_embedder.out_features,
bias=flux_transformer.x_embedder.bias is not None,
dtype=flux_transformer.dtype,
device=flux_transformer.device,
)
# Skip the layer weight copying for brevity #
flux_transformer.x_embedder = new_linear
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
print(flux_transformer.config.in_channels)
print(flux_transformer.config.out_channels)
It prints:
8
None
So, what happens is that when we call save_pretrained()
on the flux_transformer
, the config.json
is:
{'_class_name': 'FluxTransformer2DModel', '_diffusers_version': '0.33.0.dev0', '_name_or_path': 'hf-internal-testing/tiny-flux-pipe', 'attention_head_dim': 16, 'axes_dims_rope': [4, 4, 8], 'guidance_embeds': False, 'in_channels': 8, 'joint_attention_dim': 32, 'num_attention_heads': 2, 'num_layers': 1, 'num_single_layers': 1, 'out_channels': None, 'patch_size': 1, 'pooled_projection_dim': 32}
And now when we call from_config()
with the path where it was serialized, we get out_channels
to be like so:
...
flux_transformer = FluxTransformer2DModel.from_config(path)
print(flux_transformer.out_channels)
# prints 8
out_channels
should be 4 here as we don't want that doubled. So, when we try to call from_pretrained()
, it errors out:
ValueError: Cannot load /tmp/tmp22cr7b_d because proj_out.bias expected shape torch.Size([8]), but got torch.Size([4]). If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example.
So, we need to explicitly register the out_channels
, too.
LMK if it's unclear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay, now it makes sense 👍 It's the out_channels = out_channels or in_channels
in transformer init where the problem originates from, when using expanded linear config, so now it's set to correct value, got it
What does this PR do?
Additionally registers the
out_channels
otherwise, it will be overwritten as the newly setin_channels
value.