In [25]:
import tensorflow as tf
from keras import layers, Model
from scipy.special import digamma, gammaln
import numpy as np

# Define the Encoder network
class Encoder(Model):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = layers.Dense(128, activation='relu')
        self.fc2 = layers.Dense(latent_dim, activation='softmax')  # Outputs alpha values

    def call(self, x):
        x = self.fc1(x)
        alpha = self.fc2(x)
        return alpha

# Define the Decoder network
class Decoder(Model):
    def __init__(self, original_dim):
        super(Decoder, self).__init__()
        self.fc1 = layers.Dense(128, activation='relu')
        self.fc2 = layers.Dense(original_dim, activation='sigmoid')

    def call(self, z):
        x = self.fc1(z)
        x = self.fc2(x)
        return x

# Define the VAE model
class VAE(Model):
    def __init__(self, original_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(original_dim)
        self.original_dim = original_dim
        self.latent_dim = latent_dim

    def call(self, x):
        alpha = self.encoder(x)
        u = tf.random.uniform(shape=(tf.shape(x)[0], self.latent_dim))
        v = -tf.math.log(-tf.math.log(u))
        z = v / tf.reduce_sum(v, axis=1, keepdims=True)
        x_recon = self.decoder(z)
        return x_recon, alpha

    def compute_loss(self, x):
        x_recon, alpha_hat = self.call(x)
        recon_loss = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.binary_crossentropy(x, x_recon), axis=1))
        kl_loss = tf.reduce_sum(
            tf.math.log(tf.math.gamma(alpha_hat)) - tf.math.log(tf.math.gamma(1.0)) +
            (1 - alpha_hat) * digamma(alpha_hat)
        )
        elbo = recon_loss - kl_loss
        return elbo

    def train_step(self, x):
        with tf.GradientTape() as tape:
            loss = self.compute_loss(x)
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {'loss': loss}


In [26]:
def train_vae(vae, dataset, epochs=10, hyper_update=True):
    for epoch in range(epochs):
        for x_batch in dataset:
            loss = vae.train_step(x_batch)

            # Update alpha parameter if hyper_update is True
            if hyper_update:
                alpha_hat = vae.encoder(x_batch)
                p_nk = alpha_hat / tf.reduce_sum(alpha_hat, axis=1, keepdims=True)
                mu_jk = tf.reduce_mean(tf.pow(p_nk, [1, 2]), axis=0)
                S = 1.0 / vae.latent_dim * tf.reduce_sum((mu_jk[0] - mu_jk[1]) / (mu_jk[1] - tf.pow(mu_jk[0], 2)))
                alpha = S / tf.shape(x_batch)[0] * tf.reduce_sum(p_nk, axis=0)

        print(f"Epoch {epoch + 1}, Loss: {loss['loss'].numpy()}")

# Create a VAE model
original_dim = 784  # For example, for MNIST dataset
latent_dim = 10
vae = VAE(original_dim, latent_dim)
vae.compile(optimizer=tf.keras.optimizers.Adam())

# Dummy dataset for demonstration
x_train = np.random.rand(60000, 784)
dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(32)

# Train the VAE
train_vae(vae, dataset)


ImportError: ignored