From 0708cdb7009b11cbaa2c9b6c64f3a4dec129372f Mon Sep 17 00:00:00 2001 From: ihkap11 Date: Fri, 9 Feb 2024 21:19:11 +0000 Subject: [PATCH 1/3] Bugfix: correct import for diffusers --- examples/community/pipeline_prompt2prompt.py | 243 +++++++++++++++---- 1 file changed, 195 insertions(+), 48 deletions(-) diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py index 200b5571ef70..1d16e26bcc8c 100644 --- a/examples/community/pipeline_prompt2prompt.py +++ b/examples/community/pipeline_prompt2prompt.py @@ -21,8 +21,11 @@ import torch import torch.nn.functional as F -from ...src.diffusers.models.attention import Attention -from ...src.diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionPipelineOutput +from diffusers.models.attention import Attention +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionPipeline, + StableDiffusionPipelineOutput, +) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg @@ -31,12 +34,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) return noise_cfg @@ -165,7 +172,11 @@ def __call__( """ self.controller = create_controller( - prompt, cross_attention_kwargs, num_inference_steps, tokenizer=self.tokenizer, device=self.device + prompt, + cross_attention_kwargs, + num_inference_steps, + tokenizer=self.tokenizer, + device=self.device, ) self.register_attention_control(self.controller) # add attention controller @@ -192,7 +203,9 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None ) prompt_embeds = self._encode_prompt( prompt, @@ -230,29 +243,43 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=prompt_embeds + ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample # step callback latents = self.controller.step_callback(latents) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -260,8 +287,12 @@ def __call__( # 8. Post-processing if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) else: image = latents has_nsfw_concept = None @@ -272,7 +303,9 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + image = self.image_processor.postprocess( + image, output_type=output_type, do_denormalize=do_denormalize + ) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: @@ -281,13 +314,19 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) def register_attention_control(self, controller): attn_procs = {} cross_att_count = 0 for name in self.unet.attn_processors.keys(): - None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + ( + None + if name.endswith("attn1.processor") + else self.unet.config.cross_attention_dim + ) if name.startswith("mid_block"): self.unet.config.block_out_channels[-1] place_in_unet = "mid" @@ -302,7 +341,9 @@ def register_attention_control(self, controller): else: continue cross_att_count += 1 - attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet) + attn_procs[name] = P2PCrossAttnProcessor( + controller=controller, place_in_unet=place_in_unet + ) self.unet.set_attn_processor(attn_procs) controller.num_att_layers = cross_att_count @@ -314,14 +355,26 @@ def __init__(self, controller, place_in_unet): self.controller = controller self.place_in_unet = place_in_unet - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) query = attn.to_q(hidden_states) is_cross = encoder_hidden_states is not None - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -346,7 +399,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a def create_controller( - prompts: List[str], cross_attention_kwargs: Dict, num_inference_steps: int, tokenizer, device + prompts: List[str], + cross_attention_kwargs: Dict, + num_inference_steps: int, + tokenizer, + device, ) -> AttentionControl: edit_type = cross_attention_kwargs.get("edit_type", None) local_blend_words = cross_attention_kwargs.get("local_blend_words", None) @@ -358,27 +415,49 @@ def create_controller( # only replace if edit_type == "replace" and local_blend_words is None: return AttentionReplace( - prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + tokenizer=tokenizer, + device=device, ) # replace + localblend if edit_type == "replace" and local_blend_words is not None: lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device) return AttentionReplace( - prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + lb, + tokenizer=tokenizer, + device=device, ) # only refine if edit_type == "refine" and local_blend_words is None: return AttentionRefine( - prompts, num_inference_steps, n_cross_replace, n_self_replace, tokenizer=tokenizer, device=device + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + tokenizer=tokenizer, + device=device, ) # refine + localblend if edit_type == "refine" and local_blend_words is not None: lb = LocalBlend(prompts, local_blend_words, tokenizer=tokenizer, device=device) return AttentionRefine( - prompts, num_inference_steps, n_cross_replace, n_self_replace, lb, tokenizer=tokenizer, device=device + prompts, + num_inference_steps, + n_cross_replace, + n_self_replace, + lb, + tokenizer=tokenizer, + device=device, ) # reweight @@ -389,7 +468,9 @@ def create_controller( assert len(equalizer_words) == len( equalizer_strengths ), "equalizer_words and equalizer_strengths must be of same length." - equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) + equalizer = get_equalizer( + prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer + ) return AttentionReweight( prompts, num_inference_steps, @@ -400,7 +481,9 @@ def create_controller( equalizer=equalizer, ) - raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.") + raise ValueError( + f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight." + ) class AttentionControl(abc.ABC): @@ -447,7 +530,14 @@ def forward(self, attn, is_cross: bool, place_in_unet: str): class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): - return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} + return { + "down_cross": [], + "mid_cross": [], + "up_cross": [], + "down_self": [], + "mid_self": [], + "up_self": [], + } def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" @@ -466,7 +556,8 @@ def between_steps(self): def get_average_attention(self): average_attention = { - key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store + key: [item / self.cur_step for item in self.attention_store[key]] + for key in self.attention_store } return average_attention @@ -485,7 +576,10 @@ class LocalBlend: def __call__(self, x_t, attention_store): k = 1 maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] - maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps] + maps = [ + item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) + for item in maps + ] maps = torch.cat(maps, dim=1) maps = (maps * self.alpha_layers).sum(-1).mean(1) mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) @@ -497,7 +591,13 @@ def __call__(self, x_t, attention_store): return x_t def __init__( - self, prompts: List[str], words: [List[List[str]]], tokenizer, device, threshold=0.3, max_num_words=77 + self, + prompts: List[str], + words: [List[List[str]]], + tokenizer, + device, + threshold=0.3, + max_num_words=77, ): self.max_num_words = 77 @@ -531,7 +631,9 @@ def replace_cross_attention(self, attn_base, att_replace): def forward(self, attn, is_cross: bool, place_in_unet: str): super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) # FIXME not replace correctly - if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): + if is_cross or ( + self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1] + ): h = attn.shape[0] // (self.batch_size) attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) attn_base, attn_repalce = attn[0], attn[1:] @@ -551,7 +653,9 @@ def __init__( self, prompts, num_steps: int, - cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], + cross_replace_steps: Union[ + float, Tuple[float, float], Dict[str, Tuple[float, float]] + ], self_replace_steps: Union[float, Tuple[float, float]], local_blend: Optional[LocalBlend], tokenizer, @@ -569,7 +673,9 @@ def __init__( ).to(self.device) if isinstance(self_replace_steps, float): self_replace_steps = 0, self_replace_steps - self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) + self.num_self_replace = int(num_steps * self_replace_steps[0]), int( + num_steps * self_replace_steps[1] + ) self.local_blend = local_blend # 在外面定义后传进来 @@ -588,7 +694,13 @@ def __init__( device=None, ): super(AttentionReplace, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device + prompts, + num_steps, + cross_replace_steps, + self_replace_steps, + local_blend, + tokenizer, + device, ) self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device) @@ -610,7 +722,13 @@ def __init__( device=None, ): super(AttentionRefine, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device + prompts, + num_steps, + cross_replace_steps, + self_replace_steps, + local_blend, + tokenizer, + device, ) self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device) @@ -620,7 +738,9 @@ def __init__( class AttentionReweight(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): if self.prev_controller is not None: - attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) + attn_base = self.prev_controller.replace_cross_attention( + attn_base, att_replace + ) attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] return attn_replace @@ -637,7 +757,13 @@ def __init__( device=None, ): super(AttentionReweight, self).__init__( - prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device + prompts, + num_steps, + cross_replace_steps, + self_replace_steps, + local_blend, + tokenizer, + device, ) self.equalizer = equalizer.to(self.device) self.prev_controller = controller @@ -645,7 +771,10 @@ def __init__( ### util functions for all Edits def update_alpha_time_word( - alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor] = None + alpha, + bounds: Union[float, Tuple[float, float]], + prompt_ind: int, + word_inds: Optional[torch.Tensor] = None, ): if isinstance(bounds, float): bounds = 0, bounds @@ -659,7 +788,11 @@ def update_alpha_time_word( def get_time_words_attention_alpha( - prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77 + prompts, + num_steps, + cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], + tokenizer, + max_num_words=77, ): if not isinstance(cross_replace_steps, dict): cross_replace_steps = {"default_": cross_replace_steps} @@ -667,14 +800,23 @@ def get_time_words_attention_alpha( cross_replace_steps["default_"] = (0.0, 1.0) alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) for i in range(len(prompts) - 1): - alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) + alpha_time_words = update_alpha_time_word( + alpha_time_words, cross_replace_steps["default_"], i + ) for key, item in cross_replace_steps.items(): if key != "default_": - inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] + inds = [ + get_word_inds(prompts[i], key, tokenizer) + for i in range(1, len(prompts)) + ] for i, ind in enumerate(inds): if len(ind) > 0: - alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) - alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) + alpha_time_words = update_alpha_time_word( + alpha_time_words, item, i, ind + ) + alpha_time_words = alpha_time_words.reshape( + num_steps + 1, len(prompts) - 1, 1, 1, max_num_words + ) return alpha_time_words @@ -687,7 +829,9 @@ def get_word_inds(text: str, word_place: int, tokenizer): word_place = [word_place] out = [] if len(word_place) > 0: - words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] + words_encode = [ + tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text) + ][1:-1] cur_len, ptr = 0, 0 for i in range(len(words_encode)): @@ -750,7 +894,10 @@ def get_replacement_mapper(prompts, tokenizer, max_len=77): ### util functions for ReweightEdit def get_equalizer( - text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer + text: str, + word_select: Union[int, Tuple[int, ...]], + values: Union[List[float], Tuple[float, ...]], + tokenizer, ): if isinstance(word_select, (int, str)): word_select = (word_select,) From 1403b938d59b144aa95d85b34c36312f6d503bb4 Mon Sep 17 00:00:00 2001 From: ihkap11 Date: Fri, 9 Feb 2024 21:24:50 +0000 Subject: [PATCH 2/3] Fix: Prompt2Prompt example --- examples/community/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/community/README.md b/examples/community/README.md index 6dbac2e16d7a..97b1b037113a 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -2287,9 +2287,9 @@ Here's a full example for `ReplaceEdit``: import torch import numpy as np import matplotlib.pyplot as plt -from diffusers.pipelines import Prompt2PromptPipeline +from diffusers import DiffusionPipeline -pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to("cuda") +pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda") prompts = ["A turtle playing with a ball", "A monkey playing with a ball"] From e932a0145bf3955badd66d4b6a7611b8a7d12cd5 Mon Sep 17 00:00:00 2001 From: ihkap11 Date: Fri, 9 Feb 2024 21:57:24 +0000 Subject: [PATCH 3/3] Format style --- examples/community/pipeline_prompt2prompt.py | 129 +++++-------------- 1 file changed, 31 insertions(+), 98 deletions(-) diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py index 1d16e26bcc8c..541d93b69b68 100644 --- a/examples/community/pipeline_prompt2prompt.py +++ b/examples/community/pipeline_prompt2prompt.py @@ -34,16 +34,12 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ - std_text = noise_pred_text.std( - dim=list(range(1, noise_pred_text.ndim)), keepdim=True - ) + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = ( - guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - ) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg @@ -203,9 +199,7 @@ def __call__( # 3. Encode input prompt text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) - if cross_attention_kwargs is not None - else None + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None ) prompt_embeds = self._encode_prompt( prompt, @@ -243,43 +237,29 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=prompt_embeds - ).sample + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg( - noise_pred, noise_pred_text, guidance_rescale=guidance_rescale - ) + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # step callback latents = self.controller.step_callback(latents) # call the callback, if provided - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) @@ -287,12 +267,8 @@ def __call__( # 8. Post-processing if not output_type == "latent": - image = self.vae.decode( - latents / self.vae.config.scaling_factor, return_dict=False - )[0] - image, has_nsfw_concept = self.run_safety_checker( - image, device, prompt_embeds.dtype - ) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None @@ -303,9 +279,7 @@ def __call__( else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess( - image, output_type=output_type, do_denormalize=do_denormalize - ) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: @@ -314,19 +288,13 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput( - images=image, nsfw_content_detected=has_nsfw_concept - ) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) def register_attention_control(self, controller): attn_procs = {} cross_att_count = 0 for name in self.unet.attn_processors.keys(): - ( - None - if name.endswith("attn1.processor") - else self.unet.config.cross_attention_dim - ) + (None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim) if name.startswith("mid_block"): self.unet.config.block_out_channels[-1] place_in_unet = "mid" @@ -341,9 +309,7 @@ def register_attention_control(self, controller): else: continue cross_att_count += 1 - attn_procs[name] = P2PCrossAttnProcessor( - controller=controller, place_in_unet=place_in_unet - ) + attn_procs[name] = P2PCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet) self.unet.set_attn_processor(attn_procs) controller.num_att_layers = cross_att_count @@ -363,18 +329,12 @@ def __call__( attention_mask=None, ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask( - attention_mask, sequence_length, batch_size - ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) is_cross = encoder_hidden_states is not None - encoder_hidden_states = ( - encoder_hidden_states - if encoder_hidden_states is not None - else hidden_states - ) + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -468,9 +428,7 @@ def create_controller( assert len(equalizer_words) == len( equalizer_strengths ), "equalizer_words and equalizer_strengths must be of same length." - equalizer = get_equalizer( - prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer - ) + equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer) return AttentionReweight( prompts, num_inference_steps, @@ -481,9 +439,7 @@ def create_controller( equalizer=equalizer, ) - raise ValueError( - f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight." - ) + raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.") class AttentionControl(abc.ABC): @@ -556,8 +512,7 @@ def between_steps(self): def get_average_attention(self): average_attention = { - key: [item / self.cur_step for item in self.attention_store[key]] - for key in self.attention_store + key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store } return average_attention @@ -576,10 +531,7 @@ class LocalBlend: def __call__(self, x_t, attention_store): k = 1 maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] - maps = [ - item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) - for item in maps - ] + maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps] maps = torch.cat(maps, dim=1) maps = (maps * self.alpha_layers).sum(-1).mean(1) mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) @@ -631,9 +583,7 @@ def replace_cross_attention(self, attn_base, att_replace): def forward(self, attn, is_cross: bool, place_in_unet: str): super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) # FIXME not replace correctly - if is_cross or ( - self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1] - ): + if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): h = attn.shape[0] // (self.batch_size) attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) attn_base, attn_repalce = attn[0], attn[1:] @@ -653,9 +603,7 @@ def __init__( self, prompts, num_steps: int, - cross_replace_steps: Union[ - float, Tuple[float, float], Dict[str, Tuple[float, float]] - ], + cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], self_replace_steps: Union[float, Tuple[float, float]], local_blend: Optional[LocalBlend], tokenizer, @@ -673,9 +621,7 @@ def __init__( ).to(self.device) if isinstance(self_replace_steps, float): self_replace_steps = 0, self_replace_steps - self.num_self_replace = int(num_steps * self_replace_steps[0]), int( - num_steps * self_replace_steps[1] - ) + self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) self.local_blend = local_blend # 在外面定义后传进来 @@ -738,9 +684,7 @@ def __init__( class AttentionReweight(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): if self.prev_controller is not None: - attn_base = self.prev_controller.replace_cross_attention( - attn_base, att_replace - ) + attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] return attn_replace @@ -800,23 +744,14 @@ def get_time_words_attention_alpha( cross_replace_steps["default_"] = (0.0, 1.0) alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) for i in range(len(prompts) - 1): - alpha_time_words = update_alpha_time_word( - alpha_time_words, cross_replace_steps["default_"], i - ) + alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) for key, item in cross_replace_steps.items(): if key != "default_": - inds = [ - get_word_inds(prompts[i], key, tokenizer) - for i in range(1, len(prompts)) - ] + inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] for i, ind in enumerate(inds): if len(ind) > 0: - alpha_time_words = update_alpha_time_word( - alpha_time_words, item, i, ind - ) - alpha_time_words = alpha_time_words.reshape( - num_steps + 1, len(prompts) - 1, 1, 1, max_num_words - ) + alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) + alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) return alpha_time_words @@ -829,9 +764,7 @@ def get_word_inds(text: str, word_place: int, tokenizer): word_place = [word_place] out = [] if len(word_place) > 0: - words_encode = [ - tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text) - ][1:-1] + words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] cur_len, ptr = 0, 0 for i in range(len(words_encode)):