@@ -276,25 +276,24 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
276276 https://arxiv.org/abs/2205.11487
277277 """
278278 dtype = sample .dtype
279- batch_size , channels , height , width = sample .shape
279+ batch_size , channels , * remaining_dims = sample .shape
280280
281281 if dtype not in (torch .float32 , torch .float64 ):
282282 sample = sample .float () # upcast for quantile calculation, and clamp not implemented for cpu half
283283
284284 # Flatten sample for doing quantile calculation along each image
285- sample = sample .reshape (batch_size , channels * height * width )
285+ sample = sample .reshape (batch_size , channels * np . prod ( remaining_dims ) )
286286
287287 abs_sample = sample .abs () # "a certain percentile absolute pixel value"
288288
289289 s = torch .quantile (abs_sample , self .config .dynamic_thresholding_ratio , dim = 1 )
290290 s = torch .clamp (
291291 s , min = 1 , max = self .config .sample_max_value
292292 ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
293-
294293 s = s .unsqueeze (1 ) # (batch_size, 1) because clamp will broadcast along dim=0
295294 sample = torch .clamp (sample , - s , s ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
296295
297- sample = sample .reshape (batch_size , channels , height , width )
296+ sample = sample .reshape (batch_size , channels , * remaining_dims )
298297 sample = sample .to (dtype )
299298
300299 return sample
0 commit comments