Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 34 additions & 11 deletions src/diffusers/schedulers/scheduling_lms_discrete_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass

import flax
import jax
import jax.numpy as jnp
from scipy import integrate

Expand Down Expand Up @@ -158,14 +159,30 @@ def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarra
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample

def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order, t, current_order):
"""
Compute a linear multistep coefficient.
def get_lms_coefficient(self, state: LMSDiscreteSchedulerState, order: int, t: int, current_order: int) -> float:
r"""
Linear multistep (LMS) coefficient for the interval between `\sigma_t` and `\sigma_{t+1}`.

The coefficient is the definite integral of the Lagrange basis polynomial used in the LMS update (see the
k-diffusion reference in the class docstring). The implementation is equivalent to the PyTorch
`LMSDiscreteScheduler.get_lms_coefficient` in `diffusers.schedulers.scheduling_lms_discrete`.

Parameters `order`, `t`, and `current_order` are the same as in the PyTorch scheduler. Here `t` is the
**inference step index** into `state.sigmas` (not the training timestep label passed to the UNet).

Args:
order (TODO):
t (TODO):
current_order (TODO):
state (`LMSDiscreteSchedulerState`):
The scheduler state (provides the `\sigma` schedule).
order (`int`):
The order of the linear multistep method.
t (`int`):
Current step index into `state.sigmas` (same as `LMSDiscreteScheduler` `t`).
current_order (`int`):
Which basis polynomial's integral to return (`0` .. `order - 1`).

Returns:
`float`:
The integrated LMS weight for `current_order` at step `t`.
"""

def lms_derivative(tau):
Expand All @@ -178,7 +195,7 @@ def lms_derivative(tau):

integrated_coeff = integrate.quad(lms_derivative, state.sigmas[t], state.sigmas[t + 1], epsrel=1e-4)[0]

return integrated_coeff
return float(integrated_coeff)

def set_timesteps(
self,
Expand Down Expand Up @@ -256,17 +273,23 @@ def step(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)

sigma = state.sigmas[timestep]
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]
sigma = state.sigmas[step_index]
t_idx = int(jax.device_get(step_index))

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `v_prediction`, or"
" `sample`"
)

# 2. Convert to an ODE derivative
Expand All @@ -276,8 +299,8 @@ def step(
state = state.replace(derivatives=jnp.delete(state.derivatives, 0))

# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(state, order, timestep, curr_order) for curr_order in range(order)]
lms_order = min(t_idx + 1, order)
lms_coeffs = [self.get_lms_coefficient(state, lms_order, t_idx, curr_order) for curr_order in range(lms_order)]

# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
Expand Down
Loading