In [None]:
import keras
from keras import ops
from keras import layers
'''
hit_fingerprints = hit_fingerprints[..., None].astype("float32")
x_train = x_train[..., None].astype("float32")
y_train = y_train.astype("int")
'''

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a molecular fingerprint."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = ops.shape(z_mean)[0]
        dim = ops.shape(z_mean)[1]
        epsilon = keras.random.normal(shape=(batch, dim), seed=self.seed_generator)
        return z_mean + ops.exp(0.5 * z_log_var) * epsilon


latent_dim = 10


encoder_inputs = keras.Input(shape=(2048, 1))
x = layers.Conv1D(filters = 32, kernel_size = (3,), activation = "relu")(encoder_inputs)
x = layers.MaxPool1D(pool_size = (2,))(x)
x = layers.Conv1D(filters = 64, kernel_size = (3,), activation = "relu")(x)
x = layers.MaxPool1D(pool_size = (2,))(x)
x = layers.Flatten()(x)
x = layers.Dense(128, activation = 'relu')(x)
x = layers.Dropout(rate = 0.3)(x)
z_mean = layers.Dense(latent_dim, name = "z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()


latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(128, activation="relu")(latent_inputs)
x = layers.Dense(64 * (2048 // (2 ** 2)), activation="relu")(x)
x = layers.Reshape(((2048 // (2 ** 2)), 64))(x)
x = layers.Conv1DTranspose(filters=64, kernel_size=(3,), activation="relu", padding="same")(x)
x = layers.UpSampling1D(size=2)(x)
x = layers.Conv1DTranspose(filters=32, kernel_size=(3,), activation="relu", padding="same")(x)
x = layers.UpSampling1D(size=2)(x)
decoder_outputs = layers.Conv1DTranspose(filters = 1, kernel_size = (3,), activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        if len(data.shape) == 2:  
            data = ops.expand_dims(data, axis=-1)
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = keras.losses.binary_crossentropy(data, reconstruction)
            reconstruction_loss = ops.mean(reconstruction_loss)
            kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
            kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(hit_fingerprints, epochs=30, batch_size=128)


def plot_label_clusters(vae, data, labels):
    # display a 2D plot of the digit classes in the latent space
    z_mean, _, _ = vae.encoder.predict(data, verbose=0)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()


plot_label_clusters(vae, x_train, y_train)