Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
batch_size, channels, *remaining_dims = sample.shape

if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half

# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))

abs_sample = sample.abs() # "a certain percentile absolute pixel value"

Expand All @@ -300,7 +300,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"

sample = sample.reshape(batch_size, channels, height, width)
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)

return sample
Expand Down Expand Up @@ -534,14 +534,14 @@ def multistep_uni_p_bh_update(
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
Expand Down Expand Up @@ -670,15 +670,15 @@ def multistep_uni_c_bh_update(
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
Expand Down
110 changes: 110 additions & 0 deletions tests/schedulers/test_scheduler_unipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,113 @@ def test_full_loop_with_noise(self):

assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}"
assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}"


class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
@property
def dummy_sample(self):
batch_size = 4
num_channels = 3
width = 8

sample = torch.rand((batch_size, num_channels, width))

return sample

@property
def dummy_noise_deter(self):
batch_size = 4
num_channels = 3
width = 8

num_elems = batch_size * num_channels * width
sample = torch.arange(num_elems).flip(-1)
sample = sample.reshape(num_channels, width, batch_size)
sample = sample / num_elems
sample = sample.permute(2, 0, 1)

return sample

@property
def dummy_sample_deter(self):
batch_size = 4
num_channels = 3
width = 8

num_elems = batch_size * num_channels * width
sample = torch.arange(num_elems)
sample = sample.reshape(num_channels, width, batch_size)
sample = sample / num_elems
sample = sample.permute(2, 0, 1)

return sample

def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.2441) < 1e-3

scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)

sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.2441) < 1e-3

def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.2441) < 1e-3

def test_full_loop_with_karras(self):
sample = self.full_loop(use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.2898) < 1e-3

def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.1014) < 1e-3

def test_full_loop_with_karras_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))

assert abs(result_mean.item() - 0.1944) < 1e-3

def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

num_inference_steps = 10
t_start = 8

model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)

# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])

for i, t in enumerate(timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}"
assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"