In [1]:
import torch
import numpy as np
from denoising_diffusion_pytorch.denoising_diffusion_pytorch import (
    cosine_beta_schedule,
)
from my_multinomial import MultiGaussianDiffusion

def log_add_exp(a, b):
    maximum = torch.max(a, b)
    return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def log_1_min_a(a):
    to__log = 1 - a.exp() + 1e-40
    return torch.log(to__log)


class MyDiff:
    def __init__(self, timesteps=1000, num_classes=4, beta_schedule=None) -> None:
        self.num_classes = 4

        alphas = np.sqrt(1 - cosine_beta_schedule(timesteps))
        log_alpha = np.log(alphas)
        self.log_cumprod_alphas = np.cumsum(log_alpha)

        self.log_betas = log_1_min_a(log_alpha)
        self.log_cumprod_betas = log_1_min_a(self.log_cumprod_alphas)

    def q_pred_my(self, log_x_0, t):
        return log_add_exp(
            log_x_0 + extract(self.log_cumprod_alphas, t, log_x_0.shape),
            extract(self.log_cumprod_betas, t, log_x_0.shape)
            - np.log(self.num_classes),
        )

    def q_pred(self, log_x_start, t):
        log_cumprod_alpha_t = extract(self.log_cumprod_alphas, t, log_x_start.shape)
        print(log_cumprod_alpha_t.shape)
        log_1_min_cumprod_alpha = extract(self.log_cumprod_betas, t, log_x_start.shape)

        log_probs = log_add_exp(
            log_x_start + log_cumprod_alpha_t,
            log_1_min_cumprod_alpha - np.log(self.num_classes),
        )

        return log_probs


In [2]:
x_0 = torch.log(torch.tensor([[0.5, 0, 0.25, 0.25], [1, 0, 0, 0]]))
ob = MyDiff(timesteps=100, num_classes=4, beta_schedule='cosine')
t = torch.tensor([0, 9])

torch.exp(ob.q_pred_my(x_0, t))#, ob.q_pred_my(x_0, t)


tensor([[4.9992e-01, 7.8923e-05, 2.5000e-01, 2.5000e-01],
        [9.8946e-01, 3.5131e-03, 3.5131e-03, 3.5131e-03]], dtype=torch.float64)

In [3]:
from diffusion_multinomial import MultinomialDiffusion

ob = MultinomialDiffusion(4, shape=(4,), denoise_fn=lambda x: x, timesteps=100)

torch.exp(ob.q_pred(x_0, t))

tensor([[4.9992e-01, 7.8923e-05, 2.5000e-01, 2.5000e-01],
        [9.8946e-01, 3.5131e-03, 3.5131e-03, 3.5131e-03]])

In [4]:
x_0 = torch.log(torch.tensor([[0.5, 0, 0.25, 0.25], [1, 0, 0, 0]]))
ob = MultiGaussianDiffusion(timesteps=100, num_classes=4, beta_schedule='cosine', decode_network=None)
t = torch.tensor([0, 9])

torch.exp(ob.q_pred(x_0, t))#, ob.q_pred_my(x_0, t)

tensor([[4.9992e-01, 7.8923e-05, 2.5000e-01, 2.5000e-01],
        [9.8946e-01, 3.5131e-03, 3.5131e-03, 3.5131e-03]], dtype=torch.float64)

In [10]:
log_x_t = ob.q_pred(x_0, t)
torch.exp(ob.q_posterior(x_0, log_x_t, t))

tensor([0.3749, 0.9791], dtype=torch.float64)

In [6]:
torch.exp(ob.q_pred(x_0, torch.tensor([0,0])))

tensor([[4.9992e-01, 7.8923e-05, 2.5000e-01, 2.5000e-01],
        [9.9976e-01, 7.8923e-05, 7.8923e-05, 7.8923e-05]], dtype=torch.float64)

In [9]:
torch.exp(log_x_t)

tensor([[4.9992e-01, 7.8923e-05, 2.5000e-01, 2.5000e-01],
        [9.8946e-01, 3.5131e-03, 3.5131e-03, 3.5131e-03]], dtype=torch.float64)