From d3c25b26708919af38ef58513c66da7d276346fa Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 7 Dec 2024 12:54:06 +0000 Subject: [PATCH 01/20] Separate Sigma Schedule --- src/diffusers/schedulers/__init__.py | 2 + .../schedulers/scheduling_heun_discrete.py | 125 ++---------------- src/diffusers/schedulers/sigmas/__init__.py | 60 +++++++++ .../schedulers/sigmas/beta_sigmas.py | 64 +++++++++ .../schedulers/sigmas/exponential_sigmas.py | 42 ++++++ .../schedulers/sigmas/karras_sigmas.py | 47 +++++++ 6 files changed, 225 insertions(+), 115 deletions(-) create mode 100644 src/diffusers/schedulers/sigmas/__init__.py create mode 100644 src/diffusers/schedulers/sigmas/beta_sigmas.py create mode 100644 src/diffusers/schedulers/sigmas/exponential_sigmas.py create mode 100644 src/diffusers/schedulers/sigmas/karras_sigmas.py diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index bb9088538653..a07d7fff62c8 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -74,6 +74,7 @@ _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] + _import_structure["sigmas"] = ["BetaSigmas", "ExponentialSigmas", "KarrasSigmas"] try: if not is_flax_available(): @@ -174,6 +175,7 @@ from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler + from .sigmas import BetaSigmas, ExponentialSigmas, KarrasSigmas try: if not is_flax_available(): diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index f2aaa738233b..da9c6ce67374 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -22,10 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_scipy_available from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin - - -if is_scipy_available(): - import scipy.stats +from .sigmas import BetaSigmas, ExponentialSigmas, KarrasSigmas @dataclass @@ -119,14 +116,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): Clip the predicted sample for numerical stability. clip_sample_range (`float`, defaults to 1.0): The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, - the sigmas are determined according to a sequence of noise levels {σi}. - use_exponential_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. - use_beta_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta - Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -134,6 +123,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps, as required by some model families. """ + ignore_for_config = ["sigma_schedule"] _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @@ -146,20 +136,14 @@ def __init__( beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, + sigma_schedule: Optional[Union[BetaSigmas, ExponentialSigmas, KarrasSigmas]] = None, clip_sample: Optional[bool] = False, clip_sample_range: float = 1.0, timestep_spacing: str = "linspace", steps_offset: int = 0, ): - if self.config.use_beta_sigmas and not is_scipy_available(): + if isinstance(sigma_schedule, BetaSigmas) and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError( - "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." - ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -178,9 +162,10 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + self.sigma_schedule = sigma_schedule + # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) - self.use_karras_sigmas = use_karras_sigmas self._step_index = None self._begin_index = None @@ -287,12 +272,8 @@ def set_timesteps( raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") - if timesteps is not None and self.config.use_karras_sigmas: - raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") - if timesteps is not None and self.config.use_exponential_sigmas: - raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") - if timesteps is not None and self.config.use_beta_sigmas: - raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.") + if timesteps is not None and self.sigma_schedule is not None: + raise ValueError("Cannot use `timesteps` with `sigma_schedule`") num_inference_steps = num_inference_steps or len(timesteps) self.num_inference_steps = num_inference_steps @@ -325,14 +306,8 @@ def set_timesteps( log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - if self.config.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - elif self.config.use_exponential_sigmas: - sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - elif self.config.use_beta_sigmas: - sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + if self.sigma_schedule is not None: + sigmas = self.sigma_schedule(sigmas) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) @@ -376,86 +351,6 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential - def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) - return sigmas - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta - def _convert_to_beta( - self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 - ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.array( - [ - sigma_min + (ppf * (sigma_max - sigma_min)) - for ppf in [ - scipy.stats.beta.ppf(timestep, alpha, beta) - for timestep in 1 - np.linspace(0, 1, num_inference_steps) - ] - ] - ) - return sigmas - @property def state_in_first_order(self): return self.dt is None diff --git a/src/diffusers/schedulers/sigmas/__init__.py b/src/diffusers/schedulers/sigmas/__init__.py new file mode 100644 index 000000000000..34a7ff5edff2 --- /dev/null +++ b/src/diffusers/schedulers/sigmas/__init__.py @@ -0,0 +1,60 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["beta_sigmas"] = ["BetaSigmas"] + _import_structure["exponential_sigmas"] = ["ExponentialSigmas"] + _import_structure["karras_sigmas"] = ["KarrasSigmas"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_pt_objects import * # noqa F403 + else: + from .beta_sigmas import BetaSigmas + from .exponential_sigmas import ExponentialSigmas + from .karras_sigmas import KarrasSigmas + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/schedulers/sigmas/beta_sigmas.py b/src/diffusers/schedulers/sigmas/beta_sigmas.py new file mode 100644 index 000000000000..a390bd9c90ea --- /dev/null +++ b/src/diffusers/schedulers/sigmas/beta_sigmas.py @@ -0,0 +1,64 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np +import torch + +from ...utils import is_scipy_available + + +if is_scipy_available(): + import scipy.stats + + +class BetaSigmas: + def __init__( + self, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + alpha: float = 0.6, + beta: float = 0.6, + ): + if not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.alpha = alpha + self.beta = beta + + def __call__(self, in_sigmas: torch.Tensor): + sigma_min = self.sigma_min + if sigma_min is None: + sigma_min = in_sigmas[-1].item() + sigma_max = self.sigma_max + if sigma_max is None: + sigma_max = in_sigmas[0].item() + + num_inference_steps = len(in_sigmas) + + alpha = self.alpha + beta = self.beta + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas diff --git a/src/diffusers/schedulers/sigmas/exponential_sigmas.py b/src/diffusers/schedulers/sigmas/exponential_sigmas.py new file mode 100644 index 000000000000..70f0a794fe0b --- /dev/null +++ b/src/diffusers/schedulers/sigmas/exponential_sigmas.py @@ -0,0 +1,42 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import numpy as np +import torch + + +class ExponentialSigmas: + def __init__( + self, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def __call__(self, in_sigmas: torch.Tensor): + sigma_min = self.sigma_min + if sigma_min is None: + sigma_min = in_sigmas[-1].item() + sigma_max = self.sigma_max + if sigma_max is None: + sigma_max = in_sigmas[0].item() + + num_inference_steps = len(in_sigmas) + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas diff --git a/src/diffusers/schedulers/sigmas/karras_sigmas.py b/src/diffusers/schedulers/sigmas/karras_sigmas.py new file mode 100644 index 000000000000..a60d5e017ef4 --- /dev/null +++ b/src/diffusers/schedulers/sigmas/karras_sigmas.py @@ -0,0 +1,47 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np +import torch + + +class KarrasSigmas: + def __init__( + self, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + rho: float = 7.0, + ): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def __call__(self, in_sigmas: torch.Tensor): + sigma_min = self.sigma_min + if sigma_min is None: + sigma_min = in_sigmas[-1].item() + sigma_max = self.sigma_max + if sigma_max is None: + sigma_max = in_sigmas[0].item() + + num_inference_steps = len(in_sigmas) + + rho = self.rho + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas From 193f90df21bc8d8f7cb5768d8d7f7836d2fee9d0 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 7 Dec 2024 13:47:39 +0000 Subject: [PATCH 02/20] test_backend_registration --- src/diffusers/schedulers/__init__.py | 2 -- src/diffusers/schedulers/sigmas/__init__.py | 23 +++++++++++++++++-- .../schedulers/sigmas/beta_sigmas.py | 2 ++ .../utils/dummy_torch_and_scipy_objects.py | 6 +++++ 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a07d7fff62c8..bb9088538653 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -74,7 +74,6 @@ _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"] _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"] - _import_structure["sigmas"] = ["BetaSigmas", "ExponentialSigmas", "KarrasSigmas"] try: if not is_flax_available(): @@ -175,7 +174,6 @@ from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler - from .sigmas import BetaSigmas, ExponentialSigmas, KarrasSigmas try: if not is_flax_available(): diff --git a/src/diffusers/schedulers/sigmas/__init__.py b/src/diffusers/schedulers/sigmas/__init__.py index 34a7ff5edff2..a66b34890c3f 100644 --- a/src/diffusers/schedulers/sigmas/__init__.py +++ b/src/diffusers/schedulers/sigmas/__init__.py @@ -19,6 +19,7 @@ OptionalDependencyNotAvailable, _LazyModule, get_objects_from_module, + is_scipy_available, is_torch_available, is_transformers_available, ) @@ -35,10 +36,21 @@ _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) else: - _import_structure["beta_sigmas"] = ["BetaSigmas"] _import_structure["exponential_sigmas"] = ["ExponentialSigmas"] _import_structure["karras_sigmas"] = ["KarrasSigmas"] +try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_scipy_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_scipy_objects)) + +else: + _import_structure["beta_sigmas"] = ["BetaSigmas"] + + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not is_torch_available(): @@ -47,10 +59,17 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_pt_objects import * # noqa F403 else: - from .beta_sigmas import BetaSigmas from .exponential_sigmas import ExponentialSigmas from .karras_sigmas import KarrasSigmas + try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_scipy_objects import * # noqa F403 + else: + from .beta_sigmas import BetaSigmas + else: import sys diff --git a/src/diffusers/schedulers/sigmas/beta_sigmas.py b/src/diffusers/schedulers/sigmas/beta_sigmas.py index a390bd9c90ea..67269a5496ab 100644 --- a/src/diffusers/schedulers/sigmas/beta_sigmas.py +++ b/src/diffusers/schedulers/sigmas/beta_sigmas.py @@ -40,6 +40,8 @@ def __init__( self.beta = beta def __call__(self, in_sigmas: torch.Tensor): + if not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") sigma_min = self.sigma_min if sigma_min is None: sigma_min = in_sigmas[-1].item() diff --git a/src/diffusers/utils/dummy_torch_and_scipy_objects.py b/src/diffusers/utils/dummy_torch_and_scipy_objects.py index a1ff25863822..94aa351ddeec 100644 --- a/src/diffusers/utils/dummy_torch_and_scipy_objects.py +++ b/src/diffusers/utils/dummy_torch_and_scipy_objects.py @@ -15,3 +15,9 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "scipy"]) + +class BetaSigmas(metaclass=DummyObject): + _backends = ["torch", "scipy"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "scipy"]) From 8703cdc962aebbdf54daa68ec06f2bfa0b600c78 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 18 Dec 2024 13:31:35 +0000 Subject: [PATCH 03/20] make --- src/diffusers/utils/dummy_torch_and_scipy_objects.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/diffusers/utils/dummy_torch_and_scipy_objects.py b/src/diffusers/utils/dummy_torch_and_scipy_objects.py index 94aa351ddeec..a1ff25863822 100644 --- a/src/diffusers/utils/dummy_torch_and_scipy_objects.py +++ b/src/diffusers/utils/dummy_torch_and_scipy_objects.py @@ -15,9 +15,3 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "scipy"]) - -class BetaSigmas(metaclass=DummyObject): - _backends = ["torch", "scipy"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "scipy"]) From ada44e7a3e31fc751acec1185f97f515645fd013 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 18 Dec 2024 13:40:19 +0000 Subject: [PATCH 04/20] check_torch_dependencies --- src/diffusers/schedulers/scheduling_heun_discrete.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index da9c6ce67374..4ad4b85464b6 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -22,7 +22,11 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, is_scipy_available from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin -from .sigmas import BetaSigmas, ExponentialSigmas, KarrasSigmas +from .sigmas import ExponentialSigmas, KarrasSigmas + + +if is_scipy_available(): + from .sigmas import BetaSigmas @dataclass From f12841cd546ae219f1f5cd99e943d29662bfe411 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 18 Dec 2024 19:02:11 +0000 Subject: [PATCH 05/20] --- src/diffusers/configuration_utils.py | 28 ++ src/diffusers/pipelines/flux/pipeline_flux.py | 4 +- .../pipelines/mochi/pipeline_mochi.py | 8 +- .../schedulers/schedules/__init__.py | 0 .../schedulers/schedules/beta_schedule.py | 252 ++++++++++++ .../schedulers/schedules/flow_schedule.py | 183 +++++++++ .../scheduling_euler_ancestral_discrete.py | 211 +++------- .../schedulers/scheduling_euler_discrete.py | 372 ++---------------- .../schedulers/scheduling_heun_discrete.py | 209 +++------- src/diffusers/schedulers/scheduling_utils.py | 57 ++- src/diffusers/schedulers/sigmas/__init__.py | 79 ---- .../schedulers/sigmas/beta_sigmas.py | 3 +- .../schedulers/sigmas/exponential_sigmas.py | 3 +- .../schedulers/sigmas/karras_sigmas.py | 3 +- 14 files changed, 680 insertions(+), 732 deletions(-) create mode 100644 src/diffusers/schedulers/schedules/__init__.py create mode 100644 src/diffusers/schedulers/schedules/beta_schedule.py create mode 100644 src/diffusers/schedulers/schedules/flow_schedule.py delete mode 100644 src/diffusers/schedulers/sigmas/__init__.py diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index d21ada6fbe60..8184d6247661 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -245,6 +245,34 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + # Handle old scheduler configs + if "Scheduler" in cls.__name__ and "schedule_config" not in config: + prediction_type = config.pop("prediction_type", None) + _class_name = config.pop("_class_name", None) + _diffusers_version = config.pop("_diffusers_version", None) + use_karras_sigmas = config.pop("use_karras_sigmas", None) + use_exponential_sigmas = config.pop("use_exponential_sigmas", None) + use_beta_sigmas = config.pop("use_beta_sigmas", None) + if use_karras_sigmas: + sigma_schedule_config = {"class_name": "KarrasSigmas"} + elif use_exponential_sigmas: + sigma_schedule_config = {"class_name": "ExponentialSigmas"} + elif use_beta_sigmas: + sigma_schedule_config = {"class_name": "BetaSigmas"} + else: + sigma_schedule_config = {} + if "beta_schedule" in config: + config.update({"class_name": "BetaSchedule"}) + elif "shift" in config: + config.update({"class_name": "FlowMatchSchedule"}) + config = { + "_class_name": _class_name, + "_diffusers_version": _diffusers_version, + "prediction_type": prediction_type, + "schedule_config": config, + "sigma_schedule_config": sigma_schedule_config, + } + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) # Allow dtype to be specified on initialization diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index ec2801625552..c8b8306eea9c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast @@ -699,7 +698,8 @@ def __call__( ) # 5. Prepare timesteps - sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if self.scheduler.schedule.__class__.__name__ != "FlowMatchFlux": + self.scheduler._schedule.set_base_schedule("FlowMatchFlux") image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index aac4e32e33f0..72a43827c9bf 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import T5EncoderModel, T5TokenizerFast @@ -495,6 +494,7 @@ def __call__( num_frames: int = 19, num_inference_steps: int = 64, timesteps: List[int] = None, + sigmas: List[float] = None, guidance_scale: float = 4.5, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -652,10 +652,8 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 5. Prepare timestep - # from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 - threshold_noise = 0.025 - sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) - sigmas = np.array(sigmas) + if self.scheduler.schedule.__class__.__name__ != "FlowMatchLinearQuadratic": + self.scheduler._schedule.set_base_schedule("FlowMatchLinearQuadratic") timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/schedulers/schedules/__init__.py b/src/diffusers/schedulers/schedules/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/diffusers/schedulers/schedules/beta_schedule.py b/src/diffusers/schedulers/schedules/beta_schedule.py new file mode 100644 index 000000000000..d5b93c3b3923 --- /dev/null +++ b/src/diffusers/schedulers/schedules/beta_schedule.py @@ -0,0 +1,252 @@ +from typing import List, Optional, Union + +import math +import numpy as np +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ..sigmas.beta_sigmas import BetaSigmas +from ..sigmas.exponential_sigmas import ExponentialSigmas +from ..sigmas.karras_sigmas import KarrasSigmas + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class BetaSchedule: + + scale_model_input = True + + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + rescale_betas_zero_snr: bool = False, + interpolation_type: str = "linear", + timestep_spacing: str = "linspace", + timestep_type: str = "discrete", # can be "discrete" or "continuous" + steps_offset: int = 0, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" + **kwargs, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + self.num_train_timesteps = num_train_timesteps + self.beta_start = beta_start + self.beta_end = beta_end + self.beta_schedule = beta_schedule + self.trained_betas = trained_betas + self.rescale_betas_zero_snr = rescale_betas_zero_snr + self.interpolation_type = interpolation_type + self.timestep_spacing = timestep_spacing + self.timestep_type = timestep_type + self.steps_offset = steps_offset + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.final_sigmas_type = final_sigmas_type + + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def __call__( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + sigma_schedule: Optional[Union[KarrasSigmas, ExponentialSigmas, BetaSigmas]] = None, + **kwargs, + ): + if sigmas is not None: + log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) + sigmas = np.array(sigmas).astype(np.float32) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) + + else: + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.float32) + else: + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.timestep_spacing == "linspace": + timesteps = np.linspace( + 0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 + )[::-1].copy() + elif self.timestep_spacing == "leading": + step_ratio = self.num_train_timesteps // num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + ) + timesteps += self.steps_offset + elif self.timestep_spacing == "trailing": + step_ratio = self.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + if self.interpolation_type == "linear": + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + elif self.interpolation_type == "log_linear": + sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() + else: + raise ValueError( + f"{self.interpolation_type} is not implemented. Please specify interpolation_type to either" + " 'linear' or 'log_linear'" + ) + + if sigma_schedule is not None: + sigmas = sigma_schedule(sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + if self.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + + # TODO: Support the full EDM scalings for all prediction types and timestep types + if self.timestep_type == "continuous" and self.prediction_type == "v_prediction": + timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device) + else: + timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) + + return sigmas, timesteps diff --git a/src/diffusers/schedulers/schedules/flow_schedule.py b/src/diffusers/schedulers/schedules/flow_schedule.py new file mode 100644 index 000000000000..a8332ee75f12 --- /dev/null +++ b/src/diffusers/schedulers/schedules/flow_schedule.py @@ -0,0 +1,183 @@ +from typing import List, Optional, Union + +import math +import numpy as np +import torch + +from ...configuration_utils import ConfigMixin, register_to_config +from ..sigmas.beta_sigmas import BetaSigmas +from ..sigmas.exponential_sigmas import ExponentialSigmas +from ..sigmas.karras_sigmas import KarrasSigmas + +class FlowMatchSD3: + + def _sigma_to_t(self, sigma): + return sigma * self.num_train_timesteps + + def __call__(self, num_inference_steps: int, num_train_timesteps: int, shift: float, use_dynamic_shifting: bool = False, **kwargs): + self.num_train_timesteps = num_train_timesteps + timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + sigma_min = sigmas[-1].item() + sigma_max = sigmas[0].item() + timesteps = np.linspace( + self._sigma_to_t(sigma_max), self._sigma_to_t(sigma_min), num_inference_steps + ) + sigmas = timesteps / num_train_timesteps + return sigmas + +class FlowMatchFlux: + def __call__(self, num_inference_steps: int, **kwargs): + return np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + +class FlowMatchLinearQuadratic: + def __call__(self, num_inference_steps: int, threshold_noise: float = 0.25, linear_steps: Optional[int] = None, **kwargs): + if linear_steps is None: + linear_steps = num_inference_steps // 2 + linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] + threshold_noise_step_diff = linear_steps - threshold_noise * num_inference_steps + quadratic_steps = num_inference_steps - linear_steps + quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2) + linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2) + const = quadratic_coef * (linear_steps**2) + quadratic_sigma_schedule = [ + quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_inference_steps) + ] + sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + sigma_schedule = [1.0 - x for x in sigma_schedule] + return sigma_schedule + +class FlowMatchHunyuanVideo: + def __call__(self, num_inference_steps: int, **kwargs): + return np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].copy() + +BASE_SCHEDULE_MAP = { + "FlowMatchHunyuanVideo": FlowMatchHunyuanVideo, + "FlowMatchLinearQuadratic": FlowMatchLinearQuadratic, + "FlowMatchFlux": FlowMatchFlux, + "FlowMatchSD3": FlowMatchSD3, +} + +class FlowMatchSchedule: + + scale_model_input = False + + base_schedules = BASE_SCHEDULE_MAP + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + base_schedule: Optional[Union[str]] = None, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + **kwargs, + ): + self.set_base_schedule(base_schedule) + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.use_dynamic_shifting = use_dynamic_shifting + self.base_shift = base_shift + self.max_shift = max_shift + self.base_image_seq_len = base_image_seq_len + self.max_image_seq_len = max_image_seq_len + self.invert_sigmas = invert_sigmas + self.shift_terminal = shift_terminal + self.sigma_min = sigma_min + self.sigma_max = sigma_max + + def set_base_schedule(self, base_schedule: Union[str]): + if base_schedule is None: + raise ValueError("Must set base schedule.") + if isinstance(base_schedule, str): + if base_schedule not in self.base_schedules: + raise ValueError(f"Expected one of {self.base_schedules.keys()}") + _class = self.base_schedules[base_schedule] + self.base_schedule = _class() + else: + self.base_schedule = base_schedule() + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def __call__( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + sigma_schedule: Optional[Union[KarrasSigmas, ExponentialSigmas, BetaSigmas]] = None, + mu: Optional[float] = None, + shift: Optional[float] = None, + ): + shift = shift or self.shift + if self.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = self.base_schedule( + num_inference_steps=num_inference_steps, + num_train_timesteps=self.num_train_timesteps, + shift=shift, + use_dynamic_shifting=self.use_dynamic_shifting, + ) + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + self.num_inference_steps = num_inference_steps + + if self.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + if self.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if sigma_schedule is not None: + sigmas = sigma_schedule(sigmas) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + timesteps = sigmas * self.num_train_timesteps + + if self.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + return sigmas, timesteps diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 4df43a160ce1..d66af0b04deb 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union -import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config @@ -47,88 +45,6 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.Tensor] = None -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Ancestral sampling with Euler method steps. @@ -169,53 +85,19 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + schedule_config, + sigma_schedule_config, prediction_type: str = "epsilon", - timestep_spacing: str = "linspace", - steps_offset: int = 0, - rescale_betas_zero_snr: bool = False, ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") - - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.betas) - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - if rescale_betas_zero_snr: - # Close to 0 without being 0 so first sigma is not inf - # FP16 smallest positive subnormal works well here - self.alphas_cumprod[-1] = 2**-24 - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) + self.set_schedule(schedule_config) + self.set_sigma_schedule(sigma_schedule_config) # setable values self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.is_scale_input_called = False + self.is_scale_input_called = False self._step_index = None self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): @@ -274,7 +156,15 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T self.is_scale_input_called = True return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + shift: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -284,39 +174,58 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - self.num_inference_steps = num_inference_steps - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ - ::-1 - ].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - timesteps -= 1 - else: + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` should be set.") + if num_inference_steps is None and timesteps is None and sigmas is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.") + if num_inference_steps is not None and (timesteps is not None or sigmas is not None): + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "KarrasSigmas" + ): + raise ValueError("Cannot set `timesteps` with `KarrasSigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "ExponentialSigmas" + ): + raise ValueError("Cannot set `timesteps` with `ExponentialSigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "BetaSigmas" + ): + raise ValueError("Cannot set `timesteps` with `BetaSigmas`.") + if ( + timesteps is not None + and self._schedule.config.get("timestep_type", None) == "continuous" + and self.config.prediction_type == "v_prediction" + ): raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) + if num_inference_steps is None: + num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 + self.num_inference_steps = num_inference_steps + + sigmas, timesteps = self._schedule( + num_inference_steps=num_inference_steps, + device=device, + timesteps=timesteps, + sigmas=sigmas, + sigma_schedule=self._sigma_schedule, + mu=mu, + shift=shift, + ) - self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.timesteps = timesteps.to(device=device) self._step_index = None self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): @@ -384,7 +293,7 @@ def step( ), ) - if not self.is_scale_input_called: + if self._schedule.scale_model_input and not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 56757f3ca197..a2191346accc 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -12,22 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union -import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, is_scipy_available, logging +from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin -if is_scipy_available(): - import scipy.stats - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -50,88 +45,6 @@ class EulerDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.Tensor] = None -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr -def rescale_zero_terminal_snr(betas): - """ - Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) - - - Args: - betas (`torch.Tensor`): - the betas that the scheduler is being initialized with. - - Returns: - `torch.Tensor`: rescaled betas with zero terminal SNR - """ - # Convert betas to alphas_bar_sqrt - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_bar_sqrt = alphas_cumprod.sqrt() - - # Store old values. - alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() - alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() - - # Shift so the last timestep is zero. - alphas_bar_sqrt -= alphas_bar_sqrt_T - - # Scale so the first timestep is back to the old value. - alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) - - # Convert alphas_bar_sqrt to betas - alphas_bar = alphas_bar_sqrt**2 # Revert sqrt - alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod - alphas = torch.cat([alphas_bar[0:1], alphas]) - betas = 1 - alphas - - return betas - - class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Euler scheduler. @@ -186,77 +99,19 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + schedule_config, + sigma_schedule_config, prediction_type: str = "epsilon", - interpolation_type: str = "linear", - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, - timestep_spacing: str = "linspace", - timestep_type: str = "discrete", # can be "discrete" or "continuous" - steps_offset: int = 0, - rescale_betas_zero_snr: bool = False, - final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" ): - if self.config.use_beta_sigmas and not is_scipy_available(): - raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError( - "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." - ) - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") - - if rescale_betas_zero_snr: - self.betas = rescale_zero_terminal_snr(self.betas) - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - if rescale_betas_zero_snr: - # Close to 0 without being 0 so first sigma is not inf - # FP16 smallest positive subnormal works well here - self.alphas_cumprod[-1] = 2**-24 - - sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0) - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + self.set_schedule(schedule_config) + self.set_sigma_schedule(sigma_schedule_config) # setable values self.num_inference_steps = None - # TODO: Support the full EDM scalings for all prediction types and timestep types - if timestep_type == "continuous" and prediction_type == "v_prediction": - self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) - else: - self.timesteps = timesteps - - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - self.is_scale_input_called = False - self.use_karras_sigmas = use_karras_sigmas - self.use_exponential_sigmas = use_exponential_sigmas - self.use_beta_sigmas = use_beta_sigmas - self._step_index = None self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): @@ -322,6 +177,8 @@ def set_timesteps( device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + shift: Optional[float] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -348,15 +205,27 @@ def set_timesteps( raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.") if num_inference_steps is not None and (timesteps is not None or sigmas is not None): raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") - if timesteps is not None and self.config.use_karras_sigmas: + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "KarrasSigmas" + ): raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") - if timesteps is not None and self.config.use_exponential_sigmas: - raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") - if timesteps is not None and self.config.use_beta_sigmas: - raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.") if ( timesteps is not None - and self.config.timestep_type == "continuous" + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "ExponentialSigmas" + ): + raise ValueError("Cannot set `timesteps` with `ExponentialSigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "BetaSigmas" + ): + raise ValueError("Cannot set `timesteps` with `BetaSigmas`.") + if ( + timesteps is not None + and self._schedule.config.get("timestep_type", None) == "continuous" and self.config.prediction_type == "v_prediction" ): raise ValueError( @@ -367,190 +236,21 @@ def set_timesteps( num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 self.num_inference_steps = num_inference_steps - if sigmas is not None: - log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) - sigmas = np.array(sigmas).astype(np.float32) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) - - else: - if timesteps is not None: - timesteps = np.array(timesteps).astype(np.float32) - else: - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = np.linspace( - 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 - )[::-1].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - ) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - ) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - if self.config.interpolation_type == "linear": - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - elif self.config.interpolation_type == "log_linear": - sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() - else: - raise ValueError( - f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" - " 'linear' or 'log_linear'" - ) - - if self.config.use_karras_sigmas: - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - - elif self.config.use_exponential_sigmas: - sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - - elif self.config.use_beta_sigmas: - sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) - - # TODO: Support the full EDM scalings for all prediction types and timestep types - if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": - self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device) - else: - self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) + sigmas, timesteps = self._schedule( + num_inference_steps=num_inference_steps, + device=device, + timesteps=timesteps, + sigmas=sigmas, + sigma_schedule=self._sigma_schedule, + mu=mu, + shift=shift, + ) self._step_index = None self._begin_index = None + self.timesteps = timesteps self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(np.maximum(sigma, 1e-10)) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - - # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - # Copied from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L26 - def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) - return sigmas - - def _convert_to_beta( - self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 - ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.array( - [ - sigma_min + (ppf * (sigma_max - sigma_min)) - for ppf in [ - scipy.stats.beta.ppf(timestep, alpha, beta) - for timestep in 1 - np.linspace(0, 1, num_inference_steps) - ] - ] - ) - return sigmas - def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -622,7 +322,7 @@ def step( ), ) - if not self.is_scale_input_called: + if self._schedule.scale_model_input and not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 4ad4b85464b6..017c855a4b43 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -20,13 +19,8 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, is_scipy_available +from ..utils import BaseOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin -from .sigmas import ExponentialSigmas, KarrasSigmas - - -if is_scipy_available(): - from .sigmas import BetaSigmas @dataclass @@ -48,51 +42,6 @@ class HeunDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.Tensor] = None -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): """ Scheduler with Heun steps for discrete beta schedules. @@ -134,46 +83,15 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( self, - num_train_timesteps: int = 1000, - beta_start: float = 0.00085, # sensible defaults - beta_end: float = 0.012, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + schedule_config, + sigma_schedule_config, prediction_type: str = "epsilon", - sigma_schedule: Optional[Union[BetaSigmas, ExponentialSigmas, KarrasSigmas]] = None, - clip_sample: Optional[bool] = False, - clip_sample_range: float = 1.0, - timestep_spacing: str = "linspace", - steps_offset: int = 0, ): - if isinstance(sigma_schedule, BetaSigmas) and not is_scipy_available(): - raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine") - elif beta_schedule == "exp": - self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="exp") - else: - raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - self.sigma_schedule = sigma_schedule - - # set all values - self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + self.set_schedule(schedule_config) + self.set_sigma_schedule(sigma_schedule_config) self._step_index = None self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): @@ -253,8 +171,10 @@ def set_timesteps( self, num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, - num_train_timesteps: Optional[int] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + shift: Optional[float] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -272,53 +192,56 @@ def set_timesteps( generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` must be `None`, and `timestep_spacing` attribute will be ignored. """ - if num_inference_steps is None and timesteps is None: - raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") - if num_inference_steps is not None and timesteps is not None: - raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") - if timesteps is not None and self.sigma_schedule is not None: - raise ValueError("Cannot use `timesteps` with `sigma_schedule`") - - num_inference_steps = num_inference_steps or len(timesteps) + + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` should be set.") + if num_inference_steps is None and timesteps is None and sigmas is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.") + if num_inference_steps is not None and (timesteps is not None or sigmas is not None): + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "KarrasSigmas" + ): + raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "ExponentialSigmas" + ): + raise ValueError("Cannot set `timesteps` with `ExponentialSigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "BetaSigmas" + ): + raise ValueError("Cannot set `timesteps` with `BetaSigmas`.") + if ( + timesteps is not None + and self._schedule.config.get("timestep_type", None) == "continuous" + and self.config.prediction_type == "v_prediction" + ): + raise ValueError( + "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." + ) + + if num_inference_steps is None: + num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 self.num_inference_steps = num_inference_steps - num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps - if timesteps is not None: - timesteps = np.array(timesteps, dtype=np.float32) - else: - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - timesteps -= 1 - else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - - if self.sigma_schedule is not None: - sigmas = self.sigma_schedule(sigmas) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - sigmas = torch.from_numpy(sigmas).to(device=device) + sigmas, timesteps = self._schedule( + num_inference_steps=num_inference_steps, + device=device, + timesteps=timesteps, + sigmas=sigmas, + sigma_schedule=self._sigma_schedule, + mu=mu, + shift=shift, + ) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - timesteps = torch.from_numpy(timesteps) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) self.timesteps = timesteps.to(device=device) @@ -331,30 +254,6 @@ def set_timesteps( self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(np.maximum(sigma, 1e-10)) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - @property def state_in_first_order(self): return self.dt is None diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index f20224b19009..94b2bcaa5573 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -15,12 +15,17 @@ import os from dataclasses import dataclass from enum import Enum -from typing import Optional, Union +from typing import Dict, Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args from ..utils import BaseOutput, PushToHubMixin +from .schedules.beta_schedule import BetaSchedule +from .schedules.flow_schedule import FlowMatchSchedule +from .sigmas.beta_sigmas import BetaSigmas +from .sigmas.exponential_sigmas import ExponentialSigmas +from .sigmas.karras_sigmas import KarrasSigmas SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -56,6 +61,17 @@ class KarrasDiffusionSchedulers(Enum): "StableDiffusionVideoSigmas": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.0], } +SCHEDULE_MAP = { + "BetaSchedule": BetaSchedule, + "FlowMatchSchedule": FlowMatchSchedule, +} + +SIGMA_SCHEDULE_MAP = { + "BetaSigmas": BetaSigmas, + "ExponentialSigmas": ExponentialSigmas, + "KarrasSigmas": KarrasSigmas, +} + @dataclass class SchedulerOutput(BaseOutput): @@ -90,6 +106,10 @@ class SchedulerMixin(PushToHubMixin): config_name = SCHEDULER_CONFIG_NAME _compatibles = [] has_compatibles = True + schedule_configs = SCHEDULE_MAP + sigma_configs = SIGMA_SCHEDULE_MAP + _schedule = None + _sigma_schedule = None @classmethod @validate_hf_hub_args @@ -191,3 +211,38 @@ def _get_compatibles(cls): getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes + + def set_schedule(self, schedule: Union[Dict]): + if isinstance(schedule, dict): + class_name = schedule.get("class_name", None) + if class_name is None: + raise ValueError("Schedule config `class_name` is None.") + elif class_name not in self.schedule_configs: + raise ValueError(f"Expected one of {self.schedule_configs.keys()}") + _class = self.schedule_configs[class_name] + self._schedule = _class(**schedule) + else: + self._schedule = schedule + + def set_sigma_schedule(self, sigma_schedule: Union[Dict]): + if isinstance(sigma_schedule, dict): + if not sigma_schedule: + self._sigma_schedule = None + return + class_name = sigma_schedule.get("class_name", None) + if class_name is None: + raise ValueError("Schedule config `class_name` is None.") + elif class_name not in self.sigma_configs: + raise ValueError(f"Expected one of {self.sigma_configs.keys()}") + _class = self.sigma_configs[class_name] + sigma_min = getattr(self._schedule, "sigma_min", None) or sigma_schedule.get("sigma_min", None) + sigma_max = getattr(self._schedule, "sigma_max", None) or sigma_schedule.get("sigma_max", None) + sigma_schedule.update( + { + "sigma_min": sigma_min, + "sigma_max": sigma_max, + } + ) + self._sigma_schedule = _class(**sigma_schedule) + else: + self._sigma_schedule = sigma_schedule diff --git a/src/diffusers/schedulers/sigmas/__init__.py b/src/diffusers/schedulers/sigmas/__init__.py deleted file mode 100644 index a66b34890c3f..000000000000 --- a/src/diffusers/schedulers/sigmas/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -from ...utils import ( - DIFFUSERS_SLOW_IMPORT, - OptionalDependencyNotAvailable, - _LazyModule, - get_objects_from_module, - is_scipy_available, - is_torch_available, - is_transformers_available, -) - - -_dummy_objects = {} -_import_structure = {} - -try: - if not (is_transformers_available() and is_torch_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_pt_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) -else: - _import_structure["exponential_sigmas"] = ["ExponentialSigmas"] - _import_structure["karras_sigmas"] = ["KarrasSigmas"] - -try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils import dummy_torch_and_scipy_objects # noqa F403 - - _dummy_objects.update(get_objects_from_module(dummy_torch_and_scipy_objects)) - -else: - _import_structure["beta_sigmas"] = ["BetaSigmas"] - - -if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: - try: - if not is_torch_available(): - raise OptionalDependencyNotAvailable() - - except OptionalDependencyNotAvailable: - from ...utils.dummy_pt_objects import * # noqa F403 - else: - from .exponential_sigmas import ExponentialSigmas - from .karras_sigmas import KarrasSigmas - - try: - if not (is_torch_available() and is_scipy_available()): - raise OptionalDependencyNotAvailable() - except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_scipy_objects import * # noqa F403 - else: - from .beta_sigmas import BetaSigmas - - -else: - import sys - - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) - for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/schedulers/sigmas/beta_sigmas.py b/src/diffusers/schedulers/sigmas/beta_sigmas.py index 67269a5496ab..d056f857f71d 100644 --- a/src/diffusers/schedulers/sigmas/beta_sigmas.py +++ b/src/diffusers/schedulers/sigmas/beta_sigmas.py @@ -31,6 +31,7 @@ def __init__( sigma_max: Optional[float] = None, alpha: float = 0.6, beta: float = 0.6, + **kwargs, ): if not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -39,7 +40,7 @@ def __init__( self.alpha = alpha self.beta = beta - def __call__(self, in_sigmas: torch.Tensor): + def __call__(self, in_sigmas: torch.Tensor, **kwargs): if not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") sigma_min = self.sigma_min diff --git a/src/diffusers/schedulers/sigmas/exponential_sigmas.py b/src/diffusers/schedulers/sigmas/exponential_sigmas.py index 70f0a794fe0b..d67883904ad7 100644 --- a/src/diffusers/schedulers/sigmas/exponential_sigmas.py +++ b/src/diffusers/schedulers/sigmas/exponential_sigmas.py @@ -24,11 +24,12 @@ def __init__( self, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, + **kwargs, ): self.sigma_min = sigma_min self.sigma_max = sigma_max - def __call__(self, in_sigmas: torch.Tensor): + def __call__(self, in_sigmas: torch.Tensor, **kwargs): sigma_min = self.sigma_min if sigma_min is None: sigma_min = in_sigmas[-1].item() diff --git a/src/diffusers/schedulers/sigmas/karras_sigmas.py b/src/diffusers/schedulers/sigmas/karras_sigmas.py index a60d5e017ef4..f2a17b0921a9 100644 --- a/src/diffusers/schedulers/sigmas/karras_sigmas.py +++ b/src/diffusers/schedulers/sigmas/karras_sigmas.py @@ -24,12 +24,13 @@ def __init__( sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, rho: float = 7.0, + **kwargs, ): self.sigma_min = sigma_min self.sigma_max = sigma_max self.rho = rho - def __call__(self, in_sigmas: torch.Tensor): + def __call__(self, in_sigmas: torch.Tensor, **kwargs): sigma_min = self.sigma_min if sigma_min is None: sigma_min = in_sigmas[-1].item() From cc849e2ff1dc4df8f060dfca6c9f809d5cca7435 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 19 Dec 2024 11:07:55 +0000 Subject: [PATCH 06/20] Notes, sana schedule, scale_noise->add_noise --- .../pipelines/flux/pipeline_flux_img2img.py | 4 +- .../schedulers/schedules/beta_schedule.py | 23 ++++--- .../schedulers/schedules/flow_schedule.py | 61 +++++++++++++------ .../scheduling_euler_ancestral_discrete.py | 5 +- .../schedulers/scheduling_euler_discrete.py | 5 +- .../schedulers/scheduling_heun_discrete.py | 5 +- 6 files changed, 68 insertions(+), 35 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 2b336fbdd472..ab9a07ae6b44 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -562,7 +562,9 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + # NOTE: `scale_noise` changed to `add_noise` + # the signature is `noise`, `timestep` instead of `timestep`, `noise` + latents = self.scheduler.add_noise(image_latents, noise, timestep) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids diff --git a/src/diffusers/schedulers/schedules/beta_schedule.py b/src/diffusers/schedulers/schedules/beta_schedule.py index d5b93c3b3923..78bf58c50dc8 100644 --- a/src/diffusers/schedulers/schedules/beta_schedule.py +++ b/src/diffusers/schedulers/schedules/beta_schedule.py @@ -1,14 +1,14 @@ +import math from typing import List, Optional, Union -import math import numpy as np import torch -from ...configuration_utils import ConfigMixin, register_to_config from ..sigmas.beta_sigmas import BetaSigmas from ..sigmas.exponential_sigmas import ExponentialSigmas from ..sigmas.karras_sigmas import KarrasSigmas + def betas_for_alpha_bar( num_diffusion_timesteps, max_beta=0.999, @@ -52,6 +52,7 @@ def alpha_bar_fn(t): betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) return torch.tensor(betas, dtype=torch.float32) + def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) @@ -89,7 +90,6 @@ def rescale_zero_terminal_snr(betas): class BetaSchedule: - scale_model_input = True def __init__( @@ -132,7 +132,7 @@ def __init__( # Close to 0 without being 0 so first sigma is not inf # FP16 smallest positive subnormal works well here self.alphas_cumprod[-1] = 2**-24 - + self.num_train_timesteps = num_train_timesteps self.beta_start = beta_start self.beta_end = beta_end @@ -181,6 +181,7 @@ def __call__( ): if sigmas is not None: log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) + # NOTE: current usage is **with** `sigma_last` - different than FlowMatch. sigmas = np.array(sigmas).astype(np.float32) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) @@ -190,9 +191,9 @@ def __call__( else: # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.timestep_spacing == "linspace": - timesteps = np.linspace( - 0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 - )[::-1].copy() + timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ + ::-1 + ].copy() elif self.timestep_spacing == "leading": step_ratio = self.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio @@ -205,9 +206,7 @@ def __call__( step_ratio = self.num_train_timesteps / num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) - ) + timesteps = (np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) timesteps -= 1 else: raise ValueError( @@ -240,7 +239,7 @@ def __call__( ) sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) # TODO: Support the full EDM scalings for all prediction types and timestep types @@ -248,5 +247,5 @@ def __call__( timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device) else: timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) - + return sigmas, timesteps diff --git a/src/diffusers/schedulers/schedules/flow_schedule.py b/src/diffusers/schedulers/schedules/flow_schedule.py index a8332ee75f12..ba96755168b2 100644 --- a/src/diffusers/schedulers/schedules/flow_schedule.py +++ b/src/diffusers/schedulers/schedules/flow_schedule.py @@ -1,21 +1,32 @@ +import math from typing import List, Optional, Union -import math import numpy as np import torch -from ...configuration_utils import ConfigMixin, register_to_config from ..sigmas.beta_sigmas import BetaSigmas from ..sigmas.exponential_sigmas import ExponentialSigmas from ..sigmas.karras_sigmas import KarrasSigmas -class FlowMatchSD3: - - def _sigma_to_t(self, sigma): - return sigma * self.num_train_timesteps - def __call__(self, num_inference_steps: int, num_train_timesteps: int, shift: float, use_dynamic_shifting: bool = False, **kwargs): - self.num_train_timesteps = num_train_timesteps +class FlowMatchSD3: + def __call__( + self, + num_inference_steps: int, + num_train_timesteps: int, + shift: float, + use_dynamic_shifting: bool = False, + **kwargs, + ) -> np.ndarray: + """ + This is different to others that directly calculate `sigmas`. + It needs `sigma_min` and `sigma_max` after shift + https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L89-L95 + Then we calculate `sigmas` from that `sigma_min` and `sigma_max`. + https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L238-L240 + Shifting happens again after (outside of this). + https://github.com/huggingface/diffusers/blob/0ed09a17bbab784a78fb163b557b4827467b0468/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L248-L251 + """ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) @@ -25,18 +36,20 @@ def __call__(self, num_inference_steps: int, num_train_timesteps: int, shift: fl sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) sigma_min = sigmas[-1].item() sigma_max = sigmas[0].item() - timesteps = np.linspace( - self._sigma_to_t(sigma_max), self._sigma_to_t(sigma_min), num_inference_steps - ) + timesteps = np.linspace(sigma_max * num_train_timesteps, sigma_min * num_train_timesteps, num_inference_steps) sigmas = timesteps / num_train_timesteps return sigmas + class FlowMatchFlux: - def __call__(self, num_inference_steps: int, **kwargs): + def __call__(self, num_inference_steps: int, **kwargs) -> np.ndarray: return np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + class FlowMatchLinearQuadratic: - def __call__(self, num_inference_steps: int, threshold_noise: float = 0.25, linear_steps: Optional[int] = None, **kwargs): + def __call__( + self, num_inference_steps: int, threshold_noise: float = 0.25, linear_steps: Optional[int] = None, **kwargs + ) -> np.ndarray: if linear_steps is None: linear_steps = num_inference_steps // 2 linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)] @@ -49,22 +62,33 @@ def __call__(self, num_inference_steps: int, threshold_noise: float = 0.25, line quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_inference_steps) ] sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule - sigma_schedule = [1.0 - x for x in sigma_schedule] + sigma_schedule = np.array([1.0 - x for x in sigma_schedule]).astype(np.float32) return sigma_schedule + class FlowMatchHunyuanVideo: - def __call__(self, num_inference_steps: int, **kwargs): + def __call__(self, num_inference_steps: int, **kwargs) -> np.ndarray: return np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].copy() + +class FlowMatchSANA: + def __call__(self, num_inference_steps: int, num_train_timesteps: int, shift: float, **kwargs) -> np.ndarray: + alphas = np.linspace(1, 1 / num_train_timesteps, num_inference_steps + 1) + sigmas = 1.0 - alphas + sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy() + return sigmas + + BASE_SCHEDULE_MAP = { "FlowMatchHunyuanVideo": FlowMatchHunyuanVideo, "FlowMatchLinearQuadratic": FlowMatchLinearQuadratic, "FlowMatchFlux": FlowMatchFlux, "FlowMatchSD3": FlowMatchSD3, + "FlowMatchSANA": FlowMatchSANA, } -class FlowMatchSchedule: +class FlowMatchSchedule: scale_model_input = False base_schedules = BASE_SCHEDULE_MAP @@ -145,7 +169,7 @@ def __call__( ): shift = shift or self.shift if self.use_dynamic_shifting and mu is None: - raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + raise ValueError("You have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") if sigmas is None: sigmas = self.base_schedule( @@ -155,9 +179,8 @@ def __call__( use_dynamic_shifting=self.use_dynamic_shifting, ) else: + # NOTE: current usage is **without** `sigma_last` - different than BetaSchedule sigmas = np.array(sigmas).astype(np.float32) - num_inference_steps = len(sigmas) - self.num_inference_steps = num_inference_steps if self.use_dynamic_shifting: sigmas = self.time_shift(mu, 1.0, sigmas) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index d66af0b04deb..8b4c5e4a242d 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -384,7 +384,10 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + if self._schedule.__class__.__name__ == "FlowMatchSchedule": + noisy_samples = (1.0 - sigma) * original_samples + noise * sigma + else: + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index a2191346accc..3700e1670deb 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -413,7 +413,10 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + if self._schedule.__class__.__name__ == "FlowMatchSchedule": + noisy_samples = (1.0 - sigma) * original_samples + noise * sigma + else: + noisy_samples = original_samples + noise * sigma return noisy_samples def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 017c855a4b43..3022e3851327 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -401,7 +401,10 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + if self._schedule.__class__.__name__ == "FlowMatchSchedule": + noisy_samples = (1.0 - sigma) * original_samples + noise * sigma + else: + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): From 323806c5f3acce96e4ddde2c9c290b0165b9de4c Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:16:59 +0000 Subject: [PATCH 07/20] SamplingMixin --- .../schedulers/schedules/beta_schedule.py | 4 + src/diffusers/schedulers/scheduling_utils.py | 132 ++++++++++++++++++ src/diffusers/schedulers/sigmas/__init__.py | 0 3 files changed, 136 insertions(+) create mode 100644 src/diffusers/schedulers/sigmas/__init__.py diff --git a/src/diffusers/schedulers/schedules/beta_schedule.py b/src/diffusers/schedulers/schedules/beta_schedule.py index 78bf58c50dc8..86b8d14f601b 100644 --- a/src/diffusers/schedulers/schedules/beta_schedule.py +++ b/src/diffusers/schedulers/schedules/beta_schedule.py @@ -104,6 +104,8 @@ def __init__( timestep_spacing: str = "linspace", timestep_type: str = "discrete", # can be "discrete" or "continuous" steps_offset: int = 0, + clip_sample: Optional[bool] = False, + clip_sample_range: float = 1.0, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" @@ -143,6 +145,8 @@ def __init__( self.timestep_spacing = timestep_spacing self.timestep_type = timestep_type self.steps_offset = steps_offset + self.clip_sample = clip_sample + self.clip_sample_range = clip_sample_range self.sigma_min = sigma_min self.sigma_max = sigma_max self.final_sigmas_type = final_sigmas_type diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 94b2bcaa5573..14b95fe25f12 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -246,3 +246,135 @@ def set_sigma_schedule(self, sigma_schedule: Union[Dict]): self._sigma_schedule = _class(**sigma_schedule) else: self._sigma_schedule = sigma_schedule + +class SamplingMixin: + _step_index = None + _begin_index = None + timesteps = None + num_inference_steps = None + is_scale_input_called = False + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + if self._schedule.__class__.__name__ == "FlowMatchSchedule": + noisy_samples = (1.0 - sigma) * original_samples + noise * sigma + else: + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: + if self._schedule.__class__.__name__ != "BetaSchedule": + raise ValueError("`get_velocity` only supports `BetaSchedule`.") + if ( + isinstance(timesteps, int) + or isinstance(timesteps, torch.IntTensor) + or isinstance(timesteps, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if sample.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timesteps = timesteps.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timesteps = timesteps.to(sample.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + alphas_cumprod = self._schedule.alphas_cumprod.to(sample) + sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/sigmas/__init__.py b/src/diffusers/schedulers/sigmas/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From 488fb7bb062e1079425cba18caae7af26b0f3529 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:17:12 +0000 Subject: [PATCH 08/20] EulerAncestralDiscreteScheduler --- .../scheduling_euler_ancestral_discrete.py | 118 ++---------------- 1 file changed, 13 insertions(+), 105 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 8b4c5e4a242d..ca092b87cd23 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SamplingMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -45,7 +45,7 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.Tensor] = None -class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin, SamplingMixin): """ Ancestral sampling with Euler method steps. @@ -92,46 +92,17 @@ def __init__( self.set_schedule(schedule_config) self.set_sigma_schedule(sigma_schedule_config) - # setable values - self.num_inference_steps = None - - self.is_scale_input_called = False - self._step_index = None - self._begin_index = None - @property + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.init_noise_sigma def init_noise_sigma(self): # standard deviation of the initial noise distribution + max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() if self.config.timestep_spacing in ["linspace", "trailing"]: - return self.sigmas.max() - - return (self.sigmas.max() ** 2 + 1) ** 0.5 - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index + return max_sigma - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index + return (max_sigma**2 + 1) ** 0.5 + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -147,18 +118,19 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T `torch.Tensor`: A scaled input sample. """ - if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True return sample + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.set_timesteps def set_timesteps( self, - num_inference_steps: int, + num_inference_steps: int = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, @@ -201,11 +173,11 @@ def set_timesteps( raise ValueError("Cannot set `timesteps` with `BetaSigmas`.") if ( timesteps is not None - and self._schedule.config.get("timestep_type", None) == "continuous" + and self._schedule.timestep_type == "continuous" and self.config.prediction_type == "v_prediction" ): raise ValueError( - "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." + "Cannot set `timesteps` with `schedule.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) if num_inference_steps is None: @@ -222,35 +194,11 @@ def set_timesteps( shift=shift, ) - self.timesteps = timesteps.to(device=device) self._step_index = None self._begin_index = None + self.timesteps = timesteps.to(device=device) self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - def step( self, model_output: torch.Tensor, @@ -352,43 +300,3 @@ def step( return EulerAncestralDiscreteSchedulerOutput( prev_sample=prev_sample, pred_original_sample=pred_original_sample ) - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - if self._schedule.__class__.__name__ == "FlowMatchSchedule": - noisy_samples = (1.0 - sigma) * original_samples + noise * sigma - else: - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps From de5fe50fe4ed552b6107d6f6574de2e9ab35d7bb Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:17:17 +0000 Subject: [PATCH 09/20] EulerDiscreteScheduler --- .../schedulers/scheduling_euler_discrete.py | 153 +----------------- 1 file changed, 7 insertions(+), 146 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 3700e1670deb..321f90ff1409 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SamplingMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -45,7 +45,7 @@ class EulerDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.Tensor] = None -class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): +class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin, SamplingMixin): """ Euler scheduler. @@ -106,13 +106,6 @@ def __init__( self.set_schedule(schedule_config) self.set_sigma_schedule(sigma_schedule_config) - # setable values - self.num_inference_steps = None - - self.is_scale_input_called = False - self._step_index = None - self._begin_index = None - @property def init_noise_sigma(self): # standard deviation of the initial noise distribution @@ -122,31 +115,6 @@ def init_noise_sigma(self): return (max_sigma**2 + 1) ** 0.5 - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index - def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -188,15 +156,6 @@ def set_timesteps( The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated - based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` - must be `None`, and `timestep_spacing` attribute will be ignored. - sigmas (`List[float]`, *optional*): - Custom sigmas used to support arbitrary timesteps schedule schedule. If `None`, timesteps and sigmas - will be generated based on the relevant scheduler attributes. If `sigmas` is passed, - `num_inference_steps` and `timesteps` must be `None`, and the timesteps will be generated based on the - custom sigmas schedule. """ if timesteps is not None and sigmas is not None: @@ -210,7 +169,7 @@ def set_timesteps( and self._sigma_schedule is not None and self._sigma_schedule.__class__.__name__ == "KarrasSigmas" ): - raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") + raise ValueError("Cannot set `timesteps` with `KarrasSigmas`.") if ( timesteps is not None and self._sigma_schedule is not None @@ -225,11 +184,11 @@ def set_timesteps( raise ValueError("Cannot set `timesteps` with `BetaSigmas`.") if ( timesteps is not None - and self._schedule.config.get("timestep_type", None) == "continuous" + and self._schedule.timestep_type == "continuous" and self.config.prediction_type == "v_prediction" ): raise ValueError( - "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." + "Cannot set `timesteps` with `schedule.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) if num_inference_steps is None: @@ -248,30 +207,8 @@ def set_timesteps( self._step_index = None self._begin_index = None - self.timesteps = timesteps - self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - def _init_step_index(self, timestep): - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas.to("cpu") def step( self, @@ -382,79 +319,3 @@ def step( ) return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - if self._schedule.__class__.__name__ == "FlowMatchSchedule": - noisy_samples = (1.0 - sigma) * original_samples + noise * sigma - else: - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: - if ( - isinstance(timesteps, int) - or isinstance(timesteps, torch.IntTensor) - or isinstance(timesteps, torch.LongTensor) - ): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.get_velocity()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if sample.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) - timesteps = timesteps.to(sample.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(sample.device) - timesteps = timesteps.to(sample.device) - - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - alphas_cumprod = self.alphas_cumprod.to(sample) - sqrt_alpha_prod = alphas_cumprod[step_indices] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[step_indices]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity - - def __len__(self): - return self.config.num_train_timesteps From 6e5341bdd47b1ff9d398a5c7e7e8b615ede9e3ef Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:17:24 +0000 Subject: [PATCH 10/20] HeunDiscreteScheduler --- .../schedulers/scheduling_heun_discrete.py | 143 +++--------------- 1 file changed, 20 insertions(+), 123 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 3022e3851327..e1e6d4654b83 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SamplingMixin @dataclass @@ -42,7 +42,7 @@ class HeunDiscreteSchedulerOutput(BaseOutput): pred_original_sample: Optional[torch.Tensor] = None -class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): +class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin, SamplingMixin): """ Scheduler with Heun steps for discrete beta schedules. @@ -76,7 +76,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): An offset added to the inference steps, as required by some model families. """ - ignore_for_config = ["sigma_schedule"] _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @@ -90,65 +89,21 @@ def __init__( self.set_schedule(schedule_config) self.set_sigma_schedule(sigma_schedule_config) - self._step_index = None - self._begin_index = None - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - @property + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.init_noise_sigma def init_noise_sigma(self): # standard deviation of the initial noise distribution + max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max() if self.config.timestep_spacing in ["linspace", "trailing"]: - return self.sigmas.max() - - return (self.sigmas.max() ** 2 + 1) ** 0.5 + return max_sigma - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index + return (max_sigma**2 + 1) ** 0.5 - def scale_model_input( - self, - sample: torch.Tensor, - timestep: Union[float, torch.Tensor], - ) -> torch.Tensor: + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.Tensor`): @@ -165,11 +120,13 @@ def scale_model_input( sigma = self.sigmas[self.step_index] sample = sample / ((sigma**2 + 1) ** 0.5) + + self.is_scale_input_called = True return sample def set_timesteps( self, - num_inference_steps: Optional[int] = None, + num_inference_steps: int = None, device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, @@ -184,13 +141,6 @@ def set_timesteps( The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - num_train_timesteps (`int`, *optional*): - The number of diffusion steps used when training the model. If `None`, the default - `num_train_timesteps` attribute is used. - timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, timesteps will be - generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` - must be `None`, and `timestep_spacing` attribute will be ignored. """ if timesteps is not None and sigmas is not None: @@ -204,7 +154,7 @@ def set_timesteps( and self._sigma_schedule is not None and self._sigma_schedule.__class__.__name__ == "KarrasSigmas" ): - raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") + raise ValueError("Cannot set `timesteps` with `KarrasSigmas`.") if ( timesteps is not None and self._sigma_schedule is not None @@ -219,11 +169,11 @@ def set_timesteps( raise ValueError("Cannot set `timesteps` with `BetaSigmas`.") if ( timesteps is not None - and self._schedule.config.get("timestep_type", None) == "continuous" + and self._schedule.timestep_type == "continuous" and self.config.prediction_type == "v_prediction" ): raise ValueError( - "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." + "Cannot set `timesteps` with `schedule.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) if num_inference_steps is None: @@ -240,33 +190,20 @@ def set_timesteps( shift=shift, ) - self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) - + sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) - self.timesteps = timesteps.to(device=device) - - # empty dt and derivative self.prev_derivative = None self.dt = None - self._step_index = None self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def state_in_first_order(self): return self.dt is None - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - def step( self, model_output: Union[torch.Tensor, np.ndarray], @@ -327,9 +264,9 @@ def step( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) - if self.config.clip_sample: + if self._schedule.clip_sample: pred_original_sample = pred_original_sample.clamp( - -self.config.clip_sample_range, self.config.clip_sample_range + -self._schedule.clip_sample_range, self._schedule.clip_sample_range ) if self.state_in_first_order: @@ -369,43 +306,3 @@ def step( ) return HeunDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.Tensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - if self._schedule.__class__.__name__ == "FlowMatchSchedule": - noisy_samples = (1.0 - sigma) * original_samples + noise * sigma - else: - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps From 53ba24b73d6066b845691ca3838ea84ea51f000d Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:17:37 +0000 Subject: [PATCH 11/20] scale_noise->add_noise --- examples/community/pipeline_flux_differential_img2img.py | 6 +++--- .../pipeline_stable_diffusion_3_differential_img2img.py | 2 +- .../pipelines/flux/pipeline_flux_control_img2img.py | 2 +- .../pipelines/flux/pipeline_flux_control_inpaint.py | 6 +++--- .../flux/pipeline_flux_controlnet_image_to_image.py | 2 +- .../pipelines/flux/pipeline_flux_controlnet_inpainting.py | 6 +++--- src/diffusers/pipelines/flux/pipeline_flux_inpaint.py | 6 +++--- src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py | 2 +- .../pipeline_stable_diffusion_3_img2img.py | 2 +- .../pipeline_stable_diffusion_3_inpaint.py | 6 +++--- 10 files changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py index 68cb69115bde..a98fd4deb9b3 100644 --- a/examples/community/pipeline_flux_differential_img2img.py +++ b/examples/community/pipeline_flux_differential_img2img.py @@ -582,7 +582,7 @@ def prepare_latents( if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, noise, timestep) else: noise = latents.to(device) latents = noise @@ -976,8 +976,8 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - image_latent = self.scheduler.scale_noise( - original_image_latents, torch.tensor([noise_timestep]), noise + image_latent = self.scheduler.add_noise( + original_image_latents, noise, torch.tensor([noise_timestep]) ) # start diff diff diff --git a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py index 8cee5ecbc141..83322395edc4 100644 --- a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py +++ b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py @@ -640,7 +640,7 @@ def prepare_latents( shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents.to(device=device, dtype=dtype) return latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index 7001b19569f2..a60dacd62d79 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -580,7 +580,7 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, noise, timestep) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index a9ac1c72c6ed..b31935ac18f6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -646,7 +646,7 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, noise, timestep) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, noise, image_latents, latent_image_ids @@ -1091,8 +1091,8 @@ def __call__( init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - image_latents, torch.tensor([noise_timestep]), noise + init_latents_proper = self.scheduler.add_noise( + image_latents, noise, torch.tensor([noise_timestep]) ) else: init_latents_proper = image_latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 4c82d73f0379..e47fb615ffaa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -579,7 +579,7 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, noise, timestep) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 85943b278dc6..e1a25b6457aa 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -605,7 +605,7 @@ def prepare_latents( if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, noise, timestep) else: noise = latents.to(device) latents = noise @@ -1159,8 +1159,8 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 15abdb90ebd0..c7d1ef0f4bc1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -582,7 +582,7 @@ def prepare_latents( if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, noise, timestep) else: noise = latents.to(device) latents = noise @@ -977,8 +977,8 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index 24e31fa4cfc7..01e0cf8f5261 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -702,7 +702,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents - init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents.to(device=device, dtype=dtype) return latents diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 77daf5b0b4e0..bb124902fd6b 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -708,7 +708,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents - init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents.to(device=device, dtype=dtype) return latents diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index e1cfdb3e6e97..a9caa92a7bc1 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -717,7 +717,7 @@ def prepare_latents( if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) else: noise = latents.to(device) latents = noise @@ -1203,8 +1203,8 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( - init_latents_proper, torch.tensor([noise_timestep]), noise + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) ) latents = (1 - init_mask) * init_latents_proper + init_mask * latents From 2e80a5df72ec56d818207ea816742a2faa204b10 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:18:43 +0000 Subject: [PATCH 12/20] make --- src/diffusers/pipelines/flux/pipeline_flux.py | 1 + src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py | 2 +- src/diffusers/schedulers/scheduling_euler_discrete.py | 2 +- src/diffusers/schedulers/scheduling_heun_discrete.py | 2 +- src/diffusers/schedulers/scheduling_utils.py | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index ea3639f17620..35c06d54a5db 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -15,6 +15,7 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union +import numpy as np import torch from transformers import ( CLIPImageProcessor, diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index ca092b87cd23..39d52903c59a 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SamplingMixin +from .scheduling_utils import KarrasDiffusionSchedulers, SamplingMixin, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 321f90ff1409..b852c9207182 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SamplingMixin +from .scheduling_utils import KarrasDiffusionSchedulers, SamplingMixin, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index e1e6d4654b83..5f4280f1f6e9 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SamplingMixin +from .scheduling_utils import KarrasDiffusionSchedulers, SamplingMixin, SchedulerMixin @dataclass diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 14b95fe25f12..712237e83d35 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -247,6 +247,7 @@ def set_sigma_schedule(self, sigma_schedule: Union[Dict]): else: self._sigma_schedule = sigma_schedule + class SamplingMixin: _step_index = None _begin_index = None @@ -375,6 +376,5 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity - def __len__(self): return self.config.num_train_timesteps From b34539e9e89e9869733fa1fa42dc05f49999249e Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:22:30 +0000 Subject: [PATCH 13/20] self.scheduler.config --- examples/community/adaptive_mask_inpainting.py | 2 +- examples/community/edict_pipeline.py | 4 ++-- examples/community/lpw_stable_diffusion_xl.py | 8 ++++---- .../masked_stable_diffusion_xl_img2img.py | 4 ++-- examples/community/pipeline_demofusion_sdxl.py | 8 ++++---- .../pipeline_flux_differential_img2img.py | 8 ++++---- examples/community/pipeline_flux_rf_inversion.py | 16 ++++++++-------- examples/community/pipeline_flux_with_cfg.py | 8 ++++---- .../pipeline_kolors_differential_img2img.py | 8 ++++---- .../community/pipeline_null_text_inversion.py | 4 ++-- .../community/pipeline_sdxl_style_aligned.py | 8 ++++---- ...ine_stable_diffusion_xl_controlnet_adapter.py | 4 ++-- ...le_diffusion_xl_controlnet_adapter_inpaint.py | 8 ++++---- ...e_stable_diffusion_xl_differential_img2img.py | 8 ++++---- .../pipeline_stable_diffusion_xl_ipex.py | 4 ++-- examples/community/rerender_a_video.py | 2 +- examples/community/sde_drag.py | 2 +- .../stable_diffusion_xl_controlnet_reference.py | 4 ++-- .../community/stable_diffusion_xl_reference.py | 4 ++-- .../animatediff/pipeline_animatediff_sdxl.py | 4 ++-- .../pipeline_controlnet_inpaint_sd_xl.py | 8 ++++---- .../controlnet/pipeline_controlnet_sd_xl.py | 4 ++-- .../pipeline_controlnet_union_inpaint_sd_xl.py | 8 ++++---- .../pipeline_controlnet_union_sd_xl.py | 4 ++-- .../audio_diffusion/pipeline_audio_diffusion.py | 2 +- .../pipeline_spectrogram_diffusion.py | 2 +- src/diffusers/pipelines/flux/pipeline_flux.py | 8 ++++---- .../pipelines/flux/pipeline_flux_control.py | 8 ++++---- .../flux/pipeline_flux_control_img2img.py | 8 ++++---- .../flux/pipeline_flux_control_inpaint.py | 8 ++++---- .../pipelines/flux/pipeline_flux_controlnet.py | 8 ++++---- .../pipeline_flux_controlnet_image_to_image.py | 8 ++++---- .../flux/pipeline_flux_controlnet_inpainting.py | 8 ++++---- .../pipelines/flux/pipeline_flux_fill.py | 8 ++++---- .../pipelines/flux/pipeline_flux_img2img.py | 8 ++++---- .../pipelines/flux/pipeline_flux_inpaint.py | 8 ++++---- src/diffusers/pipelines/free_init_utils.py | 2 +- .../kandinsky/pipeline_kandinsky_img2img.py | 2 +- .../pipelines/kolors/pipeline_kolors.py | 4 ++-- .../pipelines/kolors/pipeline_kolors_img2img.py | 8 ++++---- .../pipeline_leditspp_stable_diffusion_xl.py | 4 ++-- src/diffusers/pipelines/ltx/pipeline_ltx.py | 8 ++++---- .../pipelines/ltx/pipeline_ltx_image2video.py | 8 ++++---- .../pipelines/lumina/pipeline_lumina.py | 2 +- .../pag/pipeline_pag_controlnet_sd_xl.py | 4 ++-- .../pipelines/pag/pipeline_pag_kolors.py | 4 ++-- .../pipelines/pag/pipeline_pag_sd_xl.py | 4 ++-- .../pipelines/pag/pipeline_pag_sd_xl_img2img.py | 8 ++++---- .../pipelines/pag/pipeline_pag_sd_xl_inpaint.py | 8 ++++---- .../pipeline_stable_diffusion_3.py | 10 +++++----- .../pipeline_stable_diffusion_3_img2img.py | 10 +++++----- .../pipeline_stable_diffusion_3_inpaint.py | 10 +++++----- .../pipeline_stable_diffusion_xl.py | 4 ++-- .../pipeline_stable_diffusion_xl_img2img.py | 8 ++++---- .../pipeline_stable_diffusion_xl_inpaint.py | 8 ++++---- ...eline_stable_diffusion_xl_instruct_pix2pix.py | 4 ++-- .../pipeline_stable_diffusion_xl_adapter.py | 4 ++-- .../pipeline_text_to_video_zero_sdxl.py | 4 ++-- .../unidiffuser/pipeline_unidiffuser.py | 2 +- 59 files changed, 178 insertions(+), 178 deletions(-) diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py index a9de26b29a89..2013ea648a95 100644 --- a/examples/community/adaptive_mask_inpainting.py +++ b/examples/community/adaptive_mask_inpainting.py @@ -1148,7 +1148,7 @@ def __call__( # run segmentation if use_adaptive_mask: if enforce_full_mask_ratio > 0.0: - use_default_mask = t < self.scheduler.config.num_train_timesteps * enforce_full_mask_ratio + use_default_mask = t < self.scheduler._schedule.num_train_timesteps * enforce_full_mask_ratio elif enforce_full_mask_ratio == 0.0: use_default_mask = False else: diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py index ac977f79abec..9958e7616f8f 100644 --- a/examples/community/edict_pipeline.py +++ b/examples/community/edict_pipeline.py @@ -97,7 +97,7 @@ def noise_step( model_output: torch.Tensor, timestep: torch.Tensor, ): - prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps + prev_timestep = timestep - self.scheduler._schedule.num_train_timesteps / self.scheduler.num_inference_steps alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep) alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep) @@ -116,7 +116,7 @@ def denoise_step( model_output: torch.Tensor, timestep: torch.Tensor, ): - prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps + prev_timestep = timestep - self.scheduler._schedule.num_train_timesteps / self.scheduler.num_inference_steps alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep) alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep) diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py index 13d1e2a1156a..4bb8d1a16966 100644 --- a/examples/community/lpw_stable_diffusion_xl.py +++ b/examples/community/lpw_stable_diffusion_xl.py @@ -1050,8 +1050,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N if denoising_start is not None: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1819,8 +1819,8 @@ def denoising_value_valid(dnv): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/masked_stable_diffusion_xl_img2img.py b/examples/community/masked_stable_diffusion_xl_img2img.py index c6b0ced527b5..d1492b3448ac 100644 --- a/examples/community/masked_stable_diffusion_xl_img2img.py +++ b/examples/community/masked_stable_diffusion_xl_img2img.py @@ -376,8 +376,8 @@ def denoising_value_valid(dnv): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py index f83d1b401420..8374621afef2 100644 --- a/examples/community/pipeline_demofusion_sdxl.py +++ b/examples/community/pipeline_demofusion_sdxl.py @@ -933,8 +933,8 @@ def __call__( if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) @@ -1040,8 +1040,8 @@ def __call__( 1 + torch.cos( torch.pi - * (self.scheduler.config.num_train_timesteps - t) - / self.scheduler.config.num_train_timesteps + * (self.scheduler._schedule.num_train_timesteps - t) + / self.scheduler._schedule.num_train_timesteps ) ).cpu() ) diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py index a98fd4deb9b3..9f9d0de56eae 100644 --- a/examples/community/pipeline_flux_differential_img2img.py +++ b/examples/community/pipeline_flux_differential_img2img.py @@ -876,10 +876,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index c8a87a426dc0..ea4128c7ceb4 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -822,10 +822,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, @@ -992,10 +992,10 @@ def invert( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inversion_steps = retrieve_timesteps( self.scheduler, diff --git a/examples/community/pipeline_flux_with_cfg.py b/examples/community/pipeline_flux_with_cfg.py index 06da6da899cd..992ed18f204d 100644 --- a/examples/community/pipeline_flux_with_cfg.py +++ b/examples/community/pipeline_flux_with_cfg.py @@ -757,10 +757,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/examples/community/pipeline_kolors_differential_img2img.py b/examples/community/pipeline_kolors_differential_img2img.py index e5570248d22b..9fc2578360c8 100644 --- a/examples/community/pipeline_kolors_differential_img2img.py +++ b/examples/community/pipeline_kolors_differential_img2img.py @@ -580,8 +580,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N if denoising_start is not None: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1159,8 +1159,8 @@ def denoising_value_valid(dnv): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/pipeline_null_text_inversion.py b/examples/community/pipeline_null_text_inversion.py index 7e27b4647bc9..ea7faa8029d8 100644 --- a/examples/community/pipeline_null_text_inversion.py +++ b/examples/community/pipeline_null_text_inversion.py @@ -87,7 +87,7 @@ def latent2image(self, latents): return image def prev_step(self, model_output, timestep, sample): - prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + prev_timestep = timestep - self.scheduler._schedule.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod @@ -100,7 +100,7 @@ def prev_step(self, model_output, timestep, sample): def next_step(self, model_output, timestep, sample): timestep, next_timestep = ( - min(timestep - self.scheduler.config.num_train_timesteps // self.num_inference_steps, 999), + min(timestep - self.scheduler._schedule.num_train_timesteps // self.num_inference_steps, 999), timestep, ) alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py index 8328bc2caed9..d906acb0fb94 100644 --- a/examples/community/pipeline_sdxl_style_aligned.py +++ b/examples/community/pipeline_sdxl_style_aligned.py @@ -874,8 +874,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N if denoising_start is not None: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1785,8 +1785,8 @@ def denoising_value_valid(dnv): ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py index ae495979f366..d24904a9aa4f 100644 --- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py @@ -1279,8 +1279,8 @@ def __call__( if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py index 94ca71cf7b1b..7933f7d65fad 100644 --- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py @@ -1057,8 +1057,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N if denoising_start is not None: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1691,8 +1691,8 @@ def denoising_value_valid(dnv): elif denoising_end is not None and denoising_value_valid(denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py index 584820e86254..4ccef64ec72d 100644 --- a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py +++ b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py @@ -626,8 +626,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N if denoising_start is not None: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1306,8 +1306,8 @@ def denoising_value_valid(dnv): elif denoising_end is not None and denoising_value_valid(denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/pipeline_stable_diffusion_xl_ipex.py b/examples/community/pipeline_stable_diffusion_xl_ipex.py index 022dfb1abf82..56ed5f1fa543 100644 --- a/examples/community/pipeline_stable_diffusion_xl_ipex.py +++ b/examples/community/pipeline_stable_diffusion_xl_ipex.py @@ -1055,8 +1055,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index d9c616ab5ebc..9880755ca7c6 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -1079,7 +1079,7 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None): # get x_t from x_0 latents = self.scheduler.add_noise(pred_x0, noise_pred, t).to(latents_dtype) - prev_t = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + prev_t = t - self.scheduler._schedule.num_train_timesteps // self.scheduler.num_inference_steps if i == len(timesteps) - 1: alpha_t_prev = 1.0 else: diff --git a/examples/community/sde_drag.py b/examples/community/sde_drag.py index 902eaa99f417..082923356339 100644 --- a/examples/community/sde_drag.py +++ b/examples/community/sde_drag.py @@ -295,7 +295,7 @@ def train_lora(self, prompt, image, lora_step=100, lora_rank=16, generator=None) # Sample a random timestep for each image timesteps = torch.randint( - 0, self.scheduler.config.num_train_timesteps, (bsz,), device=model_input.device, generator=generator + 0, self.scheduler._schedule.num_train_timesteps, (bsz,), device=model_input.device, generator=generator ) timesteps = timesteps.long() diff --git a/examples/community/stable_diffusion_xl_controlnet_reference.py b/examples/community/stable_diffusion_xl_controlnet_reference.py index ac3159e5e6e8..e9a66f0aebad 100644 --- a/examples/community/stable_diffusion_xl_controlnet_reference.py +++ b/examples/community/stable_diffusion_xl_controlnet_reference.py @@ -1177,8 +1177,8 @@ def hacked_UpBlock2D_forward( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/examples/community/stable_diffusion_xl_reference.py b/examples/community/stable_diffusion_xl_reference.py index 6439280cb185..34acd105fed4 100644 --- a/examples/community/stable_diffusion_xl_reference.py +++ b/examples/community/stable_diffusion_xl_reference.py @@ -1035,8 +1035,8 @@ def hacked_UpBlock2D_forward( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 6016917537b9..08f53f3b2fa4 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -1179,8 +1179,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index c6c4ce935a1f..234ccf28de00 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1047,8 +1047,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1702,8 +1702,8 @@ def denoising_value_valid(dnv): elif denoising_end is not None and denoising_value_valid(denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 536c00ee361c..d6191856b8b3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1434,8 +1434,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index 7012f3b95458..8f991fab8a22 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -1001,8 +1001,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1623,8 +1623,8 @@ def denoising_value_valid(dnv): elif denoising_end is not None and denoising_value_valid(denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index dcd885f7d604..4cda44f2b380 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -1338,8 +1338,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py b/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py index 47044e050acf..599e4fb29153 100644 --- a/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +++ b/src/diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py @@ -293,7 +293,7 @@ def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray: sample = torch.Tensor(sample).to(self.device) for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): - prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + prev_timestep = t - self.scheduler._schedule.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[prev_timestep] diff --git a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py index b8ac8e1416bf..b6c99db38b85 100644 --- a/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +++ b/src/diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py @@ -231,7 +231,7 @@ def __call__( output = self.decode( encodings_and_masks=encodings_and_masks, input_tokens=x, - noise_time=t / self.scheduler.config.num_train_timesteps, # rescale to [0, 1) + noise_time=t / self.scheduler._schedule.num_train_timesteps, # rescale to [0, 1) ) # Compute previous output: x_t -> x_t-1 diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 35c06d54a5db..8e4aa53fe768 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -825,10 +825,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index ac8474becb78..da1415eccde7 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -802,10 +802,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index a60dacd62d79..6fbeba20dd98 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -809,10 +809,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index b31935ac18f6..edbef4463a86 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -985,10 +985,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 4c2d2a0a3db9..de05dafbaa1d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -876,10 +876,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index e47fb615ffaa..7138121cc4a5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -864,10 +864,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index e1a25b6457aa..f3fae63ecc03 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1017,10 +1017,10 @@ def __call__( ) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 723478ce724d..7272bec94eef 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -882,10 +882,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index ab9a07ae6b44..f311a263b851 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -748,10 +748,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index c7d1ef0f4bc1..b21655d63aa3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -877,10 +877,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/free_init_utils.py b/src/diffusers/pipelines/free_init_utils.py index 1fb67592ca4f..61b9f47f48cf 100644 --- a/src/diffusers/pipelines/free_init_utils.py +++ b/src/diffusers/pipelines/free_init_utils.py @@ -159,7 +159,7 @@ def _apply_free_init( temporal_stop_frequency=self._free_init_temporal_stop_frequency, ) - current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1 + current_diffuse_timestep = self.scheduler._schedule.num_train_timesteps - 1 diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long() z_t = self.scheduler.add_noise( diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py index ef5241fee5d2..ba6b727e6dea 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -420,7 +420,7 @@ def __call__( timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) # the formular to calculate timestep for add_noise is taken from the original kandinsky repo - latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2 + latent_timestep = int(self.scheduler._schedule.num_train_timesteps * strength) - 2 latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device) diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py index 1d2d07572d68..e607abc0eb65 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py @@ -958,8 +958,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 6ddda7acf2a8..eab46ab947db 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -576,8 +576,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1138,8 +1138,8 @@ def denoising_value_valid(dnv): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py index 834445bfcd06..a3b0ff3c809c 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py @@ -1098,8 +1098,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 96d41bb3224b..62c0e264f9fa 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -671,10 +671,10 @@ def __call__( sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) mu = calculate_shift( video_sequence_length, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 71fd725c915b..3858200246a9 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -741,10 +741,10 @@ def __call__( sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) mu = calculate_shift( video_sequence_length, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 0a59d98919f0..6e2e221a305b 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -812,7 +812,7 @@ def __call__( current_timestep = current_timestep.expand(latent_model_input.shape[0]) # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image - current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps + current_timestep = 1 - current_timestep / self.scheduler._schedule.num_train_timesteps # prepare image_rotary_emb for positional encoding # dynamic scaling_factor for different resolution. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index 15a93357470f..214620cb1e75 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -1463,8 +1463,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py index 3e84f44adcf7..71ae8f399131 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py @@ -1010,8 +1010,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index c2611164a049..b9004726ccd1 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -1200,8 +1200,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 6d634d524848..14edc44d2435 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -672,8 +672,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1400,8 +1400,8 @@ def denoising_value_valid(dnv): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index 7f85c13ac561..422b66557c63 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -921,8 +921,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1609,8 +1609,8 @@ def denoising_value_valid(dnv): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index a53d786798ca..d151e04ac034 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1007,17 +1007,17 @@ def __call__( # 5. Prepare timesteps scheduler_kwargs = {} - if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + if self.scheduler._schedule.use_dynamic_shifting and mu is None: _, _, height, width = latents.shape image_seq_len = (height // self.transformer.config.patch_size) * ( width // self.transformer.config.patch_size ) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) scheduler_kwargs["mu"] = mu elif mu is not None: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index bb124902fd6b..e0a4484b92d9 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -939,16 +939,16 @@ def __call__( # 4. Prepare timesteps scheduler_kwargs = {} - if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + if self.scheduler._schedule.use_dynamic_shifting and mu is None: image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( int(width) // self.vae_scale_factor // self.transformer.config.patch_size ) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) scheduler_kwargs["mu"] = mu elif mu is not None: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index a9caa92a7bc1..4d87f48a6d4a 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -1049,16 +1049,16 @@ def __call__( # 3. Prepare timesteps scheduler_kwargs = {} - if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: + if self.scheduler._schedule.use_dynamic_shifting and mu is None: image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( int(width) // self.vae_scale_factor // self.transformer.config.patch_size ) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler._schedule.base_image_seq_len, + self.scheduler._schedule.max_image_seq_len, + self.scheduler._schedule.base_shift, + self.scheduler._schedule.max_shift, ) scheduler_kwargs["mu"] = mu elif mu is not None: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index d83fa6201117..084f6412f0f6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -1175,8 +1175,8 @@ def __call__( ): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 126f25a41adc..2716737d4ecf 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -661,8 +661,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1370,8 +1370,8 @@ def denoising_value_valid(dnv): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index a378ae65eb30..faaf35c053b0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -922,8 +922,8 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N # that is, strength is determined by the denoising_start instead. discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_start * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_start * self.scheduler._schedule.num_train_timesteps) ) ) @@ -1585,8 +1585,8 @@ def denoising_value_valid(dnv): elif self.denoising_end is not None and denoising_value_valid(self.denoising_end): discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (self.denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (self.denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b59f2312726d..d07005f29b90 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -882,8 +882,8 @@ def __call__( if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 20569d0adb32..ef10f5ff1499 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -1206,8 +1206,8 @@ def __call__( if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py index 9ff473cc3a38..d6595b1767fc 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py @@ -1274,8 +1274,8 @@ def __call__( if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( - self.scheduler.config.num_train_timesteps - - (denoising_end * self.scheduler.config.num_train_timesteps) + self.scheduler._schedule.num_train_timesteps + - (denoising_end * self.scheduler._schedule.num_train_timesteps) ) ) num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 4f65caf4e610..08aa10b82c5e 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -1331,7 +1331,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # max_timestep = timesteps[0] - max_timestep = self.scheduler.config.num_train_timesteps + max_timestep = self.scheduler._schedule.num_train_timesteps # 6. Prepare latent variables if mode == "joint": From 459a0cb94d1fcbd8744614f717fced1f14b01f2a Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:51:32 +0000 Subject: [PATCH 14/20] override FlowMatch with pipeline from_pretrained --- src/diffusers/configuration_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 8184d6247661..7593bcfa5708 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -485,6 +485,12 @@ def extract_init_dict(cls, config_dict, **kwargs): # Skip keys that were not present in the original config, so default __init__ values were used used_defaults = config_dict.get("_use_default_values", []) config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"} + if ( + "scheduler" in config_dict + and isinstance(config_dict["scheduler"], list) + and config_dict["scheduler"][1].startswith("FlowMatch") + ): + config_dict["scheduler"][1] = config_dict["scheduler"][1].replace("FlowMatch", "") # 0. Copy origin config dict original_dict = dict(config_dict.items()) @@ -522,6 +528,8 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove attributes from orig class that cannot be expected orig_cls_name = config_dict.pop("_class_name", cls.__name__) + if orig_cls_name.startswith("FlowMatch"): + orig_cls_name = orig_cls_name.replace("FlowMatch", "") if ( isinstance(orig_cls_name, str) and orig_cls_name != cls.__name__ From d9ad3f89d0233c0b234db158379698177947092e Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:51:42 +0000 Subject: [PATCH 15/20] set default flow base schedule --- src/diffusers/schedulers/schedules/flow_schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/schedules/flow_schedule.py b/src/diffusers/schedulers/schedules/flow_schedule.py index ba96755168b2..262dd1a3a2f9 100644 --- a/src/diffusers/schedulers/schedules/flow_schedule.py +++ b/src/diffusers/schedulers/schedules/flow_schedule.py @@ -124,7 +124,7 @@ def __init__( def set_base_schedule(self, base_schedule: Union[str]): if base_schedule is None: - raise ValueError("Must set base schedule.") + base_schedule = self.base_schedules["FlowMatchSD3"] if isinstance(base_schedule, str): if base_schedule not in self.base_schedules: raise ValueError(f"Expected one of {self.base_schedules.keys()}") From 730931af2d987b522f2435d2e24b25b6bd4034dc Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 10:57:13 +0000 Subject: [PATCH 16/20] prediction_type --- src/diffusers/configuration_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 7593bcfa5708..c0f11c132883 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -247,12 +247,12 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un # Handle old scheduler configs if "Scheduler" in cls.__name__ and "schedule_config" not in config: - prediction_type = config.pop("prediction_type", None) _class_name = config.pop("_class_name", None) _diffusers_version = config.pop("_diffusers_version", None) use_karras_sigmas = config.pop("use_karras_sigmas", None) use_exponential_sigmas = config.pop("use_exponential_sigmas", None) use_beta_sigmas = config.pop("use_beta_sigmas", None) + prediction_type = config.pop("prediction_type", None) if use_karras_sigmas: sigma_schedule_config = {"class_name": "KarrasSigmas"} elif use_exponential_sigmas: @@ -265,10 +265,11 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un config.update({"class_name": "BetaSchedule"}) elif "shift" in config: config.update({"class_name": "FlowMatchSchedule"}) + if prediction_type: + config.update({"prediction_type": prediction_type}) config = { "_class_name": _class_name, "_diffusers_version": _diffusers_version, - "prediction_type": prediction_type, "schedule_config": config, "sigma_schedule_config": sigma_schedule_config, } From 6928f03489aff831cd683776b24ba5a0ff463aac Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 11:02:25 +0000 Subject: [PATCH 17/20] base_schedule --- src/diffusers/pipelines/flux/pipeline_flux.py | 2 +- src/diffusers/pipelines/mochi/pipeline_mochi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 8e4aa53fe768..0f4a6669417a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -820,7 +820,7 @@ def __call__( ) # 5. Prepare timesteps - if self.scheduler.schedule.__class__.__name__ != "FlowMatchFlux": + if self.scheduler._schedule.base_schedule.__class__.__name__ != "FlowMatchFlux": self.scheduler._schedule.set_base_schedule("FlowMatchFlux") image_seq_len = latents.shape[1] mu = calculate_shift( diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 72a43827c9bf..8ba1135f03a0 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -652,7 +652,7 @@ def __call__( prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # 5. Prepare timestep - if self.scheduler.schedule.__class__.__name__ != "FlowMatchLinearQuadratic": + if self.scheduler._schedule.base_schedule.__class__.__name__ != "FlowMatchLinearQuadratic": self.scheduler._schedule.set_base_schedule("FlowMatchLinearQuadratic") timesteps, num_inference_steps = retrieve_timesteps( From 647658b7ce5a5838ef87be88e1c40d5075a8b0f6 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 2 Jan 2025 11:18:56 +0000 Subject: [PATCH 18/20] clip_sample --- src/diffusers/schedulers/schedules/beta_schedule.py | 4 ---- src/diffusers/schedulers/scheduling_heun_discrete.py | 6 ++++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/schedules/beta_schedule.py b/src/diffusers/schedulers/schedules/beta_schedule.py index 86b8d14f601b..78bf58c50dc8 100644 --- a/src/diffusers/schedulers/schedules/beta_schedule.py +++ b/src/diffusers/schedulers/schedules/beta_schedule.py @@ -104,8 +104,6 @@ def __init__( timestep_spacing: str = "linspace", timestep_type: str = "discrete", # can be "discrete" or "continuous" steps_offset: int = 0, - clip_sample: Optional[bool] = False, - clip_sample_range: float = 1.0, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" @@ -145,8 +143,6 @@ def __init__( self.timestep_spacing = timestep_spacing self.timestep_type = timestep_type self.steps_offset = steps_offset - self.clip_sample = clip_sample - self.clip_sample_range = clip_sample_range self.sigma_min = sigma_min self.sigma_max = sigma_max self.final_sigmas_type = final_sigmas_type diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 5f4280f1f6e9..aabf6dfda8a8 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -85,6 +85,8 @@ def __init__( schedule_config, sigma_schedule_config, prediction_type: str = "epsilon", + clip_sample: Optional[bool] = False, + clip_sample_range: float = 1.0, ): self.set_schedule(schedule_config) self.set_sigma_schedule(sigma_schedule_config) @@ -264,9 +266,9 @@ def step( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" ) - if self._schedule.clip_sample: + if self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( - -self._schedule.clip_sample_range, self._schedule.clip_sample_range + -self.config.clip_sample_range, self.config.clip_sample_range ) if self.state_in_first_order: From a4453edbfd8549b9b09ce64cf3102bf380180fbc Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 13 Jan 2025 10:38:22 +0000 Subject: [PATCH 19/20] deis --- src/diffusers/configuration_utils.py | 32 +- .../schedulers/schedules/beta_schedule.py | 4 + .../schedulers/scheduling_deis_multistep.py | 436 ++++-------------- src/diffusers/schedulers/scheduling_utils.py | 34 ++ tests/schedulers/test_scheduler_deis.py | 17 +- tests/schedulers/test_schedulers.py | 3 +- 6 files changed, 160 insertions(+), 366 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index c0f11c132883..5d60812abaf9 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -252,7 +252,9 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un use_karras_sigmas = config.pop("use_karras_sigmas", None) use_exponential_sigmas = config.pop("use_exponential_sigmas", None) use_beta_sigmas = config.pop("use_beta_sigmas", None) + use_flow_sigmas = config.pop("use_flow_sigmas", None) prediction_type = config.pop("prediction_type", None) + schedule_config = {} if use_karras_sigmas: sigma_schedule_config = {"class_name": "KarrasSigmas"} elif use_exponential_sigmas: @@ -263,16 +265,32 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un sigma_schedule_config = {} if "beta_schedule" in config: config.update({"class_name": "BetaSchedule"}) - elif "shift" in config: + from .schedulers.schedules.beta_schedule import BetaSchedule + + expected_kwargs = list(inspect.signature(BetaSchedule.__init__).parameters)[1:-1] + for expected_kwarg in expected_kwargs: + if expected_kwarg in config: + schedule_config[expected_kwarg] = config.pop(expected_kwarg) + elif "shift" in config or use_flow_sigmas: config.update({"class_name": "FlowMatchSchedule"}) + from .schedulers.schedules.flow_schedule import FlowMatchSchedule + + expected_kwargs = list(inspect.signature(FlowMatchSchedule.__init__).parameters)[1:-1] + for expected_kwarg in expected_kwargs: + if expected_kwarg in config: + schedule_config[expected_kwarg] = config.pop(expected_kwarg) + if prediction_type == "flow_prediction": + prediction_type = "epsilon" if prediction_type: config.update({"prediction_type": prediction_type}) - config = { - "_class_name": _class_name, - "_diffusers_version": _diffusers_version, - "schedule_config": config, - "sigma_schedule_config": sigma_schedule_config, - } + config.update( + { + "_class_name": _class_name, + "_diffusers_version": _diffusers_version, + "schedule_config": schedule_config, + "sigma_schedule_config": sigma_schedule_config, + } + ) init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) diff --git a/src/diffusers/schedulers/schedules/beta_schedule.py b/src/diffusers/schedulers/schedules/beta_schedule.py index 78bf58c50dc8..9ba456308be3 100644 --- a/src/diffusers/schedulers/schedules/beta_schedule.py +++ b/src/diffusers/schedulers/schedules/beta_schedule.py @@ -133,6 +133,10 @@ def __init__( # FP16 smallest positive subnormal works well here self.alphas_cumprod[-1] = 2**-24 + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.num_train_timesteps = num_train_timesteps self.beta_start = beta_start self.beta_end = beta_end diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 6a653f183bba..3a2fdff25f51 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -15,67 +15,17 @@ # DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info # The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py -import math from typing import List, Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import deprecate, is_scipy_available -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +from ..utils import deprecate +from .scheduling_utils import KarrasDiffusionSchedulers, SamplingMixin, SchedulerMixin, SchedulerOutput -if is_scipy_available(): - import scipy.stats - - -# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar -def betas_for_alpha_bar( - num_diffusion_timesteps, - max_beta=0.999, - alpha_transform_type="cosine", -): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. - Choose from `cosine` or `exp` - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - if alpha_transform_type == "cosine": - - def alpha_bar_fn(t): - return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 - - elif alpha_transform_type == "exp": - - def alpha_bar_fn(t): - return math.exp(t * -12.0) - - else: - raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): +class DEISMultistepScheduler(SchedulerMixin, ConfigMixin, SamplingMixin): """ `DEISMultistepScheduler` is a fast high order solver for diffusion ordinary differential equations (ODEs). @@ -133,11 +83,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + schedule_config, + sigma_schedule_config, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, @@ -146,43 +93,9 @@ def __init__( algorithm_type: str = "deis", solver_type: str = "logrho", lower_order_final: bool = True, - use_karras_sigmas: Optional[bool] = False, - use_exponential_sigmas: Optional[bool] = False, - use_beta_sigmas: Optional[bool] = False, - use_flow_sigmas: Optional[bool] = False, - flow_shift: Optional[float] = 1.0, - timestep_spacing: str = "linspace", - steps_offset: int = 0, ): - if self.config.use_beta_sigmas and not is_scipy_available(): - raise ImportError("Make sure to install scipy if you want to use beta sigmas.") - if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: - raise ValueError( - "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." - ) - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 + self.set_schedule(schedule_config) + self.set_sigma_schedule(sigma_schedule_config) # settings for DEIS if algorithm_type not in ["deis"]: @@ -197,42 +110,37 @@ def __init__( else: raise NotImplementedError(f"solver type {solver_type} is not implemented for {self.__class__}") - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) self.model_outputs = [None] * solver_order self.lower_order_nums = 0 - self._step_index = None - self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index + def init_noise_sigma(self): + return 1.0 - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Args: - begin_index (`int`): - The begin index for the scheduler. + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. """ - self._begin_index = begin_index + return sample - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + shift: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -242,222 +150,74 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(self.config.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64) - timesteps -= 1 - else: + + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` should be set.") + if num_inference_steps is None and timesteps is None and sigmas is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps` or `sigmas.") + if num_inference_steps is not None and (timesteps is not None or sigmas is not None): + raise ValueError("Can only pass one of `num_inference_steps` or `timesteps` or `sigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "KarrasSigmas" + ): + raise ValueError("Cannot set `timesteps` with `KarrasSigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "ExponentialSigmas" + ): + raise ValueError("Cannot set `timesteps` with `ExponentialSigmas`.") + if ( + timesteps is not None + and self._sigma_schedule is not None + and self._sigma_schedule.__class__.__name__ == "BetaSigmas" + ): + raise ValueError("Cannot set `timesteps` with `BetaSigmas`.") + if ( + timesteps is not None + and self._schedule.timestep_type == "continuous" + and self.config.prediction_type == "v_prediction" + ): raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + "Cannot set `timesteps` with `schedule.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - log_sigmas = np.log(sigmas) - if self.config.use_karras_sigmas: - sigmas = np.flip(sigmas).copy() - sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - elif self.config.use_exponential_sigmas: - sigmas = np.flip(sigmas).copy() - sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - elif self.config.use_beta_sigmas: - sigmas = np.flip(sigmas).copy() - sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - elif self.config.use_flow_sigmas: - alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) - sigmas = 1.0 - alphas - sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() - timesteps = (sigmas * self.config.num_train_timesteps).copy() - sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) - else: - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) - - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) + if num_inference_steps is None: + num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 + self.num_inference_steps = num_inference_steps + + sigmas, timesteps = self._schedule( + num_inference_steps=num_inference_steps, + device=device, + timesteps=timesteps, + sigmas=sigmas, + sigma_schedule=self._sigma_schedule, + mu=mu, + shift=shift, + ) + self._step_index = None + self._begin_index = None + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas.to("cpu") self.model_outputs = [ None, ] * self.config.solver_order self.lower_order_nums = 0 - # add an index counter for schedulers that allow duplicated timesteps - self._step_index = None - self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): - # get log sigma - log_sigma = np.log(np.maximum(sigma, 1e-10)) - - # get distribution - dists = log_sigma - log_sigmas[:, np.newaxis] - - # get sigmas range - low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) - high_idx = low_idx + 1 - - low = log_sigmas[low_idx] - high = log_sigmas[high_idx] - - # interpolate sigmas - w = (low - log_sigma) / (low - high) - w = np.clip(w, 0, 1) - - # transform interpolation to time range - t = (1 - w) * low_idx + w * high_idx - t = t.reshape(sigma.shape) - return t - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): - if self.config.use_flow_sigmas: + if self._schedule.__class__.__name__ == "FlowMatchSchedule": alpha_t = 1 - sigma sigma_t = sigma - else: + elif self._schedule.__class__.__name__ == "BetaSchedule": alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t - - return alpha_t, sigma_t - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras - def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - rho = 7.0 # 7.0 is the value used in the paper - ramp = np.linspace(0, 1, num_inference_steps) - min_inv_rho = sigma_min ** (1 / rho) - max_inv_rho = sigma_max ** (1 / rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return sigmas - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential - def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: - """Constructs an exponential noise schedule.""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min - else: - sigma_min = None - - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) - return sigmas - - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta - def _convert_to_beta( - self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 - ) -> torch.Tensor: - """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" - - # Hack to make sure that other schedulers which copy this function don't break - # TODO: Add this logic to the other schedulers - if hasattr(self.config, "sigma_min"): - sigma_min = self.config.sigma_min else: - sigma_min = None + raise ValueError("..") - if hasattr(self.config, "sigma_max"): - sigma_max = self.config.sigma_max - else: - sigma_max = None - - sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() - sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() - - sigmas = np.array( - [ - sigma_min + (ppf * (sigma_max - sigma_min)) - for ppf in [ - scipy.stats.beta.ppf(timestep, alpha, beta) - for timestep in 1 - np.linspace(0, 1, num_inference_steps) - ] - ] - ) - return sigmas + return alpha_t, sigma_t def convert_model_output( self, @@ -502,13 +262,10 @@ def convert_model_output( x0_pred = model_output elif self.config.prediction_type == "v_prediction": x0_pred = alpha_t * sample - sigma_t * model_output - elif self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " - "`v_prediction`, or `flow_prediction` for the DEISMultistepScheduler." + "`v_prediction` for the DEISMultistepScheduler." ) if self.config.thresholding: @@ -754,19 +511,6 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return step_index - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): - """ - Initialize the step_index counter for the scheduler. - """ - - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - def step( self, model_output: torch.Tensor, @@ -832,28 +576,14 @@ def step( return SchedulerOutput(prev_sample=prev_sample) - def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.Tensor`): - The input sample. - - Returns: - `torch.Tensor`: - A scaled input sample. - """ - return sample - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + # TODO: integrate with add_noise in SamplingMixin def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + print("add_noise") # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 712237e83d35..3d9c0fff602a 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -17,6 +17,7 @@ from enum import Enum from typing import Dict, Optional, Union +import numpy as np import torch from huggingface_hub.utils import validate_hf_hub_args @@ -376,5 +377,38 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + def __len__(self): return self.config.num_train_timesteps diff --git a/tests/schedulers/test_scheduler_deis.py b/tests/schedulers/test_scheduler_deis.py index 986a8f6a44cf..da0c9c057612 100644 --- a/tests/schedulers/test_scheduler_deis.py +++ b/tests/schedulers/test_scheduler_deis.py @@ -18,10 +18,15 @@ class DEISMultistepSchedulerTest(SchedulerCommonTest): def get_scheduler_config(self, **kwargs): config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", + "schedule_config": { + "class_name": "BetaSchedule", + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "final_sigma_type": "sigma_min", + }, + "sigma_schedule_config": {}, "solver_order": 2, } @@ -161,7 +166,9 @@ def test_switch(self): def test_timesteps(self): for timesteps in [25, 50, 100, 999, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) + scheduler_config = self.get_scheduler_config() + scheduler_config["schedule_config"].update({"num_train_timesteps": timesteps}) + self.check_over_configs(**scheduler_config) def test_thresholding(self): self.check_over_configs(thresholding=False) diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index fc7f22d2a8e5..9a61d1e61960 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -754,7 +754,8 @@ def test_trained_betas(self): continue scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.1, 0.3])) + scheduler_config["schedule_config"].update({"trained_betas": np.array([0.1, 0.3])}) + scheduler = scheduler_class(**scheduler_config) with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_pretrained(tmpdirname) From 3b4722967da46592342f1a350de227c19392e7b8 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 13 Jan 2025 11:00:24 +0000 Subject: [PATCH 20/20] deis --- src/diffusers/schedulers/schedules/beta_schedule.py | 10 +++++++--- src/diffusers/schedulers/scheduling_deis_multistep.py | 4 ++-- tests/schedulers/test_scheduler_deis.py | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/schedules/beta_schedule.py b/src/diffusers/schedulers/schedules/beta_schedule.py index 9ba456308be3..be5b15fc0b40 100644 --- a/src/diffusers/schedulers/schedules/beta_schedule.py +++ b/src/diffusers/schedulers/schedules/beta_schedule.py @@ -195,9 +195,13 @@ def __call__( else: # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.timestep_spacing == "linspace": - timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[ - ::-1 - ].copy() + # TODO: check this + timesteps = ( + np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) elif self.timestep_spacing == "leading": step_ratio = self.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 3a2fdff25f51..77eae2776fe7 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -215,7 +215,7 @@ def _sigma_to_alpha_sigma_t(self, sigma): alpha_t = 1 / ((sigma**2 + 1) ** 0.5) sigma_t = sigma * alpha_t else: - raise ValueError("..") + raise ValueError("Unsupported schedule type.") return alpha_t, sigma_t @@ -491,7 +491,7 @@ def ind_fn(t, b, c, d): else: raise NotImplementedError("only support log-rho multistep deis now") - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + # TODO: integrate with SamplingMixin def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps diff --git a/tests/schedulers/test_scheduler_deis.py b/tests/schedulers/test_scheduler_deis.py index da0c9c057612..8c57721c6d20 100644 --- a/tests/schedulers/test_scheduler_deis.py +++ b/tests/schedulers/test_scheduler_deis.py @@ -24,7 +24,7 @@ def get_scheduler_config(self, **kwargs): "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", - "final_sigma_type": "sigma_min", + "final_sigmas_type": "sigma_min", }, "sigma_schedule_config": {}, "solver_order": 2,