diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 67c8c4f1df68..9bbf2023eb99 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -80,6 +80,8 @@ def __init__( norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, + latents_mean: Optional[Tuple[float]] = None, + latents_std: Optional[Tuple[float]] = None, force_upcast: float = True, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e90fe6571f63..14376cc2d9ca 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1313,7 +1313,22 @@ def __call__( self.upcast_vae() latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] # cast back to fp16 if needed if needs_upcasting: