Skip to content

Commit

Permalink
fix(img2img): do not attempt to do a zero-step img2img when strength …
Browse files Browse the repository at this point in the history
…is low (#2472)
  • Loading branch information
keturn committed Feb 2, 2023
2 parents 5ef66ca + da181ce commit 80c5322
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions ldm/invoke/generator/diffusers_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import einops
import torch
import torchvision.transforms as T
from diffusers.models import attention
from diffusers.utils.import_utils import is_xformers_available

from ...models.diffusion import cross_attention_control
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter

Expand Down Expand Up @@ -506,11 +504,7 @@ def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_ste
strength,
noise: torch.Tensor, run_id=None, callback=None
) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)

timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, num_inference_steps, conditioning_data,
timesteps=timesteps,
Expand All @@ -526,6 +520,18 @@ def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_ste
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
return self.check_for_safety(output, dtype=conditioning_data.dtype)

def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
# Workaround for low strength resulting in zero timesteps.
# TODO: submit upstream fix for zero-step img2img
if timesteps.numel() == 0:
timesteps = self.scheduler.timesteps[-1:]
adjusted_steps = timesteps.numel()
return timesteps, adjusted_steps

def inpaint_from_embeddings(
self,
init_image: torch.FloatTensor,
Expand All @@ -549,11 +555,7 @@ def inpaint_from_embeddings(
if init_image.dim() == 3:
init_image = init_image.unsqueeze(0)

img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)

assert img2img_pipeline.scheduler is self.scheduler
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, device=device)

# 6. Prepare latent variables
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
Expand Down

0 comments on commit 80c5322

Please sign in to comment.