In [1]:
from sklearn.mixture import GaussianMixture
from tensorflow.keras import layers, models, backend as K
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [5]:
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [6]:
latent_dim = 2  # Dimensionality of the latent space

# Encoder
encoder_inputs = layers.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation='relu', strides=2, padding='same')(encoder_inputs)
x = layers.Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation='relu')(x)
z_mean = layers.Dense(latent_dim, name='z_mean')(x)
z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
z = Sampling()([z_mean, z_log_var])
encoder = models.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')

# Decoder
latent_inputs = layers.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation='relu')(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
x = layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation='sigmoid', padding='same')(x)
decoder = models.Model(latent_inputs, decoder_outputs, name='decoder')

# VAE
vae_outputs = decoder(encoder(encoder_inputs)[2])
vae = models.Model(encoder_inputs, vae_outputs, name='vae')

In [7]:
class VAE(models.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        reconstruction_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(K.flatten(inputs), K.flatten(reconstructed))
        ) * 28 * 28
        kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
        kl_loss = tf.reduce_mean(K.sum(kl_loss, axis=-1)) * -0.5
        self.add_loss(reconstruction_loss + kl_loss)
        return reconstructed

vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')

In [9]:
# Load dataset (for example, MNIST)
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1).astype('float32') / 255
x_test = np.expand_dims(x_test, -1).astype('float32') / 255

vae.fit(x_train, epochs=3, batch_size=128, validation_data=(x_test, None))

Epoch 1/3
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m32s[0m 68ms/step - loss: 191.8419 - val_loss: 173.7255
Epoch 2/3
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 62ms/step - loss: 171.3431 - val_loss: 166.5556
Epoch 3/3
[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 55ms/step - loss: 165.4107 - val_loss: 162.7740


<keras.src.callbacks.history.History at 0x376079d50>

In [10]:
# Extract latent representations
z_mean, _, _ = encoder.predict(x_train, batch_size=128)
gmm = GaussianMixture(n_components=7, covariance_type='full')
gmm.fit(z_mean)

[1m469/469[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step


In [16]:
# Encode new images and predict mixture component responsibilities
z_mean_test, _, _ = encoder.predict(x_test, batch_size=128)
responsibilities = gmm.predict_proba(z_mean_test)

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step


In [17]:
responsibilities_rounded = np.round(responsibilities, 3)

print("Rounded mixture responsibilities for the first test image:")
print(responsibilities_rounded[0])
print("Sum of responsibilities for the first test image:")
print(np.sum(responsibilities_rounded[0]))

Rounded mixture responsibilities for the first test image:
[0.142 0.005 0.089 0.711 0.    0.053 0.   ]
Sum of responsibilities for the first test image:
1.0
