Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
_import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
_import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
_import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"]
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]

try:
if not is_flax_available():
Expand Down Expand Up @@ -155,6 +156,7 @@
from .scheduling_unipc_multistep import UniPCMultistepScheduler
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
from .scheduling_sasolver import SASolverScheduler

try:
if not is_flax_available():
Expand Down
45 changes: 27 additions & 18 deletions src/diffusers/schedulers/scheduling_sasolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,19 @@
# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

import math
from typing import List, Optional, Tuple, Union, Callable

import numpy as np
import torch

from typing import List, Optional, Tuple, Union, Callable
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils.torch_utils import randn_tensor
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput


# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
Expand Down Expand Up @@ -275,33 +273,32 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
batch_size, channels, *remaining_dims = sample.shape

if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half

# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))

abs_sample = sample.abs() # "a certain percentile absolute pixel value"

s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
s = torch.clamp(
s, min=1, max=self.config.sample_max_value
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]

s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"

sample = sample.reshape(batch_size, channels, height, width)
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)

return sample

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
log_sigma = np.log(np.maximum(sigma, 1e-10))

# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
Expand All @@ -326,8 +323,20 @@ def _sigma_to_t(self, sigma, log_sigmas):
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""

sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None

if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None

sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()

rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
Expand Down Expand Up @@ -832,10 +841,10 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
Expand All @@ -855,4 +864,4 @@ def add_noise(
return noisy_samples

def __len__(self):
return self.config.num_train_timesteps
return self.config.num_train_timesteps