In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

from utils import display, sample_batch
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models, callbacks, utils, metrics, optimizers

In [None]:
# Functions for diffusion process
def linear_diffusion_schedule(diffusion_times):
    min_rate = 0.0001
    max_rate = 0.02
    betas = min_rate + diffusion_times * (max_rate - min_rate)
    alphas = 1 - betas
    alpha_bars = tf.math.cumprod(alphas)
    signal_rates = tf.sqrt(alpha_bars)
    noise_rates = tf.sqrt(1 - alpha_bars)
    return noise_rates, signal_rates

def cosine_diffusion_schedule(diffusion_times):
    signal_rates = tf.cos(diffusion_times * math.pi / 2)
    noise_rates = tf.sin(diffusion_times * math.pi / 2)
    return noise_rates, signal_rates

def offset_cosine_diffusion_schedule(diffusion_times):
    min_signal_rate = 0.02
    max_signal_rate = 0.95
    start_angle = tf.acos(max_signal_rate)
    end_angle = tf.acos(min_signal_rate)

    diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)

    signal_rates = tf.cos(diffusion_angles)
    noise_rates = tf.sin(diffusion_angles)

    return noise_rates, signal_rates

#### 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = # Define image size
CHANNELS = 3
BATCH_SIZE = # Define batch size
NUM_FEATURES = # Define number of features
Z_DIM = # Define z dimension for latent
LEARNING_RATE = # Define learning rate
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.999
EPOCHS = # Define number of Epochs
CRITIC_STEPS = 3
GP_WEIGHT = 10.0
LOAD_MODEL = False
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9

In [None]:
# The diffusion process function
def apply_diffusion_process(batch_size, T=1000, latent_dim=Z_DIM):
    diffusion_times = tf.linspace(0.0, 1.0, T)
    linear_noise_rates, linear_signal_rates = linear_diffusion_schedule(diffusion_times)

    linear_noise_rates_reshaped = tf.reshape(linear_noise_rates, (T, 1, 1, 1))

    batch_size = tf.shape(batch_size)[0]
    batch_size = tf.minimum(batch_size, T)

    batch_indices = tf.range(batch_size)
    noise_stack = tf.gather(linear_noise_rates_reshaped, batch_indices)
    noise_stack = tf.squeeze(noise_stack, axis=(1, 2, 3))

    latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))

    noise_stack = tf.expand_dims(noise_stack, axis=1)

    for step in range(T):
        step = tf.minimum(step, batch_size - 1)

        noise_rates = noise_stack[step]  

        noisy_latent = noise_rates[:, tf.newaxis] * tf.random.normal(shape=tf.shape(latent_vectors))
        latent_vectors = latent_vectors + noisy_latent  

    return latent_vectors

#### 1. Data preparing <a name="prepare"></a>

In [None]:
# Load the data
train_data = utils.image_dataset_from_directory(
    "/home/WorkingDir/WBC/01_Band_neutrophil",  
    labels=None,
    color_mode="rgb",
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=42,
    interpolation="bilinear",
).repeat()

In [None]:
# Preprocess the data
def preprocess(img):
    """
    Normalize and reshape the images
    """
    img = (tf.cast(img, "float32") - 127.5) / 127.5
    return img


train = train_data.map(lambda x: preprocess(x))

In [None]:
# Some images' visualization from the training set
train_sample = sample_batch(train)

# display(train_sample, cmap=None)

#### 2. Building the Diffusion-based WGAN-GP <a name="build"></a>

In [None]:
critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(critic_input)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(256, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(512, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size=4, strides=1, padding="valid")(x)
critic_output = layers.Flatten()(x)

critic = models.Model(critic_input, critic_output)
critic.summary()

In [None]:
generator_input = layers.Input(shape=(Z_DIM,))
x = layers.Reshape((1, 1, Z_DIM))(generator_input)
x = layers.Conv2DTranspose(
    512, kernel_size=4, strides=1, padding="valid", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    256, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    64, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(
    CHANNELS, kernel_size=4, strides=2, padding="same", activation="tanh"
)(x)
generator = models.Model(generator_input, generator_output)
generator.summary()

In [None]:
# Defining a callback for plotting and saving losses
class LossPlotter(callbacks.Callback):
    def __init__(self, gan, num_steps):
        self.gan = gan
        self.num_steps = num_steps
        self.c_loss_list = []
        self.c_wass_loss_list = []
        self.c_gp_list = []
        self.g_loss_list = []

    def on_epoch_end(self, epoch, logs=None):
        self.c_loss_list.append(logs["c_loss"])
        self.c_wass_loss_list.append(logs["c_wass_loss"])
        self.c_gp_list.append(logs["c_gp"])
        self.g_loss_list.append(logs["g_loss"])

        if epoch % self.num_steps == 0:
            self.plot_losses(epoch)

    def plot_losses(self, epoch):
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))
        ax.plot(self.c_loss_list, color='black', linewidth=0.25, label='Critic(Sum of real & fake)')
        ax.plot(self.c_wass_loss_list, color='green', linewidth=0.25, label='Critic(real)')
        ax.plot(self.c_gp_list, color='blue', linewidth=0.25, label='Critic(fake)')
        ax.plot(self.g_loss_list, color='orange', linewidth=0.25, label='Generator_loss')

        ax.set_xlabel('Epochs')
        ax.set_ylabel('Loss')

        ax.set_xlim(0, EPOCHS)
        plt.legend(loc="upper right")

        plt.savefig(f'gan_losses_epoch_{epoch}.pdf', format='pdf')
        plt.show()

In [None]:
class DWGANGP(models.Model):
    def __init__(self, critic, generator, latent_dim, critic_steps, gp_weight):
        super(DWGANGP, self).__init__()
        self.critic = critic
        self.generator = generator
        self.latent_dim = latent_dim
        self.critic_steps = critic_steps
        self.gp_weight = gp_weight

    def compile(self, c_optimizer, g_optimizer):
        super(DWGANGP, self).compile()
        self.c_optimizer = c_optimizer
        self.g_optimizer = g_optimizer
        self.c_wass_loss_metric = metrics.Mean(name="c_wass_loss")
        self.c_gp_metric = metrics.Mean(name="c_gp")
        self.c_loss_metric = metrics.Mean(name="c_loss")
        self.g_loss_metric = metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [
            self.c_loss_metric,
            self.c_wass_loss_metric,
            self.c_gp_metric,
            self.g_loss_metric,
        ]

    def gradient_penalty(self, batch_size, real_images, fake_images):
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = self.critic(interpolated, training=True)

        grads = gp_tape.gradient(pred, [interpolated])[0]
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]


        for i in range(self.critic_steps):
            noised_vectors = apply_diffusion_process(real_images)

            with tf.GradientTape() as tape:
                fake_images = self.generator(
                    noised_vectors, training=True
                )
                fake_predictions = self.critic(fake_images, training=True)
                real_predictions = self.critic(real_images, training=True)

                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(
                    real_predictions
                )
                c_gp = self.gradient_penalty(
                    batch_size, real_images, fake_images
                )
                c_loss = c_wass_loss + c_gp * self.gp_weight

            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            self.c_optimizer.apply_gradients(
                zip(c_gradient, self.critic.trainable_variables)
            )

        
        random_latent_vectors = tf.random.normal(
            shape=(batch_size, self.latent_dim)
        )
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_latent_vectors, training=True)
            fake_predictions = self.critic(fake_images, training=True)
            g_loss = -tf.reduce_mean(fake_predictions)

        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )

        self.c_loss_metric.update_state(c_loss)
        self.c_wass_loss_metric.update_state(c_wass_loss)
        self.c_gp_metric.update_state(c_gp)
        self.g_loss_metric.update_state(g_loss)


        return {m.name: m.result() for m in self.metrics}

In [None]:
dwgangp = DWGANGP(
    critic=critic,
    generator=generator,
    latent_dim=Z_DIM,
    critic_steps=CRITIC_STEPS,
    gp_weight=GP_WEIGHT,
)

In [None]:
if LOAD_MODEL:
    dwgangp.load_weights("./checkpoint/checkpoint.ckpt")

#### 3. Training the Model <a name="train"></a>

In [None]:
# Compiling model
dwgangp.compile(
    c_optimizer=optimizers.Adam(
        learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2
    ),
    g_optimizer=optimizers.Adam(
        learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2
    ),
)

In [None]:
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath="./checkpoint/checkpoint.ckpt",
    save_weights_only=True,
    save_freq="epoch",
    verbose=0,
)

tensorboard_callback = callbacks.TensorBoard(log_dir="./logs")


class ImageGenerator(callbacks.Callback):
    def __init__(self, num_img, latent_dim):
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.last_epoch_images = None

    def on_epoch_end(self, epoch, logs=None):
        if epoch >= (EPOCHS - 2):
            random_latent_vectors = tf.random.normal(
                shape=(self.num_img, self.latent_dim)
            )
            generated_images = self.model.generator(random_latent_vectors)
            generated_images = generated_images * 127.5 + 127.5
            generated_images = generated_images.numpy()

            if epoch == (EPOCHS - 1):
                display(
                    generated_images,
                    cmap=None,
                )
            elif epoch == (EPOCHS - 2):
                self.last_epoch_images = generated_images
        else:
            self.last_epoch_images = None

    def on_train_end(self, logs=None):
        if self.last_epoch_images is not None:
            display(
                self.last_epoch_images,
                save_to="./output/generated_img_last_epoch.pdf",
                cmap=None,
            )

In [None]:
dwgangp.fit(
    train,
    epochs=EPOCHS,
    steps_per_epoch=2,                          
    callbacks=[
        model_checkpoint_callback,
        tensorboard_callback,
        ImageGenerator(num_img=25, latent_dim=Z_DIM),
    ],
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(dwgangp.history.history['c_loss'], color='black', linewidth=0.25, label='Critic(Sum of real & fake)')
ax.plot(dwgangp.history.history['c_wass_loss'], color='green', linewidth=0.25, label='Critic(real)')
ax.plot(dwgangp.history.history['c_gp'], color='blue', linewidth=0.25, label='Critic(fake)')
ax.plot(dwgangp.history.history['g_loss'], color='orange', linewidth=0.25, label='Generator_loss')

ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')

ax.set_xlim(0, EPOCHS)
plt.legend(loc="upper right")

plt.savefig('dwgangp_losses_final.pdf', format='pdf')
plt.show()

In [None]:
generator.save("./models/generator")
critic.save("./models/critic")

#### Generate images

In [None]:
z_sample = np.random.normal(size=(25, Z_DIM))
imgs = dwgangp.generator.predict(z_sample)
display(imgs, cmap=None)