Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/diffusers/schedulers/scheduling_edm_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,13 @@ def step(

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)

eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)

if gamma > 0:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
eps = noise * s_noise
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
Expand Down
9 changes: 4 additions & 5 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,14 +638,13 @@ def step(

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)

eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)

if gamma > 0:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
eps = noise * s_noise
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,13 @@ def step(

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)

eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)

if gamma > 0:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
eps = noise * s_noise
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5

if self.state_in_first_order:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,6 @@ def step(
gamma = 0
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now

device = model_output.device
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=device, generator=generator)

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
Expand Down Expand Up @@ -564,6 +561,9 @@ def step(
self.sample = None

prev_sample = sample + derivative * dt
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
prev_sample = prev_sample + noise * sigma_up

# upon completion increase step index by one
Expand Down
16 changes: 8 additions & 8 deletions tests/schedulers/test_scheduler_kdpm2_ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def test_full_loop_no_noise(self):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 13849.3877) < 1e-2
assert abs(result_mean.item() - 18.0331) < 5e-3
assert abs(result_sum.item() - 13979.9433) < 1e-2
assert abs(result_mean.item() - 18.2030) < 5e-3

def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
Expand Down Expand Up @@ -92,8 +92,8 @@ def test_full_loop_with_v_prediction(self):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 328.9970) < 1e-2
assert abs(result_mean.item() - 0.4284) < 1e-3
assert abs(result_sum.item() - 331.8133) < 1e-2
assert abs(result_mean.item() - 0.4320) < 1e-3

def test_full_loop_device(self):
if torch_device == "mps":
Expand All @@ -119,8 +119,8 @@ def test_full_loop_device(self):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 13849.3818) < 1e-1
assert abs(result_mean.item() - 18.0331) < 1e-3
assert abs(result_sum.item() - 13979.9433) < 1e-1
assert abs(result_mean.item() - 18.2030) < 1e-3

def test_full_loop_with_noise(self):
if torch_device == "mps":
Expand Down Expand Up @@ -154,5 +154,5 @@ def test_full_loop_with_noise(self):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 93087.0312) < 1e-2, f" expected result sum 93087.0312, but get {result_sum}"
assert abs(result_mean.item() - 121.2071) < 5e-3, f" expected result mean 121.2071, but get {result_mean}"
assert abs(result_sum.item() - 93087.3437) < 1e-2, f" expected result sum 93087.3437, but get {result_sum}"
assert abs(result_mean.item() - 121.2074) < 5e-3, f" expected result mean 121.2074, but get {result_mean}"
Loading