# VQ-VAE Exploratory Analysis

Interactive analysis of R(D)-optimal initialization experiments.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from src.data.linear_gaussian import LinearGaussianDataset
from src.models.vqvae import LinearGaussianVQVAE

%matplotlib inline
plt.style.use('seaborn-v0_8-darkgrid')

## 1. Generate Data

In [None]:
# Create dataset
dataset = LinearGaussianDataset(
    d=64,
    k=8,
    sigma_noise=0.1,
    n_samples=10000,
    seed=42
)

print(dataset)
print(f"\nData variance: {dataset.X.var():.3f}")
print(f"Latent variance (σ_z²): {dataset.sigma_z_squared:.3f}")

## 2. Load Experiment Results

In [None]:
# Load metrics
results_dir = '../results/idea_7_linear_gaussian'

metrics = {}
for method in ['uniform', 'kmeans', 'rd_gaussian']:
    path = f"{results_dir}/{method}/metrics.csv"
    metrics[method] = pd.read_csv(path)

# Display summary
for method, df in metrics.items():
    print(f"\n{method.upper()}:")
    print(f"  Steps: {len(df)}")
    print(f"  Initial Q.Err: {df['quantization_error'].iloc[0]:.6f}")
    print(f"  Final Q.Err: {df['quantization_error'].iloc[-1]:.6f}")

## 3. Visualize Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

colors = {'uniform': 'red', 'kmeans': 'blue', 'rd_gaussian': 'green'}

for method, df in metrics.items():
    axes[0, 0].plot(df['step'], df['quantization_error'], 
                    label=method, color=colors[method], linewidth=2)
    axes[0, 1].plot(df['step'], df['perplexity'], 
                    label=method, color=colors[method], linewidth=2)
    axes[1, 0].plot(df['step'], df['dead_codes'], 
                    label=method, color=colors[method], linewidth=2)
    axes[1, 1].plot(df['step'], df['recon_loss'], 
                    label=method, color=colors[method], linewidth=2)

axes[0, 0].set_title('Quantization Error')
axes[0, 0].legend()
axes[0, 0].set_yscale('log')

axes[0, 1].set_title('Perplexity')
axes[0, 1].legend()

axes[1, 0].set_title('Dead Codes')
axes[1, 0].legend()

axes[1, 1].set_title('Reconstruction Loss')
axes[1, 1].legend()
axes[1, 1].set_yscale('log')

plt.tight_layout()
plt.show()

## 4. Compare to Theory

In [None]:
# Compute theoretical R(D) bound
k = 8
codebook_size = 256
sigma_z_sq = dataset.sigma_z_squared

R = np.log2(codebook_size) / k
D_theory = k * sigma_z_sq * (2 ** (-2 * R))

print(f"\nTheoretical Analysis:")
print(f"  Rate R: {R:.3f} bits/dim")
print(f"  σ_z²: {sigma_z_sq:.3f}")
print(f"  Theoretical bound D*: {D_theory:.6f}")

print(f"\nInitial Distortion:")
for method, df in metrics.items():
    init_dist = df['quantization_error'].iloc[0]
    ratio = init_dist / D_theory
    print(f"  {method:12s}: {init_dist:.6f} ({ratio:.2f}× theory)")

## 5. Your Analysis Here

Add custom analysis, plots, or experiments below.

In [None]:
# Your code here
