Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1089,6 +1093,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1235,6 +1240,9 @@ def __call__(
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1081,6 +1085,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1211,6 +1216,9 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1191,6 +1195,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1375,6 +1380,9 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1427,6 +1431,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1695,6 +1700,9 @@ def denoising_value_valid(dnv):

with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,10 @@ def denoising_end(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1245,6 +1249,7 @@ def __call__(
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1442,6 +1447,9 @@ def __call__(
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,10 @@ def cross_attention_kwargs(self):
def num_timesteps(self):
return self._num_timesteps

@property
def interrupt(self):
return self._interrupt

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1338,6 +1342,7 @@ def __call__(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
Expand Down Expand Up @@ -1510,6 +1515,9 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down
Loading