diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index aded6c224671..651532b06ddb 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - num_train_timesteps (`int`, defaults to 1000): + num_train_timesteps (`int`, defaults to `1000`): The number of diffusion steps to train the model. - beta_start (`float`, defaults to 0.0001): + beta_start (`float`, defaults to `0.0001`): The starting `beta` value of inference. - beta_end (`float`, defaults to 0.02): + beta_end (`float`, defaults to `0.02`): The final `beta` value. - beta_schedule (`str`, defaults to `"linear"`): - The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. trained_betas (`np.ndarray`, *optional*): Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. skip_prk_steps (`bool`, defaults to `False`): @@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): Each diffusion step uses the alphas product value at that step and at the previous one. For the final step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, otherwise it uses the alpha value at step 0. - prediction_type (`str`, defaults to `epsilon`, *optional*): + prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process) - or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) - paper). - timestep_spacing (`str`, defaults to `"leading"`): + or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper). + timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): + steps_offset (`int`, defaults to `0`): An offset added to the inference steps, as required by some model families. """ @@ -117,12 +115,12 @@ def __init__( num_train_timesteps: int = 1000, beta_start: float = 0.0001, beta_end: float = 0.02, - beta_schedule: str = "linear", + beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, - prediction_type: str = "epsilon", - timestep_spacing: str = "leading", + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", + timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading", steps_offset: int = 0, ): if trained_betas is not None: @@ -164,7 +162,7 @@ def __init__( self.plms_timesteps = None self.timesteps = None - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -243,7 +241,7 @@ def step( The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. Returns: @@ -276,14 +274,13 @@ def step_prk( The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. - """ if self.num_inference_steps is None: raise ValueError( @@ -335,14 +332,13 @@ def step_plms( The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - return_dict (`bool`): + return_dict (`bool`, defaults to `True`): Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. - """ if self.num_inference_steps is None: raise ValueError( @@ -403,19 +399,27 @@ def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tens """ return sample - def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): - # See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778 - # this function computes x_(t−δ) using the formula of (9) - # Note that x_t needs to be added to both sides of the equation - - # Notation ( -> - # alpha_prod_t -> α_t - # alpha_prod_t_prev -> α_(t−δ) - # beta_prod_t -> (1 - α_t) - # beta_prod_t_prev -> (1 - α_(t−δ)) - # sample -> x_t - # model_output -> e_θ(x_t, t) - # prev_sample -> x_(t−δ) + def _get_prev_sample( + self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor + ) -> torch.Tensor: + """ + Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM + paper](https://huggingface.co/papers/2202.09778). + + Args: + sample (`torch.Tensor`): + The current sample x_t. + timestep (`int`): + The current timestep t. + prev_timestep (`int`): + The previous timestep (t-δ). + model_output (`torch.Tensor`): + The model output e_θ(x_t, t). + + Returns: + `torch.Tensor`: + The previous sample x_(t-δ). + """ alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t @@ -489,5 +493,5 @@ def add_noise( noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps