In [2]:
import math
import dataclasses
import numpy as np
import matplotlib.pyplot

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

import hydra
from omegaconf import OmegaConf


In [3]:
%%writefile ./configs/ddpm.yaml

batch_size: 32
num_epochs: 1
total_timesteps: 1000
learning_rate: 0.0002

img_size: 64
img_channels: 3
clip_min: -1.0
clip_max: 1.0

first_conv_channels: 64
channel_multiplier:
  - 1
  - 2
  - 4
  - 8
has_attention:
  - false
  - false
  - true
  - true
num_res_blocks: 2

dataset_name: oxford_flowers102
splits:
  - train


Overwriting ./configs/ddpm.yaml


In [4]:
cfg = OmegaConf.load("./configs/ddpm.yaml")
cfg.widths = [cfg.first_conv_channels * mult for mult in cfg.channel_multiplier]
hp = OmegaConf.to_object(cfg)

In [5]:
def augment(img):
    return tf.image.random_flip_left_right(img)

def resize_and_rescale(img, size):
    height = tf.shape(img)[0]
    width = tf.shape(img)[1]
    crop_size = tf.minimum(height, width)
    
    img = tf.image.crop_to_bounding_box(
        img,
        (height - crop_size) // 2,
        (width - crop_size) // 2,
        crop_size,
        crop_size
    )
    
    img = tf.cast(img, dtype=tf.float32)
    img = tf.image.resize(img, size=size, antialias=True)
    
    img = img / 127.5 - 1.0
    img = tf.clip_by_value(img, hp['clip_min'], hp['clip_max'])
    return img

def train_preprocessing(x):
    img = x['image']
    img = resize_and_rescale(img, size=(hp['img_size'], hp['img_size']))
    img = augment(img)
    return img

(ds,) = tfds.load(hp['dataset_name'], split=hp['splits'], with_info=False, shuffle_files=True)
train_ds = (ds.map(train_preprocessing, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(hp['batch_size'], drop_remainder=True)
            .shuffle(hp['batch_size'] * 2)
            .prefetch(tf.data.AUTOTUNE))


2024-08-21 20:42:21.277811: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2024-08-21 20:42:21.277849: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-08-21 20:42:21.277856: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-08-21 20:42:21.277875: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-08-21 20:42:21.277896: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [6]:
class GaussianDiffusion:
    def __init__(self, beta_start=1e-4, beta_end=0.02, timesteps=1000, clip_min=-1.0, clip_max=1.0):
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.timesteps = timesteps
        self.clip_min = clip_min
        self.clip_max = clip_max
        
        self.betas = betas = np.linspace(
            beta_start,
            beta_end,
            timesteps,
            dtype=np.float64,
        )
        
        self.num_timesteps = int(timesteps)
        
        alphas = 1.0 - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
        
        self.betas = tf.constant(betas, dtype=tf.float32)
        self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
        self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)
        
        self.sqrt_alphas_cumprod = tf.constant(np.sqrt(alphas_cumprod), dtype=tf.float32)
        self.sqrt_one_minus_alphas_cumprod = tf.constant(np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32)
        self.log_one_minus_alphas_cumprod = tf.constant(np.log(1.0 - alphas_cumprod), dtype=tf.float32)
        
        self.sqrt_recip_alphas_cumprod = tf.constant(np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32)
        self.sqrt_recipm1_alphas_cumprod = tf.constant(np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32)
        
        posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        
        self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32)
        
        self.posterior_log_variance_clipped = tf.constant(np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32)
        
        self.posterior_mean_coef1 = tf.constant(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), dtype=tf.float32)
        
        self.posterior_mean_coef2 = tf.constant((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod), dtype=tf.float32)
        
    def _extract(self, a, t, x_shape):
        batch_size = x_shape[0]
        out = tf.gather(a, t)
        return tf.reshape(out, [batch_size, 1, 1, 1])
    
    def q_mean_variance(self, x_start, t):
        x_start_shape = tf.shape(x_start)
        mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start
        variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start_shape)
        log_variance = self._extract(self.log_one_minus_alphas_cumprod, t, x_start_shape)
        
        return mean, variance, log_variance
    
    def q_sample(self, x_start, t, noise):
        x_start_shape = tf.shape(x_start)
        return self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape) * noise
    
    def predict_start_from_noise(self, x_t, t, noise):
        x_t_shape = tf.shape(x_t)
        return self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise
    
    def q_posterior(self, x_start, x_t, t):
        x_t_shape = tf.shape(x_t)
        posterior_mean = self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t
    
        posterior_variance = self._extract(self.posterior_variance, t, x_t_shape)
        posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t_shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
    
    def p_mean_variance(self, pred_noise, x, t, clip_denoised=True):
        x_recon = self.predict_start_from_noise(x, t, pred_noise)
        if clip_denoised:
            x_recon = tf.clip_by_value(x_recon, self.clip_min, self.clip_max)
            
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance
    
    def p_sample(self, pred_noise, x, t, clip_denoised=True):
        model_mean, _, model_log_variance = self.p_mean_variance(pred_noise, x, t, clip_denoised)
        noise = tf.random.normal(shape=x.shape, dtype=x.dtype)
        nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1, 1])
        
        return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
        

In [None]:
def kernel_init(scale):
    scale = max(scale, 1e-10)
    
    return keras.initializers.VarianceScaling(scale=scale, mode='fan_avg', distribution='uniform')

class AttentionBlock(layers.Layer):
    def __init__(self, units, groups=8, **kwargs):
        self.units = units
        self.groups = groups
        super().__init__(**kwargs)

        self.norm = layers.GroupNormalization(groups=groups)
        self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))

    def call(self, inputs):
        batch_size, height, width = tf.shape(inputs)[0], tf.shape(inputs)[1], tf.shape(inputs)[2]
        scale = tf.cast(self.units, tf.float32) ** (-0.5)

        inputs = self.norm(inputs)
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)

        attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])
        attn_score = tf.nn.softmax(attn_score, -1)
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])

        proj = tf.einsum('bhwHW,bHWc->bhwc', attn_score, v)
        proj = self.proj(proj)
        return inputs + proj
    


In [None]:
class TimeEmbedding(layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.half_dim = dim // 2
        self.emb = math.log(10000) / (self.half_dim - 1)
        self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

    def call(self, inputs):
        inputs = tf.cast(inputs, dtype=tf.float32)
        emb = inputs[:, None] * self.emb[None, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        return emb

In [None]:
def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
    def apply(inputs):
        x, t = inputs
        input_width = x.shape[3]

        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(width, kernel_size=1, kernel_initializer=kernel_init(1.0))(x)
        
        temb = activation_fn(t)
        temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[:, None, None, :]

        x = layers.GroupNormalization(groups)(x)
        x = activation_fn(x)
        x = layers.Conv2d(width, kernel_size=3, padding='same', kernel_initializer=kernel_init(1.0))(x)

        x = layers.Add()([x, temb])
        x = layers.GroupNormalization(groups)(x)
        x = activation_fn(x)
        
        x = layers.Conv2D(width, kernel_size=3, padding='same', kernel_initializer=kernel_init(0.0))(x)
        x = layers.Add()([x, residual])
        return x
    
    return apply


In [None]:
def DownSample(width):
    def apply(x):
        x = layers.Conv2D(width, kernel_size=3, strides=2, padding='same', kernel_initializer=kernel_initializer(1.0))(x)
        return x
    
    return apply

def UpSample(width, interpolation='nearest'):
    def apply(x):
        x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
        x = layers.Conv2D(width, kernel_size=3, padding='same', kernel_initializer=kernel_init(1.0))(x)
        return x
    
    return apply

In [None]:
def TimeMLP(units, activation_fn=keras.activations.swish):
    def apply(inputs):
        temb = layers.Dense(units, activation=activation_fn, kernel_initializer=kernel_init(1.0))(inputs)
        temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
        return temb
    
    return apply

def build_model(img_size, img_channels, widths, has_attention, num_res_blocks=2, norm_groups=8, interpolation='nearest', activation_fn=keras.activations.swish):
    image_input = layers = layers.Input((img_size, img_size, img_channels), name='image_input')
    time_input = keras.Input(shape=(), dtype=tf.int64, name='time_input')

    x = layers.Conv2D(hp['first_conv_channels'], kernel_size=(3, 3), padding='same', kernel_initializer=kernel_init(1.0))(image_input)

    temb = TimeEmbedding(dim=hp['first_conv_channels'] * 4)(time_input)
    temb = TimeMLP(units=hp['first_conv_channels'] * 4, activation_fn=activation_fn)(temb)

    skips = [x]

    for i in range(len(widths)):
        for _ in range(num_res_blocks):
            x = ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)(x)([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)
            
            skips.append(x)

        if widths[i] != widths[-1]:
            x = DownSample(widths[i])(x)
            skips.append(x)
    
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])
    x = AttentionBlock(widths[-1], groups=norm_groups)(x)
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)([x, temb])

    for i in reversed(range(len(widths))):
        for _ in range(num_res_blocks + 1):
            x = layers.Concatenate(axis=-1)([x, skips.pop()])
            x = ResidualBlock(widths[i], groups=norm_groups, activation_fn=activation_fn)([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)
        
        if i != 0:
            x = UpSample(widths[i], interpolation=interpolation)(x)

    x = layers.GoupNormalization(groups=norm_groups)(x)
    x = layers.Conv2D(3, kernel_size=(3, 3), padding='same', kernel_initializer=kernel_init(0.0))(x)
    return keras.Model([image_input, time_input], x, name='unet')