diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 65902678e1d9..5659d364d559 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -15,6 +15,7 @@ from dataclasses import dataclass import flax +import jax import jax.numpy as jnp from scipy import integrate @@ -158,14 +159,30 @@ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarra sample = sample / ((sigma**2 + 1) ** 0.5) return sample - def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order): - """ - Compute a linear multistep coefficient. + def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order: int, t: int, current_order: int) -> float: + r""" + Linear multistep (LMS) coefficient for the interval between `\sigma_t` and `\sigma_{t+1}`. + + The coefficient is the definite integral of the Lagrange basis polynomial used in the LMS update (see the + k-diffusion reference in the class docstring). The implementation is equivalent to the PyTorch + `LMSDiscreteScheduler.get_lms_coefficient` in `diffusers.schedulers.scheduling_lms_discrete`. + + Parameters `order`, `t`, and `current_order` are the same as in the PyTorch scheduler. Here `t` is the + **inference step index** into `state.sigmas` (not the training timestep label passed to the UNet). Args: - order (TODO): - t (TODO): - current_order (TODO): + state (`LMSDiscreteSchedulerState`): + The scheduler state (provides the `\sigma` schedule). + order (`int`): + The order of the linear multistep method. + t (`int`): + Current step index into `state.sigmas` (same as `LMSDiscreteScheduler` `t`). + current_order (`int`): + Which basis polynomial's integral to return (`0` .. `order - 1`). + + Returns: + `float`: + The integrated LMS weight for `current_order` at step `t`. """ def lms_derivative(tau): @@ -178,7 +195,7 @@ def lms_derivative(tau): integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0] - return integrated_coeff + return float(integrated_coeff) def set_timesteps( self, @@ -256,7 +273,10 @@ def step( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) - sigma = state.sigmas[timestep] + (step_index,) = jnp.where(state.timesteps == timestep, size=1) + step_index = step_index[0] + sigma = state.sigmas[step_index] + t_idx = int(jax.device_get(step_index)) # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": @@ -264,9 +284,12 @@ def step( elif self.config.prediction_type == "v_prediction": # * c_out + input * c_skip pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output else: raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `v_prediction`, or" + " `sample`" ) # 2. Convert to an ODE derivative @@ -276,8 +299,8 @@ def step( state = state.replace(derivatives=jnp.delete(state.derivatives, 0)) # 3. Compute linear multistep coefficients - order = min(timestep + 1, order) - lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)] + lms_order = min(t_idx + 1, order) + lms_coeffs = [self.get_lms_coefficient(state, lms_order, t_idx, curr_order) for curr_order in range(lms_order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum(