diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 486f9fb764d1..582f51ab480e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -31,6 +31,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -42,6 +43,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -1323,6 +1331,8 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: