Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,22 @@ def multistep_dpm_solver_second_order_update(
return x_t

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

Expand All @@ -452,6 +467,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""

if self.begin_index is None:
Expand Down
47 changes: 46 additions & 1 deletion src/diffusers/schedulers/scheduling_deis_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,17 @@ def _sigma_to_t(self, sigma, log_sigmas):

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
Expand Down Expand Up @@ -808,7 +819,22 @@ def ind_fn(t, b, c, d):
raise NotImplementedError("only support log-rho multistep deis now")

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.
Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

Expand All @@ -831,6 +857,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.
Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""

if self.begin_index is None:
Expand Down Expand Up @@ -927,6 +957,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.
Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
Expand Down
155 changes: 109 additions & 46 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,17 @@ def _sigma_to_t(self, sigma, log_sigmas):

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.

Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.

Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
Expand Down
47 changes: 46 additions & 1 deletion src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,17 @@ def _sigma_to_t(self, sigma, log_sigmas):

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.

Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.

Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
Expand Down Expand Up @@ -1079,7 +1090,22 @@ def singlestep_dpm_solver_update(
raise ValueError(f"Order must be 1, 2, 3, got {order}")

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.

Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

Expand All @@ -1102,6 +1128,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.

Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""

if self.begin_index is None:
Expand Down Expand Up @@ -1204,6 +1234,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.

Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.

Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
Expand Down
21 changes: 20 additions & 1 deletion src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,22 @@ def multistep_dpm_solver_third_order_update(
return x_t

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.

Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

Expand All @@ -601,6 +616,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.

Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""

if self.begin_index is None:
Expand Down
32 changes: 31 additions & 1 deletion src/diffusers/schedulers/scheduling_sasolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,17 @@ def _sigma_to_t(self, sigma, log_sigmas):

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.

Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.

Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
Expand Down Expand Up @@ -1103,7 +1114,22 @@ def stochastic_adams_moulton_update(
return x_t

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.

Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

Expand All @@ -1126,6 +1152,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.

Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""

if self.begin_index is None:
Expand Down
47 changes: 46 additions & 1 deletion src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,17 @@ def _sigma_to_t(self, sigma, log_sigmas):

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.

Args:
sigma (`torch.Tensor`):
The sigma value(s) to convert.

Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing (alpha_t, sigma_t) values.
"""
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = sigma
Expand Down Expand Up @@ -984,7 +995,22 @@ def multistep_uni_c_bh_update(
return x_t

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

Args:
timestep (`int` or `torch.Tensor`):
The timestep for which to find the index.
schedule_timesteps (`torch.Tensor`, *optional*):
The timestep schedule to search in. If `None`, uses `self.timesteps`.

Returns:
`int`:
The index of the timestep in the schedule.
"""
if schedule_timesteps is None:
schedule_timesteps = self.timesteps

Expand All @@ -1007,6 +1033,10 @@ def index_for_timestep(self, timestep, schedule_timesteps=None):
def _init_step_index(self, timestep):
"""
Initialize the step_index counter for the scheduler.

Args:
timestep (`int` or `torch.Tensor`):
The current timestep for which to initialize the step index.
"""

if self.begin_index is None:
Expand Down Expand Up @@ -1119,6 +1149,21 @@ def add_noise(
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the original samples according to the noise schedule at the specified timesteps.

Args:
original_samples (`torch.Tensor`):
The original samples without noise.
noise (`torch.Tensor`):
The noise to add to the samples.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise to the samples.

Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
Expand Down
Loading