Skip to content

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Oct 5, 2023

This PR fixes the SDXL pipeline so that users can:

  • Precompute the text embeddings.
  • Load the SDXL pipeline without text encoders and leverage the precomputed text embeddings to do their pipeline calls. This can lead to memory savings and is especially helpful for consumer GPUs.

Please follow #5301 (comment) for a full-fledged example.

@bghira
Copy link
Contributor

bghira commented Oct 5, 2023

conversely, pipeline fails when using text embeds, and text_encoder (and _2) are unavailable

@sayakpaul
Copy link
Member Author

sayakpaul commented Oct 6, 2023

conversely, pipeline fails when using text embeds, and text_encoder (and _2) are unavailable

Could you be a bit more specific here? Happy to try to fix it here.

Edit: Ah I know what you mean.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 6, 2023

The documentation is not available anymore as the PR was closed or merged.

@sayakpaul
Copy link
Member Author

sayakpaul commented Oct 6, 2023

Huh!

from diffusers import StableDiffusionXLPipeline
import torch 

prompt = "hey"
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    unet=None,
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds
) = pipe.encode_prompt(prompt)

del pipe

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

call_args = dict(
    prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
    negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_images_per_prompt=1,
    num_inference_steps=2
)
_ = pipe(**call_args)

Does not work:

[/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py](https://localhost:8080/#) in __call__(self, prompt, prompt_2, height, width, num_inference_steps, denoising_end, guidance_scale, negative_prompt, negative_prompt_2, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, output_type, return_dict, callback, callback_steps, cross_attention_kwargs, guidance_rescale, original_size, crops_coords_top_left, target_size, negative_original_size, negative_crops_coords_top_left, negative_target_size, clip_skip)
    851         # 7. Prepare added time ids & embeddings
    852         add_text_embeds = pooled_prompt_embeds
--> 853         add_time_ids = self._get_add_time_ids(
    854             original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
    855         )

[/usr/local/lib/python3.10/dist-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py](https://localhost:8080/#) in _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype)
    544 
    545         passed_add_embed_dim = (
--> 546             self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
    547         )
    548         expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

AttributeError: 'NoneType' object has no attribute 'config'

Looking into it.

@sayakpaul
Copy link
Member Author

@patrickvonplaten let me know if the proposed solution works for you. If so, I will go ahead and propagate the changes along with adding a test.

@patrickvonplaten
Copy link
Contributor

What is this needed for? I don't think one should be allowed to create a StableDiffusionPipeline instance without a unet, this doesn't make much sense.

If one wants to just encode the inputs, we should probably follows @williamberman's suggestion here and try to make the encode_prompt function a classmethod (not the no_grad naming, just making it a class method).

Overall I'm curious though to know when a pipeline without unet would be needed?

@sayakpaul
Copy link
Member Author

sayakpaul commented Oct 9, 2023

Much like how we allow the IF pipeline to be loaded with a unet and use encode_prompt(). For SDXL, this could be a big memory saver as essentially you have the ability to not initialize a pipeline without the text encoders or the UNet and reuse computations.

Additionally, the example provided in #5301 (comment) should already help explain the use case.

@bghira
Copy link
Contributor

bghira commented Oct 9, 2023

in SimpleTuner, we have to keep the text encoder "loaded" just so that the checks in the SDXL pipeline do not fail during validation, despite passing in the negative/positive embeds/conditionings

this would acceptably solve the problem there, instead of consuming system memory to keep the model loaded. i'm sure there's another way i could nuke the weights from orbit, but this seems like it's in alignment with Kandinsky and DeepFloyd and other pipelines that we/I support.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 9, 2023

We don't allow loading IF with a unet. We allow loading IF without a text encoder as the text encoder is the optional component list here:

_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]

We should not allow SDXL to be loaded without a unet. The unet is the heart of SDXL and it makes no sense to load a pipeline without it (the pipeline can then not be used at all).

We can allow loading SDXL without text encoders (if we believe that this is not an edge case), which would mean we should do:

  • add text_encoder_2 to the optional components list
  • add an additional class method of encode_prompt so that one can use the text encoder without a unet

Also cc @williamberman here

@sayakpaul
Copy link
Member Author

sayakpaul commented Oct 9, 2023

We don't allow loading IF with a unet. We allow loading IF without a text encoder as the text encoder is the optional component list here:

From the blog post:

from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-I-XL-v1.0", 
    text_encoder=text_encoder, # pass the previously instantiated 8bit text encoder
    unet=None, 
    device_map="auto"
)

I don't follow your reasoning or the design you are proposing much. In particular, the "edge case" part. If I am missing something please elaborate a bit more about the APIs you're envisioning.

Why shouldn't a user be allowed to precompute the text embeddings (both pooled and non-pooled) with the text encoders (without loading the UNet) and then reuse it while calling the SDXL pipeline without loading the text encoders? This saves memory, if I am not mistaken.

The flow I have in mind for the users is exactly the one I showed in #5301 (comment). Are you essentially saying we should instead do the following?

from diffusers import StableDiffusionXLPipeline
import torch 

prompt = "hey"
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds
) = pipe.encode_prompt_class_method(prompt)

del pipe.text_encoder, pipe.text_encoder_2, pipe.tokenizer, pipe.tokenizer_2

call_args = dict(
    prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
    negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, num_images_per_prompt=1,
    num_inference_steps=2
)
_ = pipe(**call_args)

I am okay with it but this would require another library-wide refactoring / addition of a class-based encode_prompt() method, I believe. It kind of disregards the long discussion in #4140.

@sayakpaul
Copy link
Member Author

sayakpaul commented Oct 12, 2023

@patrickvonplaten thought about a bit and here's the trade-off (developer exp. wise) I okay having.

Note that the use case we're trying to target here is to allow people to load a pipeline without the text encoders for memory savings. I have updated the PR title and the description accordingly.

So, the flow now becomes:

  1. Users make use of the encode_prompt() method from here with pre-loaded text encoders and corresponding tokenizers. This loading can be done in many ways:
  • users can either load them explicitly (like: text_encoder = CLIPTextModel.from_pretrained(...); ....
  • users can leverage an existing pipeline with all the components loaded, call encode_prompt() to get the embeddings for the prompt, delete the text encoders, and then do the regular pipeline call.
  1. Step 1 above gives us the text embeddings we need to make the actual pipeline call. Before making that call, we make sure to free the text encoders from memory.

In code, it looks something like this.

First, load up the text encoders along with their tokenizers:

import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer


pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
torch_dtype = torch.float16

# load the text encoders and tokenizers
text_encoder = CLIPTextModel.from_pretrained(pipe_id, subfolder="text_encoder", torch_dtype=torch.float16).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(pipe_id, subfolder="tokenizer")
text_encoder_2 = CLIPTextModel.from_pretrained(pipe_id, subfolder="text_encoder_2", torch_dtype=torch.float16).to("cuda")
tokenizer_2 = CLIPTokenizer.from_pretrained(pipe_id, subfolder="tokenizer_2")

Then repurpose the encode_prompt() method:

def encode_prompt(tokenizers, text_encoders, prompt: str, negative_prompt: str = None):
    device = text_encoders[0].device 

    if isinstance(prompt, str):
        prompt = [prompt]
    batch_size = len(prompt)

    prompt_embeds_list = []
    for tokenizer, text_encoder in zip(tokenizers, text_encoders):
        text_inputs = tokenizer(
            prompt,
            padding="max_length",
            max_length=tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )

        text_input_ids = text_inputs.input_ids

        prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
        pooled_prompt_embeds = prompt_embeds[0]
        prompt_embeds = prompt_embeds.hidden_states[-2]
        prompt_embeds_list.append(prompt_embeds)

    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)

    if negative_prompt is None:
        negative_prompt_embeds = torch.zeros_like(prompt_embeds)
        negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
    else:
        negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
        
        negative_prompt_embeds_list = []
        for tokenizer, text_encoder in zip(tokenizers, text_encoders):
            uncond_input = tokenizer(
                negative_prompt,
                padding="max_length",
                max_length=tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )

            negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True)
            negative_pooled_prompt_embeds = negative_prompt_embeds[0]
            negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
            negative_prompt_embeds_list.append(negative_prompt_embeds)

        negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

    bs_embed, seq_len, _ = prompt_embeds.shape
    
    # duplicate text embeddings for each generation per prompt, using mps friendly method
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

    # for classifier-free guidance
    # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
    seq_len = negative_prompt_embeds.shape[1]

    negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
    negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
        bs_embed * num_images_per_prompt, -1
    )
    
    # for classifier-free guidance
    negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
        bs_embed * num_images_per_prompt, -1
    )

    return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
tokenizers = [tokenizer, tokenizer_2]
text_encoders = [text_encoder, text_encoder_2]

(
    prompt_embeds,
    negative_prompt_embeds,
    pooled_prompt_embeds,
    negative_pooled_prompt_embeds
) = encode_prompt(tokenizers, text_encoders, prompt)

Delete the text encoders:

del text_encoder, text_encoder_2, tokenizer, tokenizer_2

Then do the pipeline call:

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    pipe_id,
    text_encoder=None,
    text_encoder_2=None,
    tokenizer=None,
    tokenizer_2=None,
    torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
_ = pipe(
    prompt=None, 
    prompt_embeds=prompt_embeds, 
    negative_prompt_embeds=negative_prompt_embeds, 
    pooled_prompt_embeds=pooled_prompt_embeds, 
    negative_pooled_prompt_embeds=negative_pooled_prompt_embeds
)

This experience isn't too bad IMO as we're still able to leverage parts of the library as per our needs. This also doesn't require us to maintain separate classmethod variants. Not to mention, this way users can also fully operate on the embedding space doing things like weighted prompt embeddings with compel, etc., for example

WDYT?

@sayakpaul sayakpaul changed the title [WIP][Core] Fix/pipeline without unet [WIP][Core] Fix/pipeline without text encoders for SDXL Oct 12, 2023
@patrickvonplaten
Copy link
Contributor

Ok for me to allow loading the unet without text encoder! Let's make sure though that in this case:

@sayakpaul sayakpaul marked this pull request as ready for review October 13, 2023 11:25
PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
SDXLOptionalComponentsTesterMixin,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New pipeline class to consolidate the testing of the optional components in SDXL and its derivative pipelines.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some nits, but apart from this, the PR looks good!

@bghira
Copy link
Contributor

bghira commented Oct 16, 2023

wow, thank you so much @sayakpaul this was a complicated one. i was going to solve it last month, but the amount of earth-moving you've done to make it happen is exactly why i couldn't get it there. good work!

@sayakpaul
Copy link
Member Author

@patrickvonplaten I had to do 38e16f8.

The nits you suggested couldn't correctly instantiate the pipeline properly.

import torch.nn as nn
import torch.nn.functional as F

from diffusers.utils import deprecate
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to be imported. @patrickvonplaten just as an FYI.

@sayakpaul sayakpaul merged commit 8b3d2ae into main Oct 17, 2023
@sayakpaul sayakpaul deleted the fix/pipeline-without-unet branch October 17, 2023 05:47
@sayakpaul sayakpaul mentioned this pull request Oct 17, 2023
mhetrerajat pushed a commit to mhetrerajat/diffusers that referenced this pull request Oct 23, 2023
* fix: sdxl pipeline when unet is not available.

* fix moe

* account for text

* ifx more

* don't make unet optional.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* split conditionals.

* add optional components to sdxl pipeline

* propagate changes to the rest of the pipelines.

* add: test

* add to all

* fix: rest of the pipelines.

* use pipeline_class variable

* separate pipeline mixin

* use safe_serialization

* fix: test

* access actual output.

* add: optional test to adapter and ip2p sdxl pipeline tests/

* add optional test to controlnet sdxl.

* fix tests

* fix ip2p tests

* fix more

* fifx more.

* use np output type.

* fix for StableDiffusionXLMultiControlNetPipelineFastTests.

* fix: SDXLOptionalComponentsTesterMixin

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix tests

* Empty-Commit

* revert previous

* quality

* fix: test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* fix: sdxl pipeline when unet is not available.

* fix moe

* account for text

* ifx more

* don't make unet optional.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* split conditionals.

* add optional components to sdxl pipeline

* propagate changes to the rest of the pipelines.

* add: test

* add to all

* fix: rest of the pipelines.

* use pipeline_class variable

* separate pipeline mixin

* use safe_serialization

* fix: test

* access actual output.

* add: optional test to adapter and ip2p sdxl pipeline tests/

* add optional test to controlnet sdxl.

* fix tests

* fix ip2p tests

* fix more

* fifx more.

* use np output type.

* fix for StableDiffusionXLMultiControlNetPipelineFastTests.

* fix: SDXLOptionalComponentsTesterMixin

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix tests

* Empty-Commit

* revert previous

* quality

* fix: test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* fix: sdxl pipeline when unet is not available.

* fix moe

* account for text

* ifx more

* don't make unet optional.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* split conditionals.

* add optional components to sdxl pipeline

* propagate changes to the rest of the pipelines.

* add: test

* add to all

* fix: rest of the pipelines.

* use pipeline_class variable

* separate pipeline mixin

* use safe_serialization

* fix: test

* access actual output.

* add: optional test to adapter and ip2p sdxl pipeline tests/

* add optional test to controlnet sdxl.

* fix tests

* fix ip2p tests

* fix more

* fifx more.

* use np output type.

* fix for StableDiffusionXLMultiControlNetPipelineFastTests.

* fix: SDXLOptionalComponentsTesterMixin

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fix tests

* Empty-Commit

* revert previous

* quality

* fix: test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants