Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PAG support for SD architecture #8725

Merged
merged 11 commits into from
Jun 29, 2024
Merged

Conversation

shauray8
Copy link
Contributor

@shauray8 shauray8 commented Jun 28, 2024

What does this PR do?

Adds PAG (Perturbed-Attention Guidance) support for SD models (StableDiffusionPAGPipeline). Continuation of #7944

Fixes #8710 (partially)

Before submitting

Who can review?

@yiyixuxu
Anyone in the community is free to review the PR once the tests have passed.


for the wonderers, here are some of my results I found during testing

Comparison between activation layers Comparison between PAG and no-PAG

I thought attention applied on the latter layers should give out much better quality and applying on middle layers would be much faster.


Usage [SD+PAG]

from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
    "Lykon/DreamShaper",
    enable_pag=True,
    pag_applied_layers = ["mid", "up.block_1.attentions_0"],
    torch_dtype=torch.float16
)
pipeline.enable_model_cpu_offload()


pag_scales =  [0.0, 3.0]
guidance_scales = [0.0, 2.0]

grid = []
for pag_scale in pag_scales:
    for guidance_scale in guidance_scales:
        generator = torch.Generator(device="cpu").manual_seed(0)
        images = pipeline(
            prompt="a polar bear sitting in a chair drinking a milkshake",
            negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
            num_inference_steps=30,
            guidance_scale=guidance_scale,
            generator=generator,
            pag_scale=pag_scale,
            height=512,
            width=512,
        ).images
        images[0]

        grid.append(images[0])

# save the grid
from diffusers.utils import make_image_grid
make_image_grid(grid, rows=len(pag_scales), cols=len(guidance_scales)).save("sample.png")

**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu should I remove all the deprecate messages, I think this is long deprecated

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this method here since it is not used in the pipeline


return prompt_embeds, negative_prompt_embeds

def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
Copy link
Member

@a-r-r-o-w a-r-r-o-w Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add # Copied from comments at every method that requires it in similar fashion to how it's done in other pipelines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, forgot to do that, done!

@Bhavay-2001
Copy link
Contributor

Hi @shauray8, I just have one query. How did you compare the StableDiffusionPipeline and added the support for PAG. How did you figure out where to add the lines for PAG support. I am facing a bit difficulty in that.
Thanks

@yiyixuxu yiyixuxu mentioned this pull request Jun 28, 2024
5 tasks
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! very nice work!
I think there are some deprecated method from SD1.5 that we do not need to add in PAG, other than that, it is perfect!

**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this method here since it is not used in the pipeline

return image, has_nsfw_concept

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this method here if it is not used by this pipeline :)

callback_on_step_end_tensor_inputs: List[str] = ["latents"],
pag_scale: float = 3.0,
pag_adaptive_scale: float = 0.0,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,

can remove this if not used

@@ -76,7 +76,7 @@
>>> pipe = AutoPipelineForText2Image.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0",
... torch_dtype=torch.float16,
... enabe_pag=True,
... enable_pag=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you!!

@yiyixuxu
Copy link
Collaborator

hi @Bhavay-2001
if you are working on StableDiffusionControlNetPAGImg2ImgPipeline, no? so I think you can:

  1. copy over the code from StableDiffusionControlNetImg2ImgPipeline as a starting point
  2. compare the code between StableDiffusionXLControlNetPAGImg2ImgPipelineand StableDiffusionXLControlNetImg2ImgPipeline, understand the change we introduced for PAG and apply same logic to StableDiffusionControlNetPAGImg2ImgPipeline

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu
Copy link
Collaborator

can get the CI to pass too: run make style and make fix-copies

@shauray8
Copy link
Contributor Author

@yiyixuxu removed the methods mentioned above, with all the changes necessary for the CI to go green.

@yiyixuxu yiyixuxu merged commit 8690e8b into huggingface:main Jun 29, 2024
14 of 15 checks passed
@yiyixuxu
Copy link
Collaborator

@shauray8 thanks for your contribution!

@shauray8 shauray8 deleted the pag_sd15 branch July 4, 2024 10:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add PAG support to SD1.5
5 participants