@@ -90,6 +90,43 @@ def alpha_bar_fn(t):
90
90
return torch .tensor (betas , dtype = torch .float32 )
91
91
92
92
93
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
94
+ def rescale_zero_terminal_snr (betas ):
95
+ """
96
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
97
+
98
+
99
+ Args:
100
+ betas (`torch.FloatTensor`):
101
+ the betas that the scheduler is being initialized with.
102
+
103
+ Returns:
104
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
105
+ """
106
+ # Convert betas to alphas_bar_sqrt
107
+ alphas = 1.0 - betas
108
+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
109
+ alphas_bar_sqrt = alphas_cumprod .sqrt ()
110
+
111
+ # Store old values.
112
+ alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
113
+ alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
114
+
115
+ # Shift so the last timestep is zero.
116
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
117
+
118
+ # Scale so the first timestep is back to the old value.
119
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
120
+
121
+ # Convert alphas_bar_sqrt to betas
122
+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
123
+ alphas = alphas_bar [1 :] / alphas_bar [:- 1 ] # Revert cumprod
124
+ alphas = torch .cat ([alphas_bar [0 :1 ], alphas ])
125
+ betas = 1 - alphas
126
+
127
+ return betas
128
+
129
+
93
130
class DDIMInverseScheduler (SchedulerMixin , ConfigMixin ):
94
131
"""
95
132
DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`].
@@ -126,9 +163,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
126
163
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
127
164
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
128
165
https://imagen.research.google/video/paper.pdf)
166
+ timestep_spacing (`str`, default `"leading"`):
167
+ The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
168
+ Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
169
+ rescale_betas_zero_snr (`bool`, default `False`):
170
+ whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
171
+ This can enable the model to generate very bright and dark samples instead of limiting it to samples with
172
+ medium brightness. Loosely related to
173
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
129
174
"""
130
175
131
176
order = 1
177
+ ignore_for_config = ["kwargs" ]
178
+ _deprecated_kwargs = ["set_alpha_to_zero" ]
132
179
133
180
@register_to_config
134
181
def __init__ (
@@ -139,18 +186,20 @@ def __init__(
139
186
beta_schedule : str = "linear" ,
140
187
trained_betas : Optional [Union [np .ndarray , List [float ]]] = None ,
141
188
clip_sample : bool = True ,
142
- set_alpha_to_zero : bool = True ,
189
+ set_alpha_to_one : bool = True ,
143
190
steps_offset : int = 0 ,
144
191
prediction_type : str = "epsilon" ,
145
192
clip_sample_range : float = 1.0 ,
193
+ timestep_spacing : str = "leading" ,
194
+ rescale_betas_zero_snr : bool = False ,
146
195
** kwargs ,
147
196
):
148
- if kwargs .get ("set_alpha_to_one " , None ) is not None :
197
+ if kwargs .get ("set_alpha_to_zero " , None ) is not None :
149
198
deprecation_message = (
150
- "The `set_alpha_to_one ` argument is deprecated. Please use `set_alpha_to_zero ` instead."
199
+ "The `set_alpha_to_zero ` argument is deprecated. Please use `set_alpha_to_one ` instead."
151
200
)
152
- deprecate ("set_alpha_to_one " , "1.0.0" , deprecation_message , standard_warn = False )
153
- set_alpha_to_zero = kwargs ["set_alpha_to_one " ]
201
+ deprecate ("set_alpha_to_zero " , "1.0.0" , deprecation_message , standard_warn = False )
202
+ set_alpha_to_one = kwargs ["set_alpha_to_zero " ]
154
203
if trained_betas is not None :
155
204
self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
156
205
elif beta_schedule == "linear" :
@@ -166,15 +215,19 @@ def __init__(
166
215
else :
167
216
raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
168
217
218
+ # Rescale for zero SNR
219
+ if rescale_betas_zero_snr :
220
+ self .betas = rescale_zero_terminal_snr (self .betas )
221
+
169
222
self .alphas = 1.0 - self .betas
170
223
self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
171
224
172
225
# At every step in inverted ddim, we are looking into the next alphas_cumprod
173
- # For the final step, there is no next alphas_cumprod, and the index is out of bounds
174
- # `set_alpha_to_zero ` decides whether we set this parameter simply to zero
226
+ # For the initial step, there is no current alphas_cumprod, and the index is out of bounds
227
+ # `set_alpha_to_one ` decides whether we set this parameter simply to one
175
228
# in this case, self.step() just output the predicted noise
176
- # or whether we use the final alpha of the "non-previous" one .
177
- self .final_alpha_cumprod = torch .tensor (0 .0 ) if set_alpha_to_zero else self .alphas_cumprod [- 1 ]
229
+ # or whether we use the initial alpha used in training the diffusion model .
230
+ self .initial_alpha_cumprod = torch .tensor (1 .0 ) if set_alpha_to_one else self .alphas_cumprod [0 ]
178
231
179
232
# standard deviation of the initial noise distribution
180
233
self .init_noise_sigma = 1.0
@@ -215,12 +268,29 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
215
268
)
216
269
217
270
self .num_inference_steps = num_inference_steps
218
- step_ratio = self .config .num_train_timesteps // self .num_inference_steps
219
- # creates integer timesteps by multiplying by ratio
220
- # casting to int to avoid issues when num_inference_step is power of 3
221
- timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ().copy ().astype (np .int64 )
271
+
272
+ # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
273
+ if self .config .timestep_spacing == "leading" :
274
+ step_ratio = self .config .num_train_timesteps // self .num_inference_steps
275
+ # creates integer timesteps by multiplying by ratio
276
+ # casting to int to avoid issues when num_inference_step is power of 3
277
+ timesteps = (np .arange (0 , num_inference_steps ) * step_ratio ).round ().copy ().astype (np .int64 )
278
+ timesteps += self .config .steps_offset
279
+ elif self .config .timestep_spacing == "trailing" :
280
+ step_ratio = self .config .num_train_timesteps / self .num_inference_steps
281
+ # creates integer timesteps by multiplying by ratio
282
+ # casting to int to avoid issues when num_inference_step is power of 3
283
+ timesteps = np .round (np .arange (self .config .num_train_timesteps , 0 , - step_ratio )[::- 1 ]).astype (np .int64 )
284
+ timesteps -= 1
285
+ else :
286
+ raise ValueError (
287
+ f"{ self .config .timestep_spacing } is not supported. Please make sure to choose one of 'leading' or 'trailing'."
288
+ )
289
+
290
+ # Roll timesteps array by one to reflect reversed origin and destination semantics for each step
291
+ timesteps = np .roll (timesteps , 1 )
292
+ timesteps [0 ] = int (timesteps [1 ] - step_ratio )
222
293
self .timesteps = torch .from_numpy (timesteps ).to (device )
223
- self .timesteps += self .config .steps_offset
224
294
225
295
def step (
226
296
self ,
@@ -237,12 +307,8 @@ def step(
237
307
238
308
# 2. compute alphas, betas
239
309
# change original implementation to exactly match noise levels for analogous forward process
240
- alpha_prod_t = self .alphas_cumprod [timestep ]
241
- alpha_prod_t_prev = (
242
- self .alphas_cumprod [prev_timestep ]
243
- if prev_timestep < self .config .num_train_timesteps
244
- else self .final_alpha_cumprod
245
- )
310
+ alpha_prod_t = self .alphas_cumprod [timestep ] if timestep >= 0 else self .initial_alpha_cumprod
311
+ alpha_prod_t_prev = self .alphas_cumprod [prev_timestep ]
246
312
247
313
beta_prod_t = 1 - alpha_prod_t
248
314
0 commit comments