From 36415e89450c8f3d08cb48970d046b30944f26ff Mon Sep 17 00:00:00 2001 From: Konstantine Tsafatinos Date: Wed, 7 Aug 2024 15:23:31 -0400 Subject: [PATCH] add vae slicing and tiling to flux pipeline --- src/diffusers/pipelines/flux/pipeline_flux.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index c1a7010d919a..5eed20d46925 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -454,6 +454,35 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def prepare_latents( self, batch_size,