From ce7335a266aaed1ac7fbe82d593662a242f3a5f2 Mon Sep 17 00:00:00 2001 From: scxue Date: Mon, 18 Dec 2023 19:38:29 +0800 Subject: [PATCH 1/2] fix some bugs --- src/diffusers/schedulers/__init__.py | 2 ++ src/diffusers/schedulers/scheduling_sasolver.py | 12 +++++------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 40c435dd5637..aae4e4afb9ab 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -65,6 +65,7 @@ _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] + _import_structure["scheduling_sasolver"] = ["SASolverScheduler"] try: if not is_flax_available(): @@ -155,6 +156,7 @@ from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler + from .scheduling_sasolver import SASolverScheduler try: if not is_flax_available(): diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 589f2324c25b..c0fe029558ce 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -16,11 +16,9 @@ # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py import math -from typing import List, Optional, Tuple, Union, Callable - import numpy as np import torch - +from typing import List, Optional, Tuple, Union, Callable from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils.torch_utils import randn_tensor from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -28,9 +26,9 @@ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", ): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of @@ -855,4 +853,4 @@ def add_noise( return noisy_samples def __len__(self): - return self.config.num_train_timesteps \ No newline at end of file + return self.config.num_train_timesteps From 9bd3a93204e87bd9201426e1a3b69d26439c6d1f Mon Sep 17 00:00:00 2001 From: scxue Date: Mon, 18 Dec 2023 20:15:06 +0800 Subject: [PATCH 2/2] fix bugs in repository consistency --- .../schedulers/scheduling_sasolver.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index c0fe029558ce..719b515706ee 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -273,13 +273,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype - batch_size, channels, height, width = sample.shape + batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * height * width) + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" @@ -287,11 +287,10 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: s = torch.clamp( s, min=1, max=self.config.sample_max_value ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - sample = sample.reshape(batch_size, channels, height, width) + sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample @@ -299,7 +298,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma - log_sigma = np.log(sigma) + log_sigma = np.log(np.maximum(sigma, 1e-10)) # get distribution dists = log_sigma - log_sigmas[:, np.newaxis] @@ -324,8 +323,20 @@ def _sigma_to_t(self, sigma, log_sigmas): def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" - sigma_min: float = in_sigmas[-1].item() - sigma_max: float = in_sigmas[0].item() + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() rho = 7.0 # 7.0 is the value used in the paper ramp = np.linspace(0, 1, num_inference_steps) @@ -830,10 +841,10 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, ) -> torch.FloatTensor: # Make sure alphas_cumprod and timestep have same device and dtype as original_samples alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)