In [1]:
import torch

In [None]:
class LinearNoiseSchedule:
    def __init__(self, num_timesteps, beta_start, beta_end):
        self.num_timesteps = num_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)

        self.alphas = 1. - self.betas
        self.alphas_cum_prod = torch.cumprod(self.alphas, dim=0)
        
        self.sqrt_alphas_cum_prod = torch.sqrt(self.alphas_cum_prod)
        self.sqrt_one_minus_alphas_num_prod = torch.sqrt(1. - self.sqrt_alphas_cum_prod)

    def add_noise(self, original, noise, t):
        sqrt_alphas_cum_prod = self.sqrt_alphas_cum_prod[t]
        sqrt_one_minus_alphas_cum_prod = self.sqrt_one_minus_alphas_num_prod[t]

        for _ in range(len(original.shape) - 1):
            sqrt_alphas_cum_prod = sqrt_alphas_cum_prod.unsqueeze(-1)
            sqrt_one_minus_alphas_cum_prod = sqrt_one_minus_alphas_cum_prod.unsqueeze(-1)

        return sqrt_alphas_cum_prod * original + sqrt_one_minus_alphas_cum_prod * noise
    
    def sample_prev_timestep(self, xt, noise_pred, t):
        x0 = (xt - self.sqrt_one_minus_alphas_num_prod[t] * noise_pred) / self.sqrt_alphas_cum_prod[t]
        x0 = torch.clamp(x0, 0.0, 1.0)

        xt_1 = xt - (self.betas[t] / self.sqrt_one_minus_alphas_num_prod[t]) * noise_pred
        xt_1 = xt / self.sqrt_alphas_cum_prod[t]

        if t == 0:
            return xt_1, x0
        else:
            variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (1.0 - self.alpha_cum_prod.to(xt.device)[t])
            variance = variance * self.betas.to(xt.device)[t]
            sigma = variance ** 0.5
            z = torch.randn(xt.shape).to(xt.device)

            return xt_1 + sigma * z, x0