From 2d81556c35482073ae7b1b00b147777bed3e2d67 Mon Sep 17 00:00:00 2001 From: Pakkapon Phongthawee Date: Wed, 9 Oct 2024 18:46:49 +0700 Subject: [PATCH 1/2] make controlnet support interrupt --- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 8 ++++++++ .../pipelines/controlnet/pipeline_controlnet_img2img.py | 8 ++++++++ .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 8 ++++++++ .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 8 ++++++++ .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 8 ++++++++ .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 8 ++++++++ 6 files changed, 48 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 9b2fefe7b0a4..c4188b97e486 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -893,6 +893,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1089,6 +1093,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1235,6 +1240,9 @@ def __call__( is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 2a4f46d61990..16689b4fa22c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -890,6 +890,10 @@ def cross_attention_kwargs(self): @property def num_timesteps(self): return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -1081,6 +1085,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1211,6 +1216,9 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 9f7d464f9a91..7463fd3a2a30 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -976,6 +976,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1191,6 +1195,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1375,6 +1380,9 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 17fd2cb6c81d..73f6fbb5e3f2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1144,6 +1144,10 @@ def cross_attention_kwargs(self): @property def num_timesteps(self): return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -1427,6 +1431,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1695,6 +1700,9 @@ def denoising_value_valid(dnv): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index fdebcdf83641..fdbbc462ffc9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -989,6 +989,10 @@ def denoising_end(self): @property def num_timesteps(self): return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -1245,6 +1249,7 @@ def __call__( self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1442,6 +1447,9 @@ def __call__( is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index af19f3c309f8..736f3434ea52 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -1070,6 +1070,10 @@ def cross_attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def interrupt(self): + return self._interrupt + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -1338,6 +1342,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1510,6 +1515,9 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) From ede26e41a891beb48cda8acc5dc1d32472cb565e Mon Sep 17 00:00:00 2001 From: Pakkapon Phongthawee Date: Wed, 9 Oct 2024 19:10:14 +0700 Subject: [PATCH 2/2] remove white space in controlnet interrupt --- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_img2img.py | 4 ++-- .../pipelines/controlnet/pipeline_controlnet_inpaint.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index c4188b97e486..60ad5eda8e0f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1242,7 +1242,7 @@ def __call__( for i, t in enumerate(timesteps): if self.interrupt: continue - + # Relevant thread: # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 16689b4fa22c..4cdec5b3cf5f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -890,7 +890,7 @@ def cross_attention_kwargs(self): @property def num_timesteps(self): return self._num_timesteps - + @property def interrupt(self): return self._interrupt @@ -1218,7 +1218,7 @@ def __call__( for i, t in enumerate(timesteps): if self.interrupt: continue - + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 7463fd3a2a30..da5a02d14108 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -1382,7 +1382,7 @@ def __call__( for i, t in enumerate(timesteps): if self.interrupt: continue - + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 73f6fbb5e3f2..496ad8d73c1d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1144,7 +1144,7 @@ def cross_attention_kwargs(self): @property def num_timesteps(self): return self._num_timesteps - + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index fdbbc462ffc9..e480a87a70ce 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -989,7 +989,7 @@ def denoising_end(self): @property def num_timesteps(self): return self._num_timesteps - + @property def interrupt(self): return self._interrupt diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 736f3434ea52..21cd87f7570e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -1517,7 +1517,7 @@ def __call__( for i, t in enumerate(timesteps): if self.interrupt: continue - + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)