In [None]:
import torch


class Sampler:
    def __init__(self, num_steps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_steps = num_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.beta_schedule = self.linear_beta_schedule()
        self.alpha = 1 - self.beta_schedule
        self.alpha_cummulative_prod = torch.cumprod(self.alpha, dim=-1)

    def linear_beta_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.num_steps)

    def _repeated_unsqueeze(self, target, tensor):
        while target.dim() > tensor.dim():
            tensor = tensor.unsqueeze(-1)
        return tensor

    def add_noise(self, image, timesteps):
        batch_size, c, h, w = image.shape
        device = image.device
        alpha_cummulative_prod_timesteps = self.alpha_cummulative_prod[timesteps].to(
            device
        )
        mean_coeff = alpha_cummulative_prod_timesteps**0.5
        var_coeff = (1 - alpha_cummulative_prod_timesteps) ** 0.5
        mean_coeff = self._repeated_unsqueeze(image, mean_coeff)
        var_coeff = self._repeated_unsqueeze(image, var_coeff)
        noise = torch.randn_like(image)
        """print(mean_coeff.shape)
        print(image.shape)"""
        noisy_image = mean_coeff * image + var_coeff * noise
        return noisy_image, noise

    def remove_noise(self, image, timesteps, predicted_noise):
        b, c, h, w = image.shape
        device = image.device
        equal_to_zero_mask = timesteps == 0
        beta_t = self.beta_schedule[timesteps].to(device)
        alpha_t = self.alpha[timesteps].to(device)
        alpha_cummulative_prod_t = self.alpha_cummulative_prod[timesteps].to(device)
        alpha_cummulative_prod_t_prev = self.alpha_cummulative_prod[timesteps - 1].to(
            device
        )
        alpha_cummulative_prod_t_prev[equal_to_zero_mask] = (
            1.0  # @QUESTION: this line of code looks weird
        )
        noise = torch.randn_like(
            image
        )  # This is element z in line 4 in Algorithm 2 Sampling
        variance = (
            beta_t
            * (1 - alpha_cummulative_prod_t_prev)
            / (1 - alpha_cummulative_prod_t)
        )  # This is element beta_t_hat in formula (7)
        variance = self._repeated_unsqueeze(image, variance)
        sigma_t_z = (
            variance**0.5
        ) * noise  # This is element sigma * z in line 4 in Algorithm 2 Sampling
        noise_coff = (
            beta_t / (1 - alpha_cummulative_prod_t) ** 0.5
        )  # This is an element in line 4 in Algorithm 2 Sampling, in the paper, they write beta_t in form of (1 - alpha_t)
        noise_coff = self._repeated_unsqueeze(image, noise_coff)
        reciprocal_root_alpha_t = alpha_t ** (
            -0.5
        )  # This is the first element in Algorithm 2 Sampling
        reciprocal_root_alpha_t = self._repeated_unsqueeze(
            image, reciprocal_root_alpha_t
        )

        # Final formula in Algorithm 2 Sampling
        mean = reciprocal_root_alpha_t * (image - noise_coff * predicted_noise)
        denoised = mean + sigma_t_z

        return denoised


sampler = Sampler()
"""rand = torch.randn(4, 3, 64, 64)
pred_noise = torch.randn_like(rand)
randtime = torch.randint(0, 1000, (4,))
sampler.remove_noise(image=rand, timesteps=randtime, predicted_noise=pred_noise)"""

tensor([[[[ 1.7592e-01,  3.3930e-01, -1.7045e+00,  ...,  4.9146e-01,
           -4.1139e-01, -4.6745e-01],
          [ 6.7702e-01, -1.0162e+00, -1.3460e+00,  ...,  2.4024e-01,
           -1.6603e+00,  1.5718e+00],
          [ 1.0607e+00,  6.4687e-01,  1.0622e+00,  ...,  1.2427e+00,
           -1.1806e+00, -1.7679e+00],
          ...,
          [-4.5491e-01, -4.6363e-01,  6.6864e-01,  ..., -2.9415e-01,
           -7.1583e-01, -2.7467e-01],
          [-2.6381e-01,  8.6508e-01,  5.2518e-01,  ..., -1.2277e+00,
            2.0265e-01, -8.6257e-01],
          [ 7.2699e-01,  2.7403e-01,  7.7304e-02,  ..., -6.3923e-01,
            1.1699e+00, -3.2380e-01]],

         [[ 1.0390e+00, -4.6748e-01, -2.5108e-01,  ...,  1.3551e+00,
            1.0221e+00,  2.4719e+00],
          [ 7.9235e-02, -8.4059e-02,  1.0399e+00,  ...,  7.3882e-01,
            8.6183e-01,  4.8081e-01],
          [ 7.6888e-01,  7.0274e-01,  8.7423e-01,  ..., -6.2833e-01,
            1.2914e+00,  1.7990e+00],
          ...,
     