Skip to content

Schedulers - what code should go into a "Scheduler" class? #8

@patrickvonplaten

Description

@patrickvonplaten

At the moment, the scheduler class has very little logic -> see:

class GaussianDDPMScheduler(nn.Module, ConfigMixin):

where as the example of the unrolled denoising process is getting quite complicated (copied from the README):

# 3. Denoise                                                                                                                                           
for t in reversed(range(len(scheduler))):
	# 1. predict noise residual
	with torch.no_grad():
		pred_noise_t = self.unet(image, t)

	# 2. compute alphas, betas
	alpha_prod_t = self.noise_scheduler.get_alpha_prod(t)
	alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1)
	beta_prod_t = 1 - alpha_prod_t
	beta_prod_t_prev = 1 - alpha_prod_t_prev

	# 3. compute predicted image from residual
	# First: compute predicted original image from predicted noise also called
	# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
	pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()

	# Second: Clip "predicted x_0"
	pred_original_image = torch.clamp(pred_original_image, -1, 1)

	# Third: Compute coefficients for pred_original_image x_0 and current image x_t
	# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
	pred_original_image_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t
	current_image_coeff = self.noise_scheduler.get_alpha(t).sqrt() * beta_prod_t_prev / beta_prod_t
	# Fourth: Compute predicted previous image µ_t
	# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
	pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image

	# 5. For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
	# and sample from it to get previous image
	# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
	if t > 0:
		variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.noise_scheduler.get_beta(t).sqrt()
		noise = scheduler.sample_noise(image.shape, device=image.device, generator=generator)
		prev_image = pred_prev_image + variance * noise
	else:
		prev_image = pred_prev_image

	# 6. Set current image to prev_image: x_t -> x_t-1
	image = prev_image

As noted by @patil-suraj , I also start to think that we should put more logic into a DDPMNoiseScheduler class since we more or less copy this loop otherwise for all other models such as GLIDE and LDM.

If we give the scheduler class more logic we could reduce the loop to:

for t in reversed(range(len(scheduler))):
	# 1. predict noise residual
	with torch.no_grad():
		pred_noise_t = self.unet(image, t)

	prev_image = scheduler.sample_prev_image(pred_noise_t, image, t)
     
        image = prev_image

I start to be in favor of this reduced for loop. Obviously a user could still do the above, very in-detail loop, but IMO it would be important to give the user a function that can be reused for different models, such as def sample_prev_image

@patil-suraj @anton-l what do you think?

Also would love to hear the opinion of @thomwolf and @srush here :-)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions