From 9821496769b64e5456b4184029f98e20ff9f3368 Mon Sep 17 00:00:00 2001 From: fancy45daddy <124528204+fancy45daddy@users.noreply.github.com> Date: Fri, 13 Dec 2024 21:59:11 -0800 Subject: [PATCH 1/2] Update pipeline_controlnet.py --- .../pipelines/controlnet/pipeline_controlnet.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 486f9fb764d1..1d48cc28f76c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -31,6 +31,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -41,6 +42,12 @@ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1323,6 +1330,8 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: From 2604f6e5a95d98538a7eaa96a0a51b3faa75084a Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 14 Dec 2024 08:36:12 +0000 Subject: [PATCH 2/2] make style --- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 1d48cc28f76c..582f51ab480e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -42,6 +42,7 @@ from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker + if is_torch_xla_available(): import torch_xla.core.xla_model as xm