Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: extract init/forward function in UNet2DConditionModel #6478

Merged
merged 4 commits into from
Jan 19, 2024

Conversation

ultranity
Copy link
Contributor

What does this PR do?

Current UNet2DConditionModel mixed with different variants have a very long impl while widely used, add some stage function might help developers/researchers to better understand the code and make it easier to hack

  • refactor UNet2DConditionModel to improve readbility and extensibility
    • extract function in init
      • _check_config
      • _set_encoder_hid_proj
      • _set_class_embeder
      • _set_add_embeder
      • _set_pos_net_if_use_gligen
      • get_mid_block
    • extract function in forward
      • get_time_embed
      • get_class_embed
      • get_aug_embed
      • process_encoder_hidden_states
      • maybe? GLIGEN/controlnet/ controlnet XS/ t2i adapter
  • Add new function get_mid_block() to unet_2d_blocks.py like get_up_block/get_down_block

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten and @sayakpaul

- Add new function get_mid_block() to unet_2d_blocks.py
@sayakpaul
Copy link
Member

Thanks for your contributions!

We have already started refactoring UNet and it will be cleaner and cleaner in the coming days.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ultranity
Copy link
Contributor Author

Thanks for your contributions!

We have already started refactoring UNet and it will be cleaner and cleaner in the coming days.

cool, is there any relevant issue or PR there?

@sayakpaul
Copy link
Member

sayakpaul commented Jan 8, 2024

There are several actually. We have started taking a bottom-up approach here. So, we're also refactoring many other building blocks such as the embeddings class, the ResNet2D class, etc. Some relevant PRs:

Does this help?

Cc: @patrickvonplaten @DN6 and @yiyixuxu here.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually a very nice refactor in my opinion - @DN6 @yiyixuxu @sayakpaul can you take a look here?

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for working on this! I love this PR!

@@ -240,6 +240,59 @@ def get_down_block(
raise ValueError(f"{down_block_type} does not exist.")


def get_mid_block(mid_block_type, block_out_channels, mid_block_scale_factor, dropout, act_fn, norm_num_groups, norm_eps, cross_attention_dim, transformer_layers_per_block, attention_head_dim, num_attention_heads, dual_cross_attention, use_linear_projection, upcast_attention, resnet_time_scale_shift, resnet_skip_time_act, attention_type, mid_block_only_cross_attention, cross_attention_norm, blocks_time_embed_dim):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add type hints like we do for get_up_block and get_down_block

encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
time_embed_dim, timestep_input_dim = self._set_time_embed_layer(flip_sin_to_cos, freq_shift, block_out_channels, act_fn, time_embedding_type, time_embedding_dim, timestep_post_act, time_cond_proj_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's separate this function into _set_time_proj and set_time_embedding and make sure we have consistent naming logic across all these methods. i.e.

  • _set_time_proj() sets self.time_proj layer,
  • _set_time_embedding() set self.time_embedding layer,
  • _set_encoder_hid_proj() set self.encoder_hid_proj layer
  • _set_class_embedding() sets self.class_embedding layer
  • _set_add_embedding() sets self.add_embedding layer

IMO this is especially important for functions that update things in-place to help our users to have a good sense of what's been update without having to go into the code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also gives a nice chronology of the layers to the readers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it also make sense to add a _ at the end of the function names to denote that the functions do things in-place?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think _set is a nice prefix to denote inplace operations. _ at the end is a bit unconventional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • currently there is only one impl for time_embedding layer, so _set_time_embedding is not necessary?
  • _set prefix also act as a indicator for internal use and should not be called directly by users.
  • _ at the end actually follows the pytorch naming style for in-place actions, but it might not be necessary if we already have a prefix?

src/diffusers/models/unet_2d_condition.py Outdated Show resolved Hide resolved
src/diffusers/models/unet_2d_condition.py Outdated Show resolved Hide resolved
self.mid_block = None
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
self.mid_block = get_mid_block(mid_block_type, block_out_channels, mid_block_scale_factor, dropout, act_fn, norm_num_groups, norm_eps, cross_attention_dim, transformer_layers_per_block, attention_head_dim, num_attention_heads, dual_cross_attention, use_linear_projection, upcast_attention, resnet_time_scale_shift, resnet_skip_time_act, attention_type, mid_block_only_cross_attention, cross_attention_norm, blocks_time_embed_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

for layer_number_per_block in transformer_layers_per_block:
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
self._check_config(down_block_types, up_block_types, only_cross_attention, block_out_channels, layers_per_block, cross_attention_dim, transformer_layers_per_block, reverse_transformer_layers_per_block, attention_head_dim, num_attention_heads)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 🚀

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uff, what a lovely PR! My eyes feel better now!

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking very nice!

@ultranity ultranity changed the title [WIP] refactor: extract init/forward function in UNet2DConditionModel refactor: extract init/forward function in UNet2DConditionModel Jan 14, 2024
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
encoder_hidden_states = self.process_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
encoder_hidden_states = self.process_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs)
if self.encoder_hid_proj is not None:
encoder_hidden_states = self.process_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs)

Adding a short if statement here makes it a bit easier to understand IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process_encoder_hidden_states (or rename to process_added_cond?) without if statement outside could be more easily extendable for future changes IMO?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'd rename the function to project_encoder_hidden_states because we linearly project the encoder hidden states here. Then personally I think it's much better to have in the if statement because most of the SD models do not use this function (it's only the IF models really). If we have it in an if-statement, people reading the code that know SD will see directly that this function is not applied.

So I guess the following would be ideal for me:

        if self.encoder_hid_proj is not None:
            encoder_hidden_states = self.project_encoder_hidden_states(encoder_hidden_states, added_cond_kwargs)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a super cool PR @ultranity! Very important to keep diffusers readable and in order to fight the bloated code.

I just have some minor suggestions - overall very happy to merge this one soon!

@ultranity
Copy link
Contributor Author

About the last failing code quality check, as versatile_diffusion is deprecated, I prefer to remove the Copy from tag directly

How do you think @patrickvonplaten

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jan 17, 2024

About the last failing code quality check, as versatile_diffusion is deprecated, I prefer to remove the Copy from tag directly

How do you think @patrickvonplaten

Agree 100%! Could you maybe remove the Copy-from from versatile diffusion?

@patrickvonplaten
Copy link
Contributor

Great job @ultranity !

@patrickvonplaten patrickvonplaten merged commit c544196 into huggingface:main Jan 19, 2024
14 checks passed
@ultranity ultranity deleted the refactor_unet branch January 20, 2024 04:11
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…ingface#6478)

* - extract function for stage in UNet2DConditionModel init & forward
- Add new function get_mid_block() to unet_2d_blocks.py

* add type hint to get_mid_block aligned with get_up_block and get_down_block; rename _set_xxx function

* add type hint and  use keyword arguments

* remove `copy from` in versatile diffusion
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants