In [33]:
from typing import Optional, Tuple, Literal
from dataclasses import dataclass

import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from easydict import EasyDict
from functools import partial

from edm_preconditioner import PreConditioner
from edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV, SCALE_T, SCALE_T_DERIV, DEFAULT_SOLVER_PARAM
from grl.generative_models.intrinsic_model import IntrinsicModel

class Simple(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 32),
            nn.ReLU(),
            nn.Linear(32, 32), 
            nn.ReLU(),
            nn.Linear(32, 2)
        )
    def forward(self, x, noise, class_labels=None):
        return self.model(x)

class EDMModel(nn.Module):
    
    def __init__(self, config: Optional[EasyDict]=None) -> None:
        
        super().__init__()
        self.config= config
        # self.x_size = config.x_size
        self.device = config.device
        
        # EDM Type ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"]
        self.edm_type: str = config.edm_model.path.edm_type
        assert self.edm_type in ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"], \
            f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}"
        
        #* 1. Construct basic Unet architecture through params in config
        self.base_denoise_network = Simple()

        #* 2. Precond setup
        self.params = config.edm_model.path.params
        self.preconditioner = PreConditioner(
            self.edm_type, 
            base_denoise_model=self.base_denoise_network, 
            use_mixes_precision=False,
            **self.params
        )
        
        #* 3. Solver setup
        self.solver_type = config.edm_model.solver.solver_type
        assert self.solver_type in ['euler', 'heun']
        
        self.solver_params = DEFAULT_SOLVER_PARAM
        self.solver_params.update(config.edm_model.solver.params)
        
        # Initialize sigma_min and sigma_max if not provided
        
        if "sigma_min" not in self.params:
            min = torch.tensor(1e-3)
            self.sigma_min = {
                "VP_edm": SIGMA_T["VP_edm"](min, 19.9, 0.1), 
                "VE_edm": 0.02, 
                "iDDPM_edm": 0.002, 
                "EDM": 0.002
            }[self.edm_type]
        else:
            self.sigma_min = self.params.sigma_min
        if "sigma_max" not in self.params:
            max = torch.tensor(1)
            self.sigma_max = {
                "VP_edm": SIGMA_T["VP_edm"](max, 19.9, 0.1), 
                "VE_edm": 100, 
                "iDDPM_edm": 81, 
                "EDM": 80
            }[self.edm_type]            
        else:
            self.sigma_max = self.params.sigma_max
            
    def get_type(self):
        return "EDMModel"

    # For VP_edm
    def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]:
        # assert the first dim of x is batch size
        print(f"params is {params}")
        rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) 
        if self.edm_type == "VP_edm":
            epsilon_t = params.get("epsilon_t", 1e-5)
            beta_d = params.get("beta_d", 19.9)
            beta_min = params.get("beta_min", 0.1)
            
            rand_uniform = torch.rand(*rand_shape, device=x.device)
            sigma = SIGMA_T["VP_edm"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min)
            weight = 1 / sigma ** 2
        elif self.edm_type == "VE_edm":
            rand_uniform = torch.rand(*rand_shape, device=x.device)
            sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform)
            weight = 1 / sigma ** 2
        elif self.edm_type == "EDM":
            P_mean = params.get("P_mean", -1.2)
            P_std = params.get("P_mean", 1.2)
            sigma_data = params.get("sigma_data", 0.5)
            
            rand_normal = torch.randn(*rand_shape, device=x.device)
            sigma = (rand_normal * P_std + P_mean).exp()
            weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
        return sigma, weight
    
    def forward(self, 
                x: Tensor, 
                class_labels=None) -> Tensor:
        x = x.to(self.device)
        sigma, weight = self._sample_sigma_weight_train(x, **self.params)
        n = torch.randn_like(x) * sigma
        D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels)
        loss = weight * ((D_xn - x) ** 2)
        return loss
    
    
    def _get_sigma_steps_t_steps(self, num_steps=18, epsilon_s=1e-3, rho=7):
        """
        Overview:
            Get the schedule of sigma according to differernt t schedules.
            
        """
        self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min)
        self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max)
    
        # Define time steps in terms of noise level
        step_indices = torch.arange(num_steps, dtype=torch.float64, device=self.device)
        sigma_steps = None
        if self.edm_type == "VP_edm":
            vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1)
            vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d
            
            orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
            sigma_steps = SIGMA_T["VP_edm"](orig_t_steps, vp_beta_d, vp_beta_min)
        
        elif self.edm_type == "VE_edm":
            orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (num_steps - 1)))
            sigma_steps = SIGMA_T["VE_edm"](orig_t_steps)
        
        elif self.edm_type == "iDDPM_edm":
            M, C_1, C_2 = self.params.M, self.params.C_1, self.params.C_2
            
            u = torch.zeros(M + 1, dtype=torch.float64, device=self.device)
            alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
            for j in torch.arange(self.params.M, 0, -1, device=self.device): # M, ..., 1
                u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
            u_filtered = u[torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)]
            
            sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]   
            orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps))         
        
        elif self.edm_type == "EDM": 
            sigma_steps = (self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * \
                (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))) ** rho
            orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps))         
        
        t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0
        
        return sigma_steps, t_steps  
    
    
    def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s=1e-3):
        """
        Overview:
            Get sigma(t) for different solver schedules.
            
        Returns:
            sigma(t), sigma'(t), sigma^{-1}(sigma) 
        """
        vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1)
        vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d
        sigma = partial(SIGMA_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)
        sigma_deriv = partial(SIGMA_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)
        sigma_inv = partial(SIGMA_T_INV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)
        scale = partial(SCALE_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)
        scale_deriv = partial(SCALE_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)

        return sigma, sigma_deriv, sigma_inv, scale, scale_deriv
  
    
    def sample(self, 
               latents: Tensor, 
               class_labels: Tensor=None, 
               use_stochastic: bool=False, 
               **solver_params) -> Tensor:
        
        # Get sigmas, scales, and timesteps
        print(f"solver_params is {solver_params}")
        num_steps = self.solver_params.num_steps
        epsilon_s = self.solver_params.epsilon_s
        rho = self.solver_params.rho
        
        latents = latents.to(self.device)
        sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=num_steps, epsilon_s=epsilon_s, rho=rho)
        sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv()
                
        S_churn = self.solver_params.S_churn
        S_min = self.solver_params.S_min
        S_max = self.solver_params.S_max
        S_noise = self.solver_params.S_noise
        alpha = self.solver_params.alpha
        
        if not use_stochastic:
            # Main sampling loop
            t_next = t_steps[0]
            x_next = latents.to(torch.float64) * (sigma(t_next) * scale(t_next))
            for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
                x_cur = x_next

                # Increase noise temporarily.
                gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
                t_hat = sigma_inv(self.preconditioner.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
                x_hat = scale(t_hat) / scale(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * scale(t_hat) * S_noise * torch.randn_like(x_cur)

                # Euler step.
                h = t_next - t_hat
                denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels).to(torch.float64)
                d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised
                x_prime = x_hat + alpha * h * d_cur
                t_prime = t_hat + alpha * h

                # Apply 2nd order correction.
                if self.solver_type == 'euler' or i == num_steps - 1:
                    x_next = x_hat + h * d_cur
                else:
                    assert self.solver_type == 'heun'
                    denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels).to(torch.float64)
                    d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised
                    x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
        
        else:
            assert self.edm_type == "EDM", f"Stochastic can only use in EDM, but your precond type is {self.edm_type}"
            x_next = latents.to(torch.float64) * t_steps[0]
            for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
                x_cur = x_next

                # Increase noise temporarily.
                gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
                t_hat = self.preconditioner.round_sigma(t_cur + gamma * t_cur)
                x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)

                # Euler step.
                denoised = self.preconditioner(x_hat, t_hat, class_labels).to(torch.float64)
                d_cur = (x_hat - denoised) / t_hat
                x_next = x_hat + (t_next - t_hat) * d_cur

                # Apply 2nd order correction.
                if i < num_steps - 1:
                    denoised = self.preconditioner(x_next, t_next, class_labels).to(torch.float64)
                    d_prime = (x_next - denoised) / t_next
                    x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)


        return x_next


In [34]:
import torch
from easydict import EasyDict

config = EasyDict(
    dict(
        device=torch.device("cuda"),  # Test if all tensors are converted to the same device
        edm_model=dict(            
            path=dict(
                edm_type="EDM", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"]
                params=dict(
                    #^ 1: VP_edm
                    # beta_d=19.9, 
                    # beta_min=0.1, 
                    # M=1000, 
                    # epsilon_t=1e-5,
                    # epsilon_s=1e-3,
                    #^ 2: VE_edm
                    # sigma_min=0.02,
                    # sigma_max=100,
                    #^ 3: iDDPM_edm
                    # C_1=0.001,
                    # C_2=0.008,
                    # M=1000,
                    #^ 4: EDM
                    # sigma_min=0.002,
                    # sigma_max=80,
                    # sigma_data=0.5,
                    # P_mean=-1.2,
                    # P_std=1.2,
                )
            ),
            solver=dict(
                solver_type="heun", 
                # *['euler', 'heun']
                params=dict(
                    num_steps=18,
                    alpha=1, 
                    S_churn=0, 
                    S_min=0, 
                    S_max=float("inf"),
                    S_noise=1,
                    rho=7, #* EDM needs rho 
                    epsilon_s=1e-3 #* VP_edm needs epsilon_s
                )
            )
        )
    )
)

edm = EDMModel(config).to(config.device)
x = torch.randn((1024, 2)).to(config.device)
noise = torch.randn_like(x)
loss = edm(x).mean()
sample = edm.sample(x)
sample.shape
loss.backward()


params is {}
solver_params is {}
