diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py index 4fe082d88957..a30af53f77a7 100644 --- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -26,6 +26,7 @@ from ...models.embeddings import get_1d_rotary_pos_embed from ...schedulers import EDMDPMSolverMultistepScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -33,6 +34,12 @@ from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline from .modeling_stable_audio import StableAudioProjectionModel +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -725,6 +732,9 @@ def __call__( if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() # 9. Post-processing if not output_type == "latent":