From 8a878cccf8fb44ae91a37ac3f992c6c261c75a54 Mon Sep 17 00:00:00 2001 From: fancy45daddy <124528204+fancy45daddy@users.noreply.github.com> Date: Wed, 4 Dec 2024 00:00:34 -0800 Subject: [PATCH] Update pipeline_stable_audio.py --- .../pipelines/stable_audio/pipeline_stable_audio.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py index 4fe082d88957..a30af53f77a7 100644 --- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -26,6 +26,7 @@ from ...models.embeddings import get_1d_rotary_pos_embed from ...schedulers import EDMDPMSolverMultistepScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -33,6 +34,12 @@ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .modeling_stable_audio import StableAudioProjectionModel +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -725,6 +732,9 @@ def __call__( if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() # 9. Post-processing if not output_type == "latent":