diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py index 7b11d704932b..8d50ee6c7ea9 100644 --- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -429,7 +429,22 @@ def multistep_dpm_solver_second_order_update( return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -452,6 +467,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index bf8e1d98d6c0..45d11c942660 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -401,6 +401,17 @@ def _sigma_to_t(self, sigma, log_sigmas): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -808,7 +819,22 @@ def ind_fn(t, b, c, d): raise NotImplementedError("only support log-rho multistep deis now") # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -831,6 +857,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: @@ -927,6 +957,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples without noise. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise to the samples. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index dee97f39ff68..e7ba0ba1f30e 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): The starting `beta` value of inference. beta_end (`float`, defaults to 0.02): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. solver_order (`int`, defaults to 2): The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`. + prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`): + Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample` + directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen + Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. @@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): sample_max_value (`float`, defaults to 1.0): The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the - [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or - `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. - solver_type (`str`, defaults to `midpoint`): - Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the - sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`): + Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the + [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the + algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`): + Solver type for the second-order solver. The solver type slightly affects the sample quality, especially + for a small number of steps. It is recommended to use `midpoint` solvers. lower_order_final (`bool`, defaults to `True`): Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. @@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. flow_shift (`float`, *optional*, defaults to 1.0): The shift value for the timestep schedule for flow matching. - final_sigmas_type (`str`, defaults to `"zero"`): + final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0. lambda_min_clipped (`float`, defaults to `-inf`): Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. - timestep_spacing (`str`, defaults to `"linspace"`): + variance_type (`"learned"` or `"learned_range"`, *optional*): + Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's + output contains the predicted Gaussian variance. + timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. steps_offset (`int`, defaults to 0): @@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to use dynamic shifting for the timestep schedule. + time_shift_type (`"exponential"`, defaults to `"exponential"`): + The type of time shift to apply when using dynamic shifting. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -208,15 +210,15 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon", thresholding: bool = False, dynamic_thresholding_ratio: float = 0.995, sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", + algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++", + solver_type: Literal["midpoint", "heun"] = "midpoint", lower_order_final: bool = True, euler_at_final: bool = False, use_karras_sigmas: Optional[bool] = False, @@ -225,14 +227,14 @@ def __init__( use_lu_lambdas: Optional[bool] = False, use_flow_sigmas: Optional[bool] = False, flow_shift: Optional[float] = 1.0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero", lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, - timestep_spacing: str = "linspace", + variance_type: Optional[Literal["learned", "learned_range"]] = None, + timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, use_dynamic_shifting: bool = False, - time_shift_type: str = "exponential", + time_shift_type: Literal["exponential"] = "exponential", ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -331,19 +333,22 @@ def set_begin_index(self, begin_index: int = 0): def set_timesteps( self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None, timesteps: Optional[List[int]] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + mu (`float`, *optional*): + The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and + `time_shift_type="exponential"`. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` @@ -503,7 +508,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: return sample # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma, log_sigmas): + def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray: """ Convert sigma values to corresponding timestep values through interpolation. @@ -539,7 +544,18 @@ def _sigma_to_t(self, sigma, log_sigmas): t = t.reshape(sigma.shape) return t - def _sigma_to_alpha_sigma_t(self, sigma): + def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -588,8 +604,21 @@ def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> to sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: - """Constructs the noise schedule of Lu et al. (2022).""" + def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """ + Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model + Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022). + + Args: + in_lambdas (`torch.Tensor`): + The input lambda values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `torch.Tensor`: + The converted lambda values following the Lu noise schedule. + """ lambda_min: float = in_lambdas[-1].item() lambda_max: float = in_lambdas[0].item() @@ -1069,7 +1098,22 @@ def multistep_dpm_solver_third_order_update( ) return x_t - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -1088,9 +1132,13 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): return step_index - def _init_step_index(self, timestep): + def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None: """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: @@ -1105,7 +1153,7 @@ def step( model_output: torch.Tensor, timestep: Union[int, torch.Tensor], sample: torch.Tensor, - generator=None, + generator: Optional[torch.Generator] = None, variance_noise: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: @@ -1115,22 +1163,22 @@ def step( Args: model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`int`): + The direct output from the learned diffusion model. + timestep (`int` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. generator (`torch.Generator`, *optional*): A random number generator. - variance_noise (`torch.Tensor`): + variance_noise (`torch.Tensor`, *optional*): Alternative to generating noise with `generator` by directly providing the noise for the variance itself. Useful for methods such as [`LEdits++`]. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ @@ -1210,6 +1258,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples without noise. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise to the samples. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 0f734aeb54c9..2c5d798be0bf 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -413,6 +413,17 @@ def _sigma_to_t(self, sigma, log_sigmas): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 0b271d7eacb4..c51171cc9835 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -491,6 +491,17 @@ def _sigma_to_t(self, sigma, log_sigmas): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -1079,7 +1090,22 @@ def singlestep_dpm_solver_update( raise ValueError(f"Order must be 1, 2, 3, got {order}") # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -1102,6 +1128,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: @@ -1204,6 +1234,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples without noise. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise to the samples. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index eeec588e27a3..5b1e84dc3a25 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -578,7 +578,22 @@ def multistep_dpm_solver_third_order_update( return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -601,6 +616,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index d9054c39c9de..9eb37c44aea9 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -423,6 +423,17 @@ def _sigma_to_t(self, sigma, log_sigmas): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -1103,7 +1114,22 @@ def stochastic_adams_moulton_update( return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -1126,6 +1152,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 7dc5f467680b..606dfeb239be 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -513,6 +513,17 @@ def _sigma_to_t(self, sigma, log_sigmas): # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t def _sigma_to_alpha_sigma_t(self, sigma): + """ + Convert sigma values to alpha_t and sigma_t values. + + Args: + sigma (`torch.Tensor`): + The sigma value(s) to convert. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: + A tuple containing (alpha_t, sigma_t) values. + """ if self.config.use_flow_sigmas: alpha_t = 1 - sigma sigma_t = sigma @@ -984,7 +995,22 @@ def multistep_uni_c_bh_update( return x_t # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): + def index_for_timestep( + self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None + ) -> int: + """ + Find the index for a given timestep in the schedule. + + Args: + timestep (`int` or `torch.Tensor`): + The timestep for which to find the index. + schedule_timesteps (`torch.Tensor`, *optional*): + The timestep schedule to search in. If `None`, uses `self.timesteps`. + + Returns: + `int`: + The index of the timestep in the schedule. + """ if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -1007,6 +1033,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): def _init_step_index(self, timestep): """ Initialize the step_index counter for the scheduler. + + Args: + timestep (`int` or `torch.Tensor`): + The current timestep for which to initialize the step index. """ if self.begin_index is None: @@ -1119,6 +1149,21 @@ def add_noise( noise: torch.Tensor, timesteps: torch.IntTensor, ) -> torch.Tensor: + """ + Add noise to the original samples according to the noise schedule at the specified timesteps. + + Args: + original_samples (`torch.Tensor`): + The original samples without noise. + noise (`torch.Tensor`): + The noise to add to the samples. + timesteps (`torch.IntTensor`): + The timesteps at which to add noise to the samples. + + Returns: + `torch.Tensor`: + The noisy samples. + """ # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):