""""
References
    - https://github.com/quickgrid/pytorch-diffusion
"""

In [12]:
import copy
import math
import os
import logging
import pathlib
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torchvision.utils
from torch.cuda.amp import GradScaler
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torch import optim
from torch.functional import F
#from torch.utils.tensorboard import SummaryWriter
#from transformers import AutoTokenizer, AutoModel
#from memory_efficient_attention_pytorch import Attention

In [44]:
class Diffusion:
    def __init__(
            self,
            device: str,
            N: int,
            D: int,
            noise_steps: int = 1000,
            beta_start: float = 1e-4,
            beta_end: float = 0.02,
    ):
        self.device = device
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.N = N
        self.D = D

        # alpha, alpha_hat, beta
        self.beta = self.linear_noise_schedule()
        self.alpha = 1 - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        # sqrt(alpha_hat), sqrt(1-alpha_hat)
        self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
        self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat)
        
        # sqrt(alpha), sqrt(beta)=std!
        self.sqrt_alpha = torch.sqrt(self.alpha)
        self.std_beta = torch.sqrt(self.beta)

        # Clean up unnecessary values
        del self.alpha
        del self.alpha_hat
        
    def linear_noise_schedule(self) -> torch.Tensor:
        """Schedules the variance beta of each diffusion step
        """
        return torch.linspace(start=self.beta_start, end=self.beta_end, steps=self.noise_steps, device=self.device)

    def forward(self, x_0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward process (q-function):
        --> samples x_t given x_0 and t
        
        Args:
            x_0: data without noise [nxNxD]
            t: timestep [n]
            
        Retruns:
            x_t: diffused data x at timestep t [nxNxD]
            epsilon: noise of x_t [nxNxD]
        """
        sqrt_alpha_hat = self.sqrt_alpha_hat[t].view(-1, 1, 1)
        sqrt_one_minus_alpha_hat = self.sqrt_one_minus_alpha_hat[t].view(-1, 1, 1)
        epsilon = torch.randn_like(x_0, device=self.device) # samples Gaussian tensor of same shape as x_0
        return sqrt_alpha_hat * x_0 + sqrt_one_minus_alpha_hat * epsilon, epsilon
    
    def sample_timesteps(self, batch_size: int) -> torch.Tensor:
        """Timesteps selected from [1, noise_steps].
        
        Args:
            batch_size: int
        
        Returns:
            t: randomly sampled timesteps for each sample in a batch [B]
        """
        return torch.randint(low=1, high=self.noise_steps, size=(batch_size,), device=self.device)
    
    def backward(
            self,
            eps_model: nn.Module,
            n: int,
            scale_factor: int = 2,
            graph_cond: torch.Tensor = None,
    ) -> torch.Tensor:
        """Denoising Process:

        Args:
            graph_cond: tbd!!!!!
            scale_factor: Scales the output image by the factor.
            eps_model: Noise prediction model. `eps_theta(x_t, t)` in paper. Theta is the model parameters.
            n: Number of samples to process.
        
        Returns:
            x0: generated denoised data [nxNxD]
        """

        #eps_model.eval()
        with torch.no_grad():
            # 1) sample x_T from noise (n times)
            x = torch.randn((n, self.N, self.D), device=self.device)
            
            # 2) iteratively samples x_t-1 from x_t
            # no additional noise is added when we compute x_0 from x_1!
            for i in reversed(range(1, self.noise_steps)):
                t = torch.ones(n, dtype=torch.long, device=self.device) * i

                sqrt_alpha_t = self.sqrt_alpha[t].view(-1, 1, 1)
                beta_t = self.beta[t].view(-1, 1, 1)
                sqrt_one_minus_alpha_hat_t = self.sqrt_one_minus_alpha_hat[t].view(-1, 1, 1)
                epsilon_t = self.std_beta[t].view(-1, 1, 1)

                random_noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x)

                x = ((1 / sqrt_alpha_t) *
                     (x - ((beta_t / sqrt_one_minus_alpha_hat_t) *
                           eps_model(
                               x=x,
                               t=t,
                               graph_cond=graph_cond
                           )))
                     ) + (epsilon_t * random_noise)

        #eps_model.train()

        #x = ((x.clamp(-1, 1) + 1) * 127.5).type(torch.uint8) # Before returning values are clamped to [-1, 1] and converted to pixel values [0, 255].
        #x = F.interpolate(input=x, scale_factor=scale_factor, mode='nearest-exact')
        return x

In [45]:
'''For testing
'''

class RGCN(nn.Module):
    def __init__(self, w):
        self.w = w
    def forward(
            self,
            x: torch.Tensor,
            t: torch.LongTensor,
            graph_cond: torch.Tensor,
    ) -> torch.Tensor:
        return self.w*x


In [46]:
'''Testing
'''
x_0 = torch.tensor([[1,2,3],[3,2,1]]).float()
t = 3
Diff = Diffusion(
            device=torch.device("cpu"),
            N=2,
            D=3,
            noise_steps= 10,
            beta_start= 1e-4,
            beta_end= 0.02,
    )
RGCN = RGCN(0.5)
sample = Diff.forward(x_0, t)
rand_timesteps = Diff.sample_timesteps(3)
x_pred = Diff.backward(
            eps_model=RGCN.forward,
            n=2,
            scale_factor = 2,
            graph_cond = None,
    )
print(x_pred)

tensor([[[-0.5697,  0.0794,  1.1458],
         [ 1.7177, -0.3367,  1.0535]],

        [[-0.2977, -0.7332,  0.4197],
         [ 1.0442, -0.3718,  0.7526]]])
