Skip to content

Commit c6e56e9

Browse files
authored
Add Recent Timestep Scheduling Improvements to DDIM Inverse Scheduler (#3865)
* Add Recent Timestep Scheduling Improvements to DDIM Inverse Scheduler Roll timesteps by one to reflect origin-destination semantic discrepancy Restore `set_alpha_to_one` option to handle negative initial timesteps Remove `set_alpha_to_zero` option not used due to previous truncation * Bugfix * Remove unnecessary calls to `detach()` Use `self.image_processor.preprocess` in DiffEdit pipeline functions * Preprocess list input for inverted image latents in diffedit pipeline * Add `timestep_spacing` and `steps_offset` to `DPMSolverMultistepInverseScheduler` * Update expected test results to account for inverting last forward diffusion step * Fix inversion progress bar bug * Add first draft for proper fast tests for DDIMInverseScheduler * Add deprecated DDIMInverseScheduler kwarg to ConfigMixer registry * Fix test failure in DPMMultistepInverseScheduler Invert step specification leads to negative noise variance in SDE-based algs Add first draft for proper fast tests for DPMMultistepInverseScheduler * Update expected test results to account for inverting last forward diffusion step Clean up diffedit fast test
1 parent 27062c3 commit c6e56e9

File tree

8 files changed

+555
-90
lines changed

8 files changed

+555
-90
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ def generate_mask(
992992
)
993993

994994
# 4. Preprocess image
995-
image = preprocess(image).repeat_interleave(num_maps_per_mask, dim=0)
995+
image = self.image_processor.preprocess(image).repeat_interleave(num_maps_per_mask, dim=0)
996996

997997
# 5. Set timesteps
998998
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1176,7 +1176,7 @@ def invert(
11761176
do_classifier_free_guidance = guidance_scale > 1.0
11771177

11781178
# 3. Preprocess image
1179-
image = preprocess(image)
1179+
image = self.image_processor.preprocess(image)
11801180

11811181
# 4. Prepare latent variables
11821182
num_images_per_prompt = 1
@@ -1201,9 +1201,9 @@ def invert(
12011201

12021202
# 7. Noising loop where we obtain the intermediate noised latent image for each timestep.
12031203
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
1204-
inverted_latents = [latents.detach().clone()]
1205-
with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
1206-
for i, t in enumerate(timesteps[:-1]):
1204+
inverted_latents = []
1205+
with self.progress_bar(total=num_inference_steps) as progress_bar:
1206+
for i, t in enumerate(timesteps):
12071207
# expand the latents if we are doing classifier free guidance
12081208
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
12091209
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
@@ -1270,7 +1270,7 @@ def invert(
12701270
# 8. Post-processing
12711271
image = None
12721272
if decode_latents:
1273-
image = self.decode_latents(latents.flatten(0, 1).detach())
1273+
image = self.decode_latents(latents.flatten(0, 1))
12741274

12751275
# 9. Convert to PIL.
12761276
if decode_latents and output_type == "pil":
@@ -1291,7 +1291,7 @@ def __call__(
12911291
self,
12921292
prompt: Optional[Union[str, List[str]]] = None,
12931293
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
1294-
image_latents: torch.FloatTensor = None,
1294+
image_latents: Union[torch.FloatTensor, PIL.Image.Image] = None,
12951295
inpaint_strength: Optional[float] = 0.8,
12961296
num_inference_steps: int = 50,
12971297
guidance_scale: float = 7.5,
@@ -1447,7 +1447,13 @@ def __call__(
14471447
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device)
14481448

14491449
# 6. Preprocess image latents
1450-
image_latents = preprocess(image_latents)
1450+
if isinstance(image_latents, list) and any(isinstance(l, torch.Tensor) and l.ndim == 5 for l in image_latents):
1451+
image_latents = torch.cat(image_latents).detach()
1452+
elif isinstance(image_latents, torch.Tensor) and image_latents.ndim == 5:
1453+
image_latents = image_latents.detach()
1454+
else:
1455+
image_latents = self.image_processor.preprocess(image_latents).detach()
1456+
14511457
latent_shape = (self.vae.config.latent_channels, latent_height, latent_width)
14521458
if image_latents.shape[-3:] != latent_shape:
14531459
raise ValueError(
@@ -1458,8 +1464,9 @@ def __call__(
14581464
image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape)
14591465
if image_latents.shape[:2] != (batch_size, len(timesteps)):
14601466
raise ValueError(
1461-
f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, "
1462-
f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps."
1467+
f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)}"
1468+
f" timesteps, but has batch size {image_latents.shape[0]} with latent images from"
1469+
f" {image_latents.shape[1]} timesteps."
14631470
)
14641471
image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1)
14651472
image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype)
@@ -1468,7 +1475,7 @@ def __call__(
14681475
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
14691476

14701477
# 8. Denoising loop
1471-
latents = image_latents[0].detach().clone()
1478+
latents = image_latents[0].clone()
14721479
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
14731480
with self.progress_bar(total=num_inference_steps) as progress_bar:
14741481
for i, t in enumerate(timesteps):

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,8 +1183,8 @@ def invert(
11831183

11841184
# 7. Denoising loop where we obtain the cross-attention maps.
11851185
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
1186-
with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
1187-
for i, t in enumerate(timesteps[:-1]):
1186+
with self.progress_bar(total=num_inference_steps) as progress_bar:
1187+
for i, t in enumerate(timesteps):
11881188
# expand the latents if we are doing classifier free guidance
11891189
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
11901190
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)

src/diffusers/schedulers/scheduling_ddim_inverse.py

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,43 @@ def alpha_bar_fn(t):
9090
return torch.tensor(betas, dtype=torch.float32)
9191

9292

93+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
94+
def rescale_zero_terminal_snr(betas):
95+
"""
96+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
97+
98+
99+
Args:
100+
betas (`torch.FloatTensor`):
101+
the betas that the scheduler is being initialized with.
102+
103+
Returns:
104+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
105+
"""
106+
# Convert betas to alphas_bar_sqrt
107+
alphas = 1.0 - betas
108+
alphas_cumprod = torch.cumprod(alphas, dim=0)
109+
alphas_bar_sqrt = alphas_cumprod.sqrt()
110+
111+
# Store old values.
112+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
113+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
114+
115+
# Shift so the last timestep is zero.
116+
alphas_bar_sqrt -= alphas_bar_sqrt_T
117+
118+
# Scale so the first timestep is back to the old value.
119+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
120+
121+
# Convert alphas_bar_sqrt to betas
122+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
123+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
124+
alphas = torch.cat([alphas_bar[0:1], alphas])
125+
betas = 1 - alphas
126+
127+
return betas
128+
129+
93130
class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
94131
"""
95132
DDIMInverseScheduler is the reverse scheduler of [`DDIMScheduler`].
@@ -126,9 +163,19 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
126163
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
127164
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
128165
https://imagen.research.google/video/paper.pdf)
166+
timestep_spacing (`str`, default `"leading"`):
167+
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
168+
Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information.
169+
rescale_betas_zero_snr (`bool`, default `False`):
170+
whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf).
171+
This can enable the model to generate very bright and dark samples instead of limiting it to samples with
172+
medium brightness. Loosely related to
173+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
129174
"""
130175

131176
order = 1
177+
ignore_for_config = ["kwargs"]
178+
_deprecated_kwargs = ["set_alpha_to_zero"]
132179

133180
@register_to_config
134181
def __init__(
@@ -139,18 +186,20 @@ def __init__(
139186
beta_schedule: str = "linear",
140187
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
141188
clip_sample: bool = True,
142-
set_alpha_to_zero: bool = True,
189+
set_alpha_to_one: bool = True,
143190
steps_offset: int = 0,
144191
prediction_type: str = "epsilon",
145192
clip_sample_range: float = 1.0,
193+
timestep_spacing: str = "leading",
194+
rescale_betas_zero_snr: bool = False,
146195
**kwargs,
147196
):
148-
if kwargs.get("set_alpha_to_one", None) is not None:
197+
if kwargs.get("set_alpha_to_zero", None) is not None:
149198
deprecation_message = (
150-
"The `set_alpha_to_one` argument is deprecated. Please use `set_alpha_to_zero` instead."
199+
"The `set_alpha_to_zero` argument is deprecated. Please use `set_alpha_to_one` instead."
151200
)
152-
deprecate("set_alpha_to_one", "1.0.0", deprecation_message, standard_warn=False)
153-
set_alpha_to_zero = kwargs["set_alpha_to_one"]
201+
deprecate("set_alpha_to_zero", "1.0.0", deprecation_message, standard_warn=False)
202+
set_alpha_to_one = kwargs["set_alpha_to_zero"]
154203
if trained_betas is not None:
155204
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
156205
elif beta_schedule == "linear":
@@ -166,15 +215,19 @@ def __init__(
166215
else:
167216
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
168217

218+
# Rescale for zero SNR
219+
if rescale_betas_zero_snr:
220+
self.betas = rescale_zero_terminal_snr(self.betas)
221+
169222
self.alphas = 1.0 - self.betas
170223
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
171224

172225
# At every step in inverted ddim, we are looking into the next alphas_cumprod
173-
# For the final step, there is no next alphas_cumprod, and the index is out of bounds
174-
# `set_alpha_to_zero` decides whether we set this parameter simply to zero
226+
# For the initial step, there is no current alphas_cumprod, and the index is out of bounds
227+
# `set_alpha_to_one` decides whether we set this parameter simply to one
175228
# in this case, self.step() just output the predicted noise
176-
# or whether we use the final alpha of the "non-previous" one.
177-
self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1]
229+
# or whether we use the initial alpha used in training the diffusion model.
230+
self.initial_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
178231

179232
# standard deviation of the initial noise distribution
180233
self.init_noise_sigma = 1.0
@@ -215,12 +268,29 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
215268
)
216269

217270
self.num_inference_steps = num_inference_steps
218-
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
219-
# creates integer timesteps by multiplying by ratio
220-
# casting to int to avoid issues when num_inference_step is power of 3
221-
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
271+
272+
# "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
273+
if self.config.timestep_spacing == "leading":
274+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
275+
# creates integer timesteps by multiplying by ratio
276+
# casting to int to avoid issues when num_inference_step is power of 3
277+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().copy().astype(np.int64)
278+
timesteps += self.config.steps_offset
279+
elif self.config.timestep_spacing == "trailing":
280+
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
281+
# creates integer timesteps by multiplying by ratio
282+
# casting to int to avoid issues when num_inference_step is power of 3
283+
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)[::-1]).astype(np.int64)
284+
timesteps -= 1
285+
else:
286+
raise ValueError(
287+
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
288+
)
289+
290+
# Roll timesteps array by one to reflect reversed origin and destination semantics for each step
291+
timesteps = np.roll(timesteps, 1)
292+
timesteps[0] = int(timesteps[1] - step_ratio)
222293
self.timesteps = torch.from_numpy(timesteps).to(device)
223-
self.timesteps += self.config.steps_offset
224294

225295
def step(
226296
self,
@@ -237,12 +307,8 @@ def step(
237307

238308
# 2. compute alphas, betas
239309
# change original implementation to exactly match noise levels for analogous forward process
240-
alpha_prod_t = self.alphas_cumprod[timestep]
241-
alpha_prod_t_prev = (
242-
self.alphas_cumprod[prev_timestep]
243-
if prev_timestep < self.config.num_train_timesteps
244-
else self.final_alpha_cumprod
245-
)
310+
alpha_prod_t = self.alphas_cumprod[timestep] if timestep >= 0 else self.initial_alpha_cumprod
311+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]
246312

247313
beta_prod_t = 1 - alpha_prod_t
248314

0 commit comments

Comments
 (0)