Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PAG support #7944

Merged
merged 50 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
a6a0429
first draft
yiyixuxu May 13, 2024
3605df9
refactor
yiyixuxu May 14, 2024
f94376c
update
yiyixuxu May 14, 2024
54c3fd6
up
yiyixuxu May 14, 2024
f571430
style
yiyixuxu May 14, 2024
91d0a5b
style
May 14, 2024
01585ab
update
yiyixuxu May 14, 2024
03bdbcd
inpaint + controlnet
yiyixuxu May 14, 2024
b662207
Merge branch 'pag' of github.com:huggingface/diffusers into pag
yiyixuxu May 14, 2024
1fb2c33
style
May 14, 2024
219f4b9
up
May 14, 2024
5641cb4
Update src/diffusers/pipelines/pag_utils.py
yiyixuxu May 14, 2024
8950e80
fix controlnet
KKIEEK May 15, 2024
4cc0b8b
fix compatability issue between PAG and IP-adapter (#8379)
sunovivid Jun 5, 2024
5cbf226
up
yiyixuxu Jun 6, 2024
58804a0
refactor ip-adapter
yiyixuxu Jun 8, 2024
7bc9229
style
Jun 8, 2024
e09e079
Merge branch 'main' into pag
yiyixuxu Jun 8, 2024
1fa54df
style
Jun 8, 2024
ba366f0
u[
Jun 9, 2024
d5a6761
up
Jun 10, 2024
854b70e
fix
Jun 10, 2024
9e4c1b6
add controlnet pag
Jun 10, 2024
623d237
copy
Jun 10, 2024
f30c2bc
add from pipe test for pag + controlnet
Jun 10, 2024
1df4391
up
Jun 10, 2024
191505e
support guess mode
yiyixuxu Jun 17, 2024
58b8330
style
Jun 17, 2024
71cf2f7
add pag + img2img
Jun 17, 2024
6da3bb6
Merge branch 'main' into pag
sayakpaul Jun 17, 2024
1e79c59
remove guess model support from pag controlnet pipeline
yiyixuxu Jun 20, 2024
14b4ddd
noise_pred_uncond -> noise_pred_text
yiyixuxu Jun 20, 2024
91c41e8
Apply suggestions from code review
yiyixuxu Jun 20, 2024
b72ef1c
fix more
yiyixuxu Jun 20, 2024
b7f4ccd
Merge branch 'main' into pag
Jun 24, 2024
d12b4a0
update docstring example
Jun 24, 2024
28e1301
add copied from
Jun 24, 2024
5653b2a
add doc
Jun 25, 2024
17520f2
up
Jun 25, 2024
e11180a
Merge branch 'main' into pag
Jun 25, 2024
074a4f0
fix copies
Jun 25, 2024
18d8b0e
up
Jun 25, 2024
0e337bf
up
Jun 25, 2024
434f63a
up
Jun 25, 2024
41b1ddc
up
Jun 25, 2024
24cadb4
Update docs/source/en/api/pipelines/pag.md
yiyixuxu Jun 25, 2024
c4ceee9
Apply suggestions from code review
yiyixuxu Jun 25, 2024
9db27cf
Update src/diffusers/models/attention_processor.py
yiyixuxu Jun 25, 2024
8ae87e2
add a tip about extending pag support and explain pag scale
Jun 25, 2024
19eb55f
Merge branch 'pag' of github.com:huggingface/diffusers into pag
Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 228 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2651,6 +2651,232 @@ def __call__(
return hidden_states


class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.FloatTensor:
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)

# original path
batch_size, sequence_length, _ = hidden_states_org.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

# perturbed path (identity attention)
batch_size, sequence_length, _ = hidden_states_ptb.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

value = attn.to_v(hidden_states_ptb)

# hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
hidden_states_ptb = value

hidden_states_ptb = hidden_states_ptb.to(query.dtype)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): I think these lines of code could be clubbed together and accompanied with a comment.


if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class PAGCFGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
) -> torch.FloatTensor:
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved

residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])

# original path
batch_size, sequence_length, _ = hidden_states_org.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states_org)
key = attn.to_k(hidden_states_org)
value = attn.to_v(hidden_states_org)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states_org = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query.dtype)

# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)

if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)

# perturbed path (identity attention)
batch_size, sequence_length, _ = hidden_states_ptb.shape

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)

value = attn.to_v(hidden_states_ptb)
hidden_states_ptb = value
hidden_states_ptb = hidden_states_ptb.to(query.dtype)

# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)

if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)

# cat
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand Down Expand Up @@ -2691,6 +2917,8 @@ def __call__(
CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
# deprecated
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
Expand Down
24 changes: 21 additions & 3 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pag_utils import PAGMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput

Expand Down Expand Up @@ -181,6 +182,7 @@ class StableDiffusionXLControlNetPipeline(
StableDiffusionXLLoraLoaderMixin,
IPAdapterMixin,
FromSingleFileMixin,
PAGMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
Expand Down Expand Up @@ -619,6 +621,7 @@ def check_inputs(
prompt_2,
image,
callback_steps,
guidance_scale,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
Expand Down Expand Up @@ -802,6 +805,11 @@ def check_inputs(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)

if hasattr(self, "_pag_cfg") and self._pag_cfg is False and guidance_scale != 0:
raise ValueError(
f"Cannot use guidance scale {guidance_scale} with PAG unconditional guidance. Please set `guidance_scale` to 0."
)

# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
Expand Down Expand Up @@ -1223,6 +1231,7 @@ def __call__(
prompt_2,
image,
callback_steps,
guidance_scale,
negative_prompt,
negative_prompt_2,
prompt_embeds,
Expand Down Expand Up @@ -1405,6 +1414,11 @@ def __call__(
else:
negative_add_time_ids = add_time_ids

if self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)

if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
Expand Down Expand Up @@ -1442,8 +1456,8 @@ def __call__(
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# expand the latents if we are doing classifier free guidance or perturbed attention guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
Expand Down Expand Up @@ -1506,7 +1520,11 @@ def __call__(
)[0]

# perform guidance
if self.do_classifier_free_guidance:
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

Expand Down
Loading
Loading