# Unconditional Diffusion Model

In [None]:
# Standard library imports
from dataclasses import dataclass
import os
import re

# Third-party library imports
import cv2
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm

# Local module imports
from utils.models import UNet
from utils.diffusion import Diffusion
from utils.utils import load_model_eval

## Configuration

In [None]:
@dataclass
class Config:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_name = "s2tld1"
    checkpoint_name = "checkpoint-400"
    batch_size = 3

## Sampling
### Helper functions

In [None]:
# Function to save a single image
def save_image(image, filename):
    # Normalize the image
    image = torch.clamp(image * 0.5 + 0.5, 0, 1)

    image = np.transpose(image.cpu().numpy(), (1, 2, 0))
    pil_image = Image.fromarray((image * 255).astype(np.uint8))

    # Split the path into directory and filename
    path = f"{filename}.png"
    directory, filename = os.path.split(path)

    # Create the directory if it does not exist
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Save image to disk
    pil_image.save(path)

In [None]:
# Function to find the next available sample number
def find_next_sample_num(directory):
    # Find all entries (directories or files) that match the pattern
    pattern = re.compile(r'^\d{5}$|^(\d{5})\.png$')
    numeric_entries = [entry for entry in os.listdir(directory) if pattern.match(entry)]
    numeric_values = [int(re.match(r'(\d{5})', entry).group(1)) for entry in numeric_entries]

    # Get next sample number
    return 0 if not numeric_values else max(numeric_values) + 1

In [None]:
# Function to sample images from the model
def sample_model(checkpoint, num_batches, batch_size, save_process_imgs=False, num_steps=500):
    # Load model
    diffusion = Diffusion(device=Config.device)
    model = UNet().to(Config.device)
    load_model_eval(checkpoint, model, diffusion)
    single_sample_num = 0

    # Handling directories for saving process images or single images
    if save_process_imgs:
        # Create directory for each sample and its intermediate images
        base_dir = os.path.join("samples", Config.model_name, "process")
        sample_num = find_next_sample_num(base_dir)
        num_samples = num_batches * batch_size
        sample_dirs = []
        for i in range(sample_num, sample_num + num_samples):
            single_sample_dir = os.path.join(base_dir, f"{i:05d}")
            sample_dirs.append(single_sample_dir)
            os.makedirs(single_sample_dir, exist_ok=True)
    else:
        # Create output directory
        single_sample_dir = os.path.join("samples", Config.model_name, "single")
        os.makedirs(single_sample_dir, exist_ok=True)
        single_sample_num = find_next_sample_num(single_sample_dir)

    # Sample images from model (algorithm 2)
    with torch.no_grad():
        model.eval()

        for i in range(num_batches):
            # Initialize starting sample noise
            sample_shape = (batch_size, 3, 64, 64)
            x_t = torch.normal(0, 1, sample_shape, device=Config.device)  # Initialized to x_T

            # Get t values to save intermediate images for in the diffusion process
            save_t_steps = np.linspace(0, diffusion.diffusion_steps, num_steps, dtype=int, endpoint=False).tolist()

            # Save initial pure noise image (in process image directory)
            if save_process_imgs:
                for j in range(batch_size):
                    save_image(x_t[j], os.path.join(sample_dirs[batch_size * i + j], f"{diffusion.diffusion_steps:06d}"))

            # Iterate over all reverse diffusion time steps from T to 1
            for t in tqdm(range(diffusion.diffusion_steps, 0, -1), desc=f"Sampling - Batch {i+1}/{num_batches}"):
                t_vec = t * torch.ones(x_t.shape[0], device=Config.device)
                epsilon_pred = model(x_t, t_vec)
                x_t_minus_1 = diffusion.remove_noise(x_t, t, epsilon_pred)
                x_t = x_t_minus_1

                # Save intermediate images (in process image directory)
                if save_process_imgs and t-1 == save_t_steps[-1]:
                    save_t_steps.pop()
                    for j in range(batch_size):
                        save_image(x_t[j], os.path.join(sample_dirs[batch_size * i + j], f"{t-1:06d}"))

            # Save final images (in single image directory)
            for j in range(batch_size):            
                save_image(x_t[j], os.path.join(single_sample_dir, f"{single_sample_num:06d}"))
                single_sample_num += 1
                        

In [None]:
# Function to create an image grid from a list of images
def create_image_grid(images, grid_size):
    # Assume all images are the same size
    img_height, img_width = images[0].shape[:2]

    # Initialize blank image for the grid
    grid_img = np.zeros((img_height * grid_size[0], img_width * grid_size[1], 3), dtype=np.uint8)

    # Place each image in its respective position
    num_blocks = grid_size[0] * grid_size[1]
    for idx, img in enumerate(images[:num_blocks]):
        row = idx // grid_size[1]
        col = idx % grid_size[1]
        grid_img[row*img_height:(row+1)*img_height, col*img_width:(col+1)*img_width] = img

    return grid_img

In [None]:
# Function to create a video from a several samples and their intermediate diffusion images
def create_video(process_dir, video_name, grid_size, fps=30):
    # Find all directories with 5 numeric digits
    sample_dirs = [dir for dir in os.listdir(process_dir) if dir.isdigit() and len(dir) == 5]

    # Get image file names
    images = []
    for sample_dir in sample_dirs:
        sample_dir = os.path.join(process_dir, sample_dir)
        sample_images = [os.path.join(sample_dir, file) for file in os.listdir(sample_dir) if file.endswith(".png")]
        sample_images.sort()
        sample_images = list(reversed(sample_images))
        images.append(sample_images)
    images = np.array(images)

    # Create gird of images at each diffusion time step
    video = None
    for t in tqdm(range(images.shape[1]), desc="Creating Video"):
        image_filenames_t = images[:,t]  # Get images at timestep t
        images_t = [cv2.imread(image_filenames_t[i]) for i in range(images.shape[0])]  # Load images
        grid_image = create_image_grid(images_t, grid_size)

        # Initialize video writer
        if video is None:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            video = cv2.VideoWriter(video_name, fourcc, fps, (grid_image.shape[1], grid_image.shape[0]))

        # Write image to video
        video.write(grid_image)

    # Close the video writer
    video.release()

In [None]:
# Function to create a grid image from samples in a directory
def create_sample_grid(sample_dir, output_file, grid_size):
    images = [cv2.imread(os.path.join(sample_dir, file)) for file in os.listdir(sample_dir) if file.endswith(".png")]
    grid_image = create_image_grid(images, grid_size)
    cv2.imwrite(output_file, grid_image)

In [None]:
# Function to create a grid image showing the diffusion process for a single sample
def create_process_image(process_sample_dir, output_file, grid_size):
    # Get image file paths
    images = [os.path.join(process_sample_dir, file) for file in os.listdir(process_sample_dir) if file.endswith(".png")]
    images.sort()
    images = list(reversed(images))
    
    # Select subset of images for grid
    num_blocks = grid_size[0] * grid_size[1]
    image_inds = np.linspace(0, len(images) - 1, num_blocks, dtype=int).tolist()
    images = np.array(images)[image_inds]

    # Load images
    images = [cv2.imread(images[i]) for i in range(len(images))]

    # Create and save grid
    grid_image = create_image_grid(images, grid_size)
    cv2.imwrite(output_file, grid_image)

## Output
### Sample Generation

In [None]:
# Load model
checkpoint = torch.load(os.path.join("models", Config.model_name, f"{Config.checkpoint_name}.pt"))

In [None]:
# Generate final samples only
sample_model(checkpoint, num_batches=1, batch_size=Config.batch_size, save_process_imgs=False)

In [None]:
# Generate final samples and intermediate diffusion images
sample_model(checkpoint, num_batches=1, batch_size=Config.batch_size, save_process_imgs=True, num_steps=500)

# Formatting

In [None]:
# Directories containing images to be formatted for output
single_sample_dir = os.path.join("samples", Config.model_name, "single")
process_sample_dir = os.path.join("samples", Config.model_name, "process")

In [None]:
# Create sample grid
create_sample_grid(single_sample_dir, "output\sample_grid.png", grid_size=(5, 8))

In [None]:
# Create process grid
create_process_image(os.path.join(process_sample_dir, "00003"), "output\process_grid.png", grid_size=(2, 10))

In [None]:
# Create video
create_video(process_sample_dir, 'output\diffusion.mp4', grid_size=(10, 16), fps=30)