In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers

class VAE(keras.Model):
    def __init__(self, input_dim=9, hidden_dim=7, latent_dim=2):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        # Build encoder and decoder
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()
        
    def build_encoder(self):
        """Build the encoder network: input(9) -> hidden(7) -> latent(2)"""
        inputs = keras.Input(shape=(self.input_dim,))
        
        # Hidden layer
        hidden = layers.Dense(self.hidden_dim, activation="relu")(inputs)
        
        # Mean and log variance for latent space
        z_mean = layers.Dense(self.latent_dim, name="z_mean")(hidden)
        z_log_var = layers.Dense(self.latent_dim, name="z_log_var")(hidden)
        
        return keras.Model(inputs, [z_mean, z_log_var], name="encoder")
    
    def build_decoder(self):
        """Build the decoder network: latent(2) -> hidden(7) -> output(9)"""
        latent_inputs = keras.Input(shape=(self.latent_dim,))
        
        # Hidden layer
        hidden = layers.Dense(self.hidden_dim, activation="relu")(latent_inputs)
        
        # Output layer (sigmoid for values between 0 and 1)
        outputs = layers.Dense(self.input_dim, activation="sigmoid")(hidden)
        
        return keras.Model(latent_inputs, outputs, name="decoder")
    
    def reparameterize(self, z_mean, z_log_var):
        """Reparameterization trick: sample from N(mu, sigma) using N(0,1)"""
        batch_size = tf.shape(z_mean)[0]
        epsilon = tf.random.normal(shape=(batch_size, self.latent_dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
    def call(self, inputs):
        """Forward pass through the VAE"""
        z_mean, z_log_var = self.encoder(inputs)
        z = self.reparameterize(z_mean, z_log_var)
        reconstructed = self.decoder(z)
        return reconstructed, z_mean, z_log_var
    
    def encode(self, x):
        """Encode input to latent space"""
        z_mean, z_log_var = self.encoder(x)
        return self.reparameterize(z_mean, z_log_var)
    
    def decode(self, z):
        """Decode from latent space"""
        return self.decoder(z)

class VAETrainer:
    def __init__(self, vae, optimizer=None):
        self.vae = vae
        self.optimizer = optimizer or keras.optimizers.Adam(1e-3)
        
        # Metrics
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
    
    def compute_loss(self, x):
        """Compute VAE loss: reconstruction loss + KL divergence"""
        reconstructed, z_mean, z_log_var = self.vae(x)
        
        # Reconstruction loss (mean squared error for continuous data)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(x - reconstructed), axis=1)
        )
        
        # KL divergence loss
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        
        total_loss = reconstruction_loss + kl_loss
        return total_loss, reconstruction_loss, kl_loss
    
    @tf.function
    def train_step(self, x):
        """Single training step"""
        with tf.GradientTape() as tape:
            total_loss, reconstruction_loss, kl_loss = self.compute_loss(x)
        
        gradients = tape.gradient(total_loss, self.vae.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.vae.trainable_variables))
        
        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    
    def train(self, dataset, epochs=100, verbose=1):
        """Train the VAE"""
        for epoch in range(epochs):
            # Reset metrics
            self.total_loss_tracker.reset_states()
            self.reconstruction_loss_tracker.reset_states()
            self.kl_loss_tracker.reset_states()
            
            # Training loop
            for step, x_batch in enumerate(dataset):
                metrics = self.train_step(x_batch)
            
            if verbose and (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch + 1}/{epochs}")
                print(f"Loss: {metrics['loss']:.4f}, "
                      f"Reconstruction: {metrics['reconstruction_loss']:.4f}, "
                      f"KL: {metrics['kl_loss']:.4f}")

def generate_synthetic_data(n_samples=1000):
    """Generate synthetic 9-dimensional data for demonstration"""
    np.random.seed(42)
    
    # Create two clusters in 9D space
    cluster1 = np.random.multivariate_normal(
        mean=[0.3, 0.7, 0.2, 0.8, 0.1, 0.9, 0.4, 0.6, 0.5],
        cov=np.eye(9) * 0.05,
        size=n_samples // 2
    )
    
    cluster2 = np.random.multivariate_normal(
        mean=[0.8, 0.2, 0.9, 0.1, 0.7, 0.3, 0.6, 0.4, 0.5],
        cov=np.eye(9) * 0.05,
        size=n_samples // 2
    )
    
    # Combine and normalize to [0, 1]
    data = np.vstack([cluster1, cluster2])
    data = np.clip(data, 0, 1)  # Ensure values are in [0, 1]
    
    # Create labels for visualization
    labels = np.concatenate([np.zeros(n_samples // 2), np.ones(n_samples // 2)])
    
    return data.astype(np.float32), labels.astype(int)

def plot_latent_space(vae, data, labels=None, title="Latent Space Representation"):
    """Plot the 2D latent space representation"""
    # Encode to latent space
    z_mean, _ = vae.encoder(data)
    
    plt.figure(figsize=(8, 6))
    if labels is not None:
        scatter = plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels, cmap='viridis', alpha=0.6)
        plt.colorbar(scatter, label='Cluster')
    else:
        plt.scatter(z_mean[:, 0], z_mean[:, 1], alpha=0.6)
    
    plt.xlabel("Latent Dimension 1")
    plt.ylabel("Latent Dimension 2")
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.show()

def plot_reconstructions(vae, data, n_samples=5):
    """Plot original vs reconstructed vectors"""
    indices = np.random.choice(len(data), n_samples, replace=False)
    x_sample = data[indices]
    reconstructions, _, _ = vae(x_sample)
    
    fig, axes = plt.subplots(2, n_samples, figsize=(15, 6))
    
    for i in range(n_samples):
        # Original
        axes[0, i].bar(range(9), x_sample[i])
        axes[0, i].set_title(f"Original {i+1}")
        axes[0, i].set_ylim(0, 1)
        axes[0, i].set_xticks(range(9))
        
        # Reconstruction
        axes[1, i].bar(range(9), reconstructions[i])
        axes[1, i].set_title(f"Reconstructed {i+1}")
        axes[1, i].set_ylim(0, 1)
        axes[1, i].set_xticks(range(9))
    
    plt.tight_layout()
    plt.show()

def plot_generated_samples(vae, n_samples=5):
    """Generate and plot new samples from random latent vectors"""
    # Sample random points in latent space
    random_latent = tf.random.normal(shape=(n_samples, vae.latent_dim))
    generated_samples = vae.decode(random_latent)
    
    plt.figure(figsize=(15, 3))
    for i in range(n_samples):
        plt.subplot(1, n_samples, i + 1)
        plt.bar(range(9), generated_samples[i])
        plt.title(f"Generated {i+1}")
        plt.ylim(0, 1)
        plt.xticks(range(9))
    
    plt.suptitle("Generated Samples from Random Latent Vectors")
    plt.tight_layout()
    plt.show()

def interpolate_in_latent_space(vae, data, n_steps=5):
    """Interpolate between two points in latent space"""
    # Get two random samples
    indices = np.random.choice(len(data), 2, replace=False)
    x1, x2 = data[indices[0]:indices[0]+1], data[indices[1]:indices[1]+1]
    
    # Encode to latent space
    z1_mean, _ = vae.encoder(x1)
    z2_mean, _ = vae.encoder(x2)
    
    # Interpolate
    alphas = np.linspace(0, 1, n_steps)
    interpolated_z = []
    for alpha in alphas:
        z_interp = (1 - alpha) * z1_mean + alpha * z2_mean
        interpolated_z.append(z_interp)
    
    interpolated_z = tf.concat(interpolated_z, axis=0)
    interpolated_samples = vae.decode(interpolated_z)
    
    # Plot interpolation
    plt.figure(figsize=(15, 3))
    for i in range(n_steps):
        plt.subplot(1, n_steps, i + 1)
        plt.bar(range(9), interpolated_samples[i])
        plt.title(f"α = {alphas[i]:.2f}")
        plt.ylim(0, 1)
        plt.xticks(range(9))
    
    plt.suptitle("Latent Space Interpolation")
    plt.tight_layout()
    plt.show()

def print_model_summary(vae):
    """Print model architecture summary"""
    print("VAE Architecture:")
    print("=" * 50)
    print(f"Input dimension: {vae.input_dim}")
    print(f"Hidden dimension: {vae.hidden_dim}")
    print(f"Latent dimension: {vae.latent_dim}")
    print("\nEncoder:")
    vae.encoder.summary()
    print("\nDecoder:")
    vae.decoder.summary()

# Example usage
if __name__ == "__main__":
    # Generate synthetic data
    print("Generating synthetic 9D data...")
    data, labels = generate_synthetic_data(n_samples=2000)
    
    # Split into train/test
    train_size = int(0.8 * len(data))
    x_train, x_test = data[:train_size], data[train_size:]
    y_train, y_test = labels[:train_size], labels[train_size:]
    
    # Create dataset
    batch_size = 64
    train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
    train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size)
    
    # Create VAE
    vae = VAE(input_dim=9, hidden_dim=7, latent_dim=2)
    print_model_summary(vae)
    
    # Train VAE
    trainer = VAETrainer(vae)
    print("\nTraining VAE...")
    trainer.train(train_dataset, epochs=100)
    
    # Visualizations
    print("\nPlotting reconstructions...")
    plot_reconstructions(vae, x_test)
    
    print("Plotting generated samples...")
    plot_generated_samples(vae)
    
    print("Plotting latent space...")
    plot_latent_space(vae, x_test, y_test)
    
    print("Plotting latent space interpolation...")
    interpolate_in_latent_space(vae, x_test)

ModuleNotFoundError: No module named 'tensorflow.python'

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers

class VAE(keras.Model):
    def __init__(self, latent_dim=32, input_shape=(28, 28, 1)):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        self.input_shape = input_shape
        
        # Build encoder and decoder
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()
        
    def build_encoder(self):
        """Build the encoder network"""
        inputs = keras.Input(shape=self.input_shape)
        
        # Convolutional layers
        x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(inputs)
        x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
        x = layers.Flatten()(x)
        x = layers.Dense(16, activation="relu")(x)
        
        # Mean and log variance for latent space
        z_mean = layers.Dense(self.latent_dim, name="z_mean")(x)
        z_log_var = layers.Dense(self.latent_dim, name="z_log_var")(x)
        
        return keras.Model(inputs, [z_mean, z_log_var], name="encoder")
    
    def build_decoder(self):
        """Build the decoder network"""
        latent_inputs = keras.Input(shape=(self.latent_dim,))
        x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
        x = layers.Reshape((7, 7, 64))(x)
        
        # Transpose convolutional layers
        x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
        x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
        
        # Output layer
        decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
        
        return keras.Model(latent_inputs, decoder_outputs, name="decoder")
    
    def reparameterize(self, z_mean, z_log_var):
        """Reparameterization trick: sample from N(mu, sigma) using N(0,1)"""
        batch_size = tf.shape(z_mean)[0]
        epsilon = tf.random.normal(shape=(batch_size, self.latent_dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
    def call(self, inputs):
        """Forward pass through the VAE"""
        z_mean, z_log_var = self.encoder(inputs)
        z = self.reparameterize(z_mean, z_log_var)
        reconstructed = self.decoder(z)
        return reconstructed, z_mean, z_log_var
    
    def encode(self, x):
        """Encode input to latent space"""
        z_mean, z_log_var = self.encoder(x)
        return self.reparameterize(z_mean, z_log_var)
    
    def decode(self, z):
        """Decode from latent space"""
        return self.decoder(z)

class VAETrainer:
    def __init__(self, vae, optimizer=None):
        self.vae = vae
        self.optimizer = optimizer or keras.optimizers.Adam(1e-4)
        
        # Metrics
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
    
    def compute_loss(self, x):
        """Compute VAE loss: reconstruction loss + KL divergence"""
        reconstructed, z_mean, z_log_var = self.vae(x)
        
        # Reconstruction loss (binary crossentropy)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(
                keras.losses.binary_crossentropy(x, reconstructed), axis=(1, 2)
            )
        )
        
        # KL divergence loss
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        
        total_loss = reconstruction_loss + kl_loss
        return total_loss, reconstruction_loss, kl_loss
    
    @tf.function
    def train_step(self, x):
        """Single training step"""
        with tf.GradientTape() as tape:
            total_loss, reconstruction_loss, kl_loss = self.compute_loss(x)
        
        gradients = tape.gradient(total_loss, self.vae.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.vae.trainable_variables))
        
        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    
    def train(self, dataset, epochs=10, verbose=1):
        """Train the VAE"""
        for epoch in range(epochs):
            # Reset metrics
            self.total_loss_tracker.reset_states()
            self.reconstruction_loss_tracker.reset_states()
            self.kl_loss_tracker.reset_states()
            
            # Training loop
            for step, x_batch in enumerate(dataset):
                metrics = self.train_step(x_batch)
            
            if verbose:
                print(f"Epoch {epoch + 1}/{epochs}")
                print(f"Loss: {metrics['loss']:.4f}, "
                      f"Reconstruction: {metrics['reconstruction_loss']:.4f}, "
                      f"KL: {metrics['kl_loss']:.4f}")

def load_and_preprocess_data():
    """Load and preprocess MNIST dataset"""
    (x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
    
    # Normalize to [0, 1] and add channel dimension
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    
    return x_train, x_test

def plot_latent_space(vae, x_test, y_test=None, n_samples=1000):
    """Plot the latent space representation"""
    # Sample subset of test data
    indices = np.random.choice(len(x_test), n_samples, replace=False)
    x_sample = x_test[indices]
    
    # Encode to latent space
    z_mean, _ = vae.encoder(x_sample)
    
    plt.figure(figsize=(8, 6))
    if y_test is not None:
        y_sample = y_test[indices]
        scatter = plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_sample, cmap='tab10', alpha=0.6)
        plt.colorbar(scatter)
    else:
        plt.scatter(z_mean[:, 0], z_mean[:, 1], alpha=0.6)
    
    plt.xlabel("Latent Dimension 1")
    plt.ylabel("Latent Dimension 2")
    plt.title("Latent Space Representation")
    plt.show()

def plot_generated_images(vae, n_samples=10):
    """Generate and plot new images from random latent vectors"""
    # Sample random points in latent space
    random_latent = tf.random.normal(shape=(n_samples, vae.latent_dim))
    generated_images = vae.decode(random_latent)
    
    plt.figure(figsize=(20, 4))
    for i in range(n_samples):
        plt.subplot(1, n_samples, i + 1)
        plt.imshow(generated_images[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.suptitle("Generated Images from Random Latent Vectors")
    plt.show()

def plot_reconstructions(vae, x_test, n_samples=10):
    """Plot original vs reconstructed images"""
    indices = np.random.choice(len(x_test), n_samples, replace=False)
    x_sample = x_test[indices]
    reconstructions, _, _ = vae(x_sample)
    
    plt.figure(figsize=(20, 4))
    for i in range(n_samples):
        # Original
        ax = plt.subplot(2, n_samples, i + 1)
        plt.imshow(x_sample[i, :, :, 0], cmap='gray')
        plt.title("Original")
        plt.axis('off')
        
        # Reconstruction
        ax = plt.subplot(2, n_samples, i + 1 + n_samples)
        plt.imshow(reconstructions[i, :, :, 0], cmap='gray')
        plt.title("Reconstructed")
        plt.axis('off')
    plt.show()

# Example usage
if __name__ == "__main__":
    # Load data
    x_train, x_test = load_and_preprocess_data()
    
    # Create dataset
    batch_size = 128
    train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
    
    # Create and train VAE
    vae = VAE(latent_dim=2)  # 2D latent space for visualization
    trainer = VAETrainer(vae)
    
    print("Training VAE...")
    trainer.train(train_dataset, epochs=10)
    
    # Visualizations
    print("Plotting reconstructions...")
    plot_reconstructions(vae, x_test)
    
    print("Plotting generated images...")
    plot_generated_images(vae)
    
    # Load labels for latent space visualization
    (_, y_train), (_, y_test) = keras.datasets.mnist.load_data()
    print("Plotting latent space...")
    plot_latent_space(vae, x_test, y_test)b

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers

class CVAE(keras.Model):
    def __init__(self, input_dim=9, hidden_dim=7, latent_dim=2, num_classes=3):
        super(CVAE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Build encoder and decoder
        self.encoder = self.build_encoder()
        self.decoder = self.build_decoder()
        
    def build_encoder(self):
        """Build the conditional encoder network: input(9) + condition(num_classes) -> hidden(7) -> latent(2)"""
        # Input data
        data_inputs = keras.Input(shape=(self.input_dim,), name="data_input")
        
        # Condition input (one-hot encoded)
        condition_inputs = keras.Input(shape=(self.num_classes,), name="condition_input")
        
        # Concatenate data and condition
        combined = layers.Concatenate()([data_inputs, condition_inputs])
        
        # Hidden layer
        hidden = layers.Dense(self.hidden_dim, activation="relu")(combined)
        
        # Mean and log variance for latent space
        z_mean = layers.Dense(self.latent_dim, name="z_mean")(hidden)
        z_log_var = layers.Dense(self.latent_dim, name="z_log_var")(hidden)
        
        return keras.Model([data_inputs, condition_inputs], [z_mean, z_log_var], name="encoder")
    
    def build_decoder(self):
        """Build the conditional decoder network: latent(2) + condition(num_classes) -> hidden(7) -> output(9)"""
        # Latent input
        latent_inputs = keras.Input(shape=(self.latent_dim,), name="latent_input")
        
        # Condition input (one-hot encoded)
        condition_inputs = keras.Input(shape=(self.num_classes,), name="condition_input")
        
        # Concatenate latent and condition
        combined = layers.Concatenate()([latent_inputs, condition_inputs])
        
        # Hidden layer
        hidden = layers.Dense(self.hidden_dim, activation="relu")(combined)
        
        # Output layer (sigmoid for values between 0 and 1)
        outputs = layers.Dense(self.input_dim, activation="sigmoid")(hidden)
        
        return keras.Model([latent_inputs, condition_inputs], outputs, name="decoder")
    
    def reparameterize(self, z_mean, z_log_var):
        """Reparameterization trick: sample from N(mu, sigma) using N(0,1)"""
        batch_size = tf.shape(z_mean)[0]
        epsilon = tf.random.normal(shape=(batch_size, self.latent_dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
    def call(self, inputs):
        """Forward pass through the CVAE"""
        data, conditions = inputs
        z_mean, z_log_var = self.encoder([data, conditions])
        z = self.reparameterize(z_mean, z_log_var)
        reconstructed = self.decoder([z, conditions])
        return reconstructed, z_mean, z_log_var
    
    def encode(self, data, conditions):
        """Encode input to latent space"""
        z_mean, z_log_var = self.encoder([data, conditions])
        return self.reparameterize(z_mean, z_log_var)
    
    def decode(self, z, conditions):
        """Decode from latent space"""
        return self.decoder([z, conditions])

class CVAETrainer:
    def __init__(self, cvae, optimizer=None):
        self.cvae = cvae
        self.optimizer = optimizer or keras.optimizers.Adam(1e-3)
        
        # Metrics
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
    
    def compute_loss(self, data, conditions):
        """Compute CVAE loss: reconstruction loss + KL divergence"""
        reconstructed, z_mean, z_log_var = self.cvae([data, conditions])
        
        # Reconstruction loss (mean squared error for continuous data)
        reconstruction_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(data - reconstructed), axis=1)
        )
        
        # KL divergence loss
        kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
        kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
        
        total_loss = reconstruction_loss + kl_loss
        return total_loss, reconstruction_loss, kl_loss
    
    @tf.function
    def train_step(self, data, conditions):
        """Single training step"""
        with tf.GradientTape() as tape:
            total_loss, reconstruction_loss, kl_loss = self.compute_loss(data, conditions)
        
        gradients = tape.gradient(total_loss, self.cvae.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.cvae.trainable_variables))
        
        # Update metrics
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }
    
    def train(self, dataset, epochs=100, verbose=1):
        """Train the CVAE"""
        for epoch in range(epochs):
            # Reset metrics
            self.total_loss_tracker.reset_states()
            self.reconstruction_loss_tracker.reset_states()
            self.kl_loss_tracker.reset_states()
            
            # Training loop
            for step, (data_batch, condition_batch) in enumerate(dataset):
                metrics = self.train_step(data_batch, condition_batch)
            
            if verbose and (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch + 1}/{epochs}")
                print(f"Loss: {metrics['loss']:.4f}, "
                      f"Reconstruction: {metrics['reconstruction_loss']:.4f}, "
                      f"KL: {metrics['kl_loss']:.4f}")

def generate_conditional_synthetic_data(n_samples=1500, num_classes=3):
    """Generate synthetic 9-dimensional data with different conditions/classes"""
    np.random.seed(42)
    
    data_list = []
    labels_list = []
    
    # Define different patterns for each class
    class_patterns = [
        # Class 0: Low values pattern
        {"mean": [0.2, 0.3, 0.1, 0.4, 0.2, 0.3, 0.1, 0.2, 0.3], "cov": 0.03},
        # Class 1: Medium values pattern  
        {"mean": [0.5, 0.6, 0.4, 0.7, 0.5, 0.6, 0.4, 0.5, 0.6], "cov": 0.03},
        # Class 2: High values pattern
        {"mean": [0.8, 0.7, 0.9, 0.6, 0.8, 0.7, 0.9, 0.8, 0.7], "cov": 0.03}
    ]
    
    samples_per_class = n_samples // num_classes
    
    for class_id in range(num_classes):
        pattern = class_patterns[class_id]
        
        class_data = np.random.multivariate_normal(
            mean=pattern["mean"],
            cov=np.eye(9) * pattern["cov"],
            size=samples_per_class
        )
        
        # Clip to [0, 1] range
        class_data = np.clip(class_data, 0, 1)
        
        data_list.append(class_data)
        labels_list.append(np.full(samples_per_class, class_id))
    
    # Combine all classes
    data = np.vstack(data_list).astype(np.float32)
    labels = np.concatenate(labels_list).astype(int)
    
    # Shuffle the data
    indices = np.random.permutation(len(data))
    data = data[indices]
    labels = labels[indices]
    
    return data, labels

def labels_to_onehot(labels, num_classes):
    """Convert integer labels to one-hot encoding"""
    return tf.one_hot(labels, num_classes)

def plot_conditional_latent_space(cvae, data, labels, title="Conditional Latent Space"):
    """Plot the 2D latent space for different conditions"""
    conditions_onehot = labels_to_onehot(labels, cvae.num_classes)
    
    # Encode to latent space
    z_mean, _ = cvae.encoder([data, conditions_onehot])
    
    plt.figure(figsize=(10, 8))
    colors = ['red', 'blue', 'green', 'orange', 'purple']
    
    for class_id in range(cvae.num_classes):
        mask = labels == class_id
        plt.scatter(z_mean[mask, 0], z_mean[mask, 1], 
                   c=colors[class_id], label=f'Class {class_id}', alpha=0.6)
    
    plt.xlabel("Latent Dimension 1")
    plt.ylabel("Latent Dimension 2")
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

def plot_conditional_reconstructions(cvae, data, labels, n_samples=5):
    """Plot original vs reconstructed vectors for each class"""
    fig, axes = plt.subplots(2 * cvae.num_classes, n_samples, 
                            figsize=(15, 4 * cvae.num_classes))
    
    for class_id in range(cvae.num_classes):
        # Get samples from this class
        class_mask = labels == class_id
        class_data = data[class_mask]
        class_labels = labels[class_mask]
        
        if len(class_data) < n_samples:
            continue
            
        indices = np.random.choice(len(class_data), n_samples, replace=False)
        x_sample = class_data[indices]
        y_sample = class_labels[indices]
        
        conditions_onehot = labels_to_onehot(y_sample, cvae.num_classes)
        reconstructions, _, _ = cvae([x_sample, conditions_onehot])
        
        for i in range(n_samples):
            row_orig = class_id * 2
            row_recon = class_id * 2 + 1
            
            # Original
            axes[row_orig, i].bar(range(9), x_sample[i])
            axes[row_orig, i].set_title(f"Class {class_id} - Original {i+1}")
            axes[row_orig, i].set_ylim(0, 1)
            axes[row_orig, i].set_xticks(range(9))
            
            # Reconstruction
            axes[row_recon, i].bar(range(9), reconstructions[i])
            axes[row_recon, i].set_title(f"Class {class_id} - Reconstructed {i+1}")
            axes[row_recon, i].set_ylim(0, 1)
            axes[row_recon, i].set_xticks(range(9))
    
    plt.tight_layout()
    plt.show()

def generate_conditional_samples(cvae, n_samples_per_class=3):
    """Generate new samples for each condition"""
    fig, axes = plt.subplots(cvae.num_classes, n_samples_per_class, 
                            figsize=(12, 3 * cvae.num_classes))
    
    for class_id in range(cvae.num_classes):
        # Create condition vector for this class
        conditions = np.zeros((n_samples_per_class, cvae.num_classes))
        conditions[:, class_id] = 1  # One-hot encoding
        conditions = tf.constant(conditions, dtype=tf.float32)
        
        # Sample random latent vectors
        random_latent = tf.random.normal(shape=(n_samples_per_class, cvae.latent_dim))
        
        # Generate samples
        generated_samples = cvae.decode(random_latent, conditions)
        
        for i in range(n_samples_per_class):
            if cvae.num_classes == 1:
                ax = axes[i]
            else:
                ax = axes[class_id, i]
                
            ax.bar(range(9), generated_samples[i])
            ax.set_title(f"Class {class_id} - Generated {i+1}")
            ax.set_ylim(0, 1)
            ax.set_xticks(range(9))
    
    plt.suptitle("Generated Samples by Condition")
    plt.tight_layout()
    plt.show()

def interpolate_between_conditions(cvae, data, labels):
    """Interpolate between different conditions in latent space"""
    # Get one sample from each class
    samples_by_class = []
    for class_id in range(cvae.num_classes):
        class_mask = labels == class_id
        class_data = data[class_mask]
        if len(class_data) > 0:
            samples_by_class.append(class_data[0:1])  # Take first sample
    
    if len(samples_by_class) < 2:
        print("Need at least 2 classes for interpolation")
        return
    
    # Encode samples to latent space
    latent_codes = []
    for i, sample in enumerate(samples_by_class):
        condition = np.zeros((1, cvae.num_classes))
        condition[0, i] = 1
        condition = tf.constant(condition, dtype=tf.float32)
        
        z_mean, _ = cvae.encoder([sample, condition])
        latent_codes.append(z_mean)
    
    # Interpolate between first two classes
    z1, z2 = latent_codes[0], latent_codes[1]
    n_steps = 5
    alphas = np.linspace(0, 1, n_steps)
    
    # Create conditions for interpolation (gradually change from class 0 to class 1)
    plt.figure(figsize=(15, 6))
    
    for i, alpha in enumerate(alphas):
        # Interpolate latent code
        z_interp = (1 - alpha) * z1 + alpha * z2
        
        # Interpolate condition (gradually change from class 0 to class 1)
        condition = np.array([[1-alpha, alpha] + [0] * (cvae.num_classes-2)])
        condition = tf.constant(condition, dtype=tf.float32)
        
        # Generate sample
        generated = cvae.decode(z_interp, condition)
        
        plt.subplot(1, n_steps, i + 1)
        plt.bar(range(9), generated[0])
        plt.title(f"α = {alpha:.2f}")
        plt.ylim(0, 1)
        plt.xticks(range(9))
    
    plt.suptitle("Interpolation Between Conditions")
    plt.tight_layout()
    plt.show()

def print_cvae_summary(cvae):
    """Print CVAE architecture summary"""
    print("CVAE Architecture:")
    print("=" * 50)
    print(f"Input dimension: {cvae.input_dim}")
    print(f"Hidden dimension: {cvae.hidden_dim}")
    print(f"Latent dimension: {cvae.latent_dim}")
    print(f"Number of classes: {cvae.num_classes}")
    print("\nEncoder (takes data + condition):")
    cvae.encoder.summary()
    print("\nDecoder (takes latent + condition):")
    cvae.decoder.summary()

# Example usage
if __name__ == "__main__":
    # Generate conditional synthetic data
    print("Generating synthetic conditional 9D data...")
    num_classes = 3
    data, labels = generate_conditional_synthetic_data(n_samples=1800, num_classes=num_classes)
    
    # Split into train/test
    train_size = int(0.8 * len(data))
    x_train, x_test = data[:train_size], data[train_size:]
    y_train, y_test = labels[:train_size], labels[train_size:]
    
    # Convert labels to one-hot
    y_train_onehot = labels_to_onehot(y_train, num_classes)
    y_test_onehot = labels_to_onehot(y_test, num_classes)
    
    # Create dataset
    batch_size = 64
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train_onehot))
    train_dataset = train_dataset.shuffle(buffer_size=1000).batch(batch_size)
    
    # Create CVAE
    cvae = CVAE(input_dim=9, hidden_dim=7, latent_dim=2, num_classes=num_classes)
    print_cvae_summary(cvae)
    
    # Train CVAE
    trainer = CVAETrainer(cvae)
    print("\nTraining CVAE...")
    trainer.train(train_dataset, epochs=100)
    
    # Visualizations
    print("\nPlotting conditional reconstructions...")
    plot_conditional_reconstructions(cvae, x_test, y_test)
    
    print("Generating conditional samples...")
    generate_conditional_samples(cvae)
    
    print("Plotting conditional latent space...")
    plot_conditional_latent_space(cvae, x_test, y_test)
    
    print("Interpolating between conditions...")
    interpolate_between_conditions(cvae, x_test, y_test)