# Diffusion models

Synthetid 2-D data

2024-2-10

Author: Yung-Kyun Noh, Ph.D.

Hanyang University / Korea Institute for Advanced Study

2024 Machine Learning Algorithms

This notebook implements simple diffusion models using synthetic 2-dimensional set. 

The codes are based on the functions modified from the following example, 

https://colab.research.google.com/drive/1AZ2_BAwXrU8InE_qAE9cFZ0lsIO5a_xp?usp=sharing,

which has been explained in 

https://medium.com/mlearning-ai/enerating-images-with-ddpms-a-pytorch-implementation-cef5a2ba8cb1.

Partly, sample hands-on codes in the following NVIDIA DLI program is used.

https://www.nvidia.com/en-us/training/instructor-led-workshops/generative-ai-with-diffusion-models/.


In [None]:
import random
import numpy as np
import glob
import torch
import torch.nn.functional as F
import torch.nn as nn

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

# Visualization tools
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from PIL import Image
from torchvision.utils import save_image, make_grid

import math

# Setting reproducibility
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


In [None]:
!nvidia-smi

In [None]:
# Getting device
run_gpu = 1    # 0,1,2,3,...
dev = 'cuda:' + str(run_gpu)
# dev='cpu'

print(torch.cuda.is_available())
device = torch.device(dev if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\t" + (f"{torch.cuda.get_device_name(0)}" if torch.cuda.is_available() else "CPU"))

In [None]:
def draw_data(data_0, data_1, title_str='Data'):
    # function for scattering data
    
    # create a figure and axis
    fig, ax = plt.subplots()

    # Scatter data points in 2-dimensional space
    ax.scatter(data_0[:,0], data_0[:,1], label='class 0', c='red', alpha=.3)
    ax.scatter(data_1[:,0], data_1[:,1], label='class 1', marker='^', c='blue', alpha=.3)
    # set a title and labels
    ax.set_title(title_str)
    ax.legend()


In [None]:
# generate two Gaussians (class 1 & class 2)
dim = 2
datanum_0 = 200
datanum_1 = 200
mean_0 = np.array([-.1, .1])
mean_1 = np.array([.1, -.1])
cov_0 = np.array([[.1,.02],[.02,.1]])
cov_1 = np.array([[.1,.09],[.09,.1]])
# float32
L = torch.linalg.cholesky(torch.from_numpy(cov_0).to(dev)).to(torch.float32)
data_0 = torch.matmul(torch.randn(datanum_0, dim, device=dev, dtype=torch.float32), L.T) \
        + torch.from_numpy(mean_0).to(torch.float32).to(dev)
L = torch.linalg.cholesky(torch.from_numpy(cov_1).to(dev)).to(torch.float32)
data_1 = torch.matmul(torch.randn(datanum_1, dim, device=dev, dtype=torch.float32), L.T) \
        + torch.from_numpy(mean_1).to(torch.float32).to(dev)

# data_0 = np.random.multivariate_normal(mean_0, cov_0, datanum_0)
# data_1 = np.random.multivariate_normal(mean_1, cov_1, datanum_1)

n_classes = 2
data = torch.cat([data_0, data_1])
datanums = [len(data_0), len(data_1)]
labels = torch.cat([torch.zeros(datanums[0]), torch.ones(datanums[1])])


In [None]:
draw_data(data_0.cpu(), data_1.cpu(), title_str='Data')
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.grid(True)

In [None]:
# DDPM class
class MyDDPM(nn.Module):
    def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None):
        super(MyDDPM, self).__init__()
        self.n_steps = n_steps
        self.device = device
        self.network = network.to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(
            device)  # Number of steps is typically in the order of thousands
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device)

    def forward(self, x0, t, eta=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, c, h, w = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, c, h, w).to(self.device)

        noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta
        return noisy

    def backward(self, x, t):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        return self.network(x, t)
    
    def backward_cfg(self, x, t, c, c_mask):   # Classifier-free guidance
        return self.network(x, t, c, c_mask)

In [None]:
# Diffuse data
T_col = 8
T_row = 10
T = T_col*T_row
B_start = 0.0001
B_end = 0.02


In [None]:
class EmbedBlock(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedBlock, self).__init__()
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
#             nn.Unflatten(1, (emb_dim, 1, 1)),
            nn.Unflatten(1, (emb_dim,))  # Noh, corrected
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)

In [None]:
class SinusoidalPositionEmbedBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [None]:
class epsilon_diffuse(nn.Module): # all the dependencies from torch will be given to this class [parent class] # nn.Module contains all the building block of neural networks:
  def __init__(self, input_dim, output_dim, T, ns_layer=[10,5,3], t_embed_dim=8, c_embed_dim=3):
    super(epsilon_diffuse,self).__init__()  # building connection with parent and child classes
    self.T = T
    self.fc1=nn.Linear(input_dim, ns_layer[0], bias=True)       # hidden layer 1
    self.fc2=nn.Linear(ns_layer[0], ns_layer[1], bias=True)     # hidden layer 2
    self.fc3=nn.Linear(ns_layer[1], ns_layer[2], bias=True)     # hidden layer 3
    self.fc4=nn.Linear(ns_layer[2], output_dim)          # last layer

    self.sinusoidaltime = SinusoidalPositionEmbedBlock(t_embed_dim)
    self.t_emb1 = EmbedBlock(t_embed_dim, ns_layer[0])
    self.t_emb2 = EmbedBlock(t_embed_dim, ns_layer[1])
    self.c_embed1 = EmbedBlock(c_embed_dim, ns_layer[0])  # dim n_classes -> dim layer
    self.c_embed2 = EmbedBlock(c_embed_dim, ns_layer[1])  # dim n_classes -> dim layer

  def forward(self, x, t, c, c_mask):
    t = t.float() / self.T  # Convert from [0, T] to [0, 1]
    t = self.sinusoidaltime(t)
    t_emb1 = self.t_emb1(t)
    t_emb2 = self.t_emb2(t)

    c = c*c_mask
    c_emb1 = self.c_embed1(c)
    c_emb2 = self.c_embed2(c)

    out=torch.relu(self.fc1(x))              # input * weights + bias for layer 1
    out=torch.relu(self.fc2(c_emb1*out + t_emb1))            # input * weights + bias for layer 2
    out=torch.relu(self.fc3(c_emb2*out + t_emb2))            # input * weights + bias for layer 3
    out=self.fc4(out)                        # input * weights + bias for last layer
    return out                               # final outcome


### Initiating the model

In [None]:
# DDPM class
class MyDDPM(nn.Module):
    def __init__(self, network, n_steps=200, min_beta=10 ** -4, max_beta=0.02, device=None):
        super(MyDDPM, self).__init__()
        self.n_steps = n_steps
        self.device = device
        self.network = network.to(device)
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(
            device)  # Number of steps is typically in the order of thousands
        self.alphas = 1 - self.betas
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i + 1]) for i in range(len(self.alphas))]).to(device)

    def forward(self, x0, t, eta=None):
        # Make input image more noisy (we can directly skip to the desired step)
        n, d = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(n, d).to(self.device)

        noisy = a_bar.sqrt().reshape(n, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1) * eta
        return noisy

    def backward(self, x, t):
        # Run each image through the network for each timestep t in the vector t.
        # The network returns its estimation of the noise that was added.
        return self.network(x, t)
    
    def backward_cfg(self, x, t, c, c_mask):   # Classifier-free guidance
        return self.network(x, t, c, c_mask)

In [None]:
# Defining model
n_steps, min_beta, max_beta = 1000, 10 ** -4, 0.02  # Originally used by the authors
input_dim=2
output_dim=2

ddpm = MyDDPM(epsilon_diffuse(input_dim, output_dim, T, ns_layer=[50,50,3], c_embed_dim=2), \
              n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)


In [None]:
sum([p.numel() for p in ddpm.parameters()])

## Training

In [None]:
def get_context_mask(c, drop_prob, n_classes=2, device='cpu'):
    c_hot = F.one_hot(c.to(torch.int64), num_classes=n_classes).to(device)
    c_mask = torch.bernoulli(torch.ones_like(c_hot).float() - drop_prob).to(device)
    return c_hot, c_mask

In [None]:
def training_loop_cfg(ddpm, data, labels, n_epochs, optim, device, n_classes=10, c_drop_prob=0.1, display=False, store_path="ddpm_model.pt"):
    mse = nn.MSELoss()
    best_loss = float("inf")
    n_steps = ddpm.n_steps

    for epoch in tqdm(range(n_epochs), desc=f"Training progress", colour="#00ff00"):
        epoch_loss = 0.0
        x0 = data
        n = len(x0)
        c_hot, c_mask = get_context_mask(labels, c_drop_prob, n_classes=n_classes, device=device)  # New

        # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
        eta = torch.randn_like(x0).to(device)
        t = torch.randint(0, n_steps, (n,)).to(device)

        # Computing the noisy image based on x0 and the time-step (forward process)
        noisy_imgs = ddpm(x0, t, eta)

        # Getting model estimation of noise based on the images and the time-step
        eta_theta = ddpm.backward_cfg(noisy_imgs, t.reshape(n, -1), c_hot, c_mask)

        # Optimizing the MSE between the noise plugged and the predicted noise
        loss = mse(eta_theta, eta)
        optim.zero_grad()
        loss.backward()
        optim.step()

        epoch_loss += loss.item() * len(x0) / len(data)

        # Display images generated at this epoch
        if display:
            show_images(generate_new_images(ddpm, device=device), f"Images generated at epoch {epoch + 1}")

        log_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"

        # Storing the model
        if best_loss > epoch_loss:
            best_loss = epoch_loss
            torch.save(ddpm.state_dict(), store_path)
            log_string += " --> Best model ever (stored)"

        print(log_string)

In [None]:
n_classes = 2
# n_epochs = 20
n_epochs = 1000
lr = 0.001

# Training
store_path = "ddpm_2d_cfg.pt"
training_loop_cfg(ddpm, data, labels, n_epochs, optim=Adam(ddpm.parameters(), lr), device=device, \
                  n_classes=n_classes, store_path=store_path)

In [None]:
# Loading the trained model
best_model = MyDDPM(epsilon_diffuse(input_dim, output_dim, T, ns_layer=[50,50,3], c_embed_dim=2), \
                    n_steps=n_steps, min_beta=min_beta, max_beta=max_beta, device=device)
store_path = "ddpm_2d_cfg.pt"
best_model.load_state_dict(torch.load(store_path, map_location=device))
best_model.eval()
print("Model loaded")

In [None]:
# with w
def generate_new_data_cfg(ddpm, n_samples, labels, n_classes=2, \
                          device=None, frames_per_gif=100, gif_name="sampling.gif", \
                          d=2, w_val = 0.):
    """Given a DDPM model, a number of samples to be generated and a device, returns some newly generated samples"""
    frame_idxs = np.linspace(0, ddpm.n_steps, frames_per_gif).astype(np.uint)
    frames = []

    with torch.no_grad():
        if device is None:
            device = ddpm.device

        # Starting from random noise
        x = torch.randn(n_samples, d).to(device)

        for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]):
            # Estimating noise to be removed
            time_tensor = (torch.ones(n_samples, 1) * t).to(device).long()
            c_drop_prob = 0 
            c_hot, c_mask = get_context_mask(labels, c_drop_prob, device=device)
            eta_theta_keep_class = ddpm.backward_cfg(x, time_tensor, c_hot, c_mask)
            c_mask = torch.zeros_like(c_mask) 
            eta_theta_drop_class = ddpm.backward_cfg(x, time_tensor, c_hot, c_mask)
            eta_theta = (1 + w_val) * eta_theta_keep_class - w_val * eta_theta_drop_class

            alpha_t = ddpm.alphas[t]
            alpha_t_bar = ddpm.alpha_bars[t]

            # Partially denoising the image
            x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta)

            if t > 0:
                z = torch.randn(n_samples, d).to(device)

                # Option 1: sigma_t squared = beta_t
                beta_t = ddpm.betas[t]
                sigma_t = beta_t.sqrt()

                # Option 2: sigma_t squared = beta_tilda_t
                # prev_alpha_t_bar = ddpm.alpha_bars[t-1] if t > 0 else ddpm.alphas[0]
                # beta_tilda_t = ((1 - prev_alpha_t_bar)/(1 - alpha_t_bar)) * beta_t
                # sigma_t = beta_tilda_t.sqrt()

                # Adding some more noise like in Langevin Dynamics fashion
                x = x + sigma_t * z

    return x

In [None]:
n_gen = 300
labels = torch.cat([torch.zeros(int(n_gen/2)), torch.ones(int(n_gen/2))])
generated = generate_new_data_cfg(
        best_model, n_gen, labels,
        n_classes=n_classes,
        device=device,
        gif_name="synthetic_2d_cfg.gif"
    )

In [None]:
draw_data(generated[labels==0].cpu(), generated[labels==1].cpu(), title_str='Generated Data')
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.grid(True)

In [None]:
w_val = 1.
n_gen = 300
labels = torch.cat([torch.zeros(int(n_gen/2)), torch.ones(int(n_gen/2))])
generated = generate_new_data_cfg(
        best_model, n_gen, labels,
        n_classes=n_classes,
        device=device,
        gif_name="synthetic_2d_cfg.gif", w_val=w_val
    )

draw_data(generated[labels==0].cpu(), generated[labels==1].cpu(), title_str='Generated Data')
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.grid(True)

In [None]:
w_val = 2.
n_gen = 300
labels = torch.cat([torch.zeros(int(n_gen/2)), torch.ones(int(n_gen/2))])
generated = generate_new_data_cfg(
        best_model, n_gen, labels,
        n_classes=n_classes,
        device=device,
        gif_name="synthetic_2d_cfg.gif", w_val=w_val
    )

draw_data(generated[labels==0].cpu(), generated[labels==1].cpu(), title_str='Generated Data')
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.grid(True)