Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make compel work with SD-XL #41

Merged
merged 22 commits into from
Jul 18, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Jul 10, 2023

This PR allows to use the Compel library with SD-XL. This is a first proposal.

The reason I went for this design is because the different text encoders even use different tokenizers (they don't match 100% - the pad token ID is different).

The following seems to work (I have not tested it thoroughly):

from diffusers import DiffusionPipeline
from compel import Compel
import torch

pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9", variant="fp16", use_safetensors=True, torch_dtype=torch.float16).to("cuda")
compel = Compel(tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] , text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], use_penultimate_clip_layer=True, use_penultimate_layer_norm=False, requires_pooled=[False, True])

# ++ or -- "ball"

prompt = "a cat playing with a ball++ in the forest"
# prompt = "a cat playing with a ball-- in the forest"
conditioning, pooled = compel(prompt)

# generate image
image = pipeline(prompt_embeds=conditioning, pooled_prompt_embeds=pooled, num_inference_steps=30).images[0]

For "ball++" I'm gettin
ball_plus_plus

Fro "ball--" I'm getting:
ball_minus_minus

@patrickvonplaten patrickvonplaten changed the title Add sd xl Make compel work with SD-XL Jul 10, 2023
@patrickvonplaten patrickvonplaten changed the title Make compel work with SD-XL [WIP] Make compel work with SD-XL Jul 10, 2023
@patrickvonplaten patrickvonplaten mentioned this pull request Jul 10, 2023
@damian0815
Copy link
Owner

thanks @patrickvonplaten

because people are going to ask - how does this play with >75 tokens? especially considering the need to run pad_conditioning_tensors_to_same_length when the positive prompt is >75 tokens but the negative is <= 75

@damian0815
Copy link
Owner

damian0815 commented Jul 11, 2023

@patrickvonplaten i had a closer look and there's a couple of issues. first of all this simply won't work:

pooled = self.conditioning_provider.maybe_get_pooled([text])

this will passing all of the Compel syntax markers into CLIP, which is not good.

instead i think this should be designed as follows:

  1. on EmbeddingsProvider, replace the use_penultimate_clip_layer, use_penultimate_layer_norm and requires_pooled arguments to __init__ with an enum ReturnedEmbeddingsType which has options PooledOutputs, PenultimateHiddenStatesNonNormalized, PenultimateHiddenStatesNormalized and LastHiddenStatesNormalized.
  2. for SDXL support, make 3 EmbeddingsProviders, not just 2 - one for the Refiner text encoder passing i guess PenultimateHiddenStatesNormalized, and two for the Base text encoder, one passing PenultimateHiddenStatesNonNormalized and the other passing PooledOutputs (i may have got those wrong but you get the idea i hope). this introduces a performance penalty because we'll be hitting the text encoders three times instead of only twice when we could be re-using the outputs, but i think it's worth it to avoid turning the logic into spaghetti.

doing number 2. at least will solve the problem of having to make a separate call to get the pooled outputs, and you'll get syntax handling on the pooled outputs the same as for the non-pooled outputs.

@patrickvonplaten
Copy link
Contributor Author

1.) For the pooled embedding vector I don't think it's a problem that all syntax markers are passed as it's a pooled vector not a sequential hidden states vector. E.g. for an input sentence:

A cat AND a ball

where AND is a syntax marker IMO it's better to do:

pooled = pooled_vector_of("A cat AND a ball")

compared to:

pooled = torch.mean([pooled_vector_of("A cat"), pooled_vector_of("A ball")]

or some other "merging" operation.

Pooling vectors only need to get the "general" gist of the input text, they don't need to have contextualized embeddings for every token in the input.

The way I understood it: "A cat AND a ball" will be split into multiple forward passes for the clip cross attention embedding, which is not what we want to do for pooling IMO. For clip cross attention this is fine / good because here we rather do torch.concat(...) here which is not "distribution-destructive" compared to merged two pooled vectors IMO.

What do you think?

2.) Regrading sequences that run over the max limit I'd just truncate / cut them for now. We could revisit later if we think it makes sense, but the orginal implementation also cuts it for now (see here).

RE:

instead i think this should be designed as follows:

on EmbeddingsProvider, replace the use_penultimate_clip_layer, use_penultimate_layer_norm and requires_pooled arguments to init with an enum ReturnedEmbeddingsType which has options PooledOutputs, PenultimateHiddenStatesNonNormalized, PenultimateHiddenStatesNormalized and LastHiddenStatesNormalized.
for SDXL support, make 3 EmbeddingsProviders, not just 2 - one for the Refiner text encoder passing i guess PenultimateHiddenStatesNormalized, and two for the Base text encoder, one passing PenultimateHiddenStatesNonNormalized and the other passing PooledOutputs (i may have got those wrong but you get the idea i hope). this introduces a performance penalty because we'll be hitting the text encoders three times instead of only twice when we could be re-using the outputs, but i think it's worth it to avoid turning the logic into spaghetti.
doing number 2. at least will solve the problem of having to make a separate call to get the pooled outputs, and you'll get syntax handling on the pooled outputs the same as for the non-pooled outputs.

I don't fully understand. Do you maybe just want to take over the PR? Feel free to open a new one if this one is too far away from what you were thinking.

@bghira

This comment was marked as outdated.

@patrickvonplaten
Copy link
Contributor Author

Note that we're just talking about "long prompt parsing" of the pooled vector where it's much less clear how we can average multiple pooled vectors. The cross attention vectors are just like before "long prompt parsed". We cannot provide a pooled prompt vector of arbitrary length by definition (contrary to the cross attention vectors).

@bghira

This comment was marked as resolved.

@damian0815 damian0815 merged commit 3308c50 into damian0815:main Jul 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants