New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vq diffusion classifier free sampling #1294
vq diffusion classifier free sampling #1294
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
d0d5beb
to
4ee1e06
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
4ee1e06
to
cac658d
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
cac658d
to
f5fbbbe
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
f5fbbbe
to
8746de8
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
8746de8
to
fe8db41
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
fe8db41
to
bfc4459
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
bfc4459
to
40dc3ff
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
40dc3ff
to
10e2ea3
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
10e2ea3
to
08984ab
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
"https://huggingface.co/datasets/williamberman/misc/resolve/main" | ||
"/vq_diffusion/teddy_bear_pool_classifier_free_sampling.png" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be moved to the huggingface testing dataset. FWIW you might have to also regenerate the image because I get a different image on VQDiffusionPipelineIntegrationtests#test_vq_diffusion
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool :-) I'll move it!
src/diffusers/models/embeddings.py
Outdated
class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): | ||
""" | ||
Utility class for storing learned text embeddings for classifier free sampling | ||
""" | ||
|
||
@register_to_config | ||
def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None): | ||
super().__init__() | ||
|
||
self.learnable = learnable | ||
|
||
if self.learnable: | ||
assert hidden_size is not None, "learnable=True requires `hidden_size` to be set" | ||
assert length is not None, "learnable=True requires `length` to be set" | ||
|
||
embeddings = torch.zeros(length, hidden_size) | ||
else: | ||
embeddings = None | ||
|
||
self.embeddings = nn.Parameter(embeddings) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is the preferred way to add the learned embeddings to the pipeline. An alternative might be to add the additional vector to the scheduler instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's very model specific, so moving it to the pipeline here directly :-)
Think that's a bit cleaner! The model works much better now though - thanks!
@@ -64,6 +65,7 @@ def __init__( | |||
tokenizer: CLIPTokenizer, | |||
transformer: Transformer2DModel, | |||
scheduler: VQDiffusionScheduler, | |||
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's definitely the right way to do it - it's quite specific to vq-diffusion IMO though, so will move it here :-)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Very nice job @williamberman ! |
* vq diffusion classifier free sampling * correct * uP Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Adds classifier free sampling to VQ diffusion. This results in significantly better image quality.
The pipeline now has a default guidance_scale of 5.0
Additionally, the ithq dataset uses a learned parameter for the classifier free embeddings. We modify the convert script to add this parameter to the ported model. Weights will have to be reuploaded
Prompts: "teddy bear playing in the pool" and "horse"
Diffusers VQ diffusion with classifier free sampling
Diffusers VQ diffusion without classifier free sampling
Original VQ diffusion implementation with classifier free sampling
Original VQ diffusion implementation without classifier free sampling