# Conditional Denoising Diffusion Probabilistic Model (DDPM) - Implementation

The purpose of this notebook is to implement the Conditional Denoising Diffusion Probabilistic Models architecture, as outlined in section 3.4.5 of the bachelor thesis.

The code provided in this notebook was developed using the Kaggle platform.

The code in this notebook incorporates the following sources as references:

- https://github.com/dome272/Diffusion-Models-pytorch
- https://arxiv.org/pdf/2102.09672.pdf
- https://arxiv.org/pdf/2105.05233.pdf

## Step 1 - Importing Dependencies

- Importing the necessary libraries to execute the code.

In [3]:
import os
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import copy
import seaborn as sns


## Step 2 - Hyperparameter Settings

- Set the HPs for the Unconditional DDPM deep generative model. Besides, also check whether a GPU is available for use.

In [None]:
torch.manual_seed(42)   # Manual seed for running
num_classes = 10        # Defining the number of classes since this model follows a conditional approach
image_size = 64         # Image size for training the model. It will also be the size of the synthetic images created
batch_size = 4          # Batch size to train the model
workers = 2             # Number of CPU workers to process the data
ngpu = 2                # Number of GPU available
noise_steps = 350       # Number of steps in the forward adding noise process
epochs = 180            # Number of epochs to train the Unconditional DDPM
lr = 1e-5               # Learning rate of the AdamW optimizer

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
device

## Step 3 - Dataset Loading

- Defining the custom class for loading the images as PyTorch dataset.

In [3]:
class Dataset(Dataset):
    def __init__(self, labels_file, root_dir, transform=None):
        self.annotations = pd.read_csv(labels_file, header=None)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, str(self.annotations.iloc[index, 1]), self.annotations.iloc[index, 0])
        image = Image.open(img_path)
        image = image.convert("RGB")
        label = torch.tensor(int(self.annotations.iloc[(index, 2)]))

        if self.transform:
            image = self.transform(image)

        return(image, label)
    
    def __getlabel__(self, index):
        label = (self.annotations.iloc[(index, 1)])        

        return(label)

- The preprocessing transformation matchs the data with the expected format from the DDPM model.
- Since this is a conditional model, the labels .cvs file should be passed from the complete dataset, not class-wise.

In [4]:
preprocessing = transforms.Compose([transforms.Resize(image_size), 
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                   ])

labels_file = '/path/to/complete/labels/csv'
root_dir = '/path/to/root/image/folder'

dataset = Dataset(labels_file=labels_file, root_dir=root_dir, transform=preprocessing)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

## Step 4 - Defining the Diffusion Tools for the Forward Process

- Setting the noise schedule and complete forward process in the DDPM architecture.
- The forward process is responsible for add noise to the original image based on the noise schedule and the timestep.

In [5]:
class Diffusion:
    def __init__(self, noise_steps=noise_steps, beta_start=1e-4, beta_end=0.02, img_size=256, device=device):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n, labels, cfg_scale=3):
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t, labels)
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x


## Step 5 - Defining the Reverse Process Neural Network (U-net Architecture)

- This code defines the U-Net architecture responsible for the reverse process in the Conditional DDPM model.
- This process has the role of removing the noise that was added to the images in the forward process.
- Since now is a conditional model, the class of the images are also added to the U-Net model definition.

In [6]:
class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_dim,
                out_channels
            ),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class UNet_conditional(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device=device):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128, 32)
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256, 16)
        self.down3 = Down(256, 256)
        self.sa3 = SelfAttention(256, 8)

        self.bot1 = DoubleConv(256, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(128, 16)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(64, 32)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(64, 64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)

        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_dim)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (
            10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t, y):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        if y is not None:
            t += self.label_emb(y)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.outc(x)
        return output

## Step 6 - Defining the Exponential Moving Average (EMA) Algorithm

- As a State-of-the-art model, this Conditional DDPM uses the concept of EMA for perform a stable training.
- The EMA class is defined below and use in the training function for the Conditional DDPM definition.

In [None]:
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())

## Step 7 - Utils Functions for Visualization and Saving

- `plot_images` - Function to plot the synthetic data over the trainning of the model. Util to visualize the evolution of the images over epochs.
- `save_synthetic_data` - Function to save the synthetic images. Work in chunks to save RAM memory since the model has to synthetize and perform the reverse process for each image that is creating.

In [15]:
def plot_images(images):
    plt.figure(figsize=(32, 32))
    plt.imshow(torch.cat([
        torch.cat([i for i in images.cpu()], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_synthetic_data(model, diffusion, save_path, num_instances, label, chunk_size, label_name):
    labels = torch.ones(num_instances, dtype=torch.long) * label
    labels = labels.to(device)
    sampled_images = []
    for i in range(0, num_instances, chunk_size):
        # Sample images in chunks
        chunk_size_curr = min(chunk_size, num_instances - i)
        with torch.no_grad():
            sampled_images.append(diffusion.sample(model, n=chunk_size_curr, labels=labels[i:i+chunk_size_curr])) 
    sampled_images = torch.cat(sampled_images, dim=0)
    for i, image in enumerate(sampled_images):
      array_image = image.permute(1, 2, 0).to('cpu').numpy()
      pil_image = Image.fromarray(array_image)
      pil_image = transforms.Resize((224, 224))(pil_image)
      pil_image.save(os.path.join(save_path, label_name, 'ddpm_cond_'+label_name+'_'+str(i)+'.jpg'))


## Step 8 - Training the Unconditional DDPM

- Defining the training function for the Conditional DDPM. The code also samples and plot synthetic images in each 10 epochs.
- Since follows the conditional defintion, the training loop will sample one synthetic instance per class.

In [None]:
def train(epochs, lr, device, dataloader, num_classes):
    losses = []
    losses_epoch = []
    model = UNet_conditional(num_classes=num_classes).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=image_size, device=device)
    l = len(dataloader)
    ema = EMA(0.995)
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)


    for epoch in range(epochs):
        pbar = tqdm(dataloader)
        for i, (images, labels) in enumerate(pbar):
            images = images.to(device)
            labels = labels.to(device)            
            t = diffusion.sample_timesteps(images.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(images, t)
            if np.random.random() < 0.1:
                labels = None            
            predicted_noise = model(x_t, t, labels)
            loss = mse(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ema.step_ema(ema_model, model)

            pbar.set_postfix(MSE=loss.item())

            losses.append(loss.item())

        if epoch % 10 == 0: 
            labels = torch.arange(10).long().to(device)
            sampled_images = diffusion.sample(model, n=len(labels), labels=labels)
            plot_images(sampled_images)
        
        losses_epoch.append(sum(losses)/len(losses))

    return model, diffusion, losses_epoch

- Training the model.

In [None]:
model, diffusion, losses_epoch =  train(epochs = epochs, 
                                  lr = lr, 
                                  device=device, 
                                  dataloader=dataloader, 
                                  num_classes=num_classes)

## Step 9 - Plotting the Loss over Epochs

- Visualizing the loss of the U-Net architecture over epochs.

In [None]:
sns.set_theme(style="whitegrid")
plt.figure(figsize=(10,5))
plt.title("Training loss")
plt.plot(losses_epoch, label="Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

## Step 9 - Saving the Synthetic Images

- Saving images in a desired folder location, the images are saved in a class-wise manner.
- The `label` especify the desired class.
- The `target` dictionary define the desired number of instances per class
- Defining the `chunk_size` for the saving images function.
- **OBS.:** Is necessary to have a subdir with the class name inside the `save_path` dir.

In [None]:
discrete_class = {
    0: "0_punching_hole",
    1: "1_welding_line",
    2: "2_crescent_gap",
    3: "3_water_spot",
    4: "4_oil_spot",
    5: "5_silk_spot",
    6: "6_inclusion",
    7: "7_rolled_pit",
    8: "8_crease",
    9: "9_waist_folding"
}

target = {0: 860, 1: 826, 2: 856, 3: 816, 4: 870, 5: 584, 6: 862, 7: 981, 8: 967, 9: 904}


save_path = "/root/path/to/save/the/images"     # Root path to save the images, contains subdir for each class equals to the class name.
label = 2                                       # Enter the discrete number for this class.
chunk_size = 10       

save_synthetic_data(model, diffusion, save_path, target[label], label, chunk_size, discrete_class[label])