Skip to content

Commit 4fdca5a

Browse files
patrickvonplatenzhen-hao.chu
authored andcommitted
Fix UniPC scheduler for 1D (huggingface#5276)
1 parent b9df3b5 commit 4fdca5a

9 files changed

+27
-36
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,25 +276,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
276276
https://arxiv.org/abs/2205.11487
277277
"""
278278
dtype = sample.dtype
279-
batch_size, channels, height, width = sample.shape
279+
batch_size, channels, *remaining_dims = sample.shape
280280

281281
if dtype not in (torch.float32, torch.float64):
282282
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
283283

284284
# Flatten sample for doing quantile calculation along each image
285-
sample = sample.reshape(batch_size, channels * height * width)
285+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
286286

287287
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
288288

289289
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
290290
s = torch.clamp(
291291
s, min=1, max=self.config.sample_max_value
292292
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
293-
294293
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
295294
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
296295

297-
sample = sample.reshape(batch_size, channels, height, width)
296+
sample = sample.reshape(batch_size, channels, *remaining_dims)
298297
sample = sample.to(dtype)
299298

300299
return sample

src/diffusers/schedulers/scheduling_ddim_parallel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,25 +298,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
298298
https://arxiv.org/abs/2205.11487
299299
"""
300300
dtype = sample.dtype
301-
batch_size, channels, height, width = sample.shape
301+
batch_size, channels, *remaining_dims = sample.shape
302302

303303
if dtype not in (torch.float32, torch.float64):
304304
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
305305

306306
# Flatten sample for doing quantile calculation along each image
307-
sample = sample.reshape(batch_size, channels * height * width)
307+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
308308

309309
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
310310

311311
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
312312
s = torch.clamp(
313313
s, min=1, max=self.config.sample_max_value
314314
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
315-
316315
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
317316
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
318317

319-
sample = sample.reshape(batch_size, channels, height, width)
318+
sample = sample.reshape(batch_size, channels, *remaining_dims)
320319
sample = sample.to(dtype)
321320

322321
return sample

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,25 +330,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
330330
https://arxiv.org/abs/2205.11487
331331
"""
332332
dtype = sample.dtype
333-
batch_size, channels, height, width = sample.shape
333+
batch_size, channels, *remaining_dims = sample.shape
334334

335335
if dtype not in (torch.float32, torch.float64):
336336
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
337337

338338
# Flatten sample for doing quantile calculation along each image
339-
sample = sample.reshape(batch_size, channels * height * width)
339+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
340340

341341
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
342342

343343
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
344344
s = torch.clamp(
345345
s, min=1, max=self.config.sample_max_value
346346
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
347-
348347
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
349348
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
350349

351-
sample = sample.reshape(batch_size, channels, height, width)
350+
sample = sample.reshape(batch_size, channels, *remaining_dims)
352351
sample = sample.to(dtype)
353352

354353
return sample

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,25 +344,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
344344
https://arxiv.org/abs/2205.11487
345345
"""
346346
dtype = sample.dtype
347-
batch_size, channels, height, width = sample.shape
347+
batch_size, channels, *remaining_dims = sample.shape
348348

349349
if dtype not in (torch.float32, torch.float64):
350350
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
351351

352352
# Flatten sample for doing quantile calculation along each image
353-
sample = sample.reshape(batch_size, channels * height * width)
353+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
354354

355355
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
356356

357357
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
358358
s = torch.clamp(
359359
s, min=1, max=self.config.sample_max_value
360360
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
361-
362361
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
363362
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
364363

365-
sample = sample.reshape(batch_size, channels, height, width)
364+
sample = sample.reshape(batch_size, channels, *remaining_dims)
366365
sample = sample.to(dtype)
367366

368367
return sample

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -268,25 +268,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
268268
https://arxiv.org/abs/2205.11487
269269
"""
270270
dtype = sample.dtype
271-
batch_size, channels, height, width = sample.shape
271+
batch_size, channels, *remaining_dims = sample.shape
272272

273273
if dtype not in (torch.float32, torch.float64):
274274
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
275275

276276
# Flatten sample for doing quantile calculation along each image
277-
sample = sample.reshape(batch_size, channels * height * width)
277+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
278278

279279
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
280280

281281
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
282282
s = torch.clamp(
283283
s, min=1, max=self.config.sample_max_value
284284
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
285-
286285
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
287286
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
288287

289-
sample = sample.reshape(batch_size, channels, height, width)
288+
sample = sample.reshape(batch_size, channels, *remaining_dims)
290289
sample = sample.to(dtype)
291290

292291
return sample

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,25 +288,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
288288
https://arxiv.org/abs/2205.11487
289289
"""
290290
dtype = sample.dtype
291-
batch_size, channels, height, width = sample.shape
291+
batch_size, channels, *remaining_dims = sample.shape
292292

293293
if dtype not in (torch.float32, torch.float64):
294294
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
295295

296296
# Flatten sample for doing quantile calculation along each image
297-
sample = sample.reshape(batch_size, channels * height * width)
297+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
298298

299299
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
300300

301301
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
302302
s = torch.clamp(
303303
s, min=1, max=self.config.sample_max_value
304304
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
305-
306305
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
307306
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
308307

309-
sample = sample.reshape(batch_size, channels, height, width)
308+
sample = sample.reshape(batch_size, channels, *remaining_dims)
310309
sample = sample.to(dtype)
311310

312311
return sample

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,25 +298,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
298298
https://arxiv.org/abs/2205.11487
299299
"""
300300
dtype = sample.dtype
301-
batch_size, channels, height, width = sample.shape
301+
batch_size, channels, *remaining_dims = sample.shape
302302

303303
if dtype not in (torch.float32, torch.float64):
304304
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
305305

306306
# Flatten sample for doing quantile calculation along each image
307-
sample = sample.reshape(batch_size, channels * height * width)
307+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
308308

309309
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
310310

311311
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
312312
s = torch.clamp(
313313
s, min=1, max=self.config.sample_max_value
314314
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
315-
316315
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
317316
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
318317

319-
sample = sample.reshape(batch_size, channels, height, width)
318+
sample = sample.reshape(batch_size, channels, *remaining_dims)
320319
sample = sample.to(dtype)
321320

322321
return sample

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,25 +302,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
302302
https://arxiv.org/abs/2205.11487
303303
"""
304304
dtype = sample.dtype
305-
batch_size, channels, height, width = sample.shape
305+
batch_size, channels, *remaining_dims = sample.shape
306306

307307
if dtype not in (torch.float32, torch.float64):
308308
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
309309

310310
# Flatten sample for doing quantile calculation along each image
311-
sample = sample.reshape(batch_size, channels * height * width)
311+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
312312

313313
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
314314

315315
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
316316
s = torch.clamp(
317317
s, min=1, max=self.config.sample_max_value
318318
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
319-
320319
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
321320
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
322321

323-
sample = sample.reshape(batch_size, channels, height, width)
322+
sample = sample.reshape(batch_size, channels, *remaining_dims)
324323
sample = sample.to(dtype)
325324

326325
return sample

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,25 +282,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
282282
https://arxiv.org/abs/2205.11487
283283
"""
284284
dtype = sample.dtype
285-
batch_size, channels, height, width = sample.shape
285+
batch_size, channels, *remaining_dims = sample.shape
286286

287287
if dtype not in (torch.float32, torch.float64):
288288
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
289289

290290
# Flatten sample for doing quantile calculation along each image
291-
sample = sample.reshape(batch_size, channels * height * width)
291+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
292292

293293
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
294294

295295
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
296296
s = torch.clamp(
297297
s, min=1, max=self.config.sample_max_value
298298
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
299-
300299
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
301300
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
302301

303-
sample = sample.reshape(batch_size, channels, height, width)
302+
sample = sample.reshape(batch_size, channels, *remaining_dims)
304303
sample = sample.to(dtype)
305304

306305
return sample

0 commit comments

Comments
 (0)