-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[refactor embeddings]pixart-alpha #6212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder | ||
) | ||
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) | ||
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice refactor!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if the SLOW tests pass. Could you please run the slow tests with these changes as well?
if do_classifier_free_guidance: | ||
resolution = torch.cat([resolution, resolution], dim=0) | ||
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a new addition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really - it gets duplicated later inside embedding
diffusers/src/diffusers/models/embeddings.py
Line 758 in 6976cab
if size.shape[0] != batch_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice <3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sayakpaul
fast tests fail because there are some randomly initialized weights in some components. I think we need to put torch.manual_seed(0)
before making each component e.g.
vae = AutoencoderKL() |
should we open a new PR to only update the tests, and I rebase after that? I'm not comfortable updating tests directly from this PR since I updated the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that wasn't the case before. Wonder what changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -235,7 +235,7 @@ def __init__( | |||
|
|||
self.caption_projection = None | |||
if caption_channels is not None: | |||
self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) | |||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might actually be worth breaking up Transformer2D up into a dedicated one for PixArt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a future PR, yeah? I am happy to work on it once this is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely for a future PR
But I think we should refactor transformers and UNet after we clean up all the lower-level classes and make such decisions for all models/pipelines at once so it will be consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes 100 percent!
pixart-alpha Co-authored-by: yiyixuxu <yixu310@gmail,com>
pixart-alpha Co-authored-by: yiyixuxu <yixu310@gmail,com>
pixart-alpha Co-authored-by: yiyixuxu <yixu310@gmail,com>
part of my embedding refactor, separated by model/pipeline so it is easier to work with
this PR focuses on embeddings only used in Pixar-alpha: i.e.
CombinedTimestepSizeEmbeddings
andCaptionProjection
:PixArtAlphaCombinedTimestepSizeEmbeddings
andPixArtAlphaTextProjection
so it is clear that these embeddings are only used in PixArt-Alpha