In [1]:
from google.colab import drive
drive.mount('/content/drive')

!cp drive/MyDrive/ai-faces.zip .
!unzip ai-faces.zip

In [2]:
import os
import keras
import keras.losses
import keras.optimizers
import keras.layers as layers
import keras.losses as losses
import keras.callbacks as callbacks
import keras.backend as K
import tensorflow as tf
import numpy as np
import shutil

In [3]:

embedding_length = 2048
embedding_dim = 4
image_size = (224, 224)
latent_dim_size = (224 // 4, 224 // 4, embedding_dim)
beta = 0.25
runeager = False
small_dataset = True
ds_size = 20048
batch_size = 32
test_size = 5
epochs = 1000
num_layers = 2
# Number of columns in list_attr_celeba.txt
num_label_columns = 40

steps = 500
variance_schedule_start = 0.0001
variance_schedule_end = 0.02
variance_schedule = [i * (variance_schedule_end - variance_schedule_start) / steps + variance_schedule_start for i in range(steps)]

filters = 32

version = 1

root = ""#"drive/MyDrive/"

img_output_path = f"{root}generated-latent-diffusion-v{version}"
log_dir = "logs/latent-diffusion"
model_path = f"{root}models/latent_diffusion_faces_v{version}.keras"
  
if os.path.exists(os.path.join(log_dir, "train")):
  shutil.rmtree(os.path.join(log_dir, "train"))
  
ae_model_path = f"{root}models/vqgan_faces_v8.keras"


Encoder and decoder

In [4]:
class Swish(layers.Layer):
  def call(self, x):
    return x * K.sigmoid(x)

class GroupNormalization(layers.Layer):
  def __init__(self, num_groups = 32, epsilon=1e-7, **kwargs):
    super().__init__(**kwargs)
    self.num_groups = num_groups
    self.epsilon = epsilon

  def build(self, input_shape):
    (_, _, _, C) = input_shape
    self.channel_weights = self.add_weight("channel_weights", shape=(1, 1, 1, C), initializer=tf.random_uniform_initializer(-1.0, 1.0), trainable=True)
    self.channel_biases = self.add_weight("channel_biases", shape=(1, 1, 1, C), initializer=tf.random_uniform_initializer(-1.0, 1.0), trainable=True)

  def call(self, x):
    (_, W, H, C) = x.shape
    B = tf.shape(x)[0]
    x = tf.reshape(x, shape=(B, W, H, self.num_groups, C // self.num_groups))
    mean, var = tf.nn.moments(x, [1, 2, 4], keepdims=True)
    x = (x - mean) / tf.sqrt(var + self.epsilon)
    x = tf.reshape(x, shape=(B, W, H, C))
    x = x * self.channel_weights + self.channel_biases
    return x

  def get_config(self):
    config = super(GroupNormalization, self).get_config()
    config.update({
      "num_groups": self.num_groups,
      "epsilon": self.epsilon
    })
    return config

class VectorQuantization(layers.Layer):
  def __init__(self, embedding_length, embedding_dim, beta=0.25, **kwargs):
    super(VectorQuantization, self).__init__(**kwargs)
    self.embedding_length = embedding_length
    self.embedding_dim = embedding_dim
    self.beta = beta
    self.embedding = self.add_weight("embedding",
      shape=(embedding_length, embedding_dim),
      initializer=tf.random_uniform_initializer(-1.0, 1.0), 
      trainable=True)

  def call(self, input):
    (_, w, h, c) = input.shape
    B = tf.shape(input)[0]
    flat = tf.reshape(input, shape=(B * w * h, c))
    flat = tf.tile(flat, [1, self.embedding_length])
    flat = tf.reshape(flat, shape=(B * w * h, self.embedding_length, c))
    diff = tf.pow(flat - self.embedding, 2)
    diff = tf.reduce_sum(diff, axis=-1)
    embedding_indexes = tf.argmin(diff, axis=-1)
    embedding_indexes = tf.reshape(embedding_indexes, shape=(B, w, h))
    quantized_vectors = tf.gather(self.embedding, embedding_indexes)

    embedding_loss = tf.reduce_mean((tf.stop_gradient(input) - quantized_vectors) ** 2)
    encoding_loss = tf.reduce_mean((input - tf.stop_gradient(quantized_vectors)) ** 2)
    self.add_loss(embedding_loss + self.beta * encoding_loss)

    # Straight through estimator
    quantized_vectors = input + tf.stop_gradient(quantized_vectors - input)
    return quantized_vectors

  def get_config(self):
    config = super(VectorQuantization, self).get_config()
    config.update({
      "embedding_length": self.embedding_length,
      "embedding_dim": self.embedding_dim,
      "beta": self.beta
    })
    return config

In [5]:
custom_objects = {
  "Swish": Swish,
  "GroupNormalization": GroupNormalization,
  "VectorQuantization": VectorQuantization
}

autoencoder = keras.models.load_model(ae_model_path, custom_objects)
autoencoder.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 224, 224, 32  896         ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 group_normalization (GroupNorm  (None, 224, 224, 32  64         ['conv2d[0][0]']                 
 alization)                     )                                                             

Encoder should not include vector quantization, however the decoder should. conv2d_12 is the last layer before the quantization.

In [6]:
encoder = keras.models.Model(autoencoder.input, autoencoder.get_layer("conv2d_12").output, name="encoder")
decoder = keras.models.Model(autoencoder.get_layer("vector_quantization").input, autoencoder.output, name="decoder")

In [7]:

class Diffusion(keras.models.Model):
  def __init__(self, model,  num_steps, variance_schedule, **kwargs):
    super().__init__(kwargs)
    self.num_steps = num_steps
    self.loss_tracker = keras.metrics.Mean("loss")
    self.model = model
    self.loss_fn = keras.losses.MSE
    self.variance_schedule = variance_schedule
    self.alpha = [1 - b for b in variance_schedule]
    self.alpha_accumulated = []
    total = 1.
    for a in self.alpha:
      total *= a
      self.alpha_accumulated.append(total)

  @property
  def metrics(self):
    return [self.loss_tracker]

  def train_step(self, input):
    real_images, labels = input
    _, width, height, channels = real_images.shape
    batch_size = tf.shape(real_images)[0]
    labels = labels[:batch_size]
    t = tf.random.uniform(shape=(batch_size,), minval=0, maxval=self.num_steps, dtype=tf.int32)
    t_input = t / self.num_steps
    input_shape = (batch_size, width, height, channels)
    noise = tf.random.normal(shape=input_shape)
    alpha_t = tf.gather(self.alpha_accumulated, t)
    alpha_t = tf.reshape(alpha_t, shape=(batch_size, 1, 1, 1))
    noise_variance = tf.sqrt(1 - alpha_t) * noise
    img_median = tf.sqrt(alpha_t) * real_images
    noisy_input = img_median + noise_variance

    with tf.GradientTape() as tape:
      predicted_noise = self.model([noisy_input, labels, t_input])
      noise_loss = self.loss_fn(noise, predicted_noise)
    grads = tape.gradient(noise_loss, self.model.trainable_weights)
    self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))
    self.loss_tracker.update_state(noise_loss)

    return {
      "loss": self.loss_tracker.result()
    }

  def sample(self, num_images, labels=None):
    result_shape = (num_images, *latent_dim_size)
    result = tf.random.normal(shape=result_shape)
    for t in reversed(range(1, self.num_steps)):
      if t > 1:
        z = tf.random.normal(shape=result_shape)
      else:
        z = tf.zeros(shape=result_shape)
      
      alpha = self.alpha[t]
      alpha_t = self.alpha_accumulated[t]
      t_input = tf.constant(t / self.num_steps, dtype=tf.float32)
      t_input = tf.broadcast_to(t_input, shape=(num_images,))
      # Todo: make 40 not hardcoded
      if labels is None:
        labels = tf.random.uniform(shape=(num_images, 40), minval=0, maxval=1, dtype=tf.int32)
      # Convert to -1 or 1
      labels = labels * 2 - 1
      predicted_noise = self.model([result, labels, t_input])
      noise_factor = (1 - alpha) / tf.sqrt(1 - alpha_t)
      sigma = tf.sqrt(self.variance_schedule[t])
      result = (1 / tf.sqrt(alpha)) * (result - noise_factor * predicted_noise) + sigma * z

    return result

  


In [8]:
def conv_block(filters):
  def inner(x):
    x = layers.Conv2D(filters=filters, kernel_size=3, padding="same", activation="relu")(x)
    return x
  return inner

def time_embedding_proj(shape, idx, dense_dim=8):
  def inner(x):
    b, w, h, c = shape
    x = layers.Dense(units=dense_dim, activation="relu", name=f"time_embedding_{idx}_0")(x)
    x = layers.Dense(units=w * h, name=f"time_embedding_{idx}_1")(x)
    x = layers.Reshape(target_shape=(w, h, 1), name=f"time_embedding_reshape_{idx}")(x)
    x = layers.Conv2D(filters=c, kernel_size=1, activation="relu", name=f"time_embedding_{idx}_2")(x)
    return x
  return inner

def downsample():
  def inner(x):
    x = layers.MaxPooling2D(pool_size=2)(x)
    return x
  return inner

def upsample(filters):
  def inner(x):
    x = layers.Conv2DTranspose(kernel_size=4, strides=2, filters=filters, padding="same")(x)
    return x
  return inner

def dropout(rate):
  def inner(x):
    x = layers.SpatialDropout2D(rate)(x)
    return x
  return inner

def unet_layer(filters, next_layer):
  def inner(x, labels, time_embedding, idx=0):
    labels = layers.Dense(units=8, name=f"labels_proj_{idx}")(labels)
    _, w, h, c = x.shape
    mapped_labels = layers.Dense(units=w * h, name=f"map_labels_{idx}")(labels)
    mapped_labels = layers.Reshape(target_shape=(w, h, 1))(mapped_labels)
    mapped_labels = layers.Conv2D(filters=c, kernel_size=1, name=f"conv_labels_{idx}")(mapped_labels)
    x = layers.add([x, mapped_labels])
    x = conv_block(filters)(x)
    # In bottom layer, do self attention
    if next_layer is None:
      x = self_attention(3)(x)
    x = dropout(0.2)(x)
    x = conv_block(filters)(x)
    residual = x
    if next_layer is not None:
      x = downsample()(x)
      x = next_layer(x, labels, filters * 2, time_embedding, idx + 1)
      x = upsample(filters)(x)
      time = time_embedding_proj(x.shape, idx)(time_embedding)
      x = layers.add([residual, x, time])
      x = conv_block(filters)(x)
      x = conv_block(filters)(x)
    return x
  return inner

def sublayer(next_layer):
  def inner(x, labels, filters, time_embedding, idx):
    return unet_layer(filters, next_layer)(x, labels, time_embedding, idx)
  return inner

def positional_encoding2d():
  def inner(inputs):
    _, w, h, c = inputs.shape
    batch_size = tf.shape(inputs)[0]
    x = tf.range(start=0, limit=w, delta=1)
    x = x / w
    x = tf.expand_dims(x, axis=0)
    assert x.shape == (1, w)
    x = tf.tile(x, multiples=[h, 1])
    assert x.shape == (w, h)
    x = tf.reshape(x, shape=(w, h, 1))

    y = tf.range(start=0, limit=h, delta=1)
    y = y / h
    y = tf.expand_dims(y, axis=1)
    assert y.shape == (h, 1)
    y = tf.tile(y, [1, w])
    assert y.shape == (w, h)
    y = tf.reshape(y, shape=(w, h, 1))

    indexes = tf.concat([x, y], axis=-1)
    assert indexes.shape == (w, h, 2)

    indexes = tf.expand_dims(indexes, axis=0)
    indexes = tf.tile(indexes, [batch_size, 1, 1, 1])

    return layers.Conv2D(c, kernel_size=1, strides=1, padding="same")(indexes)
    # Todo: the sinusoidal way from All you need is attention
  return inner

def self_attention(num_heads=1, key_dim=64):
  def inner(x):
    # add positional encoding?
    pos = positional_encoding2d()(x)
    x = layers.add([x, pos])
    x = layers.MultiHeadAttention(num_heads, key_dim, attention_axes=(2, 3))(x, x, x)
    return x
  return inner

def cross_attention(num_heads=1, key_dim=64):
  def inner(qk, v):
    pos_qk = positional_encoding2d()(qk)
    pos_v = positional_encoding2d()(v)
    qk = layers.add([qk, pos_qk])
    v = layers.add([v, pos_v])
    return layers.MultiHeadAttention(num_heads, key_dim, attention_axes=(2, 3))(qk, v, qk)
  return inner



In [9]:

def get_model(num_labels, filters):
  if os.path.exists(model_path):
    return keras.models.load_model(model_path)

  w, h, c = latent_dim_size
  image_input = keras.Input(shape=(w, h, c), name="images")
  t_input = keras.Input(shape=(1,), name="t_input")
  label_input = keras.Input(shape=(num_labels,), name="labels")
  sublayers = sublayer(None)
  for i in range(num_layers):
    sublayers = sublayer(sublayers)

  x = unet_layer(filters, sublayers)(image_input, label_input, t_input)
  output = layers.Conv2D(filters=embedding_dim, kernel_size=3, padding="same")(x)
  model = keras.Model([image_input, label_input, t_input], output)
  return model


In [10]:

class DiffusionMonitor(keras.callbacks.Callback):
  def __init__(self, decoder, num_img=3):
    self.num_img = num_img
    self.decoder = decoder

  def on_epoch_end(self, epoch, logs=None):
    generated_images = self.model.sample(self.num_img)
    generated_images = self.decoder(generated_images)
    if not os.path.exists(img_output_path):
      os.mkdir(img_output_path)

    for i in range(self.num_img):
      img = keras.utils.array_to_img(generated_images[i])
      img.save(os.path.join(img_output_path, f"generated_{epoch:03d}_{i}.png"))

class Save(keras.callbacks.Callback):
  def __init__(self, model_path):
    self.model_path = model_path

  def on_epoch_end(self, epoch, logs=None):
    self.model.model.save(self.model_path)


In [11]:

model = get_model(num_label_columns, filters)
model.summary()

diffusion = Diffusion(model, steps, variance_schedule)
diffusion.compile(optimizer="rmsprop", run_eagerly=runeager)

callbacks = [
  DiffusionMonitor(decoder),
  Save(model_path)
]


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 labels (InputLayer)            [(None, 40)]         0           []                               
                                                                                                  
 labels_proj_0 (Dense)          (None, 8)            328         ['labels[0][0]']                 
                                                                                                  
 map_labels_0 (Dense)           (None, 3136)         28224       ['labels_proj_0[0][0]']          
                                                                                                  
 reshape (Reshape)              (None, 56, 56, 1)    0           ['map_labels_0[0][0]']           
                                                                                              

In [12]:

fname = os.path.join("list_attr_celeba.txt")

with open(fname) as f:
  data = f.read()

lines = data.split("\n")
header = lines[1].split()
lines = lines[2:-1]
raw_data = np.zeros((len(lines), len(header)))
for i, line in enumerate(lines):
  line_data = [int(x) for x in line.split()[1:]]
  raw_data[i, :] = line_data[:]

label_dataset = tf.data.Dataset.from_tensor_slices(raw_data).batch(batch_size)
image_dataset = keras.utils.image_dataset_from_directory(
  f"img_align_celeba{'_small' if small_dataset else ''}",
  label_mode=None,
  image_size=image_size,
  batch_size=batch_size,
  smart_resize=True,
  shuffle=False)

image_dataset = image_dataset.map(lambda x: x / (255. / 2) - 1.)
image_dataset = image_dataset.map(lambda x: encoder(x))

dataset = tf.data.Dataset.zip((image_dataset, label_dataset))
dataset = dataset.take(ds_size)


Found 1299 files belonging to 1 classes.


In [13]:

diffusion.fit(dataset, callbacks=callbacks, epochs=epochs)



Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000

KeyboardInterrupt: 