<a href="https://colab.research.google.com/github/lb-97/GenerativeAI-DDIM-MNIST/blob/main/DDIM-MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import math
from PIL import Image

import tensorflow as tf
from tensorflow import keras, einsum
from tensorflow.keras import Model, Sequential
from tensorflow.keras import layers
# import tensorflow.keras.layers as nn
# import tensorflow_addons as tfa
# import tensorflow_datasets as tfds

# from einops import rearrange
# from einops.layers.tensorflow import Rearrange
# from functools import partial
# from inspect import isfunction

In [None]:
timesteps = 200

# create a fixed beta schedule
beta = np.linspace(0.0001, 0.02, timesteps)

# this will be used as discussed in the reparameterization trick
alpha = 1 - beta
alpha_bar = np.cumprod(alpha, 0)
# alpha_bar = np.concatenate((np.array([1.]), alpha_bar[:-1]), axis=0)
sqrt_alpha_bar = np.sqrt(alpha_bar)
sqrt_one_minus_alpha_bar = np.sqrt(1-alpha_bar)

alpha = tf.constant(alpha)
alpha_bar = tf.constant(alpha_bar)
sqrt_alpha_bar = tf.constant(sqrt_alpha_bar,dtype=tf.float32)
sqrt_one_minus_alpha_bar = tf.constant(sqrt_one_minus_alpha_bar,dtype=tf.float32)

In [None]:
(sqrt_alpha_bar.dtype)

tf.float32

In [None]:
def kernel_init(scale):
    scale = max(scale, 1e-10)
    return keras.initializers.VarianceScaling(
        scale, mode="fan_avg", distribution="uniform"
    )

class AttentionBlock(layers.Layer):
    """Applies self-attention.

    Args:
        units: Number of units in the dense layers
        groups: Number of groups to be used for GroupNormalization layer
    """

    def __init__(self, units, groups=8, **kwargs):
        self.units = units
        self.groups = groups
        super().__init__(**kwargs)

        self.norm = layers.GroupNormalization(groups=groups)
        self.query = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.key = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.value = layers.Dense(units, kernel_initializer=kernel_init(1.0))
        self.proj = layers.Dense(units, kernel_initializer=kernel_init(0.0))

    def call(self, inputs):
        batch_size = tf.shape(inputs)[0]
        height = tf.shape(inputs)[1]
        width = tf.shape(inputs)[2]
        scale = tf.cast(self.units, tf.float32) ** (-0.5)

        inputs = self.norm(inputs)
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)

        attn_score = tf.einsum("bhwc, bHWc->bhwHW", q, k) * scale
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height * width])

        attn_score = tf.nn.softmax(attn_score, -1)
        attn_score = tf.reshape(attn_score, [batch_size, height, width, height, width])

        proj = tf.einsum("bhwHW,bHWc->bhwc", attn_score, v)
        proj = self.proj(proj)
        return inputs + proj

class TimeEmbedding(layers.Layer):
    def __init__(self, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.half_dim = dim // 2
        self.emb = math.log(10000) / (self.half_dim - 1)
        self.emb = tf.exp(tf.range(self.half_dim, dtype=tf.float32) * -self.emb)

    def call(self, inputs):
        inputs = tf.cast(inputs, dtype=tf.float32)
        emb = inputs[:, None] * self.emb[None, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        return emb


def ResidualBlock(width, groups=8, activation_fn=keras.activations.swish):
    def apply(inputs):
        x, t = inputs
        input_width = x.shape[3]

        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(
                width, kernel_size=1, kernel_initializer=kernel_init(1.0)
            )(x)

        temb = activation_fn(t)
        temb = layers.Dense(width, kernel_initializer=kernel_init(1.0))(temb)[
            :, None, None, :
        ]

        x = layers.GroupNormalization(groups=groups)(x)
        x = activation_fn(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
        )(x)

        x = layers.Add()([x, temb])
        x = layers.GroupNormalization(groups=groups)(x)
        x = activation_fn(x)

        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(0.0)
        )(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownSample(width):
    def apply(x):
        x = layers.Conv2D(
            width,
            kernel_size=3,
            strides=2,
            padding="same",
            kernel_initializer=kernel_init(1.0),
        )(x)
        return x

    return apply


def UpSample(width, interpolation="nearest"):
    def apply(x):
        x = layers.UpSampling2D(size=2, interpolation=interpolation)(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", kernel_initializer=kernel_init(1.0)
        )(x)
        return x

    return apply


def TimeMLP(units, activation_fn=keras.activations.swish):
    def apply(inputs):
        temb = layers.Dense(
            units, activation=activation_fn, kernel_initializer=kernel_init(1.0)
        )(inputs)
        temb = layers.Dense(units, kernel_initializer=kernel_init(1.0))(temb)
        return temb

    return apply




In [None]:
first_conv_channels=64
def build_model(
    img_size,
    img_channels,
    widths,
    has_attention,
    num_res_blocks=2,
    norm_groups=8,
    interpolation="nearest",
    activation_fn=keras.activations.swish,
):
    image_input = layers.Input(
        shape=(img_size, img_size, img_channels), name="image_input"
    )
    time_input = keras.Input(shape=(), dtype=tf.int64, name="time_input")

    x = layers.Conv2D(
        first_conv_channels,
        kernel_size=(3, 3),
        padding="same",
        kernel_initializer=kernel_init(1.0),
    )(image_input)

    temb = TimeEmbedding(dim=first_conv_channels * 4)(time_input)
    temb = TimeMLP(units=first_conv_channels * 4, activation_fn=activation_fn)(temb)

    skips = [x]

    # DownBlock
    for i in range(len(widths)):
        for _ in range(num_res_blocks):
            x = ResidualBlock(
                widths[i], groups=norm_groups, activation_fn=activation_fn
            )([x, temb])
            if has_attention[i]:
                x = AttentionBlock(widths[i], groups=norm_groups)(x)
            skips.append(x)

        if widths[i] != widths[-1]:
            x = DownSample(widths[i])(x)
            skips.append(x)

    # MiddleBlock
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
        [x, temb]
    )
    x = AttentionBlock(widths[-1], groups=norm_groups)(x)
    x = ResidualBlock(widths[-1], groups=norm_groups, activation_fn=activation_fn)(
        [x, temb]
    )

    # UpBlock
    for i in reversed(range(len(widths))):
      for _ in range(num_res_blocks + 1):
        #print(x.shape, skips[-1].shape)
        x = layers.Concatenate(axis=-1)([x, skips.pop()])
        x = ResidualBlock(
            widths[i], groups=norm_groups, activation_fn=activation_fn
        )([x, temb])
        if has_attention[i]:
            x = AttentionBlock(widths[i], groups=norm_groups)(x)

      if i != 0:
        x = UpSample(widths[i], interpolation=interpolation)(x)

    # End block
    x = layers.GroupNormalization(groups=norm_groups)(x)
    x = activation_fn(x)
    x = layers.Conv2D(1, (3, 3), padding="same", kernel_initializer=kernel_init(0.0))(x)
    return keras.Model([image_input, time_input], x, name="unet")



In [None]:
build_model(img_size=8,
        img_channels=16,
        widths=[64,128,256],
        has_attention=[False, False, True, True]).summary()

Model: "unet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 time_input (InputLayer)        [(None,)]            0           []                               
                                                                                                  
 time_embedding_16 (TimeEmbeddi  (None, 256)         0           ['time_input[0][0]']             
 ng)                                                                                              
                                                                                                  
 image_input (InputLayer)       [(None, 8, 8, 16)]   0           []                               
                                                                                                  
 dense_370 (Dense)              (None, 256)          65792       ['time_embedding_16[0][0]']   

In [None]:
class DiffusionModel(keras.Model):

  def __init__(self):
    super().__init__()
    self.network = build_model(img_size=28,
        img_channels=1,
        widths=[64,128,256],
        has_attention=[False, False, True, True])
    self.timesteps=timesteps

  def extract(self, t):
    return

  def train_step(self,images):
    batch_size = images.shape[0]
    t = tf.random.uniform(
            minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64
        )
    with tf.GradientTape() as tape:
      # 3. Sample random noise to be added to the images in the batch
      noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype)

      # 4. Diffuse the images with noise
      sqb = tf.reshape(tf.gather(sqrt_alpha_bar,t),[batch_size,1,1,1])
      osqb = tf.reshape(tf.gather(sqrt_one_minus_alpha_bar,t),[batch_size,1,1,1])
      noisy_img = sqb*images + osqb*noise
      #images_t = self.gdf_util.q_sample(images, t, noise)

      # 5. Pass the diffused images and time steps to the network
      pred_noise = self.network([noisy_img, t], training=True)

      # 6. Calculate the loss
      lo = self.loss(noise, pred_noise)

    gradients = tape.gradient(lo, self.network.trainable_weights)

    # 8. Update the weights of the network
    self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

    return {"loss": lo}

  def generate(self, shape):
    noisy_image = tf.random.normal(shape=shape, dtype=tf.float32)
    for i in range(timesteps-1,-1,-1):
      if i>0:
        noise = tf.random.normal(shape=shape, dtype=tf.float32)
      else:
        noise = 0

      pred_noise = self.network(noisy_image, i)

      a = tf.reshape(tf.gather(alpha,i),[shape[0],1,1,1])
      sqab = tf.reshape(tf.gather(sqrt_one_minus_alpha_bar,i),[shape[0],1,1,1])
      pred_noise = ((1-a)/sqab)*pred_noise

      noisy_image = (noisy_image-pred_noise)/tf.sqrt(a) + noise

    return noisy_image

In [None]:
(x_train, x_train_labels), (x_test, _) = keras.datasets.mnist.load_data()

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
x_train_scaled = (x_train / 255.0) - 0.5
x_test_scaled = (x_test / 255.0) - 0.5

data_variance = np.var(x_train / 255.0)

In [None]:
from tensorflow.python.client import device_lib

device_lib.list_local_devices()

[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 }
 incarnation: 4033244786859197162
 xla_global_id: -1]

In [None]:
with tf.device('/device:GPU:0'):
    model = DiffusionModel()

    # Compile the model
    model.compile(
        loss=keras.losses.MeanSquaredError(),
        optimizer=keras.optimizers.Adam(learning_rate=0.0002),
    )

    # Train the model
    model.fit(
        x_train_scaled,
        epochs=10,
        batch_size=1
        #callbacks=[keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images)],
    )