Skip to content

Commit

Permalink
add inferring_controlnet_cond_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
takuma104 committed May 4, 2023
1 parent 364d59d commit abe8d63
Showing 1 changed file with 22 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,7 @@ def prepare_image(
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
guess_mode=False,
inferring_controlnet_cond_batch=False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
Expand Down Expand Up @@ -696,7 +695,7 @@ def prepare_image(

image = image.to(device=device, dtype=dtype)

if do_classifier_free_guidance and not guess_mode:
if not inferring_controlnet_cond_batch:
image = torch.cat([image] * 2)

return image
Expand Down Expand Up @@ -898,7 +897,16 @@ def __call__(
if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)

# 3. Encode input prompt
# 3. Determination of whether to infer ControlNet using only for the conditional batch.
global_pool_conditions = False
if isinstance(self.controlnet, ControlNetModel):
global_pool_conditions = self.controlnet.config.global_pool_conditions
else:
... # TODO: Implement for MultiControlNetModel

inferring_controlnet_cond_batch = (guess_mode or global_pool_conditions) and do_classifier_free_guidance

# 4. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
Expand All @@ -909,7 +917,7 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
)

# 4. Prepare image
# 5. Prepare image
if isinstance(self.controlnet, ControlNetModel):
image = self.prepare_image(
image=image,
Expand All @@ -919,8 +927,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
inferring_controlnet_cond_batch=inferring_controlnet_cond_batch,
)
elif isinstance(self.controlnet, MultiControlNetModel):
images = []
Expand All @@ -934,8 +941,7 @@ def __call__(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
inferring_controlnet_cond_batch=inferring_controlnet_cond_batch,
)

images.append(image_)
Expand All @@ -944,11 +950,11 @@ def __call__(
else:
assert False

# 5. Prepare timesteps
# 6. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

# 6. Prepare latent variables
# 7. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
Expand All @@ -961,10 +967,10 @@ def __call__(
latents,
)

# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 8. Denoising loop
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand All @@ -973,8 +979,8 @@ def __call__(
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
if inferring_controlnet_cond_batch:
# Inferring ControlNet only for the conditional batch.
controlnet_latent_model_input = latents
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
Expand All @@ -991,7 +997,7 @@ def __call__(
return_dict=False,
)

if guess_mode and do_classifier_free_guidance:
if inferring_controlnet_cond_batch:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
Expand Down

0 comments on commit abe8d63

Please sign in to comment.