Skip to content

Commit a7de965

Browse files
authored
Fix unipc use_karras_sigmas exception - fixes #4580 (#4581)
* Fix unipc karras sigmas exception - fixes #4580 * Add unipc scheduler tests for karras sigmas
1 parent 351aab6 commit a7de965

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
278278

279279
return sample
280280

281+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
282+
def _sigma_to_t(self, sigma, log_sigmas):
283+
# get log sigma
284+
log_sigma = np.log(sigma)
285+
286+
# get distribution
287+
dists = log_sigma - log_sigmas[:, np.newaxis]
288+
289+
# get sigmas range
290+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
291+
high_idx = low_idx + 1
292+
293+
low = log_sigmas[low_idx]
294+
high = log_sigmas[high_idx]
295+
296+
# interpolate sigmas
297+
w = (low - log_sigma) / (low - high)
298+
w = np.clip(w, 0, 1)
299+
300+
# transform interpolation to time range
301+
t = (1 - w) * low_idx + w * high_idx
302+
t = t.reshape(sigma.shape)
303+
return t
304+
305+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
306+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
307+
"""Constructs the noise schedule of Karras et al. (2022)."""
308+
309+
sigma_min: float = in_sigmas[-1].item()
310+
sigma_max: float = in_sigmas[0].item()
311+
312+
rho = 7.0 # 7.0 is the value used in the paper
313+
ramp = np.linspace(0, 1, num_inference_steps)
314+
min_inv_rho = sigma_min ** (1 / rho)
315+
max_inv_rho = sigma_max ** (1 / rho)
316+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
317+
return sigmas
318+
281319
def convert_model_output(
282320
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
283321
) -> torch.FloatTensor:

tests/schedulers/test_scheduler_unipc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,24 @@ def test_full_loop_no_noise(self):
208208

209209
assert abs(result_mean.item() - 0.2464) < 1e-3
210210

211+
def test_full_loop_with_karras(self):
212+
sample = self.full_loop(use_karras_sigmas=True)
213+
result_mean = torch.mean(torch.abs(sample))
214+
215+
assert abs(result_mean.item() - 0.2925) < 1e-3
216+
211217
def test_full_loop_with_v_prediction(self):
212218
sample = self.full_loop(prediction_type="v_prediction")
213219
result_mean = torch.mean(torch.abs(sample))
214220

215221
assert abs(result_mean.item() - 0.1014) < 1e-3
216222

223+
def test_full_loop_with_karras_and_v_prediction(self):
224+
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
225+
result_mean = torch.mean(torch.abs(sample))
226+
227+
assert abs(result_mean.item() - 0.1966) < 1e-3
228+
217229
def test_fp16_support(self):
218230
scheduler_class = self.scheduler_classes[0]
219231
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)

0 commit comments

Comments
 (0)