# Flax Diffusion
This is a partial port of [Denoising Diffusion Implicit Models](https://keras.io/examples/generative/ddim/) from the Keras documentation, with the addition of conditional generation. I omit the Kernel Inception Distance because, to the best of my knowledge, there is currently easy way to load InceptionV3 in Flax. If you're interested in implementing KID or FID in Flax, [matthias-wright/jax-fid](https://github.com/matthias-wright/jax-fid) has a Flax implementation of InceptionV3.

In [89]:
import math
from typing import Any
import flax.linen as nn
from flax.training import train_state
import optax
import jax
import jax.numpy as jnp
import numpy as np
from tqdm import tqdm
from keras.datasets import cifar10

## Hyperparameters

In [33]:
# Sampling.
min_signal_rate = 0.02
max_signal_rate = 0.95

# Architecture.
embedding_dims = 32
embedding_max_frequency = 1000.0
widths = [32, 64, 96, 128]
block_depth = 2

# Optimization.
learning_rate = 1e-4
epochs = 100

# Input.
batch_size = 8
image_width = 32
image_height = 32
channels = 3
num_classes = 10

## Data Preparation

In [82]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [83]:
x_train = np.concatenate([x_train, x_test], axis=0)
y_train = np.concatenate([y_train, y_test], axis=0)

In [84]:
x_train = x_train.astype('float32') / 255.

In [85]:
x_train = jnp.reshape(
    x_train, (
        x_train.shape[0] // batch_size,
        batch_size,
        image_width,
        image_height,
        channels
    )
)

y_train = jnp.reshape(
    y_train, (
        y_train.shape[0] // batch_size,
        batch_size,
        y_train.shape[-1]
    )
)
print('x_train shape:', x_train.shape)
print('y_train shape:', y_train.shape)

## Embedding

In [8]:
def sinusoidal_embedding(x, embedding_max_frequency):
    embedding_min_frequency = 1.0
    frequencies = jnp.exp(
        jnp.linspace(
            jnp.log(embedding_min_frequency),
            jnp.log(embedding_max_frequency),
            embedding_dims // 2
        )
    )
    angular_speeds = 2.0 * math.pi * frequencies
    embeddings = jnp.concatenate(
        [jnp.sin(angular_speeds * x), jnp.cos(angular_speeds * x)],
        axis = -1
    )
    return embeddings

## Architecture

In [75]:
class ResidualBlock(nn.Module):
    width: int

    @nn.compact
    def __call__(self, x, train: bool):
        input_width = x.shape[-1]
        if input_width == self.width:
            residual = x
        else:
            residual = nn.Conv(self.width, kernel_size=(1, 1))(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.Conv(self.width, kernel_size=(3, 3))(x)
        x = nn.activation.swish(x)
        x = nn.Conv(self.width, kernel_size=(3, 3))(x)
        x = x + residual
        return x

class DownBlock(nn.Module):
    width: int
    block_depth: int

    @nn.compact
    def __call__(self, x, train: bool):
        x, skips = x

        for _ in range(self.block_depth):
            x = ResidualBlock(self.width)(x, train)
            skips.append(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        return x

class UpBlock(nn.Module):
    width: int
    block_depth: int

    @nn.compact
    def __call__(self, x, train: bool):
        x, skips = x

        upsample_shape = (x.shape[0], x.shape[1] * 2, x.shape[2] * 2, x.shape[3])
        x = jax.image.resize(x, upsample_shape, method='bilinear')

        for _ in range(self.block_depth):
            x = jnp.concatenate([x, skips.pop()], axis=-1)
            x = ResidualBlock(self.width)(x, train)
        return x

class DDIM(nn.Module):
    channels: int
    num_classes: int
    widths: list
    block_depth: int

    @nn.compact
    def __call__(self, x, train: bool):
        x, noise_variances, class_id = x

        e = sinusoidal_embedding(noise_variances)
        e = jax.image.resize(e, shape=x.shape, method='nearest')

        class_embedding = nn.Embed(self.num_classes, x.shape[1])(class_id)
        class_embedding = jnp.expand_dims(class_embedding, axis=1)
        class_embedding = jnp.repeat(class_embedding, x.shape[2], axis=1)
        class_embedding = jnp.expand_dims(class_embedding, axis=-1)

        x = nn.Conv(self.widths[0], kernel_size=(1, 1))(x)
        x = jnp.concatenate([x, e, class_embedding], axis=-1)

        skips = []
        for width in self.widths[:-1]:
            x = DownBlock(width, self.block_depth)([x, skips], train)

        for _ in range(self.block_depth):
            x = ResidualBlock(self.widths[-1])(x, train)

        for width in reversed(self.widths[:-1]):
            x = UpBlock(width, self.block_depth)([x, skips], train)

        x = nn.Conv(self.channels, kernel_size=(1, 1), kernel_init=nn.initializers.zeros_init())(x)
        return x

## Diffusion Schedule

In [28]:
def diffusion_schedule(diffusion_times, max_signal_rate, min_signal_rate):
    start_angle = jnp.arccos(max_signal_rate)
    end_angle = jnp.arccos(min_signal_rate)

    diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)

    signal_rates = jnp.cos(diffusion_angles)
    noise_rates = jnp.sin(diffusion_angles)
    return noise_rates, signal_rates

## Training

In [91]:
class TrainState(train_state.TrainState):
    batch_stats: Any

def create_train_state(module, rng, learning_rate, image_width, image_height, channels):
    x = (jnp.ones([1, image_width, image_height, channels]), jnp.ones([1, 1, 1, 1]), jnp.array([1]))
    variables = module.init(rng, x, True)
    params = variables['params']
    batch_stats = variables['batch_stats']
    tx = optax.adam(learning_rate)
    train_state = TrainState.create(
        apply_fn=module.apply,
        params=params,
        tx=tx,
        batch_stats=batch_stats
    )
    return train_state

@jax.jit
def train_step(state, images, labels, parent_key, max_signal_rate, min_signal_rate):
    noise_key, diffusion_time_key = jax.random.split(parent_key, 2)
    batch_size = len(images)

    def loss_fn(params):
        noises = jax.random.normal(
            noise_key, (batch_size, images.shape[1], images.shape[2], channels)
        )
        diffusion_times = jax.random.uniform(diffusion_time_key, (batch_size, 1, 1, 1))
        noise_rates, signal_rates = diffusion_schedule(
            diffusion_times, max_signal_rate, min_signal_rate
        )
        noisy_images = signal_rates * images + noise_rates * noises

        pred_noises, updates = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
             [noisy_images, noise_rates**2, labels],
             train=True,
             mutable=['batch_stats']
        )

        loss = jnp.mean((pred_noises - noises)**2)
        return loss, updates

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, updates), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])
    return loss, state

In [92]:
init_rng = jax.random.PRNGKey(0)
model = DDIM(channels, num_classes, widths, block_depth)
state = create_train_state(model, init_rng, learning_rate, image_width, image_height, channels)
del init_rng

In [93]:
steps_per_epoch = x_train.shape[0] // batch_size

losses = []
for epoch in range(epochs):
    losses_this_epoch = []
    for step in tqdm(range(steps_per_epoch)):
        images = x_train[step]
        labels = y_train[step]

        if images.shape[0] != batch_size:
            continue

        train_step_key = jax.random.PRNGKey(epoch * steps_per_epoch + step)
        loss, state = train_step(
            state,
            images,
            labels,
            train_step_key,
            max_signal_rate,
            min_signal_rate
        )
        losses_this_epoch.append(loss)

    average_loss = sum(losses_this_epoch) / len(losses_this_epoch)
    losses.append(average_loss)
    print(f'Epoch {epoch + 1} loss: {average_loss}')

  0%|          | 0/937 [00:00<?, ?it/s]


TypeError: ignored