<a href="https://colab.research.google.com/github/lrav35/diffusion/blob/main/OG_QuickDiffusionModel_TF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
from tqdm.auto import trange, tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

In [2]:
IMG_SIZE = 32
BATCH_SIZE = 64
FIG_SIZE = (9, 9)
category = 5 # dogs
timesteps = 32
num_images = 36
time_bar = 1 - np.linspace(0, 1., timesteps + 1)

In [3]:
# load data and split into train and test
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train[y_train.squeeze() == category]  # only dog images
X_train = X_train / 127.5 - 1.0  # scale to -1 to 1

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
# functions for viewing images
def show_examples(data, num_images=num_images):
  # displays example images in a grid format
  plt.figure(figsize=FIG_SIZE)
  grid_size = int(np.sqrt(num_images))
  for i in range(num_images):
    plt.subplot(grid_size, grid_size, i + 1)
    img = scale(data[i])
    plt.imshow(img)
    plt.axis('off')

def scale(img):
    # min-max scaling
    img = img - img.min()
    img = img / img.max()
    return img

In [5]:
# util functions
def generate_ts(num):
  # generates random timesteps
  return np.random.randint(0, timesteps, size=num)

In [6]:
# forward noise
def forward_noise(img, ts):
  # slowly adds noise to images
  current_time = time_bar[ts]
  next_time = time_bar[ts + 1]

  noise = np.random.normal(size=img.shape)  # noise mask

  new_shape = tuple([-1] + [1 for _ in range(img.ndim - 1)])
  current_time = current_time.reshape(new_shape)
  next_time = next_time.reshape(new_shape)

  current_img = img * (1 - current_time) + noise * current_time
  next_img = img * (1 - next_time) + noise * next_time
  return current_img, next_img

In [7]:
def block(x_img, x_ts):
  x_param = keras.layers.Conv2D(128, kernel_size=3, padding='same')(x_img)
  x_param = keras.layers.Activation('relu')(x_param)

  time_param = keras.layers.Dense(128)(x_ts)
  time_param = keras.layers.Activation('relu')(time_param)
  time_param = keras.layers.Reshape((1, 1, 128))(time_param)

  x_param = x_param * time_param

  x_out = keras.layers.Conv2D(128, kernel_size=3, padding='same')(x_img)
  x_out = x_out + x_param
  x_out = keras.layers.LayerNormalization()(x_out)
  x_out = keras.layers.Activation('relu')(x_out)

  return x_out

In [15]:
def mlp(x, x_ts):
    x = keras.layers.Flatten()(x)
    x = keras.layers.Concatenate()([x, x_ts])
    x = keras.layers.Dense(128)(x)
    x = keras.layers.LayerNormalization()(x)
    x = keras.layers.Activation('relu')(x)

    x = keras.layers.Dense(4 * 4 * 32)(x)
    x = keras.layers.LayerNormalization()(x)
    x = keras.layers.Activation('relu')(x)
    x = keras.layers.Reshape((4, 4, 32))(x)

    return x

In [28]:
def make_model():
  x = x_input = keras.layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='x_input')
  x_ts = x_ts_input = keras.layers.Input(shape=(1,), name='x_ts_input')

  x_ts = keras.layers.Dense(192)(x_ts)
  x_ts = keras.layers.LayerNormalization()(x_ts)
  x_ts = keras.layers.Activation('relu')(x_ts)

  x = x32 = block(x, x_ts)
  x = keras.layers.MaxPool2D(2)(x)

  x = x16 = block(x, x_ts)
  x = keras.layers.MaxPool2D(2)(x)

  x = x8 = block(x, x_ts)
  x = keras.layers.MaxPool2D(2)(x)

  x = x4 = block(x, x_ts)

  x = mlp(x, x_ts)

  x = keras.layers.Concatenate()([x, x4])
  x = block(x, x_ts)
  x = keras.layers.UpSampling2D(2)(x)

  x = keras.layers.Concatenate()([x, x8])
  x = block(x, x_ts)
  x = keras.layers.UpSampling2D(2)(x)

  x = keras.layers.Concatenate()([x, x16])
  x = block(x, x_ts)
  x = keras.layers.UpSampling2D(2)(x)

  x = keras.layers.Concatenate()([x, x32])
  x = block(x, x_ts)

  x = keras.layers.Conv2D(3, kernel_size=1, padding='same')(x)
  model = tf.keras.models.Model([x_input, x_ts_input], x)
  return model

In [29]:
model = make_model()

In [30]:
model.summary()

Model: "model_5"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 x_ts_input (InputLayer)     [(None, 1)]                  0         []                            
                                                                                                  
 dense_30 (Dense)            (None, 192)                  384       ['x_ts_input[0][0]']          
                                                                                                  
 layer_normalization_30 (La  (None, 192)                  384       ['dense_30[0][0]']            
 yerNormalization)                                                                                
                                                                                                  
 activation_78 (Activation)  (None, 192)                  0         ['layer_normalization_30