# Resources

Here are the main resource I relied on when creating this notebook

- [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)
- [Diffusion Models - Live Coding Tutorial](https://youtu.be/S_il77Ttrmg?si=GiwY7utZ638VRBDP)
- [How Diffusion Models Work](https://learn.deeplearning.ai/courses/diffusion-models/lesson/1/introduction)


And, of course, searching around the Internet.

In [None]:
import flax
import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt
from tensorflow import keras as K

# Sample data

In [None]:
image_rescaling = K.layers.Rescaling(scale=1.0 / 127.5, offset=-1)


def prepare_iamge(x):
    """crop the image to 27x27"""
    return tf.image.crop_to_bounding_box(image_rescaling(x), 1, 1, 27, 27)


(sample,) = tfds.load("mnist", split=["train[:8]"])
sample = sample.map(lambda x: {"image": prepare_iamge(x["image"]), "label": x["label"]})

fig, axes = plt.subplots(2, 4, figsize=(6, 3))
for ax, x in zip(axes.flatten(), (x for x in sample.as_numpy_iterator())):
    ax.imshow(x["image"], cmap="gray")
    ax.axis("off")
    ax.set_title(x["label"])

X = np.array([x["image"] for x in sample.as_numpy_iterator()])
X.shape

# U-Net

In [None]:
from typing import Any, Callable, Optional, Sequence

from flax import linen as nn
from flax.linen.module import compact
from jax import numpy as jnp

In [None]:
rng_root = jax.random.PRNGKey(0)
rng_keys = ["noise", "dropout"]
(rng,) = jax.random.split(rng_root, 1)


def update_rngs(rng: jax.random.PRNGKey, rng_keys):
    rng, *rngs = jax.random.split(rng, len(rng_keys) + 1)
    rngs = {k: rngs[i] for i, k in enumerate(rng_keys)}

    return rng, rngs


rng, rngs = update_rngs(rng, rng_keys)

In [None]:
class ConvBlock(nn.Module):
    features: int
    kernel_size: [int, int] = (3, 3)

    dtype: Any = jnp.float32

    @compact
    def __call__(self, x: jax.Array, train: bool = False):
        return nn.Sequential(
            [
                nn.Conv(self.features, kernel_size=self.kernel_size, dtype=self.dtype),
                nn.relu,
                nn.Conv(self.features, kernel_size=self.kernel_size, dtype=self.dtype),
                nn.relu,
                nn.BatchNorm(use_running_average=not train, param_dtype=jnp.float32),
            ]
        )(x)


model = ConvBlock(5, dtype=jnp.bfloat16)
variables = model.init(rng, jnp.empty_like(X))
y = model.apply(variables, X)
y.shape

In [None]:
class LearnedTimeEmbed(nn.Module):
    L: int
    d_model: int

    dtype: Any = jnp.float32

    @compact
    def __call__(self, t: Sequence[int]):
        embeds = nn.Embed(self.L, self.d_model, dtype=self.dtype)(jnp.array(t))

        return embeds[:, None, None, ...]


model = LearnedTimeEmbed(100, 5, dtype=jnp.bfloat16)
variables = model.init(rng, t=[0])
y = model.apply(variables, t=[0, 1])
y.shape

In [None]:
class DownBlock(nn.Module):
    features: int

    strides: [int, int] = (2, 2)  # down size factor
    kernel_size: [int, int] = (3, 3)
    pool_fn: Callable = nn.max_pool

    dtype: Any = jnp.float32

    @compact
    def __call__(self, x: jax.Array, train: bool = False) -> jax.Array:
        skip = ConvBlock(self.features, self.kernel_size, dtype=self.dtype)(x, train)
        down = self.pool_fn(skip, window_shape=self.strides, strides=self.strides)

        return down, skip


class UpBlock(nn.Module):
    features: int

    kernel_size: [int, int]
    strides: [int, int]

    dtype: Any = jnp.float32

    @compact
    def __call__(
        self, x: jax.Array, skip: jax.Array, time_embed: jax.Array, train: bool = False
    ) -> jax.Array:
        up = nn.ConvTranspose(
            self.features,
            kernel_size=self.kernel_size,
            strides=self.strides,
            dtype=self.dtype,
        )(x)
        up = up + nn.relu(nn.Dense(up.shape[-1], dtype=self.dtype)(time_embed))
        up = jnp.concatenate([up, skip], axis=-1)
        up = ConvBlock(self.features, self.kernel_size, dtype=self.dtype)(up, train)

        return up

In [None]:
down_block = DownBlock(3, strides=(3, 3))
variables = down_block.init(rng, jnp.empty((27, 27, 1)))
y, y_skip = down_block.apply(variables, jnp.ones((27, 27, 1)))
print(y.shape)
print(y_skip.shape)

up_block = UpBlock(3, kernel_size=(3, 3), strides=(3, 3))
variables = up_block.init(
    rng, jnp.empty_like(y), jnp.empty_like(y_skip), jnp.empty((1, 1))
)
y = up_block.apply(variables, y, y_skip, jnp.ones((1, 1)))
print(y.shape)

In [None]:
class UNet(nn.Module):
    T: int  # total time step number
    kernel_size: [int, int] = (3, 3)
    strides: [int, int] = (2, 2)

    dtype: Any = jnp.float32

    @compact
    def __call__(
        self, x: jax.Array, t: Sequence[int], train: bool = False
    ) -> jax.Array:
        assert len(x.shape) == 4, f"image shape {x.shape} != 4"
        assert (
            len(t) == x.shape[0]
        ), f"image batch size {x.shape[0]} != embed size {len(t)}"

        time_embeds = LearnedTimeEmbed(self.T, 64, dtype=self.dtype)

        # down sampling
        down16, skip16 = DownBlock(
            16, kernel_size=(3, 3), strides=(3, 3), dtype=self.dtype
        )(x, train)
        down32, skip32 = DownBlock(
            32, kernel_size=(3, 3), strides=(3, 3), dtype=self.dtype
        )(down16, train)
        down64, skip64 = DownBlock(
            64, kernel_size=(3, 3), strides=(3, 3), dtype=self.dtype
        )(down32, train)

        # up sampling
        up64 = UpBlock(64, kernel_size=(3, 3), strides=(3, 3), dtype=self.dtype)(
            down64, skip64, time_embeds(t), train=train
        )
        up32 = UpBlock(32, kernel_size=(3, 3), strides=(3, 3), dtype=self.dtype)(
            up64, skip32, time_embeds(t), train=train
        )
        up16 = UpBlock(16, kernel_size=(3, 3), strides=(3, 3), dtype=self.dtype)(
            up32, skip16, time_embeds(t), train=train
        )

        z = nn.Conv(x.shape[-1], kernel_size=(1, 1), dtype=self.dtype)(up16)

        z = nn.relu(z)
        z = nn.BatchNorm(use_running_average=not train, param_dtype=jnp.float32)(z)

        return z

In [None]:
unet = UNet(5)
variables = unet.init(rng, jnp.empty((1, 27, 27, 1)), [0])
y, _ = unet.apply(variables, jnp.ones((1, 27, 27, 1)), [1], mutable=["batch_stats"])
y.shape

In [None]:
unet = UNet(100, dtype=jnp.bfloat16)
unet_var = unet.init(rng, X[:2].astype(jnp.bfloat16), [0, 0])
y, _ = unet.apply(unet_var, X[:2], [1, 2], train=True, mutable=["batch_stats"])

print(y.shape)
plt.imshow(y[0].astype(jnp.float32), cmap="gray")

In [None]:
print(unet.tabulate({"params": rng, **rngs}, X.astype(jnp.bfloat16), [0] * X.shape[0]))

# Diffusion Model

In [None]:
class DiffusionModel:
    total_steps: int
    dtype: Any

    def __init__(self, total_steps: int = 50, dtype: Any = jnp.float32):
        self.total_steps = total_steps
        self.dtype = dtype

        self.betas = jnp.linspace(0.0001, 0.02, self.total_steps, dtype=self.dtype)
        self.alphas = 1.0 - self.betas
        self.alphas_bar = jnp.cumprod(self.alphas, axis=0)

    def __call__(self, x: jax.Array, t: Sequence[int], rng: jax.Array):
        return self.add_noise(x, t, rng)

    def add_noise(self, x: jax.Array, t: Sequence[int], rng: jax.Array):
        assert x.shape[0] == len(t), "batch size mismatch"
        alphas_bar_t = self.alphas_bar[t,].reshape((-1, 1, 1, 1))
        mean = jnp.sqrt(alphas_bar_t) * x
        noise = jax.random.normal(rng, x.shape)
        variance = jnp.sqrt(1.0 - alphas_bar_t) * noise

        x_t = mean + variance

        return x_t, noise

    def ddpm(self, x: jax.Array, noise: jax.Array, t: int, rng: jax.Array):
        B, *_ = x.shape
        betas_t = self.betas[t,]
        alphas_t = self.alphas[t,]
        alphas_bar_t = self.alphas_bar[t,]
        one_alphas_t = 1.0 - alphas_t
        sqrt_one_alphas_bar_t = jnp.sqrt(1.0 - alphas_bar_t)

        mean = (x - (one_alphas_t / sqrt_one_alphas_bar_t) * noise) / jnp.sqrt(alphas_t)

        if t > 0:
            return mean + jnp.sqrt(betas_t) * jax.random.normal(rng, x.shape)
        else:
            return mean

    def ddim(self, x_t: jax.Array, t: int, t_p: int, noise: jax.Array):
        ab_t = self.alphas_bar.take(t).reshape((-1, 1, 1, 1))
        ab_p = self.alphas_bar.take(t_p).reshape((-1, 1, 1, 1))

        x_0_ = jnp.sqrt(ab_p / ab_t) * (x_t - jnp.sqrt(1.0 - ab_t) * noise)
        dir_x_t = jnp.sqrt(1.0 - ab_p) * noise

        return x_0_ + dir_x_t

In [None]:
def ddpm_samples(n: int, rng: jax.Array):
    samples = jax.random.normal(rng, (n, 27, 27, 1), dtype=jnp.bfloat16)
    for t in range(T - 1, -1, -1):
        rng, rngs = update_rngs(rng, rng_keys)
        preds = unet_apply(
            {"params": state.params, "batch_stats": state.batch_stats},
            samples,
            [t] * n,
            rngs=rngs,
        )
        samples = dm.ddpm(samples, preds, t, rngs["noise"])

    return samples

In [None]:
dm = DiffusionModel(total_steps=500, dtype=jnp.float16)
x_t, noise = dm(jnp.stack([X[0]] * 3, axis=0), [1, 10, 20], rng)

fig, axes = plt.subplots(1, 3)
for x, ax in zip(x_t, axes.flatten()):
    ax.imshow(x.astype(jnp.float32), cmap="gray")
    ax.axis("off")

In [None]:
pred, _ = unet.apply(unet_var, x_t, [1, 10, 20], train=True, mutable=["batch_stats"])
fig, axes = plt.subplots(1, 3)
for x, ax in zip(pred, axes):
    ax.imshow(x.astype(jnp.float32), cmap="gray")
    ax.axis("off")

In [None]:
output = dm.ddpm(x_t, pred, 498, rng=rng)
# output = dm.ddim(x_t, 5, 15, noise)

In [None]:
plt.imshow(output[0], cmap="gray")

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(6, 4))

images = [[X[0]] * 3, x_t, noise]
for i, i_axes in enumerate(axes):
    for j, ax in enumerate(i_axes):
        ax.axis("off")
        ax.imshow(images[i][j], cmap="gray")

In [None]:
plt.hist(noise.flatten(), density=True, label="truth")
plt.hist(pred.flatten(), density=True, label="preds")
plt.legend()
plt.show()

# Training

In [None]:
import os
from datetime import datetime
from functools import partial

import optax  # Common loss functions and optimizers
import orbax.checkpoint as ocp
from clu import metrics
from flax import struct  # Flax dataclasses
from flax.metrics import tensorboard
from flax.training import train_state  # Useful dataclass to keep train state

In [None]:
@struct.dataclass
class Metrics(metrics.Collection):
    loss: metrics.Average.from_output("loss")


class TrainState(train_state.TrainState):
    metrics: Metrics
    batch_stats: Any

In [None]:
def create_train_state(
    module: nn.Module,
    params: dict,
    batch_stats: Any,
    learning_rate: float,
    momentum: Optional[float] = None,
    weight_decay: Optional[float] = None,
    warmup_steps: Optional[int] = None,
    max_steps: Optional[int] = None,
):
    """Creates an initial `TrainState`."""

    # lr_scheduler = optax.warmup_cosine_decay_schedule(
    #     init_value=0.0,
    #     peak_value=learning_rate,
    #     warmup_steps=warmup_steps,
    #     decay_steps=max_steps,
    # )
    tx = optax.chain(
        optax.clip_by_global_norm(1.0),
        # optax.adamw(learning_rate, weight_decay=weight_decay)
        # optax.sgd(learning_rate=learning_rate, momentum=momentum)
        optax.adam(learning_rate=learning_rate),
    )

    return TrainState.create(
        apply_fn=module.apply,
        params=params,
        batch_stats=batch_stats,
        tx=tx,
        metrics=Metrics.empty(),
    )

In [None]:
@jax.jit
def train_step(state, batch: jax.Array, noises: jax.Array, t: Sequence[int], rngs):
    """Train for a single step."""

    def loss_fn(params, noises):
        preds, updates = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats},
            x=batch,
            t=t,
            rngs=rngs,
            train=True,
            mutable=["batch_stats"],
        )
        assert preds.shape == noises.shape

        B, *_ = batch.shape
        preds = preds.reshape((B, -1))
        noises = noises.reshape((B, -1))
        loss = optax.squared_error(preds, noises).mean()

        return loss, (preds, updates)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (preds, updates)), grads = grad_fn(state.params, noises)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates["batch_stats"])
    metric_updates = state.metrics.single_from_model_output(preds=preds, loss=loss)
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)

    return state, metric_updates

In [None]:
# prepare training dataset

(train_ds,) = tfds.load("mnist", split=["train"])
batch_size = 128
num_epochs = 500

num_examples = train_ds.cardinality().numpy()
image_rescaling = K.layers.Rescaling(scale=1.0 / 127.5, offset=-1)
train_ds = (
    train_ds.map(
        lambda x: prepare_iamge(x["image"]),
        num_parallel_calls=12,
    )
    .cache()
    .repeat(num_epochs)
    .shuffle(num_examples * 3)
    .batch(batch_size, drop_remainder=True)
    .prefetch(4)
)

max_steps = train_ds.cardinality().numpy()
print(f"max steps = {max_steps}")
steps_per_epoch = max_steps // num_epochs
print(f"steps per epoch = {steps_per_epoch}")

In [None]:
# init. model and state

learning_rate = 1.5e-4
T = 500

unet = UNet(T, dtype=jnp.bfloat16)
unet_apply = jax.jit(unet.apply)
variables = unet.init({"params": rng, **rngs}, jnp.empty((1, 27, 27, 1), dtype=jnp.bfloat16), t=[0])
state = create_train_state(
    unet,
    variables["params"],
    variables["batch_stats"],
    learning_rate,
    warmup_steps=int(0.2 * max_steps),
    max_steps=max_steps,
)

dm = DiffusionModel(T, dtype=jnp.float32)

# Debug blocks

The following few blocks are for debug only.

In [None]:
x_t, noises = dm(X[:2], [3, 4], rng)

_, metrics = train_step(state, x_t, noises, [3, 4], rngs=rngs)

In [None]:
train_ds_iter = train_ds.take(5).as_numpy_iterator()

In [None]:
batch = train_ds_iter.next()
batch.shape
batch_image = batch

In [None]:
rng, rngs = update_rngs(rng, rng_keys)
ts = jax.random.randint(rng, (batch_image.shape[0],), 0, T)
x_t, noises = dm(batch_image, ts, rng)
state, metric_updates = train_step(state, x_t, noises, ts, rngs)
metric_updates.compute()

# Training

Finally, let's give the model a name, then we can start training.

In [None]:
model_name = "dm-291k-bfloat16"
model_checkpoint = datetime.now().strftime(f"{model_name}_%Y%m%d-%H%M")
# model_checkpoint = "dm-291k-bfloat16_20240423-0149"
checkpoint_path = f"{os.getcwd()}/checkpoint/{model_checkpoint}"
print(
    f"""model name: {model_name},
checlpoint path: {checkpoint_path}
"""
)

rng_root = jax.random.PRNGKey(0)
rng_keys = ["noise"]
(rng,) = jax.random.split(rng_root, 1)

In [None]:
log_dir = f"tb-log/{model_name}"

summary_writer = tf.summary.create_file_writer(f"{log_dir}/{model_checkpoint}")

In [None]:
%load_ext tensorboard

%tensorboard --logdir={log_dir} --bind_all

In [None]:
from tqdm.notebook import tqdm

with ocp.CheckpointManager(
    checkpoint_path,
    options=ocp.CheckpointManagerOptions(max_to_keep=4, create=True),
    item_handlers={
        "state": ocp.StandardCheckpointHandler(),
        "config": ocp.JsonCheckpointHandler(),
    },
) as checkpoint_manager:

    restored_step = 0
    if checkpoint_manager.latest_step():
        restored_checkpoint = checkpoint_manager.restore(
            checkpoint_manager.latest_step(),
            items={"state": state, "config": None},
        )
        state = restored_checkpoint["state"]
        restored_step = checkpoint_manager.latest_step()

    print(f"last step: {restored_step}")
    steps_per_checkpoint = 100
    train_stop_step = np.min([50_000, max_steps])
    train_steps = train_stop_step - restored_step
    print(f"traning epochs: ~ {train_steps // steps_per_epoch}")

    with summary_writer.as_default():
        for step, batch in tqdm(
            enumerate(train_ds.take(train_steps).as_numpy_iterator()),
            desc="training progress",
            initial=restored_step,
            total=max_steps,
        ):
            rng, rngs = update_rngs(rng, rng_keys)
            batch_image = batch

            B, *_ = batch_image.shape
            t = jax.random.randint(rng, (B,), 0, T)
            batch_image_t, noise = dm(batch_image, t, rng)
            state, metric_updates = train_step(state, batch_image_t, noise, t, rngs)
            current_step = restored_step + step

            if current_step % steps_per_checkpoint == 0:
                checkpoint_manager.save(
                    current_step,
                    args=ocp.args.Composite(
                        state=ocp.args.StandardSave(state), config=ocp.args.JsonSave({})
                    ),
                )

                for m, v in metric_updates.compute().items():
                    tf.summary.scalar(m, v, current_step)

            if current_step % steps_per_epoch == 0:
                samples = ddpm_samples(3, rng)
                tf.summary.image("samples", samples, current_step)

In [None]:
with ocp.CheckpointManager(
    checkpoint_path,
    options=ocp.CheckpointManagerOptions(max_to_keep=4, create=True),
    item_handlers={
        "state": ocp.StandardCheckpointHandler(),
        "config": ocp.JsonCheckpointHandler(),
    },
) as checkpoint_manager:

    restored_step = 0
    if checkpoint_manager.latest_step():
        restored_checkpoint = checkpoint_manager.restore(
            checkpoint_manager.latest_step(),
            items={"state": state, "config": None},
        )
        state = restored_checkpoint["state"]
        restored_step = checkpoint_manager.latest_step()

    print(f"last step: {restored_step}")

In [None]:
%%time

denoised_images = ddpm_samples(25, rng)

In [None]:
rows = 5
cols = 5

# Create a figure and subplots
fig, axes = plt.subplots(rows, cols, figsize=(8, 8))

# Plot each image on a separate subplot
for i in range(rows):
    for j in range(cols):
        index = i * cols + j
        axes[i, j].imshow(denoised_images[index], cmap="gray")
        axes[i, j].set_title(index)
        axes[i, j].set_xticks([])
        axes[i, j].set_yticks([])

# Adjust spacing and layout
fig.tight_layout()

# Display the plot
plt.show()

In [None]:
%%time

denoised_images = []
samples = jax.random.normal(rng, (30, 27, 27, 1))
step_size = 20
inference_time_steps = list(range(T - 1, 0, -step_size)) + [0]
# print(inference_time_steps)
for i, t in enumerate(inference_time_steps[:-1]):
    rng, rngs = update_rngs(rng, rng_keys)
    pred_noise = unet.apply(
        {"params": state.params, "batch_stats": state.batch_stats},
        samples,
        [t] * samples.shape[0],
        rngs=rngs,
    )

    samples = dm.ddim(samples, t, inference_time_steps[i + 1], pred_noise)
    denoised_images.append(samples)

In [None]:
rows = 3
cols = 10

# Create a figure and subplots
fig, axes = plt.subplots(rows, cols, figsize=(12, 4))

# Plot each image on a separate subplot
for i in range(rows):
    for j in range(cols):
        index = i * cols + j
        axes[i, j].imshow(denoised_images[-1][index], cmap="gray")
        axes[i, j].set_title(index)
        axes[i, j].set_xticks([])
        axes[i, j].set_yticks([])

# Adjust spacing and layout
fig.tight_layout()

# Display the plot
plt.show()