In [77]:
# Adapted from: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/ 

import enum
import math 
import numpy as np
import torch as th
from torch.distributions import Normal 
import unet
import sys
sys.path.insert(1, '/home/juliatest/Dropbox/diffusion/twisted_diffusion/twisted_diffusion_sampler-main/smc_utils')
import dist_util



In [78]:
#betas are from .0001 to .02 by increments of x for length 1000
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
    #same schedule as default in sde_lib
    """
    Get a pre-defined beta schedule for the given name.

    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps.
        scale = 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")

    

def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].
    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)

class LossType(enum.Enum):
    MSE = enum.auto()  # use raw MSE loss (and KL when learning variances)
    RESCALED_MSE = (
        enum.auto()
    )  # use raw MSE loss (with RESCALED_KL when learning variances)
    KL = enum.auto()  # use the variational lower-bound
    RESCALED_KL = enum.auto()  # like KL, but rescale to estimate the full VLB

    def is_vb(self):
        return self == LossType.KL or self == LossType.RESCALED_KL


class ModelMeanType(enum.Enum):
    """
    Which type of output the model predicts.
    """

    PREVIOUS_X = enum.auto()  # the model predicts x_{t-1}
    START_X = enum.auto()  # the model predicts x_0
    EPSILON = enum.auto()  # the model predicts epsilon

class ModelVarType(enum.Enum):
    """
    What is used as the model's output variance.

    The LEARNED_RANGE option has been added to allow the model to predict
    values between FIXED_SMALL and FIXED_LARGE, making its job easier.
    """

    LEARNED = enum.auto()
    FIXED_SMALL = enum.auto()
    FIXED_LARGE = enum.auto()
    LEARNED_RANGE = enum.auto()

In [97]:
class GaussianDiffusion:
    """
    Utilities for training and sampling diffusion models.

    Ported directly from here, and then adapted over time to further experimentation.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42

    :param betas: a 1-D numpy array of betas for each diffusion timestep,
                  starting at T and going to 1.
    #this is probably from get_named_beta_schedule
    :param model_mean_type: a ModelMeanType determining what the model outputs.
    :param model_var_type: a ModelVarType determining how variance is output.
    :param loss_type: a LossType determining the loss function to use.
    :param rescale_timesteps: if True, pass floating point timesteps into the
                              model so that they are always scaled like in the
                              original paper (0 to 1000).
    """

    def __init__(
        self,
        *,
        betas,
        model_mean_type,
        model_var_type,
        loss_type,
        rescale_timesteps=False,
        conf=None
    ):
        self.model_mean_type = model_mean_type
        self.model_var_type = model_var_type
        self.loss_type = loss_type
        self.rescale_timesteps = rescale_timesteps

        self.conf = conf

        # Use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, "betas must be 1-D"
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

        alphas = 1.0 - betas
        self.alphas = alphas 
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])

        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)

        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_alphas_cumprod_prev = np.sqrt(self.alphas_cumprod_prev)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(
            1.0 / self.alphas_cumprod - 1)

        #equation for ddpm x_t+1 | x_t,x_0 variance (see p. 3 of ddpm Ho paper)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) /
            (1.0 - self.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        #equation for ddpm x_t+1 | x_t,x_0 mean (part 1) (see p. 3 of ddpm Ho paper)
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) /
            (1.0 - self.alphas_cumprod)
        )
        #equation for ddpm x_t+1 | x_t,x_0 mean (part 2) (see p. 3 of ddpm Ho paper)
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * np.sqrt(alphas)
            / (1.0 - self.alphas_cumprod)
        )
    
    #equation from p 2 ddp Ho paper q(x_t | x_0)=N(sqrt alphabart*x_0, (1-alphabart)*I)
    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).
        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        mean = self.sqrt_alphas_cumprod[t] * x_start
        variance = 1.0 - self.alphas_cumprod[t]
        log_variance = self.log_one_minus_alphas_cumprod[t]
        return mean, variance, log_variance

    #sample from q(x_t | x_0)=N(sqrt alphabart*x_0, (1-alphabart)*I) by number of particles also give log prob
    #of these samples
    def q_sample(self, x_start, t, noise=None, num_particles=16, return_logprob=False):
        """
        Diffuse the data for a given number of diffusion steps.
        In other words, sample from q(x_t | x_0).
        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        assert noise is None 
        mean = self.sqrt_alphas_cumprod[t] * x_start 
        #std deviation is scale
        scale = self.sqrt_one_minus_alphas_cumprod[t] * x_start 
        # return samples, logprob (if required)
        return _gaussian_sample(mean, scale=scale,
                                sample_shape=th.Size([num_particles]), return_logprob=return_logprob)
    #q(x_t+1 | x_t)=N(sqrt(alpha_tp1)*x_t, (1-alpha_tp1)) see p 2 of ddpm Ho
    def q_sample_tp1_given_t(self, x_t, t):
        """
        sample xtp1 given xt for t in [0, T-1] 
        """
        alphas_tp1 = self.alphas[t]  # due to python indexing 
        mean = np.sqrt(alphas_tp1) * x_t
        scale = np.sqrt(1-alphas_tp1)
        #scale is std dev (doesn't include x_t)
        return mean + scale * th.randn_like(x_t)

    def q_posterior_mean_variance(self, x_start, x_t, t):
        """
        Compute the mean and variance of the diffusion posterior:

            q(x_{t-1} | x_t, x_0)

        """
        #see p 3 of DDPM Ho paper 
        posterior_mean = self.posterior_mean_coef1[t] * x_start \
            + self.posterior_mean_coef2[t] * x_t
        posterior_variance = self.posterior_variance[t]
        posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def q_posterior_sample(self, x_start, x_t, t, return_logprob=False, t_init=False):
        posterior_mean, posterior_variance, posterior_log_variance_clipped = self.q_posterior_mean_variance(x_start, x_t, t)
        if t_init:
            #if first timestep return mean of posteror q(x_{t-1} | x_t,t)
            out = (posterior_mean, th.zeros_like(posterior_mean))
        else:
            #if not, return a sample from q(x_{t-1} | x_t,t) = N(mu,sigma)
            out =  _gaussian_sample(mean=posterior_mean, variance=posterior_variance,
                                    return_logprob=return_logprob)
        return out 
    
    #sample from ddpm model i.e reverse process p(x_{t-1} | x_t) (i.e. return mean, variance, and predicted x0)
    def p_mean_variance(
        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
    ):
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.

        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample. Applies before
            clip_denoised.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """
        if model_kwargs is None:
            model_kwargs = {}

        # x.shape: (P, C, H, W)
        # t: a scalar 

        num_particles, C, H, W = x.shape

        t_tensor = th.tensor([t] * num_particles, device=x.device)
        y = model_kwargs.get("y", None)
        if y is not None:
            y = y.expand(num_particles)
        #get out predicted noise for given timestep conditioned on y if y is available
        #in particular, first output is mean and second output is variance for the reverse process
        #at that timestep
        model_output = model(x, self._scale_timesteps(t_tensor), y=y) # **model_kwargs
        

        # default: LEARNED_RANGE
        if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
            assert model_output.shape == (num_particles, C * 2, H, W)
            model_output = model_output.view(num_particles, C*2, H, W)
            #split into mean and variance for reverse process at that timestep
            model_output, model_var_values = th.split(model_output, C, dim=1)
 
            if self.model_var_type == ModelVarType.LEARNED:
                model_log_variance = model_var_values
                model_variance = th.exp(model_log_variance)
            else:
                #This is what inpainting does because vartype = LEARNED RANGE
                min_log = self.posterior_log_variance_clipped[t]
                max_log = np.log(self.betas[t])
                frac = (model_var_values + 1) / 2
                model_log_variance = frac * max_log + (1 - frac) * min_log
                model_variance = th.exp(model_log_variance)

                # hack (need to set gaussian diffusion t_truncate = 0)
                if self.t_truncate > 1 and t == self.t_truncate:
                    _beta_t = 1 - self.alphas[t]
                    _max_log = np.log(_beta_t)
                    self._model_variance_at_t_truncate = th.exp(frac*_max_log)

        else:
            model_variance, model_log_variance = {
                # for fixedlarge, we set the initial (log-)variance like so
                # to get a better decoder log likelihood.
                ModelVarType.FIXED_LARGE: (
                    np.append(self.posterior_variance[1], self.betas[1:]),
                    np.log(np.append(self.posterior_variance[1], self.betas[1:])),
                ),
                ModelVarType.FIXED_SMALL: (
                    self.posterior_variance,
                    self.posterior_log_variance_clipped,
                ),
            }[self.model_var_type]
            model_variance =model_variance[t]
            model_log_variance = model_log_variance[t]

        #denoise input for processing step
        def process_xstart(x):
            if denoised_fn is not None:
                x = denoised_fn(x)
            if clip_denoised:
                return x.clamp(-1, 1)
            return x

        if self.model_mean_type == ModelMeanType.PREVIOUS_X:
            pred_xstart = process_xstart(
                self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
            )
            model_mean = model_output
            #for inpainting, modelmeantype = EPSILON
        elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
            if self.model_mean_type == ModelMeanType.START_X:
                pred_xstart = process_xstart(model_output)
            else:
                #get denoised version of predicted x0
                pred_xstart = process_xstart(
                    self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
                )
            #q(x_{t-1} | x_{t}, pred(x_0))
            model_mean, _, _ = self.q_posterior_mean_variance(
                x_start=pred_xstart, x_t=x, t=t
            )
        else:
            raise NotImplementedError(self.model_mean_type)

        assert (
            model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
        )

        return {
            "mean": model_mean,
            "variance": model_variance,
            "log_variance": model_log_variance,
            "pred_xstart": pred_xstart,
        }

    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            
                self.sqrt_recip_alphas_cumprod[t] * x_t
            - self.sqrt_recipm1_alphas_cumprod[t] * eps
        )
    
    #Not sure what this does yet and if it's relevant
    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """

        gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)


        new_mean = (
            p_mean_var["mean"].float() + p_mean_var["variance"] *
            gradient.float()
        )
        return new_mean
    
    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t    
        
    def p_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        model_kwargs=None,
        **kwargs
    ):
        """
        Sample x_{t-1} from the model at the given timestep.

        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        noise = th.randn_like(x)

        #returns mean, variance, log_variance, predicted x0 for x_{t-1}
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )

        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x[0].shape) - 1)))
        )  # no noise when t == 0

        std = th.exp(0.5 * out["log_variance"])
        #get sample from N(pred(u_{t-1}), pred(sigma_{t-1}))
        sample = out["mean"] + nonzero_mask * \
            std * noise

        return {"sample": sample, 
                "mean": out['mean'], 
                "std": std, 
                "pred_xstart": out['pred_xstart']}

In [98]:
betas = get_named_beta_schedule("linear", 1000)

In [99]:
GD = (GaussianDiffusion(betas = betas, model_mean_type = ModelMeanType.EPSILON,
                  model_var_type = ModelVarType.LEARNED_RANGE, loss_type = LossType.MSE))

In [100]:
unet_model = unet.UNetModel(image_size = 28,
        in_channels = 1,
        model_channels = 64,
        out_channels = 2,
        num_res_blocks = 3,
        attention_resolutions = [1,2,4],
        dropout=0,
        channel_mult=(1, 2, 2, 2),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=4,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=True,
        resblock_updown=True,
        use_new_attention_order=False,
        diffusion_steps=1000, 
        use_value_logger=False)

In [101]:
home_folder = "/home/juliatest/Dropbox/diffusion/twisted_diffusion/twisted_diffusion_sampler-main/image_exp"

In [None]:
unet_model.load_state_dict(
        dist_util.load_state_dict((home_folder + "/models/model060000.pt")))
unet_model.to("cuda:0")

In [103]:
x = (th.ones((1, 1, 28,28))).to("cuda:0")
t = (th.tensor([10])).to("cuda:0")

In [104]:
GD.t_truncate = 0
a = GD.p_mean_variance(unet_model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None)

torch.Size([1, 2, 28, 28])
