In [None]:
!nvidia-smi

In [None]:
# ignore this when running on colab!
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import tensorflow as tf
tfkl = tf.keras.layers
import tensorflow_probability as tfp
tfd = tfp.distributions
import numpy as np
from matplotlib import pyplot as plt 
import tensorflow_addons as tfa

from data.utils import parse_image_example

In [None]:
def scale_data(images):
    return 2*images - 1

def descale_data(images):
    return (images + 1) / 2

In [None]:
batch_size = 512

(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
train_images = np.pad(train_images[..., None], ((0, 0), (2, 2), (2, 2), (0, 0))).astype(np.float32) / 255.
test_images = np.pad(test_images[..., None], ((0, 0), (2, 2), (2, 2), (0, 0))).astype(np.float32) / 255.

# note that scale_forward is already mapped onto the dataset here.
# for sampling etc., we need to be mindful to apply scale_backward.
# although it doesn't matter much for MNIST or single-channel images in general,
# as colormaps are scaled automatically
train_data = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(batch_size, drop_remainder=True).map(scale_data)
test_data = tf.data.Dataset.from_tensor_slices(test_images).batch(32).map(scale_data)

In [None]:
test_images = np.concatenate([batch for batch in iter(test_data)], axis=0)

plt.figure(figsize=(15,15))
for ind, img in enumerate(descale_data(test_images[:64])):
    plt.subplot(8, 8, ind+1)
    plt.imshow(img, vmin=0, vmax=1, cmap="Greys")
    plt.axis("off")
plt.show()

In [None]:
# parameters from the paper: t=1000, betas=np.linspace(0.0001, 0.02, tmax)
# I reduced t to 200 for faster sampling, seems to be ok
tmax = 200
betas = np.linspace(0.0001, 0.1, tmax).astype(np.float32)
alphas = 1 - betas
alphas_bar = np.cumprod(alphas)

plt.plot(alphas_bar)
plt.plot(np.sqrt(alphas_bar))

In [None]:
# here we can look at the forward process slowly turning data to noise
step = 20
for t_index, alpha_bar in enumerate(alphas_bar[::step]):
    noise_scale = np.sqrt(1 - alpha_bar)
    noisy_imgs = np.sqrt(alpha_bar) * test_images[:64] + noise_scale*np.random.normal(size=test_images[:64].shape)
    
    plt.figure(figsize=(15, 15))
    for ind, image in enumerate(descale_data(noisy_imgs)):
        plt.subplot(8, 8, ind+1)
        plt.imshow(image, vmin=0, vmax=1, cmap="Greys")
        plt.axis("off")
        plt.suptitle("t = {}: Shrinkage {}; Noise scale: {}".format(t_index*step+1, np.sqrt(alpha_bar), noise_scale))
    plt.show()

In [None]:
class Diffusion(tf.keras.Model):
    def __init__(self, inputs, outputs, alphas, alphas_bar, betas, t_max, **kwargs):
        super().__init__(inputs, outputs, **kwargs)
        self.loss_tracker = tf.keras.metrics.Mean("loss")
        
        self.alphas_tensor = tf.convert_to_tensor(alphas, dtype=tf.float32)
        self.alphas_bar_tensor = tf.convert_to_tensor(alphas_bar, dtype=tf.float32)
        self.betas_tensor = tf.convert_to_tensor(betas, dtype=tf.float32)
        self.tmax = t_max

    
    def diffusion_loss(self, image_batch, training=None):
        # this samples a batch_size tensor of t values
        sampled_ts = tf.random.uniform([tf.shape(image_batch)[0]], 0, self.tmax, dtype=tf.int32)
        target_epsilons = tf.random.normal(tf.shape(image_batch))

        batch_alphas_bar = tf.gather(self.alphas_bar_tensor, sampled_ts)[:, None, None, None]
        noisy_batch = (tf.math.sqrt(batch_alphas_bar) * image_batch 
                       + tf.math.sqrt(1 - batch_alphas_bar) * target_epsilons)
        # turns it into batch x 1 matrix and normalizes
        normalized_ts = tf.cast(sampled_ts, tf.float32)[:, None] / self.tmax
        
        #%%%%%%%%%% TO UPDATE 
        # this function needs to receive conditioning inputs (e.g. class labels) and hand them to the model call
        # for classifier-free guidance, the labels needs to be occasionally (randomly) dropped out
        noise_prediction = self([noisy_batch, normalized_ts], training=training)
        loss = tf.reduce_mean(tf.reduce_sum((target_epsilons - noise_prediction)**2,
                                            axis=[1,2,3]))
        return loss

    def train_step(self, data):
        #%%%%%%%%% TO UPDATE
        # your dataset should include tuples of images,labels.
        # then you can do
        # image_batch, label_batch = data
        # and pass both to the loss (labels required for conditioning)
        
        # same thing in test_step!
        with tf.GradientTape() as tape:
            loss = self.diffusion_loss(data, training=True)
        gradients = tape.gradient(loss, self.trainable_variables)
        optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}
    
    def test_step(self, data):
        loss = self.diffusion_loss(data, training=False)

        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}
    
    @tf.function
    def langevin_step(self, sample, t_step, beta_version="one"):
        z = tf.random.normal(tf.shape(sample))
        if t_step == 0:
            z = tf.zeros_like(z)

        t_normalized = tf.cast(t_step, tf.float32) * tf.ones([tf.shape(sample)[0], 1]) / self.tmax
        #%%%%%%%%% TO UPDATE
        # this function needs to take class inputs from langevin_sampler
        # and put them into the model.
        #
        # for simple class-conditioning, that is enough.
        #
        # for guided diffusion, you will want to use
        # (1+ w)*model_call(with_condition) - w*model_call(without_condition)
        # without_condition refers to "dropped out" conditioning
        # w (guidance weight) should also be an input
        model_output = self([sample, t_normalized], training=False)

        alpha_here = self.alphas_tensor[t_step]
        alpha_bar_here = self.alphas_bar_tensor[t_step]

        if t_step == 0:
            sigma = tf.convert_to_tensor(0., tf.float32)
        else:
            # these are two different choice for beta described in the paper
            if beta_version == "one":
                sigma = tf.math.sqrt(self.betas_tensor[t_step])
            elif beta_version == "two":
                sigma = tf.math.sqrt((1 - self.alphas_bar_tensor[t_step-1]) / (1 - alpha_bar_here) 
                                     * self.betas_tensor[t_step])

        noise = sigma * z
        sample = (1/tf.math.sqrt(alpha_here) 
                  * (sample - (1 - alpha_here)/tf.math.sqrt(1 - alpha_bar_here) * model_output))
        sample = sample + noise
        return sample
    
    def langevin_sampler(self, n_samples=64, beta_version="one"):
        #%%%%%%%%%%%%% TO UPDATE
        # for conditional generation, you most likely want to change this function
        # to accept class inputs.
        # maybe even remove n_samples, and instead use len(class_inputs) as the number of samples.
        # assuming you are passing a batch_size vector of classes.
        #
        # for guided diffusion, also accept guidance weight as input to pass to the step function
        sample = tf.random.normal((n_samples,) + self.input_shape[0][1:])
        for t_step in tf.range(self.tmax)[::-1]:
            sample = self.langevin_step(sample, t_step, beta_version)

        return sample

In [None]:
# This is an upsample -> conv layers, that could be used in the decoder instead of transposed convolution
class UpsampleConv2D(tfkl.Layer):
    def __init__(self, n_filters, filter_size, strides=1, padding="same", **kwargs):
        super().__init__(**kwargs)
        self.conv = tfkl.Conv2D(n_filters, filter_size, padding=padding, name=self.name + "_conv")
        if strides > 1:
            self.upsample = tfkl.UpSampling2D(size=strides, interpolation="bilinear", name=self.name + "_upsample")
        self.strides = strides
    
    def call(self, inputs):
        if self.strides > 1:
            upsampled = self.upsample(inputs)
        else:
            upsampled = inputs
        return self.conv(upsampled)


# Normalization -> Activation -> Convolution
# often referred to as "pre-activation", popular for residual networks.
# if you are interested: https://arxiv.org/pdf/1603.05027.pdf
class NormActConv(tfkl.Layer):
    def __init__(self,
                 n_filters,
                 mode,
                 strides=1,
                 activation=tf.nn.gelu,
                 **kwargs):
        if mode not in ["conv", "transpose", "upconv"]:
            raise ValueError("Invalid mode; valid choices are 'conv', 'transpose', 'upconv'.")
        
        super().__init__(**kwargs)
        if mode == "conv":
            layer_fn = tfkl.Conv2D
        elif mode == "transpose":
            layer_fn = tfkl.Conv2DTranspose
        else:
            layer_fn = UpsampleConv2D

        # NOTE harcoded 3x3 filter size, could change this
        self.conv = layer_fn(n_filters, 3, strides=strides,
                padding="same", name=self.name + "_conv_main")
        self.activation = activation

    def build(self, input_shape):
        self.normalization = tfa.layers.GroupNormalization(groups=input_shape[-1]//4,
                                                           name=self.name + "_normalization")
        # these components are for adaptive normalization
        self.time_mult = tfkl.Dense(input_shape[-1], kernel_initializer=tf.keras.initializers.Zeros(),
                                    name=self.name + "_adanorm_scale")
        self.time_add = tfkl.Dense(input_shape[-1], kernel_initializer=tf.keras.initializers.Zeros(),
                                    name=self.name + "_adanorm_shift")
        
        #%%%%%%%%% TO UPDATE
        # you can add another set of shift, scale layers for the class-conditioning input
        
    def call(self, inputs, time):
        
        if self.normalization:
            normed = self.normalization(inputs)
        else:
            normed = inputs
            
        # adaptive normalization implements conditioning on time
        normed = (1 + self.time_mult(time)) * normed + self.time_add(time)
        #%%%%%%%%%%%%% TO UPDATE
        # the layer should also receive a class input and apply adaptive normalization.
        # the same way it is done for time!
        
        if self.activation:
            acted = self.activation(normed)
        else:
            acted = normed
        
        conved = self.conv(acted)
        return conved


# residual block with two convolutional layers
class ResidualBlock(tfkl.Layer):
    def __init__(self,
                 n_filters,
                 mode,
                 strides=1,
                 activation=tf.nn.gelu,
                 **kwargs):
        super().__init__(**kwargs)
        self.main_layer1 = NormActConv(n_filters, mode, strides,
                                      activation=activation,
                                      name=self.name + "_main1")
        
        self.main_layer2 = NormActConv(n_filters, mode, 1,
                                      activation=activation,
                                      name=self.name + "_main2")
        
        if mode == "conv":
            shortcut_fn = tfkl.Conv2D
        elif mode == "upconv":
            shortcut_fn = UpsampleConv2D
        else:
            shortcut_fn = tfkl.Conv2DTranspose
        self.shortcut = shortcut_fn(n_filters, 1, strides=strides,
                                    name=self.name + "_shortcut")
        self.strides = strides
        self.n_filters = n_filters
        
    def call(self, inputs, time):
        #%%%%%%%%%%% TO UPDATE
        # this layer needs to receive class conditioning input
        # and pass it to the main layers
        l1 = self.main_layer1(inputs, time)
        l2 = self.main_layer2(l1, time)
        
        shortcut = self.shortcut(inputs)
        
        # residual connection is re-scaled by 1/sqrt(2).
        # this helps.
        out = (1/tf.math.sqrt(2.)) * (l2 + shortcut)
        
        return out


    
#%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# THESE ARE CURRENTLY NOT USED IN THE MODEL!!
# so if you don't use them, you don't need to update them for conditioning, duh

# a "level" is a series of residual blocks operating at the same resolution 
# and with the same number of filters
class DownLevel(tfkl.Layer):
    def __init__(self, n_filters, n_blocks, strides=1, **kwargs):
        super().__init__(**kwargs)
        self.blocks = [ResidualBlock(n_filters, "conv", strides=strides if ind==0 else 1,
                                     name=self.name + "_block{}".format(ind))
                       for ind in range(n_blocks)]
        
    def call(self, inputs, time):
        #%%%%%%%%%% TO UPDATE
        # this layer needs to receive class conditioning input
        # and pass it to the residual blocks
        for block in self.blocks:
            inputs = block(inputs, time)
            
        return inputs
    
    
class UpLevel(tfkl.Layer):
    def __init__(self, n_filters, n_blocks, strides=1, **kwargs):
        super().__init__(**kwargs)
        # NOTE could change "transpose" to "upconv" to use upsampling + convolution
        self.blocks = [ResidualBlock(n_filters, "transpose", strides=strides if ind==0 else 1,
                                     name=self.name + "_block{}".format(ind))
                       for ind in range(n_blocks)]
        
    def call(self, inputs, time):
        #%%%%%%%%%% TO UPDATE
        # this layer needs to receive class conditioning input
        # and pass it to the residual blocks
        for block in self.blocks:
            inputs = block(inputs, time)
            
        return inputs

In [None]:
import tensorflow_addons as tfa


#%%%%%%%%%%%%%%%%% TO UPDATE
# this function needs to receive class conditioning input similar to time
# Same in the function below (decoder stack)
def residual_stack_encoder(inputs, t_input,
                           filters, strides, blocks,
                           mode, name):
    # this collects outputs of all residual blocks, to be used in the decoder for skip connections
    all_outputs = []  
    # one initial convolution to add some channels
    outputs = tfkl.Conv2D(filters[0], 3, padding="same")(inputs)
    for level_ind, (level_filters, level_stride, level_blocks) in enumerate(zip(filters, strides, blocks)):
        for block_ind in range(level_blocks):
            # %%%%%%%%%%%%%%% TO UPDATE
            # these layers need to receive class conditiong input, just like they are receiving time
            outputs = ResidualBlock(level_filters,
                                    mode, 
                                    strides=level_stride if block_ind == 0 else 1,
                                    name="_".join([name, str(level_ind+1), str(block_ind+1)]))(outputs, t_input)
            all_outputs.append(outputs)
        
    return outputs, all_outputs

def residual_stack_decoder(inputs, t_input, all_hidden,
                           filters, strides, blocks,
                           mode, name):
    outputs = inputs
    global_ind = 0
    for level_ind, (level_filters, level_stride, level_blocks) in enumerate(zip(filters, strides, blocks)):
        for block_ind in range(level_blocks):
            
            # this part handles skip connections handling from the encoder.
            # this is bad code! very awkward!
            # sorry, but you shouldn't need to touch this
            if global_ind > 0:
                if outputs.shape[1] != all_hidden[global_ind].shape[1]:
                    all_hidden[global_ind] = tfkl.AvgPool2D(padding="same")(all_hidden[global_ind])
                outputs = tfkl.Concatenate(axis=-1)([outputs, all_hidden[global_ind]])
            global_ind += 1
            
            # back to sanity
            outputs = ResidualBlock(level_filters,
                                    mode, 
                                    strides=level_stride if block_ind == 0 else 1,
                                    name="_".join([name, str(level_ind+1), str(block_ind+1)]))(outputs, t_input)
        
    return outputs

In [None]:
def positional_encoding_v2(input_t, n_freqs):
    # input_t: b x 1
    # n_freqs: scalar
    # create frequencies up to below nyquist frequency
    frequencies = tf.convert_to_tensor(np.geomspace(0.1, tmax//2, n_freqs).astype(np.float32))
    sines = tf.math.sin(2*np.pi*frequencies*input_t)  # b x n_freqs
    cosines =  tf.math.cos(2*np.pi*frequencies*input_t)  # b x n_freqs
    return tf.concat([sines, cosines], axis=-1)  # b x 2*n_freqs


inp = tf.keras.Input(test_images.shape[1:])
t_input = tf.keras.Input((1,))
t_encoded = positional_encoding_v2(t_input, n_freqs=32)[:, None, None, :]
t_encoded = tfkl.Dense(32, tf.nn.gelu)(t_encoded)
t_encoded = tfkl.Dense(32, tf.nn.gelu)(t_encoded)


#%%%%%%%%%% TO UPDATE
# add a class input just like time. you could receive it as single indices,
# or already apply one_hot outside the model, and here receive input with shape (10,) (assuming 10 classes)

# positional encoding does not make sense for clases. so just use bunch of Dense layers.
# if you receive classes as indices, your first layer should be an Embedding
# if you receive one-hot, just use Dense


# you can make this smaller if it takes too long
blocks_per_level = [2, 2, 2, 2]
filters = [16, 32, 64, 128]
strides = [1, 2, 2, 2]
encoder_output, all_hidden = residual_stack_encoder(
    inp, t_encoded,
    filters,
    strides,
    blocks_per_level,
    "conv", "encoder")

decoder_output = residual_stack_decoder(
    encoder_output, t_encoded, list(reversed(all_hidden)),
    reversed(filters),
    strides,
    reversed(blocks_per_level),
    "upconv", "decoder")

# final layer goes back to data shape
decoder_final = tfkl.Conv2D(inp.shape[-1], 1)(decoder_output)

score_model = Diffusion([inp, t_input], decoder_final, alphas, alphas_bar, betas, tmax)
score_model.summary()

In [None]:
# use fewer steps if it takes too long
train_steps = 20000
n_data = 60000
n_epochs = train_steps // (n_data // batch_size)
lr = tf.optimizers.schedules.CosineDecay(0.005, train_steps)

# NOTE  if you use significantly fewer steps, consider reducing the ema_momentum
# or turn it off completely (use_ema=False)
optimizer = tf.optimizers.Adam(lr, use_ema=True, ema_momentum=0.99)

score_model.compile(optimizer=optimizer, jit_compile=True)


# a note on training, the loss is quite noisy and unreliable.
# it's also hard to interpret.
# for this setup you can expect values around 20 or so.
# that should lead to decent samples.
# for reference, one epoch takes about 10 seconds on my hardware.

In [None]:
#%%%%%%%%%% TO UPDATE
# as the sampler is called here, you will want to create & provide some class values for conditioning
class ImageGenCallback(tf.keras.callbacks.Callback):
    def __init__(self, frequency, **kwargs):
        super().__init__(**kwargs)
        self.frequency = frequency
        
    def on_epoch_begin(self, epoch, logs=None):
        if not epoch % self.frequency:
            generated_batch = descale_data(self.model.langevin_sampler(n_samples=100))
        
            plt.figure(figsize=(15,15))
            for ind, image in enumerate(generated_batch):
                plt.subplot(10, 10, ind+1)
                plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
                plt.axis("off")
            plt.suptitle("Random generations")
            plt.show()


score_model.fit(train_data, validation_data=test_data, epochs=n_epochs, callbacks=ImageGenCallback(20))

In [None]:
generated_batch = descale_data(score_model.langevin_sampler(n_samples=100))
plt.figure(figsize=(15,15))
for ind, image in enumerate(generated_batch):
    plt.subplot(10, 10, ind+1)
    plt.imshow(image, cmap="Greys", vmin=0, vmax=1)
    plt.axis("off")
plt.suptitle("Random generations")
plt.show()

In [None]:
score_model.save_weights("weights/weights_diffusion_mnist.hdf5")