In [1]:
import numpy as np
import torch
import math
import seaborn as sns
import matplotlib.pyplot as plt

from dataclasses import dataclass
from models.externals.guided_diffusion import UNetModel

In [2]:
T = 1000

In [3]:
def linear_betas(num_diffusion_timesteps: int, beta_start: float = 0.0001, beta_end: float = 0.02) -> torch.Tensor:
    scale = T / num_diffusion_timesteps
    return torch.linspace(
        beta_start * scale,
        beta_end * scale,
        num_diffusion_timesteps
    )

def cosine_betas(num_diffusion_timesteps: int, s: float = 0.008, max_beta: float = 0.999) -> torch.Tensor:
    alpha_bar = lambda t: math.cos((t + s) / (1 + s) * math.pi / 2) ** 2
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return torch.tensor(betas)


class GaussianDiffusion:
    def __init__(self, model: UNetModel, num_diffusion_timesteps: int = T, noise_schedule: str = 'linear'):
        self.model = model
        self.num_diffusion_timesteps = num_diffusion_timesteps
        self.noise_schedule = noise_schedule
        self.betas: torch.Tensor = (
            linear_betas(num_diffusion_timesteps)
            if self.noise_schedule == 'linear'
            else cosine_betas(num_diffusion_timesteps)
        )
        self.alpha_t = 1 - self.betas
        self.alpha_bar = torch.cumprod(self.alpha_t, dim=0)

    def denoise_at_t(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        if t > 1:
            z = torch.randn_like(x_t)
        else:
            z = 0

        model_output = self.model(x_t, t)
        x_prev = 1 / torch.sqrt(self.alpha_t[t]) * (
            x_t - (1 - self.alpha_t[t]) / torch.sqrt(1-self.alpha_bar[t]) * model_output
        )
        x_prev += self.betas[t] * z  # TODO non-fixed
        return x_prev

    def denoise(self) -> torch.Tensor:
        x = torch.randn(size=(self.model.in_channels, self.model.image_size, self.model.in_channels))
        for t in range(self.num_diffusion_timesteps)[::-1]:
            x = self.denoise_at_t(x, t)
        return x

In [None]:


    gd = GaussianDiffusion(
        model=None,
        num_diffusion_timesteps=1000,
        noise_schedule='cosine',
    )
    sns.lineplot(gd.alpha_bar)
    plt.show()
