In [None]:
# Import required libraries
# Standard library imports
import os
import sys
from pathlib import Path

# Third-party imports
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from tensorflow import keras

# Add src directory to path
sys.path.append(str(Path.cwd().parent))

# Local imports
from src.data_utils import CIFAR10_CLASSES, load_cifar10
from src.model_utils import build_vae
from src.visualization import (
    plot_image_grid,
    plot_interpolation,
    plot_latent_space_2d
)

# Set random seeds
np.random.seed(42)
tf.random.set_seed(42)

print(f'TensorFlow version: {tf.__version__}')
print(f'GPU Available: {tf.config.list_physical_devices("GPU")}')

## Load and Prepare CIFAR-10

In [None]:
# Load CIFAR-10
(x_train, y_train), (x_test, y_test) = load_cifar10(normalize=True)

print(f"Training samples: {len(x_train)}")
print(f"Test samples: {len(x_test)}")
print(f"Image shape: {x_train.shape[1:]}")

In [None]:
# Visualize samples
sample_indices = np.random.choice(len(x_test), 10, replace=False)
sample_images = x_test[sample_indices]
sample_labels = y_test[sample_indices]
titles = [CIFAR10_CLASSES[label] for label in sample_labels]

fig = plot_image_grid(sample_images, titles=titles)
plt.suptitle('Sample CIFAR-10 Images', fontsize=16, y=1.02)
plt.show()

## Build Variational Autoencoder

The VAE consists of:
1. **Encoder**: Maps images to latent distribution (μ, σ²)
2. **Sampling Layer**: Samples z using reparameterization trick
3. **Decoder**: Maps latent samples back to images

In [None]:
# VAE Configuration
LATENT_DIM = 128
EPOCHS = 100  # VAEs typically need more training
BATCH_SIZE = 128
LEARNING_RATE = 0.001

# Create directories
models_dir = Path('../models')
models_dir.mkdir(exist_ok=True)

logs_dir = Path('../logs')
logs_dir.mkdir(exist_ok=True)

In [None]:
# Build VAE
vae, encoder, decoder = build_vae(latent_dim=LATENT_DIM)

# Compile
vae.compile(optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE))

print(f"\nVAE Parameters: {vae.count_params():,}")
print(f"Encoder Parameters: {encoder.count_params():,}")
print(f"Decoder Parameters: {decoder.count_params():,}")

In [None]:
# Display model architectures
print("\n" + "="*60)
print("ENCODER ARCHITECTURE")
print("="*60)
encoder.summary()

print("\n" + "="*60)
print("DECODER ARCHITECTURE")
print("="*60)
decoder.summary()

## Train the VAE

Training a VAE involves optimizing:
- **Reconstruction Loss**: How well can we reconstruct the input?
- **KL Divergence**: How close is the latent distribution to N(0,1)?

Total Loss = Reconstruction Loss + β × KL Divergence

We use β = 0.0005 for CIFAR-10 (beta-VAE) to balance the two terms.

In [None]:
# Callbacks
callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=str(models_dir / 'vae.keras'),
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=15,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.TensorBoard(
        log_dir=str(logs_dir / 'vae'),
        histogram_freq=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=7,
        min_lr=1e-6,
        verbose=1
    )
]

print("\n" + "="*60)
print("Training Variational Autoencoder")
print("="*60 + "\n")

history = vae.fit(
    x_train,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(x_test,),
    callbacks=callbacks,
    verbose=1
)

print("\n✅ Model saved to: models/vae.keras")

## Visualize Training Progress

VAE training involves monitoring multiple loss components.

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Total loss
axes[0].plot(history.history['loss'], label='Train')
axes[0].plot(history.history['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Total Loss')
axes[0].set_title('Total VAE Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Reconstruction loss
axes[1].plot(history.history['reconstruction_loss'], label='Train')
if 'val_reconstruction_loss' in history.history:
    axes[1].plot(history.history['val_reconstruction_loss'], label='Validation')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Reconstruction Loss')
axes[1].set_title('Reconstruction Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# KL divergence
axes[2].plot(history.history['kl_loss'], label='Train')
if 'val_kl_loss' in history.history:
    axes[2].plot(history.history['val_kl_loss'], label='Validation')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('KL Divergence')
axes[2].set_title('KL Divergence Loss')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle('VAE Training History', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

## Generate New Images

Now the fun part! We can sample from the latent space to generate new images.

In [None]:
# Generate random samples
n_samples = 20

# Sample from standard normal distribution
random_latent_vectors = np.random.normal(size=(n_samples, LATENT_DIM))

# Decode to images
generated_images = decoder.predict(random_latent_vectors, verbose=0)

# Display
fig = plot_image_grid(generated_images, rows=4, cols=5, figsize=(15, 12))
plt.suptitle('Randomly Generated Images from VAE', fontsize=16, y=0.995)
plt.show()

## Compare: Real vs Reconstructed vs Generated

In [None]:
# Select test samples
n_compare = 5
test_samples = x_test[np.random.choice(len(x_test), n_compare, replace=False)]

# Reconstruct
reconstructed = vae.predict(test_samples, verbose=0)

# Generate new
random_z = np.random.normal(size=(n_compare, LATENT_DIM))
generated = decoder.predict(random_z, verbose=0)

# Display comparison
fig, axes = plt.subplots(3, n_compare, figsize=(n_compare * 3, 9))

for i in range(n_compare):
    # Real
    axes[0, i].imshow(test_samples[i])
    axes[0, i].axis('off')
    if i == 0:
        axes[0, i].set_title('Real Images', fontweight='bold', fontsize=14)
    
    # Reconstructed
    axes[1, i].imshow(reconstructed[i])
    axes[1, i].axis('off')
    if i == 0:
        axes[1, i].set_title('Reconstructed', fontweight='bold', fontsize=14)
    
    # Generated
    axes[2, i].imshow(generated[i])
    axes[2, i].axis('off')
    if i == 0:
        axes[2, i].set_title('Generated (Random)', fontweight='bold', fontsize=14, color='green')

plt.suptitle('Real vs Reconstructed vs Generated', fontsize=16, y=0.995)
plt.tight_layout()
plt.show()

## Latent Space Interpolation

One of the coolest features of VAEs: smooth interpolation between images!

In [None]:
# Select two images to interpolate between
idx1, idx2 = 42, 123
img1 = x_test[idx1:idx1+1]
img2 = x_test[idx2:idx2+1]

# Encode to latent space (use mean, not sampled z)
z1_mean, _, _ = encoder.predict(img1, verbose=0)
z2_mean, _, _ = encoder.predict(img2, verbose=0)

# Interpolate
n_steps = 10
interpolated_z = []
for alpha in np.linspace(0, 1, n_steps):
    z = (1 - alpha) * z1_mean + alpha * z2_mean
    interpolated_z.append(z)

interpolated_z = np.vstack(interpolated_z)

# Decode
interpolated_images = decoder.predict(interpolated_z, verbose=0)

# Visualize
fig, axes = plt.subplots(1, n_steps, figsize=(n_steps * 2, 2))

for i in range(n_steps):
    axes[i].imshow(interpolated_images[i])
    axes[i].axis('off')
    if i == 0:
        axes[i].set_title('Start', fontweight='bold')
    elif i == n_steps - 1:
        axes[i].set_title('End', fontweight='bold')

plt.suptitle('Latent Space Interpolation', fontsize=14, y=1.05)
plt.tight_layout()
plt.show()

In [None]:
# Multiple interpolations
fig, axes = plt.subplots(5, n_steps, figsize=(n_steps * 2, 10))

for row in range(5):
    # Random pair
    idx1, idx2 = np.random.choice(len(x_test), 2, replace=False)
    img1 = x_test[idx1:idx1+1]
    img2 = x_test[idx2:idx2+1]
    
    # Encode and interpolate
    z1_mean, _, _ = encoder.predict(img1, verbose=0)
    z2_mean, _, _ = encoder.predict(img2, verbose=0)
    
    interpolated_z = []
    for alpha in np.linspace(0, 1, n_steps):
        z = (1 - alpha) * z1_mean + alpha * z2_mean
        interpolated_z.append(z)
    
    interpolated_z = np.vstack(interpolated_z)
    interpolated_images = decoder.predict(interpolated_z, verbose=0)
    
    # Display
    for i in range(n_steps):
        axes[row, i].imshow(interpolated_images[i])
        axes[row, i].axis('off')

plt.suptitle('Multiple Latent Space Interpolations', fontsize=16, y=0.995)
plt.tight_layout()
plt.show()

## Visualize Latent Space Structure

Let's use t-SNE to visualize the 2D structure of our high-dimensional latent space.

In [None]:
# Encode a subset of test data
n_viz = 2000
viz_indices = np.random.choice(len(x_test), n_viz, replace=False)
viz_images = x_test[viz_indices]
viz_labels = y_test[viz_indices]

print("Encoding images to latent space...")
z_mean, z_log_var, z = encoder.predict(viz_images, verbose=1)

print("\nComputing t-SNE projection...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
z_2d = tsne.fit_transform(z_mean)

print("Done!")

In [None]:
# Plot latent space colored by class
fig, ax = plt.subplots(figsize=(12, 10))

# Create a scatter plot for each class
colors = plt.cm.tab10(np.linspace(0, 1, 10))

for class_idx in range(10):
    mask = viz_labels == class_idx
    ax.scatter(
        z_2d[mask, 0],
        z_2d[mask, 1],
        c=[colors[class_idx]],
        label=CIFAR10_CLASSES[class_idx],
        alpha=0.6,
        s=20
    )

ax.set_xlabel('t-SNE Dimension 1', fontsize=12)
ax.set_ylabel('t-SNE Dimension 2', fontsize=12)
ax.set_title('VAE Latent Space (t-SNE Projection)', fontsize=14, fontweight='bold')
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Latent Space Arithmetic

We can perform arithmetic in latent space: `z_cat + z_dog - z_bird = ?`

In [None]:
# Find examples of specific classes
def get_class_example(class_idx, dataset, labels):
    """Get a random example of a specific class."""
    class_indices = np.where(labels == class_idx)[0]
    idx = np.random.choice(class_indices)
    return dataset[idx:idx+1]

# Get examples
cat_img = get_class_example(3, x_test, y_test)  # cat
dog_img = get_class_example(5, x_test, y_test)  # dog
bird_img = get_class_example(2, x_test, y_test) # bird

# Encode to latent space
cat_z, _, _ = encoder.predict(cat_img, verbose=0)
dog_z, _, _ = encoder.predict(dog_img, verbose=0)
bird_z, _, _ = encoder.predict(bird_img, verbose=0)

# Latent arithmetic
result_z = cat_z + dog_z - bird_z

# Decode
result_img = decoder.predict(result_z, verbose=0)

# Display
fig, axes = plt.subplots(1, 7, figsize=(14, 2))

axes[0].imshow(cat_img[0])
axes[0].set_title('Cat', fontweight='bold')
axes[0].axis('off')

axes[1].text(0.5, 0.5, '+', fontsize=24, ha='center', va='center')
axes[1].axis('off')

axes[2].imshow(dog_img[0])
axes[2].set_title('Dog', fontweight='bold')
axes[2].axis('off')

axes[3].text(0.5, 0.5, '-', fontsize=24, ha='center', va='center')
axes[3].axis('off')

axes[4].imshow(bird_img[0])
axes[4].set_title('Bird', fontweight='bold')
axes[4].axis('off')

axes[5].text(0.5, 0.5, '=', fontsize=24, ha='center', va='center')
axes[5].axis('off')

axes[6].imshow(result_img[0])
axes[6].set_title('Result', fontweight='bold', color='green')
axes[6].axis('off')

plt.suptitle('Latent Space Arithmetic', fontsize=14, y=1.1)
plt.tight_layout()
plt.show()

## Conclusions

### What We Learned:

1. **Generative Capability**: VAEs can generate new, plausible images by sampling from the latent space.

2. **Smooth Latent Space**: The probabilistic nature of VAEs creates a continuous latent space, enabling smooth interpolations.

3. **Learned Structure**: The latent space organizes semantically similar images together, as shown by the t-SNE visualization.

4. **Interpretable Representations**: Latent space arithmetic suggests meaningful feature representations.

### Advantages of VAEs:

- **Principled probabilistic framework**: Based on variational inference
- **Continuous latent space**: Enables interpolation and exploration
- **Balanced objectives**: Reconstruction + regularization via KL divergence

### Limitations:

- Generated images can be blurry (compared to GANs)
- Requires careful tuning of β (KL divergence weight)
- More complex training than standard autoencoders

### Applications:

- **Creative tools**: Generate variations of designs, artworks
- **Data augmentation**: Create synthetic training data
- **Compression**: Lossy compression with generative reconstruction
- **Drug discovery**: Generate novel molecular structures
- **Anomaly detection**: Probabilistic anomaly scores

### Next Steps:

1. Explore the Streamlit app for interactive generation
2. Try conditional VAEs (CVAEs) to control generation by class
3. Experiment with different β values (β-VAE)
4. Compare with GANs and diffusion models