@@ -278,6 +278,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
278
278
279
279
return sample
280
280
281
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
282
+ def _sigma_to_t (self , sigma , log_sigmas ):
283
+ # get log sigma
284
+ log_sigma = np .log (sigma )
285
+
286
+ # get distribution
287
+ dists = log_sigma - log_sigmas [:, np .newaxis ]
288
+
289
+ # get sigmas range
290
+ low_idx = np .cumsum ((dists >= 0 ), axis = 0 ).argmax (axis = 0 ).clip (max = log_sigmas .shape [0 ] - 2 )
291
+ high_idx = low_idx + 1
292
+
293
+ low = log_sigmas [low_idx ]
294
+ high = log_sigmas [high_idx ]
295
+
296
+ # interpolate sigmas
297
+ w = (low - log_sigma ) / (low - high )
298
+ w = np .clip (w , 0 , 1 )
299
+
300
+ # transform interpolation to time range
301
+ t = (1 - w ) * low_idx + w * high_idx
302
+ t = t .reshape (sigma .shape )
303
+ return t
304
+
305
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
306
+ def _convert_to_karras (self , in_sigmas : torch .FloatTensor , num_inference_steps ) -> torch .FloatTensor :
307
+ """Constructs the noise schedule of Karras et al. (2022)."""
308
+
309
+ sigma_min : float = in_sigmas [- 1 ].item ()
310
+ sigma_max : float = in_sigmas [0 ].item ()
311
+
312
+ rho = 7.0 # 7.0 is the value used in the paper
313
+ ramp = np .linspace (0 , 1 , num_inference_steps )
314
+ min_inv_rho = sigma_min ** (1 / rho )
315
+ max_inv_rho = sigma_max ** (1 / rho )
316
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
317
+ return sigmas
318
+
281
319
def convert_model_output (
282
320
self , model_output : torch .FloatTensor , timestep : int , sample : torch .FloatTensor
283
321
) -> torch .FloatTensor :
0 commit comments