# 🤖 **Generative AI Comprehensive Lecture Series**

## **Applied Scientist Interview Preparation**
### *Amazon & Tech Industry Focus*

---

## 📚 **Course Overview**

This comprehensive lecture series covers the fundamental and advanced concepts in Generative AI that every Applied Scientist should master. Designed specifically for Amazon Applied Scientist interviews and senior technical roles.

### **🎯 Learning Objectives**
By the end of this series, you will:
1. **Master foundational concepts** in generative modeling
2. **Understand modern architectures** (VAEs, GANs, Diffusion Models, LLMs)
3. **Implement key algorithms** from scratch
4. **Apply GenAI to real-world problems** and business scenarios
5. **Navigate ethical considerations** and safety in AI
6. **Demonstrate production deployment** knowledge

### **📖 Lecture Structure**
- **4 Comprehensive Lectures** (90 minutes each)
- **Hands-on Implementation** with working code
- **Interview-focused assessments** and Q&A
- **Real-world case studies** and applications
- **Production considerations** and scalability

### **🏢 Target Audience**
- Applied Scientists preparing for Amazon/Meta/Google interviews
- ML Engineers transitioning to GenAI roles
- Senior practitioners seeking comprehensive GenAI knowledge
- Technical leaders building GenAI strategy

---

# 📋 **Prerequisites & Setup**

## **Mathematical Background**
- **Probability & Statistics**: Distributions, Bayes' theorem, KL divergence
- **Linear Algebra**: Matrix operations, eigenvalues, SVD
- **Calculus**: Gradients, chain rule, optimization
- **Information Theory**: Entropy, mutual information

## **Technical Prerequisites**
- **Deep Learning**: Neural networks, backpropagation, optimization
- **PyTorch/TensorFlow**: Tensor operations, autograd, model building
- **Transformers**: Attention mechanisms, encoder-decoder architectures
- **Computer Vision**: CNNs, image processing fundamentals

## **Business Context**
- Understanding of **product development** lifecycle
- **Customer-centric thinking** and user experience design
- **Scalability considerations** for production systems
- **Ethical AI** and responsible deployment practices

In [None]:
# Essential imports for GenAI implementations
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import math
import random
from typing import Tuple, List, Dict, Optional, Union
from dataclasses import dataclass
import json
import time
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Plotting configuration
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

print("✅ Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

---

# 🎓 **LECTURE 1: Foundations of Generative Modeling**
## *Duration: 90 minutes*

### **📚 Learning Objectives**
- Understand the mathematical foundations of generative modeling
- Distinguish between discriminative and generative models
- Master probability distributions and density estimation
- Implement basic generative models from scratch
- Apply concepts to real-world scenarios

---

## **🧮 1.1 Mathematical Foundations**

### **Generative vs Discriminative Models**

**Discriminative Models**: Learn P(y|x)
- Goal: Classify or predict given input
- Examples: Logistic Regression, SVMs, Random Forests
- Focus: Decision boundaries

**Generative Models**: Learn P(x) or P(x,y)
- Goal: Generate new samples from learned distribution
- Examples: VAEs, GANs, Diffusion Models, LLMs
- Focus: Data distribution modeling

### **Key Mathematical Concepts**

#### **1. Probability Density Functions**
For continuous variables:
```
P(a ≤ X ≤ b) = ∫[a to b] p(x) dx
```

#### **2. Maximum Likelihood Estimation (MLE)**
Given dataset D = {x₁, x₂, ..., xₙ}, find θ that maximizes:
```
L(θ) = ∏ᵢ p(xᵢ|θ)
log L(θ) = Σᵢ log p(xᵢ|θ)
```

#### **3. KL Divergence**
Measures difference between distributions:
```
D_KL(P||Q) = ∫ p(x) log(p(x)/q(x)) dx
```

#### **4. Evidence Lower Bound (ELBO)**
Key concept for VAEs:
```
log p(x) ≥ E_q[log p(x|z)] - D_KL(q(z|x)||p(z))
```

In [None]:
# Mathematical Foundations Implementation

class ProbabilityDistributions:
    """Implementation of key probability distributions for GenAI"""
    
    @staticmethod
    def gaussian_pdf(x, mu=0, sigma=1):
        """Gaussian probability density function"""
        return (1 / (sigma * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mu) / sigma) ** 2)
    
    @staticmethod
    def kl_divergence_gaussian(mu1, sigma1, mu2, sigma2):
        """KL divergence between two Gaussian distributions"""
        return np.log(sigma2 / sigma1) + (sigma1**2 + (mu1 - mu2)**2) / (2 * sigma2**2) - 0.5
    
    @staticmethod
    def sample_gaussian(mu, sigma, n_samples=1000):
        """Sample from Gaussian distribution"""
        return np.random.normal(mu, sigma, n_samples)

# Demonstrate probability concepts
prob_dist = ProbabilityDistributions()

# Create sample data
x = np.linspace(-5, 5, 1000)
y1 = prob_dist.gaussian_pdf(x, mu=0, sigma=1)
y2 = prob_dist.gaussian_pdf(x, mu=1, sigma=1.5)

# Plot distributions
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(x, y1, label='μ=0, σ=1', linewidth=2)
plt.plot(x, y2, label='μ=1, σ=1.5', linewidth=2)
plt.title('Gaussian Probability Density Functions')
plt.xlabel('x')
plt.ylabel('p(x)')
plt.legend()
plt.grid(True, alpha=0.3)

# Calculate and display KL divergence
kl_div = prob_dist.kl_divergence_gaussian(0, 1, 1, 1.5)
print(f"KL Divergence D_KL(N(0,1)||N(1,1.5²)): {kl_div:.4f}")

# Sample and plot histograms
plt.subplot(1, 2, 2)
samples1 = prob_dist.sample_gaussian(0, 1, 1000)
samples2 = prob_dist.sample_gaussian(1, 1.5, 1000)

plt.hist(samples1, bins=50, alpha=0.7, label='μ=0, σ=1', density=True)
plt.hist(samples2, bins=50, alpha=0.7, label='μ=1, σ=1.5', density=True)
plt.title('Sampled Data Histograms')
plt.xlabel('x')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n🎯 **Key Insights:**")
print("1. Generative models learn the underlying data distribution p(x)")
print("2. KL divergence measures how different two distributions are")
print("3. Lower KL divergence = better approximation of target distribution")
print("4. MLE finds parameters that maximize likelihood of observed data")

## **🎲 1.2 Types of Generative Models**

### **Taxonomy of Generative Models**

```
Generative Models
├── Explicit Density Models
│   ├── Tractable Models
│   │   ├── Autoregressive Models (GPT, PixelCNN)
│   │   └── Flow-based Models (RealNVP, Glow)
│   └── Approximate Density Models
│       └── Variational Autoencoders (VAEs)
└── Implicit Density Models
    ├── Generative Adversarial Networks (GANs)
    └── Diffusion Models (DDPM, DDIM)
```

### **Model Comparison**

| Model Type | Pros | Cons | Best Use Cases |
|------------|------|------|----------------|
| **VAEs** | Stable training, good representations | Blurry outputs | Representation learning, data compression |
| **GANs** | Sharp, realistic outputs | Training instability | Image generation, style transfer |
| **Diffusion** | High quality, stable | Slow sampling | Image generation, inpainting |
| **Autoregressive** | Exact likelihood, stable | Sequential generation | Text, code generation |
| **Flows** | Exact likelihood, invertible | Limited expressiveness | Density estimation, anomaly detection |

### **🎯 Interview Focus Points**
1. **Trade-offs**: Quality vs Speed vs Stability
2. **Use Cases**: When to choose each model type
3. **Scalability**: Production deployment considerations
4. **Evaluation**: How to measure generation quality

In [None]:
# Simple Generative Model Implementation

class SimpleGaussianGenerator(nn.Module):
    """A simple generative model that learns to generate 2D points from a Gaussian mixture"""
    
    def __init__(self, latent_dim=2, hidden_dim=64):
        super().__init__()
        self.latent_dim = latent_dim
        
        # Generator network
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # Output 2D points
        )
    
    def forward(self, z):
        """Generate 2D points from latent code z"""
        return self.generator(z)
    
    def sample(self, n_samples=1000):
        """Sample n_samples points from the learned distribution"""
        z = torch.randn(n_samples, self.latent_dim).to(device)
        with torch.no_grad():
            return self.forward(z)

# Create target distribution (mixture of Gaussians)
def create_target_data(n_samples=2000):
    """Create a mixture of Gaussians as target distribution"""
    # Component 1: centered at (-2, -2)
    x1 = np.random.multivariate_normal([-2, -2], [[0.5, 0], [0, 0.5]], n_samples//4)
    
    # Component 2: centered at (2, -2)
    x2 = np.random.multivariate_normal([2, -2], [[0.5, 0], [0, 0.5]], n_samples//4)
    
    # Component 3: centered at (-2, 2)
    x3 = np.random.multivariate_normal([-2, 2], [[0.5, 0], [0, 0.5]], n_samples//4)
    
    # Component 4: centered at (2, 2)
    x4 = np.random.multivariate_normal([2, 2], [[0.5, 0], [0, 0.5]], n_samples//4)
    
    return np.vstack([x1, x2, x3, x4])

# Train the simple generator
def train_simple_generator(model, target_data, epochs=1000, lr=0.001):
    """Train the generator using Maximum Mean Discrepancy (MMD) loss"""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    target_tensor = torch.FloatTensor(target_data).to(device)
    
    losses = []
    
    for epoch in range(epochs):
        # Sample from generator
        z = torch.randn(len(target_data), model.latent_dim).to(device)
        generated = model(z)
        
        # Simple MSE loss between distributions (simplified MMD)
        target_mean = target_tensor.mean(dim=0)
        generated_mean = generated.mean(dim=0)
        
        target_var = target_tensor.var(dim=0)
        generated_var = generated.var(dim=0)
        
        loss = F.mse_loss(generated_mean, target_mean) + F.mse_loss(generated_var, target_var)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if epoch % 200 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
    
    return losses

# Create and train the model
target_data = create_target_data(2000)
model = SimpleGaussianGenerator().to(device)

print("Training simple generative model...")
losses = train_simple_generator(model, target_data, epochs=1000)

# Generate samples and visualize
generated_samples = model.sample(2000).cpu().numpy()

# Plot results
plt.figure(figsize=(15, 5))

# Target distribution
plt.subplot(1, 3, 1)
plt.scatter(target_data[:, 0], target_data[:, 1], alpha=0.6, s=20)
plt.title('Target Distribution\n(Mixture of 4 Gaussians)')
plt.xlabel('X₁')
plt.ylabel('X₂')
plt.grid(True, alpha=0.3)
plt.axis('equal')

# Generated distribution
plt.subplot(1, 3, 2)
plt.scatter(generated_samples[:, 0], generated_samples[:, 1], alpha=0.6, s=20, color='orange')
plt.title('Generated Distribution\n(After Training)')
plt.xlabel('X₁')
plt.ylabel('X₂')
plt.grid(True, alpha=0.3)
plt.axis('equal')

# Training loss
plt.subplot(1, 3, 3)
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)
plt.yscale('log')

plt.tight_layout()
plt.show()

print("\n🎯 **Key Observations:**")
print("1. Generator learns to approximate target distribution")
print("2. Quality depends on architecture and loss function")
print("3. Trade-off between model complexity and training stability")
print("4. Evaluation requires comparing distributions, not individual samples")

## **🧠 1.3 Autoregressive Models**

### **Concept and Mathematical Foundation**

Autoregressive models factorize the joint probability using the chain rule:

```
p(x₁, x₂, ..., xₙ) = ∏ᵢ₌₁ⁿ p(xᵢ | x₁, x₂, ..., xᵢ₋₁)
```

### **Key Properties**
- **Exact likelihood computation**: Can compute p(x) exactly
- **Sequential generation**: Generate one element at a time
- **Flexible**: Can model any distribution
- **Stable training**: No adversarial training needed

### **Applications**
- **Language Models**: GPT, BERT (masked), Transformer-XL
- **Image Models**: PixelCNN, PixelRNN
- **Audio Models**: WaveNet, SampleRNN
- **Code Generation**: Codex, CodeT5

### **🎯 Interview Questions**
1. "How do autoregressive models handle variable-length sequences?"
2. "What are the trade-offs between parallel training and sequential inference?"
3. "How would you implement teacher forcing vs. free running?"
4. "What are the challenges in autoregressive image generation?"

In [None]:
# Simple Autoregressive Model Implementation

class SimpleAutoregressiveModel(nn.Module):
    """A simple autoregressive model for sequence generation"""
    
    def __init__(self, vocab_size, embed_dim=64, hidden_dim=128, num_layers=2):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # LSTM for sequential modeling
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        
        # Output projection
        self.output_proj = nn.Linear(hidden_dim, vocab_size)
        
        # For sampling
        self.register_buffer('sos_token', torch.tensor(0))  # Start of sequence
        self.register_buffer('eos_token', torch.tensor(1))  # End of sequence
    
    def forward(self, x, hidden=None):
        """Forward pass for training"""
        batch_size, seq_len = x.shape
        
        # Embed tokens
        embedded = self.embedding(x)  # (batch_size, seq_len, embed_dim)
        
        # LSTM forward
        lstm_out, hidden = self.lstm(embedded, hidden)
        
        # Project to vocabulary
        logits = self.output_proj(lstm_out)  # (batch_size, seq_len, vocab_size)
        
        return logits, hidden
    
    def generate(self, max_length=20, temperature=1.0, top_k=None):
        """Generate sequences autoregressively"""
        self.eval()
        generated = [self.sos_token.item()]
        hidden = None
        
        with torch.no_grad():
            for _ in range(max_length):
                # Current input
                current_input = torch.tensor([[generated[-1]]]).to(device)
                
                # Forward pass
                logits, hidden = self.forward(current_input, hidden)
                
                # Apply temperature
                logits = logits[0, -1] / temperature
                
                # Apply top-k filtering
                if top_k is not None:
                    values, indices = torch.topk(logits, top_k)
                    logits[logits < values[-1]] = -float('inf')
                
                # Sample next token
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1).item()
                
                generated.append(next_token)
                
                # Stop at end of sequence
                if next_token == self.eos_token.item():
                    break
        
        return generated
    
    def compute_perplexity(self, data_loader):
        """Compute perplexity on dataset"""
        self.eval()
        total_loss = 0
        total_tokens = 0
        
        with torch.no_grad():
            for batch in data_loader:
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)
                
                logits, _ = self.forward(inputs)
                
                # Compute cross-entropy loss
                loss = F.cross_entropy(
                    logits.view(-1, self.vocab_size),
                    targets.view(-1),
                    reduction='sum'
                )
                
                total_loss += loss.item()
                total_tokens += targets.numel()
        
        avg_loss = total_loss / total_tokens
        perplexity = torch.exp(torch.tensor(avg_loss))
        
        return perplexity.item()

# Create synthetic sequence data
def create_sequence_data(vocab_size=10, seq_length=15, num_sequences=1000):
    """Create synthetic arithmetic sequences for demonstration"""
    sequences = []
    
    for _ in range(num_sequences):
        # Create arithmetic sequence: start, start+step, start+2*step, ...
        start = np.random.randint(2, vocab_size-3)  # Leave room for SOS, EOS
        step = np.random.choice([-1, 1])  # Up or down
        
        sequence = [0]  # SOS token
        current = start
        
        for _ in range(seq_length-2):  # -2 for SOS and EOS
            if 2 <= current < vocab_size-1:  # Valid range
                sequence.append(current)
                current += step
            else:
                break
        
        sequence.append(1)  # EOS token
        
        # Pad to fixed length
        while len(sequence) < seq_length:
            sequence.append(1)  # Pad with EOS
        
        sequences.append(sequence[:seq_length])
    
    return np.array(sequences)

# Prepare data
vocab_size = 10
seq_length = 10
sequences = create_sequence_data(vocab_size, seq_length, 1000)

# Create input-target pairs (shift by one for autoregressive training)
inputs = sequences[:, :-1]
targets = sequences[:, 1:]

# Convert to tensors and create data loader
inputs_tensor = torch.LongTensor(inputs)
targets_tensor = torch.LongTensor(targets)
dataset = torch.utils.data.TensorDataset(inputs_tensor, targets_tensor)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Create and train model
model = SimpleAutoregressiveModel(vocab_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Training autoregressive model...")
train_losses = []

for epoch in range(50):
    epoch_loss = 0
    model.train()
    
    for batch_inputs, batch_targets in data_loader:
        batch_inputs = batch_inputs.to(device)
        batch_targets = batch_targets.to(device)
        
        optimizer.zero_grad()
        
        logits, _ = model(batch_inputs)
        loss = criterion(logits.view(-1, vocab_size), batch_targets.view(-1))
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(data_loader)
    train_losses.append(avg_loss)
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")

# Generate some sequences
print("\n📝 **Generated Sequences:**")
for i in range(5):
    generated = model.generate(max_length=15, temperature=0.8)
    print(f"Sample {i+1}: {generated}")

# Compute perplexity
perplexity = model.compute_perplexity(data_loader)
print(f"\n📊 **Model Perplexity**: {perplexity:.2f}")

# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(train_losses)
plt.title('Autoregressive Model Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Cross-Entropy Loss')
plt.grid(True, alpha=0.3)
plt.show()

print("\n🎯 **Key Insights:**")
print("1. Autoregressive models predict next token given previous context")
print("2. Teacher forcing during training vs. free running during inference")
print("3. Temperature controls randomness in generation")
print("4. Perplexity measures how well model predicts the test data")
print("5. Lower perplexity = better language model")

## **🔍 1.4 Evaluation Metrics for Generative Models**

### **Quality Metrics**

#### **1. Likelihood-based Metrics**
- **Log-likelihood**: Higher is better
- **Perplexity**: Lower is better (exp of negative log-likelihood)
- **Bits per dimension**: Information-theoretic measure

#### **2. Similarity Metrics**
- **Inception Score (IS)**: Measures quality and diversity
- **Fréchet Inception Distance (FID)**: Compares feature distributions
- **LPIPS**: Learned perceptual similarity

#### **3. Human Evaluation**
- **Quality ratings**: Human judges rate realism
- **Preference studies**: A/B testing between models
- **Task-specific metrics**: BLEU for text, user engagement for recommendations

### **Diversity Metrics**
- **Mode coverage**: How many modes of data distribution are captured
- **Precision and Recall**: Trade-off between quality and diversity
- **Self-BLEU**: Diversity in text generation

### **🎯 Business Metrics**
- **User engagement**: Click-through rates, time spent
- **Conversion rates**: Purchases, sign-ups
- **A/B test results**: Statistical significance of improvements
- **Computational cost**: Inference time, resource usage

In [None]:
# Evaluation Metrics Implementation

class GenerativeModelEvaluator:
    """Comprehensive evaluation suite for generative models"""
    
    def __init__(self):
        self.metrics = {}
    
    def inception_score(self, generated_images, batch_size=32, splits=10):
        """Simplified Inception Score calculation"""
        # Note: This is a simplified version for demonstration
        # Real implementation would use pre-trained Inception model
        
        # Simulate classifier predictions
        n_samples = len(generated_images)
        n_classes = 10
        
        # Random predictions for demonstration
        predictions = np.random.dirichlet(np.ones(n_classes), n_samples)
        
        # Calculate IS
        scores = []
        for i in range(splits):
            part = predictions[i * (n_samples // splits): (i + 1) * (n_samples // splits)]
            p_y = np.mean(part, axis=0)
            kl_divs = [np.sum(p * (np.log(p + 1e-10) - np.log(p_y + 1e-10))) for p in part]
            scores.append(np.exp(np.mean(kl_divs)))
        
        return np.mean(scores), np.std(scores)
    
    def frechet_distance(self, real_features, generated_features):
        """Calculate Fréchet distance between real and generated features"""
        # Calculate means
        mu1, mu2 = np.mean(real_features, axis=0), np.mean(generated_features, axis=0)
        
        # Calculate covariances
        sigma1 = np.cov(real_features, rowvar=False)
        sigma2 = np.cov(generated_features, rowvar=False)
        
        # Calculate FID
        diff = mu1 - mu2
        fid = np.dot(diff, diff) + np.trace(sigma1 + sigma2 - 2 * np.sqrt(sigma1 @ sigma2))
        
        return fid
    
    def precision_recall(self, real_samples, generated_samples, k=3):
        """Calculate precision and recall for generative models"""
        from sklearn.neighbors import NearestNeighbors
        
        # Fit k-NN on real samples
        nbrs_real = NearestNeighbors(n_neighbors=k).fit(real_samples)
        nbrs_gen = NearestNeighbors(n_neighbors=k).fit(generated_samples)
        
        # Calculate precision: fraction of generated samples close to real samples
        distances_gen, _ = nbrs_real.kneighbors(generated_samples)
        precision = np.mean(distances_gen[:, 0] < np.percentile(distances_gen[:, 0], 95))
        
        # Calculate recall: fraction of real samples close to generated samples
        distances_real, _ = nbrs_gen.kneighbors(real_samples)
        recall = np.mean(distances_real[:, 0] < np.percentile(distances_real[:, 0], 95))
        
        return precision, recall
    
    def mode_coverage(self, real_samples, generated_samples, threshold=0.5):
        """Estimate mode coverage using clustering"""
        from sklearn.cluster import KMeans
        from sklearn.metrics import pairwise_distances
        
        # Cluster real samples to identify modes
        n_modes = min(10, len(real_samples) // 50)  # Heuristic for number of modes
        kmeans = KMeans(n_clusters=n_modes, random_state=42)
        real_clusters = kmeans.fit_predict(real_samples)
        centers = kmeans.cluster_centers_
        
        # For each mode, check if generated samples are close
        covered_modes = 0
        for center in centers:
            distances = pairwise_distances([center], generated_samples)[0]
            if np.min(distances) < threshold:
                covered_modes += 1
        
        coverage = covered_modes / n_modes
        return coverage, n_modes
    
    def compute_all_metrics(self, real_samples, generated_samples):
        """Compute comprehensive evaluation metrics"""
        results = {}
        
        # Inception Score
        is_mean, is_std = self.inception_score(generated_samples)
        results['inception_score'] = {'mean': is_mean, 'std': is_std}
        
        # Fréchet Distance (simplified with raw samples as features)
        fid = self.frechet_distance(real_samples, generated_samples)
        results['frechet_distance'] = fid
        
        # Precision and Recall
        precision, recall = self.precision_recall(real_samples, generated_samples)
        results['precision'] = precision
        results['recall'] = recall
        results['f1_score'] = 2 * precision * recall / (precision + recall + 1e-10)
        
        # Mode Coverage
        coverage, n_modes = self.mode_coverage(real_samples, generated_samples)
        results['mode_coverage'] = {'coverage': coverage, 'total_modes': n_modes}
        
        return results

# Demonstrate evaluation on our previous simple generator
evaluator = GenerativeModelEvaluator()

# Use our previous target and generated data
real_data = target_data
generated_data = generated_samples

print("🔍 **Evaluating Generative Model Performance**")
print("=" * 50)

# Compute all metrics
metrics = evaluator.compute_all_metrics(real_data, generated_data)

print(f"📊 **Inception Score**: {metrics['inception_score']['mean']:.3f} ± {metrics['inception_score']['std']:.3f}")
print(f"📏 **Fréchet Distance**: {metrics['frechet_distance']:.3f}")
print(f"🎯 **Precision**: {metrics['precision']:.3f}")
print(f"🔄 **Recall**: {metrics['recall']:.3f}")
print(f"⚖️ **F1 Score**: {metrics['f1_score']:.3f}")
print(f"🎭 **Mode Coverage**: {metrics['mode_coverage']['coverage']:.3f} ({metrics['mode_coverage']['total_modes']} modes)")

# Visualize precision-recall trade-off
plt.figure(figsize=(12, 5))

# Plot original and generated distributions
plt.subplot(1, 2, 1)
plt.scatter(real_data[:, 0], real_data[:, 1], alpha=0.6, label='Real', s=20)
plt.scatter(generated_data[:, 0], generated_data[:, 1], alpha=0.6, label='Generated', s=20)
plt.title('Real vs Generated Distributions')
plt.xlabel('X₁')
plt.ylabel('X₂')
plt.legend()
plt.grid(True, alpha=0.3)
plt.axis('equal')

# Plot metrics comparison
plt.subplot(1, 2, 2)
metric_names = ['Precision', 'Recall', 'F1 Score', 'Mode Coverage']
metric_values = [metrics['precision'], metrics['recall'], 
                metrics['f1_score'], metrics['mode_coverage']['coverage']]

bars = plt.bar(metric_names, metric_values, alpha=0.7)
plt.title('Evaluation Metrics')
plt.ylabel('Score')
plt.ylim(0, 1)
plt.grid(True, alpha=0.3, axis='y')

# Add value labels on bars
for bar, value in zip(bars, metric_values):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{value:.3f}', ha='center', va='bottom')

plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

print("\n🎯 **Evaluation Best Practices:**")
print("1. **Multiple Metrics**: No single metric captures all aspects")
print("2. **Quality vs Diversity**: Precision-Recall trade-off")
print("3. **Domain-Specific**: Choose metrics relevant to your application")
print("4. **Human Evaluation**: Ultimate test for many applications")
print("5. **Business Metrics**: Align with actual product success")

print("\n📈 **Interpretation:**")
print(f"• High Precision ({metrics['precision']:.3f}): Generated samples are realistic")
print(f"• High Recall ({metrics['recall']:.3f}): Covers the real data distribution well")
print(f"• Mode Coverage ({metrics['mode_coverage']['coverage']:.3f}): Captures most data modes")
print(f"• Low FID ({metrics['frechet_distance']:.3f}): Similar feature distributions")

---

# 🎓 **LECTURE 2: Variational Autoencoders (VAEs)**
## *Duration: 90 minutes*

### **📚 Learning Objectives**
- Master the mathematical foundations of VAEs and ELBO
- Understand the reparameterization trick and its importance
- Implement VAEs from scratch with proper training procedures
- Explore advanced VAE variants (β-VAE, WAE, VQ-VAE)
- Apply VAEs to real-world representation learning problems

---

## **🧮 2.1 Mathematical Foundation of VAEs**

### **The Generative Model Setup**

VAEs model data using a latent variable model:
- **Latent variables**: z ~ p(z) (typically N(0,I))
- **Generative model**: p_θ(x|z) (decoder)
- **Inference model**: q_φ(z|x) (encoder)

### **The Variational Lower Bound (ELBO)**

The key insight: We want to maximize log p(x), but it's intractable.

```
log p(x) = log ∫ p(x,z) dz = log ∫ p(x|z)p(z) dz
```

Using Jensen's inequality with variational distribution q(z|x):

```
log p(x) ≥ E_q(z|x)[log p(x|z)] - D_KL(q(z|x)||p(z))
```

This is the **Evidence Lower BOund (ELBO)**:

```
L(θ,φ;x) = E_q_φ(z|x)[log p_θ(x|z)] - D_KL(q_φ(z|x)||p(z))
           ↑                           ↑
    Reconstruction Loss         Regularization Loss
```

### **The Reparameterization Trick**

Problem: Cannot backpropagate through random sampling.

Solution: Reparameterize z = μ + σ ⊙ ε, where ε ~ N(0,I)

```
z ~ q_φ(z|x) = N(μ_φ(x), σ_φ²(x))
z = μ_φ(x) + σ_φ(x) ⊙ ε, where ε ~ N(0,I)
```

### **🎯 Key Advantages of VAEs**
1. **Stable Training**: No adversarial optimization
2. **Principled**: Derived from maximum likelihood
3. **Latent Representation**: Meaningful continuous latent space
4. **Tractable**: Can compute exact ELBO

### **🎯 Key Limitations**
1. **Blurry Outputs**: Due to reconstruction loss (MSE/BCE)
2. **Posterior Collapse**: All latents become the same
3. **Limited Expressiveness**: Gaussian assumptions

In [None]:
# Comprehensive VAE Implementation

class VAE(nn.Module):
    """Variational Autoencoder with proper ELBO optimization"""
    
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        # Encoder network (inference model q_φ(z|x))
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder network (generative model p_θ(x|z))
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # For MNIST (pixel values in [0,1])
        )
    
    def encode(self, x):
        """Encode input to latent distribution parameters"""
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick: z = μ + σ⊙ε"""
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu  # Use mean for deterministic inference
    
    def decode(self, z):
        """Decode latent variable to reconstruction"""
        return self.decoder(z)
    
    def forward(self, x):
        """Complete forward pass"""
        # Flatten input if needed
        if len(x.shape) > 2:
            x = x.view(x.size(0), -1)
        
        # Encode
        mu, logvar = self.encode(x)
        
        # Reparameterize
        z = self.reparameterize(mu, logvar)
        
        # Decode
        recon_x = self.decode(z)
        
        return recon_x, mu, logvar, z
    
    def sample(self, num_samples=64):
        """Generate samples from prior"""
        self.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim).to(device)
            samples = self.decode(z)
        return samples
    
    def interpolate(self, x1, x2, num_steps=10):
        """Interpolate between two inputs in latent space"""
        self.eval()
        with torch.no_grad():
            # Encode both inputs
            mu1, _ = self.encode(x1.view(1, -1))
            mu2, _ = self.encode(x2.view(1, -1))
            
            # Interpolate in latent space
            interpolations = []
            for i in range(num_steps):
                alpha = i / (num_steps - 1)
                z_interp = (1 - alpha) * mu1 + alpha * mu2
                x_interp = self.decode(z_interp)
                interpolations.append(x_interp)
            
            return torch.cat(interpolations, dim=0)

def vae_loss_function(recon_x, x, mu, logvar, beta=1.0):
    """
    VAE loss function (negative ELBO)
    
    Args:
        recon_x: Reconstructed input
        x: Original input
        mu: Latent mean
        logvar: Latent log variance
        beta: Weight for KL divergence (β-VAE)
    
    Returns:
        loss: Total loss
        reconstruction_loss: Reconstruction term
        kl_loss: KL divergence term
    """
    # Flatten inputs
    if len(x.shape) > 2:
        x = x.view(x.size(0), -1)
    if len(recon_x.shape) > 2:
        recon_x = recon_x.view(recon_x.size(0), -1)
    
    # Reconstruction loss (negative log-likelihood)
    # For MNIST, we use binary cross-entropy
    reconstruction_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL divergence: D_KL(q(z|x) || p(z))
    # Analytical form for Gaussian q and standard normal p
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Total loss
    loss = reconstruction_loss + beta * kl_loss
    
    return loss, reconstruction_loss, kl_loss

# Advanced VAE Training Class
class VAETrainer:
    """Comprehensive VAE training with monitoring and evaluation"""
    
    def __init__(self, model, device='cpu'):
        self.model = model.to(device)
        self.device = device
        self.train_losses = []
        self.val_losses = []
        self.kl_losses = []
        self.recon_losses = []
        
    def train_epoch(self, train_loader, optimizer, beta=1.0):
        """Train for one epoch"""
        self.model.train()
        epoch_loss = 0
        epoch_recon_loss = 0
        epoch_kl_loss = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(self.device)
            
            optimizer.zero_grad()
            
            # Forward pass
            recon_batch, mu, logvar, z = self.model(data)
            
            # Compute loss
            loss, recon_loss, kl_loss = vae_loss_function(
                recon_batch, data, mu, logvar, beta
            )
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Accumulate losses
            epoch_loss += loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_kl_loss += kl_loss.item()
        
        # Average losses
        epoch_loss /= len(train_loader.dataset)
        epoch_recon_loss /= len(train_loader.dataset)
        epoch_kl_loss /= len(train_loader.dataset)
        
        return epoch_loss, epoch_recon_loss, epoch_kl_loss
    
    def validate(self, val_loader, beta=1.0):
        """Validate the model"""
        self.model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for data, _ in val_loader:
                data = data.to(self.device)
                recon_batch, mu, logvar, z = self.model(data)
                loss, _, _ = vae_loss_function(recon_batch, data, mu, logvar, beta)
                val_loss += loss.item()
        
        val_loss /= len(val_loader.dataset)
        return val_loss
    
    def train(self, train_loader, val_loader, epochs=50, lr=1e-3, beta=1.0, 
              beta_schedule=None, print_every=10):
        """Complete training loop"""
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        print(f"Training VAE for {epochs} epochs...")
        print(f"Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        
        for epoch in range(epochs):
            # Adjust beta if schedule provided
            current_beta = beta
            if beta_schedule:
                current_beta = beta_schedule(epoch, epochs)
            
            # Train
            train_loss, recon_loss, kl_loss = self.train_epoch(
                train_loader, optimizer, current_beta
            )
            
            # Validate
            val_loss = self.validate(val_loader, current_beta)
            
            # Store losses
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.recon_losses.append(recon_loss)
            self.kl_losses.append(kl_loss)
            
            # Print progress
            if epoch % print_every == 0:
                print(f'Epoch {epoch:3d}: Train Loss: {train_loss:.4f}, '
                      f'Val Loss: {val_loss:.4f}, '
                      f'Recon: {recon_loss:.4f}, '
                      f'KL: {kl_loss:.4f}, '
                      f'β: {current_beta:.3f}')
    
    def plot_training_curves(self):
        """Plot training curves"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Total loss
        axes[0, 0].plot(self.train_losses, label='Train')
        axes[0, 0].plot(self.val_losses, label='Validation')
        axes[0, 0].set_title('Total Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Reconstruction loss
        axes[0, 1].plot(self.recon_losses)
        axes[0, 1].set_title('Reconstruction Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].grid(True, alpha=0.3)
        
        # KL loss
        axes[1, 0].plot(self.kl_losses)
        axes[1, 0].set_title('KL Divergence')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('KL Loss')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Loss ratio
        recon_ratio = np.array(self.recon_losses) / (np.array(self.recon_losses) + np.array(self.kl_losses))
        axes[1, 1].plot(recon_ratio)
        axes[1, 1].set_title('Reconstruction Loss Ratio')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('Recon / (Recon + KL)')
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# Beta scheduling for β-VAE
def beta_schedule_cyclical(epoch, total_epochs, n_cycles=4):
    """Cyclical beta schedule for better training"""
    cycle_length = total_epochs // n_cycles
    cycle_pos = epoch % cycle_length
    return min(1.0, cycle_pos / (cycle_length * 0.5))

def beta_schedule_linear(epoch, total_epochs, max_beta=1.0):
    """Linear beta schedule"""
    return min(max_beta, epoch / (total_epochs * 0.5))

# Load and prepare MNIST data
def load_mnist_for_vae(batch_size=128):
    """Load MNIST data for VAE training"""
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

# Demonstration
print("🏗️ **Setting up VAE training on MNIST**")

# Load data
train_loader, test_loader = load_mnist_for_vae(batch_size=128)
print(f"✅ Loaded MNIST: {len(train_loader.dataset)} train, {len(test_loader.dataset)} test samples")

# Create VAE model
vae = VAE(input_dim=784, hidden_dim=400, latent_dim=20)
print(f"✅ Created VAE with {sum(p.numel() for p in vae.parameters()):,} parameters")

# Create trainer
trainer = VAETrainer(vae, device)

# Train the model
print("🚀 **Starting VAE training...**")
trainer.train(
    train_loader, 
    test_loader, 
    epochs=30, 
    lr=1e-3, 
    beta=1.0,
    print_every=5
)

print("✅ **Training completed!**")

In [None]:
# VAE Visualization and Analysis

def visualize_vae_results(vae, test_loader, device, num_samples=10):
    """Comprehensive VAE visualization"""
    vae.eval()
    
    # Get test batch
    test_data, _ = next(iter(test_loader))
    test_data = test_data.to(device)
    
    with torch.no_grad():
        # Reconstructions
        recon_data, mu, logvar, z = vae(test_data[:num_samples])
        
        # Generated samples
        generated = vae.sample(num_samples)
        
        # Interpolations
        interpolated = vae.interpolate(test_data[0], test_data[1], num_steps=num_samples)
    
    # Plot results
    fig, axes = plt.subplots(4, num_samples, figsize=(num_samples*2, 8))
    
    for i in range(num_samples):
        # Original
        axes[0, i].imshow(test_data[i].cpu().squeeze(), cmap='gray')
        axes[0, i].set_title('Original' if i == 0 else '')
        axes[0, i].axis('off')
        
        # Reconstruction
        axes[1, i].imshow(recon_data[i].cpu().view(28, 28), cmap='gray')
        axes[1, i].set_title('Reconstruction' if i == 0 else '')
        axes[1, i].axis('off')
        
        # Generated
        axes[2, i].imshow(generated[i].cpu().view(28, 28), cmap='gray')
        axes[2, i].set_title('Generated' if i == 0 else '')
        axes[2, i].axis('off')
        
        # Interpolation
        axes[3, i].imshow(interpolated[i].cpu().view(28, 28), cmap='gray')
        axes[3, i].set_title('Interpolation' if i == 0 else '')
        axes[3, i].axis('off')
    
    plt.tight_layout()
    plt.show()

def analyze_latent_space(vae, test_loader, device, num_samples=1000):
    """Analyze the learned latent space"""
    vae.eval()
    
    latents = []
    labels = []
    
    with torch.no_grad():
        for i, (data, label) in enumerate(test_loader):
            if len(latents) * test_loader.batch_size >= num_samples:
                break
                
            data = data.to(device)
            mu, logvar = vae.encode(data)
            latents.append(mu.cpu())
            labels.append(label)
    
    latents = torch.cat(latents, dim=0)[:num_samples]
    labels = torch.cat(labels, dim=0)[:num_samples]
    
    # Apply t-SNE for visualization if latent_dim > 2
    if vae.latent_dim > 2:
        print("Applying t-SNE for latent space visualization...")
        tsne = TSNE(n_components=2, random_state=42)
        latents_2d = tsne.fit_transform(latents.numpy())
    else:
        latents_2d = latents.numpy()
    
    # Plot latent space
    plt.figure(figsize=(12, 5))
    
    # Latent space colored by digit
    plt.subplot(1, 2, 1)
    scatter = plt.scatter(latents_2d[:, 0], latents_2d[:, 1], c=labels, cmap='tab10', alpha=0.7)
    plt.colorbar(scatter)
    plt.title('Latent Space (colored by digit)')
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    
    # Latent space density
    plt.subplot(1, 2, 2)
    plt.hist2d(latents_2d[:, 0], latents_2d[:, 1], bins=50, cmap='Blues')
    plt.colorbar()
    plt.title('Latent Space Density')
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    
    plt.tight_layout()
    plt.show()
    
    return latents_2d, labels

def evaluate_vae_metrics(vae, test_loader, device):
    """Evaluate VAE using various metrics"""
    vae.eval()
    
    total_loss = 0
    total_recon_loss = 0
    total_kl_loss = 0
    num_samples = 0
    
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon_data, mu, logvar, z = vae(data)
            
            loss, recon_loss, kl_loss = vae_loss_function(recon_data, data, mu, logvar)
            
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_kl_loss += kl_loss.item()
            num_samples += data.size(0)
    
    # Average losses
    avg_loss = total_loss / num_samples
    avg_recon_loss = total_recon_loss / num_samples
    avg_kl_loss = total_kl_loss / num_samples
    
    # Log-likelihood approximation
    log_likelihood = -avg_loss
    
    # Bits per dimension
    bits_per_dim = avg_loss / (np.log(2) * 784)  # 784 = 28*28 for MNIST
    
    print("📊 **VAE Evaluation Metrics:**")
    print(f"Average Test Loss: {avg_loss:.4f}")
    print(f"Reconstruction Loss: {avg_recon_loss:.4f}")
    print(f"KL Divergence: {avg_kl_loss:.4f}")
    print(f"Log-Likelihood (approx): {log_likelihood:.4f}")
    print(f"Bits per Dimension: {bits_per_dim:.4f}")
    
    return {
        'total_loss': avg_loss,
        'reconstruction_loss': avg_recon_loss,
        'kl_loss': avg_kl_loss,
        'log_likelihood': log_likelihood,
        'bits_per_dim': bits_per_dim
    }

# Run comprehensive evaluation
print("🔍 **Evaluating trained VAE...**")

# Plot training curves
trainer.plot_training_curves()

# Visualize results
print("🎨 **Visualizing VAE results...**")
visualize_vae_results(vae, test_loader, device, num_samples=8)

# Analyze latent space
print("🧠 **Analyzing latent space...**")
latents_2d, labels = analyze_latent_space(vae, test_loader, device, num_samples=1000)

# Evaluate metrics
metrics = evaluate_vae_metrics(vae, test_loader, device)

print("\n🎯 **Key Insights:**")
print("1. **Reconstruction Quality**: How well VAE reconstructs inputs")
print("2. **Latent Structure**: Meaningful organization in latent space")
print("3. **Generation Quality**: Samples from prior distribution")
print("4. **Interpolation**: Smooth transitions in latent space")
print("5. **Bits per Dimension**: Compression efficiency metric")

## **🚀 2.2 Advanced VAE Variants**

### **β-VAE: Controlling Disentanglement**

β-VAE modifies the ELBO by introducing a hyperparameter β:

```
L_β = E_q(z|x)[log p(x|z)] - β·D_KL(q(z|x)||p(z))
```

**Effects of β:**
- **β < 1**: Prioritizes reconstruction, may lead to posterior collapse
- **β = 1**: Standard VAE
- **β > 1**: Encourages disentanglement, may hurt reconstruction

### **Wasserstein AutoEncoder (WAE)**

WAE replaces KL divergence with Wasserstein distance:

```
L_WAE = E_q(z|x)[c(x, G(z))] + λ·W_c(q_Z, p_Z)
```

Where:
- c(x, G(z)) is the reconstruction cost
- W_c is the Wasserstein distance
- λ controls the regularization strength

### **Vector Quantized VAE (VQ-VAE)**

VQ-VAE uses discrete latent representations:

```
z_q = e_k where k = argmin_j ||z_e - e_j||_2
```

**Key Features:**
- **Discrete latents**: No posterior collapse
- **Codebook learning**: Learnable discrete representations
- **Straight-through estimator**: Gradient approximation

### **🎯 Interview Focus: When to Use Each Variant**

| Variant | Best For | Trade-offs |
|---------|----------|------------|
| **Standard VAE** | General representation learning | Blurry reconstructions |
| **β-VAE** | Disentangled representations | Reconstruction vs disentanglement |
| **WAE** | Better sample quality | More complex training |
| **VQ-VAE** | Discrete representations, autoregressive modeling | Limited continuous interpolation |

In [None]:
# Advanced VAE Variants Implementation

class BetaVAE(VAE):
    """β-VAE for disentangled representation learning"""
    
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20, beta=4.0):
        super().__init__(input_dim, hidden_dim, latent_dim)
        self.beta = beta
    
    def loss_function(self, recon_x, x, mu, logvar):
        """β-VAE loss function"""
        return vae_loss_function(recon_x, x, mu, logvar, beta=self.beta)

class VectorQuantizer(nn.Module):
    """Vector Quantization layer for VQ-VAE"""
    
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost
        
        # Initialize embeddings
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
    
    def forward(self, inputs):
        """
        Quantize inputs using codebook
        
        Args:
            inputs: (batch_size, embedding_dim, height, width)
        Returns:
            quantized: Quantized version of inputs
            vq_loss: Vector quantization loss
            encoding_indices: Indices of chosen embeddings
        """
        # Convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self.embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self.embeddings.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self.embeddings.weight.t()))
        
        # Find closest embeddings
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self.embeddings.weight).view(input_shape)
        
        # Calculate VQ loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        vq_loss = q_latent_loss + self.commitment_cost * e_latent_loss
        
        # Straight-through estimator
        quantized = inputs + (quantized - inputs).detach()
        
        # Convert back to BCHW
        quantized = quantized.permute(0, 3, 1, 2).contiguous()
        
        return quantized, vq_loss, encoding_indices

class VQVAE(nn.Module):
    """Vector Quantized VAE implementation"""
    
    def __init__(self, num_embeddings=512, embedding_dim=64, hidden_dim=128):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, hidden_dim, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, embedding_dim, 1)
        )
        
        # Vector Quantizer
        self.vq = VectorQuantizer(num_embeddings, embedding_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(embedding_dim, hidden_dim, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, stride=1, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # Encode
        z_e = self.encoder(x)
        
        # Quantize
        z_q, vq_loss, encoding_indices = self.vq(z_e)
        
        # Decode
        recon = self.decoder(z_q)
        
        return recon, vq_loss, encoding_indices
    
    def encode(self, x):
        """Encode input to discrete indices"""
        z_e = self.encoder(x)
        _, _, encoding_indices = self.vq(z_e)
        return encoding_indices
    
    def decode_indices(self, indices):
        """Decode from discrete indices"""
        # Convert indices to quantized vectors
        z_q = self.vq.embeddings(indices)
        
        # Reshape for decoder (assuming 7x7 spatial dimension for MNIST)
        z_q = z_q.view(-1, self.embedding_dim, 7, 7)
        
        return self.decoder(z_q)

def train_vqvae(model, train_loader, epochs=30, lr=1e-3):
    """Train VQ-VAE model"""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    model.train()
    train_losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        epoch_recon_loss = 0
        epoch_vq_loss = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            
            optimizer.zero_grad()
            
            recon, vq_loss, _ = model(data)
            
            # Reconstruction loss
            recon_loss = F.binary_cross_entropy(recon, data, reduction='sum')
            
            # Total loss
            loss = recon_loss + vq_loss
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_vq_loss += vq_loss.item()
        
        # Average losses
        epoch_loss /= len(train_loader.dataset)
        epoch_recon_loss /= len(train_loader.dataset)
        epoch_vq_loss /= len(train_loader.dataset)
        
        train_losses.append(epoch_loss)
        
        if epoch % 5 == 0:
            print(f'Epoch {epoch:3d}: Total: {epoch_loss:.4f}, '
                  f'Recon: {epoch_recon_loss:.4f}, VQ: {epoch_vq_loss:.4f}')
    
    return train_losses

# Demonstration of advanced variants
print("🚀 **Training Advanced VAE Variants**")

# 1. β-VAE with different β values
print("\n1️⃣ **β-VAE Training (β=4.0 for disentanglement)**")
beta_vae = BetaVAE(beta=4.0).to(device)
beta_trainer = VAETrainer(beta_vae, device)
beta_trainer.train(train_loader, test_loader, epochs=20, lr=1e-3, print_every=5)

# 2. VQ-VAE
print("\n2️⃣ **VQ-VAE Training**")
vq_vae = VQVAE(num_embeddings=256, embedding_dim=64).to(device)
vq_losses = train_vqvae(vq_vae, train_loader, epochs=20, lr=1e-3)

# Compare models
def compare_vae_variants(standard_vae, beta_vae, vq_vae, test_loader):
    """Compare different VAE variants"""
    fig, axes = plt.subplots(3, 8, figsize=(16, 6))
    
    # Get test data
    test_data, _ = next(iter(test_loader))
    test_data = test_data[:8].to(device)
    
    with torch.no_grad():
        # Standard VAE
        recon_std, _, _, _ = standard_vae(test_data)
        
        # β-VAE
        recon_beta, _, _, _ = beta_vae(test_data)
        
        # VQ-VAE
        recon_vq, _, _ = vq_vae(test_data)
    
    for i in range(8):
        # Standard VAE
        axes[0, i].imshow(recon_std[i].cpu().view(28, 28), cmap='gray')
        axes[0, i].set_title('Standard VAE' if i == 0 else '')
        axes[0, i].axis('off')
        
        # β-VAE
        axes[1, i].imshow(recon_beta[i].cpu().view(28, 28), cmap='gray')
        axes[1, i].set_title('β-VAE (β=4.0)' if i == 0 else '')
        axes[1, i].axis('off')
        
        # VQ-VAE
        axes[2, i].imshow(recon_vq[i].cpu().squeeze(), cmap='gray')
        axes[2, i].set_title('VQ-VAE' if i == 0 else '')
        axes[2, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Compare the models
print("\n🔄 **Comparing VAE Variants**")
compare_vae_variants(vae, beta_vae, vq_vae, test_loader)

# Plot training comparison
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(trainer.train_losses, label='Standard VAE')
plt.title('Standard VAE Training')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(beta_trainer.train_losses, label='β-VAE', color='orange')
plt.title('β-VAE Training (β=4.0)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.plot(vq_losses, label='VQ-VAE', color='green')
plt.title('VQ-VAE Training')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n🎯 **Key Comparisons:**")
print("1. **Standard VAE**: Balanced reconstruction and regularization")
print("2. **β-VAE**: Better disentanglement but potentially worse reconstruction")
print("3. **VQ-VAE**: Discrete representations, no posterior collapse")
print("4. **Training Stability**: VQ-VAE most stable, β-VAE depends on β value")
print("5. **Use Cases**: Choose based on application requirements")