From 7d94a9ea8e0bb5f2151d391682878874de2534e3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Feb 2024 12:42:40 +0530 Subject: [PATCH 1/4] denormalize latents with the mean and std if available --- .../pipeline_stable_diffusion_xl.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e90fe6571f63..23d170f4fe7e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1313,7 +1313,16 @@ def __call__( self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + # unscale/denormalize the latents + # denormalize with the mean and std if available + if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + latents = ( + latents * self.vae.config.latents_std / self.vae.config.scale_factor + self.vae.config.latents_mean + ) + else: + latents = latents * self.vae.config.scale_factor + + image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: From b9398c44e3db9d867c5a99b8e3a8cf91d70890b7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Feb 2024 13:26:42 +0530 Subject: [PATCH 2/4] fix denormalize --- .../pipeline_stable_diffusion_xl.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 23d170f4fe7e..5c09fc4eed60 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1314,13 +1314,19 @@ def __call__( latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) # unscale/denormalize the latents - # denormalize with the mean and std if available - if hasattr(self.vae.config, "latents_mean") and hasattr(self.vae.config, "latents_std"): + # denormalize with the mean and std if available and not None + if ( + hasattr(self.vae.config, "latents_mean") + and self.vae.config.latents_mean is not None + and hasattr(self.vae.config, "latents_std") + and self.vae.config.latents_std is not None + ): latents = ( - latents * self.vae.config.latents_std / self.vae.config.scale_factor + self.vae.config.latents_mean + latents * self.vae.config.latents_std / self.vae.config.scaling_factor + + self.vae.config.latents_mean ) else: - latents = latents * self.vae.config.scale_factor + latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] From bf7a7ec0bbe71ddc1e04ca0562b6daf9a6e60d95 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Feb 2024 15:14:38 +0530 Subject: [PATCH 3/4] add latent mean and std in vae config --- src/diffusers/models/autoencoders/autoencoder_kl.py | 2 ++ .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 9 ++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 67c8c4f1df68..9bbf2023eb99 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -80,6 +80,8 @@ def __init__( norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, force_upcast: float = True, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 5c09fc4eed60..aee6500c9eea 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1321,10 +1321,13 @@ def __call__( and hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None ): - latents = ( - latents * self.vae.config.latents_std / self.vae.config.scaling_factor - + self.vae.config.latents_mean + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean else: latents = latents / self.vae.config.scaling_factor From 6f5ce08f5e8f75f92418a31e1d736239d94b3d5f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 27 Feb 2024 15:29:15 +0530 Subject: [PATCH 4/4] address sayak's comment --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index aee6500c9eea..14376cc2d9ca 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1315,12 +1315,9 @@ def __call__( # unscale/denormalize the latents # denormalize with the mean and std if available and not None - if ( - hasattr(self.vae.config, "latents_mean") - and self.vae.config.latents_mean is not None - and hasattr(self.vae.config, "latents_std") - and self.vae.config.latents_std is not None - ): + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: latents_mean = ( torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) )