In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers

2024-10-14 11:36:45.844143: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-14 11:36:45.858480: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-14 11:36:45.870782: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-14 11:36:45.874411: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-14 11:36:45.884966: I tensorflow/core/platform/cpu_feature_guar

## Model Set Up

Cyclical Annealing

In [2]:
def cyclical_annealing(epoch, n_cycles, ratio):
    cycle_length = np.floor(epoch / n_cycles)
    cycle_position = epoch - cycle_length * n_cycles
    return np.minimum(1.0, cycle_position / (n_cycles * ratio))


class CyclicalAnnealingCallback(tf.keras.callbacks.Callback):
    def __init__(self, total_epochs, n_cycles, ratio, **kwargs):
        super().__init__()
        self.total_epochs = total_epochs
        self.n_cycles = n_cycles
        self.ratio = ratio
        
    def on_epoch_begin(self, epoch, logs=None):
        annealing = cyclical_annealing(epoch, self.n_cycles, self.ratio)
        self.model.kl_annealing.assign(annealing)

Sampling layer: Implements the reparametrization trick for sampling from a normal distribution.

In [3]:
@tf.keras.utils.register_keras_serializable(package='custom', name='VAE')
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

Linear block: simplifying implementing batch normalisation and dropout

In [4]:
@tf.keras.utils.register_keras_serializable(package='custom', name='VAE')
class LinearBlock(keras.Model):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True
        self.units = units
        self.dense = layers.Dense(units, activation="relu", kernel_initializer='he_normal')
        self.batch_norm = layers.BatchNormalization()

    def build(self, input_shape):
        super().build(input_shape)
        self.call(tf.random.normal(input_shape))
        self.built = True
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "units": self.units,
        })
        return config

    def call(self, inputs, training=False):
        x = self.dense(inputs)
        x = self.batch_norm(x, training=training)
        return x

### Encoder

In [5]:
@tf.keras.utils.register_keras_serializable(package='custom', name='VAE')
class Encoder(keras.Model):
    """Maps input to a triplet (z_mean, z_log_var, z)."""
    
    def __init__(self, latent_dim=32, masking_ratio=0.5, masking_value=-1, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        self.masking_ratio = masking_ratio
        self.masking_value = masking_value

        self.masking = layers.Masking(mask_value=self.masking_value)
        self.layer1 = LinearBlock(2048)
        self.layer2 = LinearBlock(1024)
        self.layer3 = LinearBlock(512)
        self.layer4 = LinearBlock(256)
        self.layer5 = LinearBlock(256)

        self.dense_mean = layers.Dense(self.latent_dim) 
        self.dense_log_var = layers.Dense(self.latent_dim)
        self.sampling = Sampling()

    def random_masking(self, data, ratio=None, mask_val=None):
        """Randomly masks input data.
        
        Args:
            data: Input data to be masked.
            ratio: The ratio of masking. Defaults to instance's masking_ratio.
            mask_val: The value to mask. Defaults to instance's masking_value.
        
        Returns:
            Masked data.
        """
        ratio = ratio or self.masking_ratio
        mask_val = mask_val or self.masking_value
        mask = tf.random.uniform(shape=tf.shape(data)) > ratio
        return tf.where(mask, data, tf.fill(tf.shape(data), mask_val))

    def get_config(self):
        """Returns the configuration of the Encoder."""
        config = super().get_config()
        config.update({
            "latent_dim": self.latent_dim,
            "masking_ratio": self.masking_ratio,
            "masking_value": self.masking_value,
        })
        return config

    def build(self, input_shape):
        """Build the model by calling it with a random input."""
        super().build(input_shape)
        self.call(tf.random.normal(input_shape))
        self.built = True


    def call(self, inputs, training=False):
        """Forward pass for the Encoder.
        
        Args:
            inputs: Input tensor.
            training: Whether the layer should behave in training mode or inference mode.
        
        Returns:
            A triplet (z_mean, z_log_var, z) representing latent variables.
        """
        inputs = self.masking(inputs)
        x = self.layer1(inputs, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)
        x = self.layer5(x, training=training)

        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z

### Decoder

In [6]:
@tf.keras.utils.register_keras_serializable(package='custom', name='VAE')
class Decoder(keras.Model):
    def __init__(self, latent_dim=32, original_dim=61124, **kwargs):
        super().__init__(**kwargs)
        self.supports_masking = True
        self.latent_dim = latent_dim
        self.original_dim = original_dim

        self.layer1 = LinearBlock(256)
        self.layer2 = LinearBlock(256)
        self.layer3 = LinearBlock(512)
        self.layer4 = LinearBlock(1024)
        self.layer5 = LinearBlock(2048)

        self.dense_output = layers.Dense(original_dim, activation="sigmoid")  

    def get_config(self):
        config = super().get_config()
        config.update({
            "latent_dim": self.latent_dim,
            "original_dim": self.original_dim,
        })
        return config

    def build(self, input_shape):
        super().build(input_shape)
        self.call(keras.random.normal(input_shape))
        self.built = True

    def call(self, inputs, training=False):
        x = self.layer1(inputs, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)
        x = self.layer5(x, training=training)
        return self.dense_output(x)

### VAE

In [7]:
@tf.keras.utils.register_keras_serializable(package='custom', name='VAE')
class VariationalAutoEncoder(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""
    def __init__(self,
            original_dim=61124, 
            latent_dim=32, 
            masking_value=-1.0, 
            loss = tf.keras.losses.MeanAbsoluteError(),
            **kwargs
    ):
        super().__init__(**kwargs)
        self.original_dim, self.latent_dim = original_dim, latent_dim
        self.supports_masking, self.masking_value = True, masking_value 
        self.loss = loss

        self.encoder = Encoder(self.latent_dim, self.masking_value)
        self.decoder = Decoder(self.latent_dim, self.original_dim)

        self.kl_annealing = tf.Variable(1.0, trainable=False)

    def get_config(self):
        config = super().get_config()
        config.update({
            "original_dim": self.original_dim,
            "latent_dim": self.latent_dim,
            "masking_value": self.masking_value,
            "loss": self.loss
        })
        return config

    def build(self):
        self.encoder.build((1, self.original_dim))
        self.decoder.build((1, self.latent_dim))
        self.call(keras.random.normal((1, self.original_dim)))
        self.built = True

    def call(self, inputs, training=False):
        z_mean, z_log_var, z = self.encoder(inputs, training=training)
        reconstructed = self.decoder(z, training=training)
        kl_loss = kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        return reconstructed, kl_loss        

    def train_step(self, data):
        x = data[0]
        y = data[1]
        with tf.GradientTape() as tape:
            reconstructed, kl_loss = self(inputs=x, training=True)
            reconstruction_loss = self.original_dim * self.loss(y, reconstructed)
            ELBO = reconstruction_loss + kl_loss * self.kl_annealing
        
        gradients = tape.gradient(ELBO, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        for metric in self.metrics:
            metric.update_state(x, reconstructed)
        
        return {"ELBO": ELBO, 
                "kl_loss": kl_loss, 
                "reconstruction_loss": reconstruction_loss,
                "annealing": self.kl_annealing}        

    def test_step(self, data):
        x = data[0]
        y = data[1]
        reconstructed, kl_loss = self(inputs=x, training=False)
        reconstruction_loss = self.original_dim * self.loss(y, reconstructed)
        ELBO = reconstruction_loss + kl_loss
        
        for metric in self.metrics:
            metric.update_state(y, reconstructed)

        return {"ELBO": ELBO, "kl_loss": kl_loss, 
                "reconstruction_loss": reconstruction_loss,
                "annealing": self.kl_annealing}

### Helper Functions

In [8]:
def build_vae(
        original_dim=61124, 
        latent_dim=64, 
        masking_value=-1., 
        loss=tf.keras.losses.MeanAbsoluteError()
):
    model = VariationalAutoEncoder(original_dim, latent_dim, masking_value, loss=loss)
    model.build()
    return model

## Training

In [9]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
print(tf.config.list_physical_devices('GPU'))

import vae_helper

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


I0000 00:00:1728877007.964385  181944 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1728877007.990475  181944 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
I0000 00:00:1728877007.990592  181944 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.


### Without Masking

In [None]:
train = np.load(f'Data/train_top_5000_genes_random.npy')
val = np.load(f'Data/val_top_5000_genes_random.npy')
vae_helper.train_model(train, val, original_dim=5000, num_masked_genes=0, latent_dim=64, batch_size=128, num_epochs=200,
                    kl_annealing_cycle_length=200, kl_annealing_ratio=0.8, by='random', mask_by=None)

In [None]:
train = np.load(f'Data/train_top_5000_genes_kl.npy')
val = np.load(f'Data/val_top_5000_genes_kl.npy')
vae_helper.train_model(train, val, original_dim=5000, num_masked_genes=0, latent_dim=64, batch_size=128, num_epochs=200,
                    kl_annealing_cycle_length=200, kl_annealing_ratio=0.8, by='random', mask_by=None)

In [None]:
train = np.load(f'Data/train_top_5000_genes_frequency.npy')
val = np.load(f'Data/val_top_5000_genes_frequency.npy')
vae_helper.train_model(train, val, original_dim=5000, num_masked_genes=0, latent_dim=64, batch_size=128, num_epochs=200,
                    kl_annealing_cycle_length=200, kl_annealing_ratio=0.8, by='frequency', mask_by=None)

### With Masking

In [None]:
train = np.load(f'Data/train_scaled.npy')
val = np.load(f'Data/val_scaled.npy')
vae_helper.train_model(train, val, original_dim=60660, num_masked_genes=5000, latent_dim=64, batch_size=128, num_epochs=200,
                    kl_annealing_cycle_length=200, kl_annealing_ratio=0.8, by='all', mask_by='random')

In [None]:
vae_helper.train_model(train, val, original_dim=60660, num_masked_genes=5000, latent_dim=64, batch_size=128, num_epochs=200,
                    kl_annealing_cycle_length=200, kl_annealing_ratio=0.8, by='all', mask_by='kl')

In [None]:
vae_helper.train_model(train, val, original_dim=60660, num_masked_genes=5000, latent_dim=64, batch_size=128, num_epochs=200,
                    kl_annealing_cycle_length=200, kl_annealing_ratio=0.8, by='all', mask_by='frequency')