In [1]:
import torch
import torch.nn as nn

In [2]:
# This UNET-style prediction model was originally included as part of the Score-based generative modelling tutorial 
# by Yang Song et al: https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed 
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)[..., None, None]


class ScoreNet(nn.Module):
  """A time-dependent score-based model built upon U-Net architecture."""

  def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
    """Initialize a time-dependent score-based network.

    Args:
      marginal_prob_std: A function that takes time t and gives the standard
        deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
      channels: The number of channels for feature maps of each resolution.
      embed_dim: The dimensionality of Gaussian random feature embeddings.
    """
    super().__init__()
    # Gaussian random feature embedding layer for time
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
         nn.Linear(embed_dim, embed_dim))
    # Encoding layers where the resolution decreases
    self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
    self.dense1 = Dense(embed_dim, channels[0])
    self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
    self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
    self.dense2 = Dense(embed_dim, channels[1])
    self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
    self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
    self.dense3 = Dense(embed_dim, channels[2])
    self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
    self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
    self.dense4 = Dense(embed_dim, channels[3])
    self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])    

    # Decoding layers where the resolution increases
    self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
    self.dense5 = Dense(embed_dim, channels[2])
    self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
    self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)    
    self.dense6 = Dense(embed_dim, channels[1])
    self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
    self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)    
    self.dense7 = Dense(embed_dim, channels[0])
    self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
    self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
    
    # The swish activation function
    self.act = lambda x: x * torch.sigmoid(x)
    self.marginal_prob_std = marginal_prob_std
  
  def forward(self, x, t): 
    # Obtain the Gaussian random feature embedding for t   
    embed = self.act(self.embed(t))    
    # Encoding path
    h1 = self.conv1(x)    
    ## Incorporate information from t
    h1 += self.dense1(embed)
    ## Group normalization
    h1 = self.gnorm1(h1)
    h1 = self.act(h1)
    h2 = self.conv2(h1)
    h2 += self.dense2(embed)
    h2 = self.gnorm2(h2)
    h2 = self.act(h2)
    h3 = self.conv3(h2)
    h3 += self.dense3(embed)
    h3 = self.gnorm3(h3)
    h3 = self.act(h3)
    h4 = self.conv4(h3)
    h4 += self.dense4(embed)
    h4 = self.gnorm4(h4)
    h4 = self.act(h4)

    # Decoding path
    h = self.tconv4(h4)
    ## Skip connection from the encoding path
    h += self.dense5(embed)
    h = self.tgnorm4(h)
    h = self.act(h)
    h = self.tconv3(torch.cat([h, h3], dim=1))
    h += self.dense6(embed)
    h = self.tgnorm3(h)
    h = self.act(h)
    h = self.tconv2(torch.cat([h, h2], dim=1))
    h += self.dense7(embed)
    h = self.tgnorm2(h)
    h = self.act(h)
    h = self.tconv1(torch.cat([h, h1], dim=1))

    # Normalize output
    h = h / self.marginal_prob_std(t)[:, None, None, None]
    return h

In [3]:
# ExponentialMovingAverage implementation as used in pytorch vision
# https://github.com/pytorch/vision/blob/main/references/classification/utils.py#L159

# BSD 3-Clause License

# Copyright (c) Soumith Chintala 2016, 
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.

# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.

# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
    """Maintains moving averages of model parameters using an exponential decay.
    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
    is used to compute the EMA.
    """

    def __init__(self, model, decay, device="cpu"):
        def ema_avg(avg_model_param, model_param, num_averaged):
            return decay * avg_model_param + (1 - decay) * model_param

        super().__init__(model, device, ema_avg, use_buffers=True)

In [4]:
from torchvision import datasets, transforms, utils
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import math
from diffusion_utils import gaussian_nll, pred_xstart_from_eps
from functions import normal_kl, discretized_gaussian_loglik, flat_mean

class DDPM(nn.Module):

    def __init__(self, network, T=100, beta_1=1e-4, beta_T=2e-2, x0_parameterization=False, use_low_discrepancy_sampler=False):
        """
        Initialize Denoising Diffusion Probabilistic Model

        Parameters
        ----------
        network: nn.Module
            The inner neural network used by the diffusion process. Typically a Unet.
        beta_1: float
            beta_t value at t=1 
        beta_T: [float]
            beta_t value at t=T (last step)
        T: int
            The number of diffusion steps.
        """
        
        super(DDPM, self).__init__()

        self.x0_parameterization = x0_parameterization
        self.use_low_discrepancy_sampler = use_low_discrepancy_sampler

        # Normalize time input before evaluating neural network
        # Reshape input into image format and normalize time value before sending it to network model
        self._network = network
        self.network = lambda x, t: (self._network(x.reshape(-1, 1, 28, 28), 
                                                   (t.squeeze()/T))
                                    ).reshape(-1, 28*28)

        # Total number of time steps
        self.T = T

        # Registering as buffers to ensure they get transferred to the GPU automatically
        self.register_buffer("beta", torch.linspace(beta_1, beta_T, T+1))
        self.register_buffer("alpha", 1-self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(dim=0))
        self.register_buffer("beta_bar", self.beta.cumprod(dim=0))
        self.register_buffer("sqrt_alphas_bar", torch.sqrt(self.alpha_bar))
        self.register_buffer("sqrt_one_minus_alphas_bar", torch.sqrt(1-self.alpha_bar))

        # q(x_{t-1} | x_t, x_0)
        #clever having two alphas with and without s but the more the merrier
        self.alphas_bar = self.alpha_bar
        self.betas = self.beta
        self.alphas = self.alpha
        self.betas_bar = self.beta_bar
        alphas_bar_prev = torch.cat([torch.as_tensor([1., ], dtype=torch.float64), self.alpha_bar[:-1]])
        sqrt_alphas_bar_prev = torch.sqrt(alphas_bar_prev)
        self.sqrt_recip_alphas_bar = torch.sqrt(1. / self.alphas_bar)
        self.sqrt_recip_m1_alphas_bar = torch.sqrt(1. / self.alphas_bar - 1.)  # m1: minus 1
        self.posterior_var = self.betas * (1. - alphas_bar_prev) / (1. - self.alphas_bar)
        self.posterior_logvar_clipped = torch.log(torch.cat([self.posterior_var[[1]], self.posterior_var[1:]]))
        self.posterior_mean_coef1 = self.betas * sqrt_alphas_bar_prev / (1. - self.alphas_bar)
        self.posterior_mean_coef2 = torch.sqrt(self.alphas) * (1. - alphas_bar_prev) / (1. - self.alphas_bar)
        

    def forward_diffusion(self, x0, t, epsilon):
        '''
        q(x_t | x_0)
        Forward diffusion from an input datapoint x0 to an xt at timestep t, provided a N(0,1) noise sample epsilon. 
        Note that we can do this operation in a single step

        Parameters
        ----------
        x0: torch.tensor
            x value at t=0 (an input image)
        t: int
            step index 
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t
        ''' 

        mean = torch.sqrt(self.alpha_bar[t])*x0
        std = torch.sqrt(1 - self.alpha_bar[t])
        
        return mean + std*epsilon

    def reverse_diffusion(self, xt, t, epsilon):
        """
        p(x_{t-1} | x_t)
        Single step in the reverse direction, from x_t (at timestep t) to x_{t-1}, provided a N(0,1) noise sample epsilon.

        Parameters
        ----------
        xt: torch.tensor
            x value at step t
        t: int
            step index
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t-1
        """
        
        #equation 11 in Ho et al, 2020
        mean =  1./torch.sqrt(self.alpha[t]) * (xt - (self.beta[t])/torch.sqrt(1-self.alpha_bar[t])*self.network(xt, t)) 

        #std sounds more like an art: "Experimentally, both σt2 = βt and σ2 = β ̃ = 1−α ̄t−1 β had similar results.""
        std = torch.where(t>0, torch.sqrt(((1-self.alpha_bar[t-1]) / (1-self.alpha_bar[t]))*self.beta[t]), 0)
        
        return mean + std*epsilon

    def reverse_diffusion_x0_parameterization(self, xt, t, epsilon):
        """
        p(x_{t-1} | x_t)
        Single step in the reverse direction, from x_t (at timestep t) to x_{t-1}, provided a N(0,1) noise sample epsilon.

        Parameters
        ----------
        xt: torch.tensor
            x value at step t
        t: int or torch.tensor
            step index
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t-1
        """

        # Network now predicts the initial image x0
        estimated_x0 = self.network(xt, t)

        alpha = self.alpha[t]
        alpha_bar = self.alpha_bar[t]
        alpha_bar_prev = self.alpha_bar[t-1]
        beta = self.beta[t]

        # Equation 6+7 in Ho et al, 2020

        beta_tilde = ((1-alpha_bar_prev)/(1-alpha_bar))*beta
        std = torch.sqrt(beta_tilde)

        coeff1 = torch.sqrt(alpha_bar_prev)/(1-alpha_bar)*beta
        coeff2 = torch.sqrt(alpha)*(1-alpha_bar_prev)/(1-alpha_bar)

        mean = coeff1*estimated_x0 + coeff2*xt
        
        return mean + std * epsilon

    @torch.no_grad()
    def sample(self, shape):
        """
        Sample from diffusion model (Algorithm 2 in Ho et al, 2020)

        Parameters
        ----------
        shape: tuple
            Specify shape of sampled output. For MNIST: (nsamples, 28*28)

        Returns
        -------
        torch.tensor
            sampled image            
        """
        
        # Sample xT: Gaussian noise
        xT = torch.randn(shape).to(self.beta.device)

        xt = xT
        for t in range(self.T, 0, -1):
            noise = torch.randn_like(xT) if t > 1 else 0
            t = torch.tensor(t).expand(xt.shape[0], 1).to(self.beta.device)   
            if self.x0_parameterization:
                xt = self.reverse_diffusion_x0_parameterization(xt, t, noise)
            else:         
                xt = self.reverse_diffusion(xt, t, noise)

        return xt

    def low_discrepancy_t_sampler(self, batch_size, device):
        k = batch_size
        u0 = np.random.uniform(0, 1)

        # ti = mod(u0 + i/k, 1) for i=0,1,...,k-1
        ti = torch.fmod(torch.arange(0, k, device=device)/k + u0, 1)
        return (ti*self.T).long().unsqueeze(1)

    
    def elbo_simple(self, x0):
        """
        ELBO training objective (Algorithm 1 in Ho et al, 2020)

        Parameters
        ----------
        x0: torch.tensor
            Input image

        Returns
        -------
        float
            ELBO value            
        """

        if self.use_low_discrepancy_sampler:
            t = self.low_discrepancy_t_sampler(x0.shape[0], x0.device)
        else:
            # Sample time step t
            t = torch.randint(1, self.T, (x0.shape[0],1)).to(x0.device)

        
        # Sample noise
        epsilon = torch.randn_like(x0)

        # TODO: Forward diffusion to produce image at step t
        xt = self.forward_diffusion(x0, t, epsilon)
        
        return -nn.MSELoss(reduction='mean')(epsilon, self.network(xt, t))


    def elbo_simple_x0_reparameterization(self, x0):
        """
        ELBO training objective (Algorithm 1 in Ho et al, 2020), modified. Network tries to predict the noise, but the loss is taken in 

        Parameters
        ----------
        x0: torch.tensor
            Input image

        Returns
        -------
        float
            ELBO value            
        """

        # Sample time step t
        t = torch.randint(1, self.T, (x0.shape[0],1)).to(x0.device)
        
        # Sample noise
        epsilon = torch.randn_like(x0)

        # TODO: Forward diffusion to produce image at step t
        xt = self.forward_diffusion(x0, t, epsilon)

        estimated_x0 = self.network(xt, t)
        
        return -nn.MSELoss(reduction='mean')(x0, estimated_x0)

    
    def loss(self, x0):
        """
        Loss function. Just the negative of the ELBO.
        """
        if self.x0_parameterization:
            return -self.elbo_simple_x0_reparameterization(x0).mean()
        else:
            return -self.elbo_simple(x0).mean()

    
    # === log likelihood ===
    # bpd: bits per dimension

    @staticmethod
    def _extract(
            arr, t, x,
            dtype=torch.float32, device=torch.device("cpu"), ndim=4):
        if x is not None:
            dtype = x.dtype
            device = x.device
            ndim = x.ndim
        out = torch.as_tensor(arr, dtype=dtype, device=device).gather(0, t)
        return out.reshape((-1, ) + (1, ) * (ndim - 1))

    def q_mean_var(self, x_0, t):
        mean = self._extract(self.sqrt_alphas_bar, t, x_0) * x_0
        var = self._extract(1. - self.alphas_bar, t, x_0)
        logvar = self._extract(torch.log(1 - self.alphas_bar), t, x_0)
        return mean, var, logvar

    def q_sample(self, x_0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_0)
        coef1 = self._extract(self.sqrt_alphas_bar, t, x_0)
        coef2 = self._extract(self.sqrt_one_minus_alphas_bar, t, x_0)
        return coef1 * x_0 + coef2 * noise

    def q_posterior_mean_var(self, x_0, x_t, t):
        posterior_mean_coef1 = self._extract(self.posterior_mean_coef1, t, x_0)
        posterior_mean_coef2 = self._extract(self.posterior_mean_coef2, t, x_0)
        posterior_mean = posterior_mean_coef1 * x_0 + posterior_mean_coef2 * x_t
        posterior_var = self._extract(self.posterior_var, t, x_0)
        posterior_logvar = self._extract(self.posterior_logvar_clipped, t, x_0)
        return posterior_mean, posterior_var, posterior_logvar

    def _loss_term_bpd(self, denoise_fn, x_0, x_t, t, clip_denoised, return_pred):
        # calculate L_t
        # t = 0: negative log likelihood of decoder, -\log p(x_0 | x_1)
        # t > 0: variational lower bound loss term, KL term
        true_mean, _, true_logvar = self.q_posterior_mean_var(x_0=x_0, x_t=x_t, t=t)
        model_mean, _, model_logvar, pred_x_0 = self.p_mean_var(x_t=x_t, t=t, clip_denoised=clip_denoised, return_pred=True)
        kl = normal_kl(true_mean, true_logvar, model_mean, model_logvar)
        kl = flat_mean(kl) / math.log(2.)  # natural base to base 2
        decoder_nll = discretized_gaussian_loglik(x_0, model_mean, log_scale=0.5 * model_logvar).neg()
        decoder_nll = flat_mean(decoder_nll) / math.log(2.)
        output = torch.where(t.to(kl.device) > 0, kl, decoder_nll)
        return (output, pred_x_0) if return_pred else output

    
    def p_mean_var(self, x_t, t, clip_denoised=True, return_pred=False):
        """
        Compute the mean and variance of p(x_{t-1} | x_t)
        """
        if self.x0_parameterization:
            # Network predicts x0 directly
            pred_x0 = self.network(x_t, t)
            if clip_denoised:
                pred_x0 = pred_x0.clamp(-1., 1.)
            
            # Compute the posterior mean and variance using the predicted x0
            mean, var, logvar = self.q_posterior_mean_var(x_0=pred_x0, x_t=x_t, t=t)
        else:
            # Network predicts the noise epsilon
            epsilon_theta = self.network(x_t, t)
            
            # Compute the mean of p(x_{t-1} | x_t) using the predicted epsilon
            mean = (1. / torch.sqrt(self.alpha[t])) * (
                x_t - (self.beta[t] / torch.sqrt(1 - self.alpha_bar[t])) * epsilon_theta
            )
            
            # Extract the variance and log variance for p(x_{t-1} | x_t)
            var = self.posterior_var[t]
            logvar = self.posterior_logvar_clipped[t]
            
            if clip_denoised:
                # Estimate x0 from the predicted epsilon
                pred_x0 = (x_t - torch.sqrt(1 - self.alpha_bar[t]) * epsilon_theta) / torch.sqrt(self.alpha_bar[t])
                pred_x0 = pred_x0.clamp(-1., 1.)
            else:
                pred_x0 = None

        if return_pred:
            return mean, var, logvar, pred_x0
        else:
            return mean, var, logvar



    def train_losses(self, denoise_fn, x_0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_0)
        x_t = self.q_sample(x_0, t, noise=noise)

        # calculate the loss
        # kl: weighted
        # mse: unweighted
        if self.loss_type == "kl":
            losses = self._loss_term_bpd(
                denoise_fn, x_0=x_0, x_t=x_t, t=t, clip_denoised=False, return_pred=False)
        elif self.loss_type == "mse":
            assert self.model_var_type != "learned"
            if self.model_mean_type == "mean":
                target = self.q_posterior_mean_var(x_0=x_0, x_t=x_t, t=t)[0]
            elif self.model_mean_type == "x_0":
                target = x_0
            elif self.model_mean_type == "eps":
                target = noise
            else:
                raise NotImplementedError(self.model_mean_type)
            model_out = denoise_fn(x_t, t)
            losses = flat_mean((target - model_out).pow(2))
        else:
            raise NotImplementedError(self.loss_type)

        return losses

    def _prior_bpd(self, x_0):
        B, T = len(x_0), self.timesteps
        T_mean, _, T_logvar = self.q_mean_var(
            x_0=x_0, t=(T - 1) * torch.ones((B, ), dtype=torch.int64))
        kl_prior = normal_kl(T_mean, T_logvar, mean2=0., logvar2=0.)
        return flat_mean(kl_prior) / math.log(2.)

    def calc_all_bpd(self, denoise_fn, x_0, clip_denoised=True):
        B = x_0.shape[0]  # Ensure B is the batch size
        T = self.T
        t = torch.empty((B,), dtype=torch.int64)  # Use tuple for size
        losses = torch.zeros((B, T), dtype=torch.float32)  # Use tuple for size
        mses = torch.zeros((B, T), dtype=torch.float32)  

        for ti in range(T - 1, -1, -1):
            t.fill_(ti)
            x_t = self.q_sample(x_0, t=t)
            loss, pred_x_0 = self._loss_term_bpd(
                denoise_fn, x_0, x_t=x_t, t=t, clip_denoised=clip_denoised, return_pred=True)
            losses[:, ti] = loss
            mses[:, ti] = flat_mean((pred_x_0 - x_0).pow(2))

        prior_bpd = self._prior_bpd(x_0)
        total_bpd = torch.sum(losses, dim=1) + prior_bpd
        return total_bpd, losses, prior_bpd, mses



def train(model, optimizer, scheduler, dataloader, epochs, device, ema=True, per_epoch_callback=None):
    """
    Training loop
    
    Parameters
    ----------
    model: nn.Module
        Pytorch model
    optimizer: optim.Optimizer
        Pytorch optimizer to be used for training
    scheduler: optim.LRScheduler
        Pytorch learning rate scheduler
    dataloader: utils.DataLoader
        Pytorch dataloader
    epochs: int
        Number of epochs to train
    device: torch.device
        Pytorch device specification
    ema: Boolean
        Whether to activate Exponential Model Averaging
    per_epoch_callback: function
        Called at the end of every epoch
    """

    # Setup progress bar
    total_steps = len(dataloader)*epochs
    progress_bar = tqdm(range(total_steps), desc="Training")

    if ema:
        ema_global_step_counter = 0
        ema_steps = 10
        ema_adjust = dataloader.batch_size * ema_steps / epochs
        ema_decay = 1.0 - 0.995
        ema_alpha = min(1.0, (1.0 - ema_decay) * ema_adjust)
        ema_model = ExponentialMovingAverage(model, device=device, decay=1.0 - ema_alpha)                
    
    for epoch in range(epochs):

        # Switch to train mode
        model.train()

        global_step_counter = 0
        for i, (x, _) in enumerate(dataloader):
            x = x.to(device)
            optimizer.zero_grad()
            loss = model.loss(x)
            loss.backward()
            optimizer.step()
            scheduler.step()

            # Update progress bar
            progress_bar.set_postfix(loss=f"{loss.item():12.4f}", epoch=f"{epoch+1}/{epochs}", lr=f"{scheduler.get_last_lr()[0]:.2E}")
            progress_bar.update()

            if ema:
                ema_global_step_counter += 1
                if ema_global_step_counter%ema_steps==0:
                    ema_model.update_parameters(model)                
        
        if per_epoch_callback:
            per_epoch_callback(ema_model.module if ema else model)
        
        if (epoch + 1) % 10 == 0: 
            torch.save(model.state_dict(), f"model_discrete_ddpm_baseline_2_{epoch}_.pt")


# Parameters
T = 1000
learning_rate = 1e-3
epochs = 100
batch_size = 256


# Rather than treating MNIST images as discrete objects, as done in Ho et al 2020, 
# we here treat them as continuous input data, by dequantizing the pixel values (adding noise to the input data)
# Also note that we map the 0..255 pixel values to [-1, 1], and that we process the 28x28 pixel values as a flattened 784 tensor.
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Lambda(lambda x: x + torch.rand(x.shape)/255),    # Dequantize pixel values
    transforms.Lambda(lambda x: (x-0.5)*2.0),                    # Map from [0,1] -> [-1, -1]
    transforms.Lambda(lambda x: x.flatten())
])

# Download and transform train dataset
dataloader_train = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True, transform=transform),
                                                batch_size=batch_size,
                                                shuffle=True)

# Select device
device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

# Construct Unet
# The original ScoreNet expects a function with std for all the
# different noise levels, such that the output can be rescaled.
# Since we are predicting the noise (rather than the score), we
# ignore this rescaling and just set std=1 for all t.
mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

# Construct model
model = DDPM(mnist_unet, T=T, x0_parameterization=False, use_low_discrepancy_sampler=True).to(device)

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Setup simple scheduler
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999)


def reporter(model):
    """Callback function used for plotting images during training"""
    
    # Switch to eval mode
    model.eval()

    with torch.no_grad():
        nsamples = 10
        samples = model.sample((nsamples,28*28)).cpu()
        
        # Map pixel values back from [-1,1] to [0,1]
        samples = (samples+1)/2 
        samples = samples.clamp(0.0, 1.0)

        # Plot in grid
        grid = utils.make_grid(samples.reshape(-1, 1, 28, 28), nrow=nsamples)
        plt.gca().set_axis_off()
        plt.imshow(transforms.functional.to_pil_image(grid), cmap="gray")
        plt.show()   


#train(model, optimizer, scheduler, dataloader_train, 
#      epochs=epochs, device=device, ema=True, per_epoch_callback=reporter)


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# Call training loop
#train(model, optimizer, scheduler, dataloader_train, 
#      epochs=epochs, device=device, ema=True, per_epoch_callback=reporter)

from metrics import calculate_inception_score_and_fid

def evaluate(dataloader, DDPM_class, model_ckpt, device, nsamples=10000):
    """
    Evaluate model using Inception Score and FID
    
    Parameters
    ----------
    dataloader: utils.DataLoader
        Pytorch dataloader
    model: nn.Module
        Pytorch model
    model_ckpt: str
        Path to model checkpoint
    device: torch.device
        Pytorch device specification
    nsamples: int
        Number of samples to evaluate
    
    Returns
    -------
    float, float
        Inception Score, FID
    """
    
    # Load scorenet model
    scorenet = ScoreNet((lambda t: torch.ones(1).to(device)))

    #scorenet.load_state_dict(torch.load(ScoreNet_model_ckpt))


    # Load DDPM model
    model = DDPM_class(scorenet, T=T, x0_parameterization=False, use_low_discrepancy_sampler=False).to(device)

    # Load model checkpoint
    model.load_state_dict(torch.load(model_ckpt), strict=False)
    
    # Generate samples
    samples = model.sample((nsamples,28*28)).cpu()
    
    # Map pixel values back from [-1,1] to [0,1]
    samples = (samples+1)/2 
    samples = samples.clamp(0.0, 1.0)

    generated_samples = samples.reshape(-1, 1, 28, 28)

    real_samples = []
    for x, _ in dataloader:
        real_samples.append(x.reshape(-1, 1, 28, 28))
        if len(real_samples)*batch_size >= nsamples:
            break
    real_samples = torch.cat(real_samples, dim=0)[:nsamples].to(device)

    #add 3 channels to generated and real samples
    generated_samples = torch.cat([generated_samples, generated_samples, generated_samples], dim=1)
    real_samples = torch.cat([real_samples, real_samples, real_samples], dim=1)

    print(generated_samples.shape)
    print(real_samples.shape)

    #convert to float tensor
    generated_samples = generated_samples.float()
    real_samples = real_samples.float()

    # Calculate Inception Score and FID
    inception_mean, inception_std, FID =  calculate_inception_score_and_fid(
    generated_samples,
    real_samples,
    batch_size=32,
    device=device,
    resize=True,
    )

    return inception_mean, inception_std, FID

    



In [30]:
model_ckpt = "/Users/marcusdreisler/Documents/phd/courses/PML-project/rfmt/model_discrete_ddpm_baseline_2_99.pt"

dataloader = dataloader_train

inc_mean, inc_std, FID = evaluate(dataloader, DDPM, model_ckpt, device = "cpu", nsamples=1000)

  model.load_state_dict(torch.load(model_ckpt), strict=False)


torch.Size([1000, 3, 28, 28])
torch.Size([1000, 3, 28, 28])




In [31]:
inc_mean, inc_std, FID

(tensor(1.8744), tensor(0.0601), 115.08816528320312)

In [7]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms, utils
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import math
from diffusion_utils import gaussian_nll, pred_xstart_from_eps
from functions import normal_kl, discretized_gaussian_loglik, flat_mean

class DDPM2(nn.Module):

    def __init__(self, network, T=1000, beta_1=1e-4, beta_T=2e-2, x0_parameterization=False, use_low_discrepancy_sampler=False):
        """
        Initialize Denoising Diffusion Probabilistic Model

        Parameters
        ----------
        network: nn.Module
            The inner neural network used by the diffusion process. Typically a Unet.
        beta_1: float
            beta_t value at t=1 
        beta_T: [float]
            beta_t value at t=T (last step)
        T: int
            The number of diffusion steps.
        """
        
        super(DDPM2, self).__init__()

        self.x0_parameterization = x0_parameterization
        self.use_low_discrepancy_sampler = use_low_discrepancy_sampler

        # Normalize time input before evaluating neural network
        # Reshape input into image format and normalize time value before sending it to network model
        self._network = network
        self.network = lambda x, t: (self._network(x.reshape(-1, 1, 28, 28), 
                                                   (t.float() / T))
                                    ).reshape(-1, 28*28)

        # Total number of time steps
        self.T = T

        # Registering as buffers to ensure they get transferred to the GPU automatically
        self.register_buffer("beta", torch.linspace(beta_1, beta_T, T+1))
        self.register_buffer("alpha", 1 - self.beta)
        self.register_buffer("alpha_bar", self.alpha.cumprod(dim=0))
        self.register_buffer("beta_bar", self.beta.cumprod(dim=0))
        self.register_buffer("sqrt_alphas_bar", torch.sqrt(self.alpha_bar))
        self.register_buffer("sqrt_one_minus_alphas_bar", torch.sqrt(1 - self.alpha_bar))

        # q(x_{t-1} | x_t, x_0)
        self.alphas_bar = self.alpha_bar
        self.betas = self.beta
        self.alphas = self.alpha
        self.betas_bar = self.beta_bar
        alphas_bar_prev = torch.cat([torch.tensor([1.], dtype=torch.float32, device=self.beta.device), self.alpha_bar[:-1]])
        sqrt_alphas_bar_prev = torch.sqrt(alphas_bar_prev)
        self.sqrt_recip_alphas_bar = torch.sqrt(1. / self.alphas_bar)
        self.sqrt_recip_m1_alphas_bar = torch.sqrt(1. / self.alphas_bar - 1.)  # m1: minus 1
        self.posterior_var = self.betas * (1. - alphas_bar_prev) / (1. - self.alphas_bar)
        self.posterior_logvar_clipped = torch.log(torch.cat([self.posterior_var[1:2], self.posterior_var[1:]]))
        self.posterior_mean_coef1 = self.betas * sqrt_alphas_bar_prev / (1. - self.alphas_bar)
        self.posterior_mean_coef2 = torch.sqrt(self.alphas) * (1. - alphas_bar_prev) / (1. - self.alphas_bar)

    def forward_diffusion(self, x0, t, epsilon):
        '''
        q(x_t | x_0)
        Forward diffusion from an input datapoint x0 to an xt at timestep t, provided a N(0,1) noise sample epsilon. 
        Note that we can do this operation in a single step

        Parameters
        ----------
        x0: torch.tensor
            x value at t=0 (an input image)
        t: torch.Tensor
            step index [batch_size]
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t
        ''' 

        mean = torch.sqrt(self.alpha_bar[t]) * x0
        std = torch.sqrt(1 - self.alpha_bar[t])
        
        return mean + std * epsilon

    def reverse_diffusion(self, xt, t, epsilon):
        """
        p(x_{t-1} | x_t)
        Single step in the reverse direction, from x_t (at timestep t) to x_{t-1}, provided a N(0,1) noise sample epsilon.

        Parameters
        ----------
        xt: torch.tensor
            x value at step t
        t: torch.Tensor
            step index [batch_size]
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t-1
        """
        
        # Equation 11 in Ho et al., 2020
        model_mean, _, model_logvar = self.p_mean_var(
            x_t=x_t, t=t, clip_denoised=False, return_pred=False)

        # Compute standard deviation
        std = torch.sqrt(((1 - self.alpha_bar[t-1]) / (1 - self.alpha_bar[t])) * self.beta[t])
        
        return model_mean + std * epsilon

    def reverse_diffusion_x0_parameterization(self, xt, t, epsilon):
        """
        p(x_{t-1} | x_t)
        Single step in the reverse direction, from x_t (at timestep t) to x_{t-1}, provided a N(0,1) noise sample epsilon.

        Parameters
        ----------
        xt: torch.tensor
            x value at step t
        t: torch.Tensor
            step index [batch_size]
        epsilon:
            noise sample

        Returns
        -------
        torch.tensor
            image at timestep t-1
        """

        # Network now predicts the initial image x0
        estimated_x0 = self.network(xt, t)

        alpha = self._extract(self.alpha, t, xt)          # Shape: [batch_size, 1]
        alpha_bar = self._extract(self.alpha_bar, t, xt)  # Shape: [batch_size, 1]
        alpha_bar_prev = self._extract(self.alpha_bar, t-1, xt)  # Shape: [batch_size, 1]
        beta = self._extract(self.beta, t, xt)            # Shape: [batch_size, 1]

        # Equation 6+7 in Ho et al., 2020

        beta_tilde = ((1 - alpha_bar_prev) / (1 - alpha_bar)) * beta  # Shape: [batch_size, 1]
        std = torch.sqrt(beta_tilde)  # Shape: [batch_size, 1]

        coeff1 = torch.sqrt(alpha_bar_prev) / (1 - alpha_bar) * beta  # Shape: [batch_size, 1]
        coeff2 = torch.sqrt(alpha) * (1 - alpha_bar_prev) / (1 - alpha_bar)  # Shape: [batch_size, 1]

        mean = coeff1 * estimated_x0 + coeff2 * xt  # Shape: [batch_size, 784]

        return mean + std * epsilon  # Shape: [batch_size, 784]

    @torch.no_grad()
    def sample(self, shape):
        """
        Sample from diffusion model (Algorithm 2 in Ho et al., 2020)

        Parameters
        ----------
        shape: tuple
            Specify shape of sampled output. For MNIST: (nsamples, 28*28)

        Returns
        -------
        torch.tensor
            sampled image            
        """
        
        # Sample xT: Gaussian noise
        xT = torch.randn(shape).to(self.beta.device)

        xt = xT
        for t in range(self.T, 0, -1):
            noise = torch.randn_like(xT) if t > 1 else torch.zeros_like(xT)
            t_tensor = torch.full((xt.shape[0],), t, dtype=torch.long, device=self.beta.device)
            if self.x0_parameterization:
                xt = self.reverse_diffusion_x0_parameterization(xt, t_tensor, noise)
            else:         
                xt = self.reverse_diffusion(xt, t_tensor, noise)

        return xt

    def low_discrepancy_t_sampler(self, batch_size, device):
        k = batch_size
        u0 = np.random.uniform(0, 1)

        # ti = mod(u0 + i/k, 1) for i=0,1,...,k-1
        ti = torch.fmod(torch.arange(0, k, device=device, dtype=torch.float32) / k + u0, 1)
        return (ti * self.T).long()

    def elbo_simple(self, x0):
        """
        ELBO training objective (Algorithm 1 in Ho et al., 2020)

        Parameters
        ----------
        x0: torch.tensor
            Input image

        Returns
        -------
        float
            ELBO value            
        """

        if self.use_low_discrepancy_sampler:
            t = self.low_discrepancy_t_sampler(x0.shape[0], x0.device)
        else:
            # Sample time step t as a 1D tensor
            t = torch.randint(1, self.T, (x0.shape[0],), device=x0.device)
        
        # Sample noise
        epsilon = torch.randn_like(x0)

        # Forward diffusion to produce image at step t
        xt = self.forward_diffusion(x0, t, epsilon)
        
        return -nn.MSELoss(reduction='mean')(epsilon, self.network(xt, t))

    def elbo_simple_x0_reparameterization(self, x0):
        """
        ELBO training objective (Algorithm 1 in Ho et al., 2020), modified. Network tries to predict the noise, but the loss is taken in 

        Parameters
        ----------
        x0: torch.tensor
            Input image

        Returns
        -------
        float
            ELBO value            
        """

        # Sample time step t as a 1D tensor
        t = torch.randint(1, self.T, (x0.shape[0],), device=x0.device)
        
        # Sample noise
        epsilon = torch.randn_like(x0)

        # Forward diffusion to produce image at step t
        xt = self.forward_diffusion(x0, t, epsilon)

        estimated_x0 = self.network(xt, t)
        
        return -nn.MSELoss(reduction='mean')(x0, estimated_x0)

    def loss(self, x0):
        """
        Loss function. Just the negative of the ELBO.
        """
        if self.x0_parameterization:
            return -self.elbo_simple_x0_reparameterization(x0).mean()
        else:
            return -self.elbo_simple(x0).mean()

    # === log likelihood ===
    # bpd: bits per dimension

    @staticmethod
    def _extract(arr, t, x, dtype=torch.float32, device=torch.device("cpu"), ndim=2):
        """
        Extracts the values from `arr` at indices `t` and reshapes them for broadcasting.

        Parameters:
            arr (torch.Tensor): Tensor from which to extract values. Shape: [T+1].
            t (torch.Tensor): Time step indices. Shape: [batch_size].
            x (torch.Tensor): Reference tensor for device and dtype.
            dtype (torch.dtype): Desired data type of the output.
            device (torch.device): Desired device of the output.
            ndim (int): Number of dimensions for the output tensor.

        Returns:
            torch.Tensor: Extracted and reshaped tensor. Shape: [batch_size, 1].
        """
        if x is not None:
            dtype = x.dtype
            device = x.device
            ndim = x.ndim
        out = torch.as_tensor(arr, dtype=dtype, device=device).gather(0, t)
        return out.reshape((-1, ) + (1, ) * (ndim - 1))

    def q_mean_var(self, x_0, t):
        mean = self._extract(self.sqrt_alphas_bar, t, x_0) * x_0
        var = self._extract(1. - self.alphas_bar, t, x_0)
        logvar = self._extract(torch.log(1 - self.alphas_bar), t, x_0)
        return mean, var, logvar

    def q_sample(self, x_0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_0)
        coef1 = self._extract(self.sqrt_alphas_bar, t, x_0)
        coef2 = self._extract(self.sqrt_one_minus_alphas_bar, t, x_0)
        return coef1 * x_0 + coef2 * noise

    def q_posterior_mean_var(self, x_0, x_t, t):
        posterior_mean_coef1 = self._extract(self.posterior_mean_coef1, t, x_0)
        posterior_mean_coef2 = self._extract(self.posterior_mean_coef2, t, x_0)
        posterior_mean = posterior_mean_coef1 * x_0 + posterior_mean_coef2 * x_t
        posterior_var = self._extract(self.posterior_var, t, x_0)
        posterior_logvar = self._extract(self.posterior_logvar_clipped, t, x_0)
        return posterior_mean, posterior_var, posterior_logvar

    def _loss_term_bpd(self, denoise_fn, x_0, x_t, t, clip_denoised, return_pred):
        # Calculate L_t
        # t = 0: negative log likelihood of decoder, -log p(x_0 | x_1)
        # t > 0: variational lower bound loss term, KL term
        true_mean, _, true_logvar = self.q_posterior_mean_var(x_0=x_0, x_t=x_t, t=t)
        model_mean, model_var, model_logvar, pred_x0 = self.p_mean_var(
            x_t=x_t, t=t, clip_denoised=clip_denoised, return_pred=True)
        kl = normal_kl(true_mean, true_logvar, model_mean, model_logvar)
        kl = flat_mean(kl) / math.log(2.)  # Convert from nats to bits
        decoder_nll = discretized_gaussian_loglik(x_0, model_mean, log_scale=0.5 * model_logvar).neg()
        decoder_nll = flat_mean(decoder_nll) / math.log(2.)
        output = torch.where(t > 0, kl, decoder_nll)
        return (output, pred_x0) if return_pred else output

    def p_mean_var(self, x_t, t, clip_denoised=True, return_pred=False):
        """
        Compute the mean and variance of p(x_{t-1} | x_t)

        Parameters:
            x_t (torch.Tensor): Current sample at time step t. Shape: [batch_size, 784].
            t (torch.Tensor): Time step indices. Shape: [batch_size].
            clip_denoised (bool): Whether to clamp the predicted x0 to [-1, 1].
            return_pred (bool): Whether to return the predicted x0.

        Returns:
            tuple:
                - mean (torch.Tensor): Mean of p(x_{t-1} | x_t). Shape: [batch_size, 784].
                - var (torch.Tensor): Variance of p(x_{t-1} | x_t). Shape: [batch_size, 1].
                - logvar (torch.Tensor): Log variance of p(x_{t-1} | x_t). Shape: [batch_size, 1].
                - pred_x0 (torch.Tensor, optional): Predicted x0, if `return_pred` is True. Shape: [batch_size, 784].
        """
        if self.x0_parameterization:
            # Network predicts x0 directly
            pred_x0 = self.network(x_t, t)  # Shape: [batch_size, 784]
            if clip_denoised:
                pred_x0 = pred_x0.clamp(-1., 1.)
            
            # Compute the posterior mean and variance using the predicted x0
            mean, var, logvar = self.q_posterior_mean_var(x_0=pred_x0, x_t=x_t, t=t)  # Shapes: [batch_size, 784], [batch_size, 1], [batch_size, 1]
        else:
            # Network predicts the noise epsilon
            epsilon_theta = self.network(x_t, t)  # Shape: [batch_size, 784]
            
            # Extract coefficients with shape [batch_size, 1] for broadcasting
            alpha_t = self._extract(self.alpha, t, x_t)          # Shape: [batch_size, 1]
            beta_t = self._extract(self.beta, t, x_t)            # Shape: [batch_size, 1]
            alpha_bar_t = self._extract(self.alpha_bar, t, x_t)  # Shape: [batch_size, 1]
            
            # Compute the mean of p(x_{t-1} | x_t) using the predicted epsilon
            mean = (1. / torch.sqrt(alpha_t)) * (
                x_t - (beta_t / torch.sqrt(1 - alpha_bar_t)) * epsilon_theta
            )  # Shape: [batch_size, 784]
            
            # Extract the variance and log variance with shape [batch_size, 1]
            var = self._extract(self.posterior_var, t, x_t)           # Shape: [batch_size, 1]
            logvar = self._extract(self.posterior_logvar_clipped, t, x_t)  # Shape: [batch_size, 1]
            
            if clip_denoised:
                # Estimate x0 from the predicted epsilon
                pred_x0 = (x_t - torch.sqrt(1 - alpha_bar_t) * epsilon_theta) / torch.sqrt(alpha_bar_t)  # Shape: [batch_size, 784]
                pred_x0 = pred_x0.clamp(-1., 1.)
            else:
                pred_x0 = None

        if return_pred:
            return mean, var, logvar, pred_x0
        else:
            return mean, var, logvar

    def train_losses(self, denoise_fn, x_0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_0)
        x_t = self.q_sample(x_0, t=t)  # Shape: [batch_size, 784]

        # Calculate the loss
        # kl: weighted
        # mse: unweighted
        if self.loss_type == "kl":
            losses = self._loss_term_bpd(
                denoise_fn, x_0=x_0, x_t=x_t, t=t, clip_denoised=False, return_pred=False)
        elif self.loss_type == "mse":
            assert self.model_var_type != "learned"
            if self.model_mean_type == "mean":
                target = self.q_posterior_mean_var(x_0=x0, x_t=x_t, t=t)[0]
            elif self.model_mean_type == "x_0":
                target = x0
            elif self.model_mean_type == "eps":
                target = noise
            else:
                raise NotImplementedError(self.model_mean_type)
            model_out = denoise_fn(x_t, t)
            losses = flat_mean((target - model_out).pow(2))
        else:
            raise NotImplementedError(self.loss_type)

        return losses

    # def _prior_bpd(self, x_0):
    #     B, T = len(x_0), self.T
    #     t = (T - 1) * torch.ones((B,), dtype=torch.long, device=x_0.device)
    #     T_mean, _, T_logvar = self.q_mean_var(
    #         x_0=x_0, t=t)
    #     kl_prior = normal_kl(T_mean, T_logvar, mean2=0., logvar2=0.)
    #     return flat_mean(kl_prior) / math.log(2.)

    def _prior_bpd(self, x_0):
        B, T = len(x_0), self.T
        t = (T - 1) * torch.ones((B,), dtype=torch.long, device=x_0.device)
        T_mean, _, T_logvar = self.q_mean_var(x_0=x_0, t=t)
        
        # Create tensors for mean2 and logvar2
        mean2 = torch.zeros_like(T_mean)
        logvar2 = torch.zeros_like(T_logvar)
        
        # Calculate KL divergence
        kl_prior = normal_kl(T_mean, T_logvar, mean2=mean2, logvar2=logvar2)
        
        return flat_mean(kl_prior) / math.log(2.)


    def calc_all_bpd(self, denoise_fn, x_0, clip_denoised=True):
        B = x_0.shape[0]  # Ensure B is the batch size
        T = self.T
        t = torch.empty((B,), dtype=torch.long, device=x_0.device)  # Use long for indexing
        losses = torch.zeros((B, T), dtype=torch.float32, device=x_0.device)  # Ensure device consistency
        mses = torch.zeros((B, T), dtype=torch.float32, device=x_0.device)  

        for ti in range(T - 1, -1, -1):
            t.fill_(ti)
            x_t = self.q_sample(x_0, t=t)
            loss, pred_x0 = self._loss_term_bpd(
                denoise_fn, x_0, x_t=x_t, t=t, clip_denoised=clip_denoised, return_pred=True)
            losses[:, ti] = loss
            mses[:, ti] = flat_mean((pred_x0 - x_0).pow(2))

        prior_bpd = self._prior_bpd(x_0)
        total_bpd = torch.sum(losses, dim=1) + prior_bpd
        return total_bpd, losses, prior_bpd, mses


In [14]:
device = "cpu"

mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

T = 1000

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x + torch.rand(x.shape)/255),
    transforms.Lambda(lambda x: (x-0.5)*2.0),
    transforms.Lambda(lambda x: x.flatten())
])

dataloader_ = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True, transform=transform),
                                                batch_size=batch_size,
                                                shuffle=True)


DDPM_model = DDPM2(mnist_unet, T=T, x0_parameterization=False, use_low_discrepancy_sampler=True).to(device)
model_ckpt = "/Users/marcusdreisler/Documents/phd/courses/PML-project/rfmt/model_discrete_ddpm_low_dis.pt"
nsamples = 1000

batch_size = 10

DDPM_model.load_state_dict(torch.load(model_ckpt), strict=False)

real_samples = []
for x, _ in dataloader_:
    real_samples.append(x)
    if len(real_samples)*batch_size >= nsamples:
        break

real_samples = torch.cat(real_samples, dim=0)[:batch_size].to(device)
print(real_samples.shape)
#calc bpd

DDPM_model.eval()
# Calculate BPD
total_bpd, losses, prior_bpd, mses = DDPM_model.calc_all_bpd(DDPM_model.network, real_samples, clip_denoised=True)





  DDPM_model.load_state_dict(torch.load(model_ckpt), strict=False)


torch.Size([10, 784])


In [15]:
total_bpd

tensor([2.9898, 3.6926, 3.0691, 3.6941, 2.5601, 3.0409, 2.9990, 2.3733, 3.0140,
        3.2997], grad_fn=<AddBackward0>)

In [16]:
np.mean(total_bpd.detach().numpy()) 

3.0732598