-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
At the moment, the scheduler class has very little logic -> see:
| class GaussianDDPMScheduler(nn.Module, ConfigMixin): |
where as the example of the unrolled denoising process is getting quite complicated (copied from the README):
# 3. Denoise
for t in reversed(range(len(scheduler))):
# 1. predict noise residual
with torch.no_grad():
pred_noise_t = self.unet(image, t)
# 2. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 3. compute predicted image from residual
# First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
# Second: Clip "predicted x_0"
pred_original_image = torch.clamp(pred_original_image, -1, 1)
# Third: Compute coefficients for pred_original_image x_0 and current image x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
current_image_coeff = self.noise_scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
# Fourth: Compute predicted previous image µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
# 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous image
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
if t > 0:
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
noise = scheduler.sample_noise(image.shape, device=image.device, generator=generator)
prev_image = pred_prev_image + variance * noise
else:
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image = prev_imageAs noted by @patil-suraj , I also start to think that we should put more logic into a DDPMNoiseScheduler class since we more or less copy this loop otherwise for all other models such as GLIDE and LDM.
If we give the scheduler class more logic we could reduce the loop to:
for t in reversed(range(len(scheduler))):
# 1. predict noise residual
with torch.no_grad():
pred_noise_t = self.unet(image, t)
prev_image = scheduler.sample_prev_image(pred_noise_t, image, t)
image = prev_imageI start to be in favor of this reduced for loop. Obviously a user could still do the above, very in-detail loop, but IMO it would be important to give the user a function that can be reused for different models, such as def sample_prev_image
@patil-suraj @anton-l what do you think?
Also would love to hear the opinion of @thomwolf and @srush here :-)