Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
654889a
fix: sdxl pipeline when unet is not available.
sayakpaul Oct 5, 2023
f204066
fix moe
sayakpaul Oct 5, 2023
1c05263
account for text
sayakpaul Oct 6, 2023
fe2e9ee
ifx more
sayakpaul Oct 6, 2023
e3bf831
don't make unet optional.
sayakpaul Oct 12, 2023
8fe7349
Apply suggestions from code review
sayakpaul Oct 13, 2023
8b0bfda
split conditionals.
sayakpaul Oct 13, 2023
9351ea9
add optional components to sdxl pipeline
sayakpaul Oct 13, 2023
646ecd1
propagate changes to the rest of the pipelines.
sayakpaul Oct 13, 2023
6b0ae28
add: test
sayakpaul Oct 13, 2023
8f8b8ec
Merge branch 'main' into fix/pipeline-without-unet
sayakpaul Oct 13, 2023
d41ddef
add to all
sayakpaul Oct 13, 2023
da2185e
fix: rest of the pipelines.
sayakpaul Oct 13, 2023
b4ecc81
resolve conflicts.
sayakpaul Oct 13, 2023
53fe5ad
use pipeline_class variable
sayakpaul Oct 14, 2023
3f412f6
separate pipeline mixin
sayakpaul Oct 14, 2023
04fee72
use safe_serialization
sayakpaul Oct 14, 2023
0b556d8
fix: test
sayakpaul Oct 14, 2023
7aa432a
access actual output.
sayakpaul Oct 14, 2023
91e4e19
add: optional test to adapter and ip2p sdxl pipeline tests/
sayakpaul Oct 14, 2023
7d33928
add optional test to controlnet sdxl.
sayakpaul Oct 14, 2023
30239a9
fix tests
sayakpaul Oct 14, 2023
55c22f9
fix ip2p tests
sayakpaul Oct 14, 2023
40f44b8
fix more
sayakpaul Oct 14, 2023
fe107e8
fifx more.
sayakpaul Oct 14, 2023
66e71be
use np output type.
sayakpaul Oct 14, 2023
d9f9d6d
fix for StableDiffusionXLMultiControlNetPipelineFastTests.
sayakpaul Oct 15, 2023
6cade1a
fix: SDXLOptionalComponentsTesterMixin
sayakpaul Oct 15, 2023
773ca86
Apply suggestions from code review
sayakpaul Oct 16, 2023
a8486da
Merge branch 'main' into fix/pipeline-without-unet
sayakpaul Oct 16, 2023
5aaa23e
fix tests
sayakpaul Oct 16, 2023
1ddd586
Empty-Commit
sayakpaul Oct 16, 2023
38e16f8
revert previous
sayakpaul Oct 16, 2023
29fc48d
Merge branch 'main' into fix/pipeline-without-unet
sayakpaul Oct 16, 2023
f5748dc
quality
sayakpaul Oct 16, 2023
dbab63f
Merge branch 'main' into fix/pipeline-without-unet
sayakpaul Oct 17, 2023
e30b3e5
fix: test
sayakpaul Oct 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "text_encoder"]
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]

def __init__(
self,
Expand Down Expand Up @@ -317,12 +317,17 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
scale_lora_layers(self.text_encoder_2, lora_scale)
if self.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if self.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder_2, lora_scale)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down Expand Up @@ -438,7 +443,11 @@ def encode_prompt(

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
if self.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand All @@ -447,7 +456,12 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)

if self.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

Expand All @@ -459,10 +473,15 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down Expand Up @@ -885,7 +904,14 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
return timesteps, num_inference_steps - t_start

def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
self,
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype,
text_encoder_projection_dim=None,
):
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
Expand All @@ -895,7 +921,7 @@ def _get_add_time_ids(
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)

passed_add_embed_dim = (
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

Expand Down Expand Up @@ -1391,13 +1417,19 @@ def denoising_value_valid(dnv):

# 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim

add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
aesthetic_score,
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)

Expand Down
67 changes: 49 additions & 18 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ class StableDiffusionXLControlNetPipeline(
watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
watermarker is used.
"""
model_cpu_offload_seq = (
"text_encoder->text_encoder_2->unet->vae" # leave controlnet out on purpose because it iterates with unet
)
# leave controlnet out on purpose because it iterates with unet
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]

def __init__(
self,
Expand Down Expand Up @@ -285,12 +285,17 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
scale_lora_layers(self.text_encoder_2, lora_scale)
if self.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if self.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder_2, lora_scale)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down Expand Up @@ -406,7 +411,11 @@ def encode_prompt(

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
if self.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand All @@ -415,7 +424,12 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)

if self.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

Expand All @@ -427,10 +441,15 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down Expand Up @@ -706,11 +725,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
return latents

# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
):
add_time_ids = list(original_size + crops_coords_top_left + target_size)

passed_add_embed_dim = (
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

Expand Down Expand Up @@ -1088,8 +1109,17 @@ def __call__(
target_size = target_size or (height, width)

add_text_embeds = pooled_prompt_embeds
if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim

add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)

if negative_original_size is not None and negative_target_size is not None:
Expand All @@ -1098,6 +1128,7 @@ def __call__(
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
else:
negative_add_time_ids = add_time_ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
watermarker will be used.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "text_encoder"]
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]

def __init__(
self,
Expand Down Expand Up @@ -329,12 +329,17 @@ def encode_prompt(
self._lora_scale = lora_scale

# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
scale_lora_layers(self.text_encoder_2, lora_scale)
if self.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)

if self.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder_2, lora_scale)

prompt = [prompt] if isinstance(prompt, str) else prompt

Expand Down Expand Up @@ -450,7 +455,11 @@ def encode_prompt(

negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)

prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
if self.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)

bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
Expand All @@ -459,7 +468,12 @@ def encode_prompt(
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)

if self.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)

negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

Expand All @@ -471,10 +485,15 @@ def encode_prompt(
bs_embed * num_images_per_prompt, -1
)

if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)

if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2)

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

Expand Down Expand Up @@ -832,6 +851,7 @@ def _get_add_time_ids(
negative_crops_coords_top_left,
negative_target_size,
dtype,
text_encoder_projection_dim=None,
):
if self.config.requires_aesthetics_score:
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
Expand All @@ -843,7 +863,7 @@ def _get_add_time_ids(
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)

passed_add_embed_dim = (
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features

Expand Down Expand Up @@ -1275,6 +1295,12 @@ def __call__(
if negative_target_size is None:
negative_target_size = target_size
add_text_embeds = pooled_prompt_embeds

if self.text_encoder_2 is None:
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
else:
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim

add_time_ids, add_neg_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
Expand All @@ -1285,6 +1311,7 @@ def __call__(
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=text_encoder_projection_dim,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)

Expand Down
Loading