diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 802ba0f099f9..8da3434fc347 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -219,8 +219,14 @@ def __init__( sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - sigmas = torch.from_numpy(sigmas[::-1].copy()).to(dtype=torch.float32) - timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + sigmas = sigmas[::-1].copy() + + if self.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_train_timesteps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) # setable values self.num_inference_steps = None @@ -229,7 +235,7 @@ def __init__( if timestep_type == "continuous" and prediction_type == "v_prediction": self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) else: - self.timesteps = timesteps + self.timesteps = torch.from_numpy(timesteps.astype(np.float32)) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])