# VQ-VAE Experiments Analysis

Analyze and compare VQ-EMA baseline vs BA-VQ results

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np
import torch
from pathlib import Path
import pandas as pd
import seaborn as sns

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 4)

## 1. Load Experiment Results

In [None]:
# List all experiments
exp_dir = Path('experiments')
experiments = [d.name for d in exp_dir.iterdir() if d.is_dir()]
print(f"Found {len(experiments)} experiments:")
for exp in experiments:
    print(f"  - {exp}")

In [None]:
# Load final metrics for all experiments
results = {}
for exp in experiments:
    metrics_file = exp_dir / exp / 'final_metrics.json'
    if metrics_file.exists():
        with open(metrics_file) as f:
            results[exp] = json.load(f)

# Convert to DataFrame
df = pd.DataFrame(results).T
df.index.name = 'experiment'
print("\nFinal metrics:")
df

## 2. Comparison: VQ-EMA vs BA-VQ

In [None]:
# Filter for K=512 experiments (main comparison)
baseline = results.get('vq_k512', {})
ba_vq = results.get('ba_k512', {})

if baseline and ba_vq:
    print("Comparison at K=512:")
    print("=" * 60)
    
    metrics_to_compare = ['perplexity', 'usage_rate', 'psnr', 'loss']
    
    for metric in metrics_to_compare:
        vq_val = baseline.get(metric, 0)
        ba_val = ba_vq.get(metric, 0)
        diff = ba_val - vq_val
        pct_change = (diff / vq_val * 100) if vq_val != 0 else 0
        
        print(f"{metric:20s}: VQ-EMA={vq_val:.4f}, BA-VQ={ba_val:.4f}, "
              f"Δ={diff:+.4f} ({pct_change:+.1f}%)")
else:
    print("K=512 experiments not found. Please run:")
    print("  python train.py --quantizer vq_ema --codebook_size 512 --epochs 30 --name vq_k512")
    print("  python train.py --quantizer ba_vq --codebook_size 512 --epochs 30 --name ba_k512")

## 3. Visualizations

In [None]:
# Plot comparison for K=256 and K=512
if len(results) >= 4:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Group experiments
    vq_ema_exps = {k: v for k, v in results.items() if 'vq_' in k or 'baseline' in k}
    ba_vq_exps = {k: v for k, v in results.items() if 'ba_' in k or 'ba_test' in k}
    
    metrics = ['perplexity', 'usage_rate', 'psnr']
    titles = ['Perplexity (↑ better)', 'Usage Rate (↑ better)', 'PSNR (↑ better)']
    
    for i, (metric, title) in enumerate(zip(metrics, titles)):
        ax = axes[i]
        
        # Extract values
        vq_vals = [v.get(metric, 0) for v in vq_ema_exps.values()]
        ba_vals = [v.get(metric, 0) for v in ba_vq_exps.values()]
        
        x = np.arange(len(vq_vals))
        width = 0.35
        
        ax.bar(x - width/2, vq_vals, width, label='VQ-EMA', alpha=0.8)
        ax.bar(x + width/2, ba_vals, width, label='BA-VQ', alpha=0.8)
        
        ax.set_xlabel('Experiment')
        ax.set_ylabel(metric.replace('_', ' ').title())
        ax.set_title(title)
        ax.set_xticks(x)
        ax.set_xticklabels([f"#{i+1}" for i in range(len(vq_vals))])
        ax.legend()
        ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(exp_dir / 'comparison_plots.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("Not enough experiments for visualization. Run all 4 experiments first.")

## 4. Load and Visualize Reconstructions

In [None]:
# Load model and generate reconstructions
from vqvae import VQVAE, get_dataloaders
import torchvision

def load_model(exp_name, codebook_size=512, quantizer_type='vq_ema'):
    """Load trained model from checkpoint"""
    model = VQVAE(quantizer_type=quantizer_type, codebook_size=codebook_size)
    checkpoint_path = exp_dir / exp_name / 'final_model.pt'
    
    if checkpoint_path.exists():
        model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
        model.eval()
        return model
    else:
        print(f"Checkpoint not found: {checkpoint_path}")
        return None

def show_reconstructions(model, n_samples=8):
    """Show original vs reconstructed images"""
    _, val_loader = get_dataloaders(batch_size=n_samples)
    
    # Get one batch
    x, _ = next(iter(val_loader))
    
    with torch.no_grad():
        x_recon, _, _, _ = model(x)
    
    # Denormalize
    x = x * 0.5 + 0.5
    x_recon = x_recon * 0.5 + 0.5
    
    # Plot
    fig, axes = plt.subplots(2, n_samples, figsize=(n_samples*2, 4))
    
    for i in range(n_samples):
        axes[0, i].imshow(x[i].permute(1, 2, 0).clamp(0, 1))
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=12)
        
        axes[1, i].imshow(x_recon[i].permute(1, 2, 0).clamp(0, 1))
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=12)
    
    plt.tight_layout()
    return fig

In [None]:
# Show reconstructions for VQ-EMA
if 'vq_k512' in results:
    print("VQ-EMA (K=512) Reconstructions:")
    model_vq = load_model('vq_k512', codebook_size=512, quantizer_type='vq_ema')
    if model_vq:
        fig = show_reconstructions(model_vq)
        fig.savefig(exp_dir / 'reconstructions_vq_ema.png', dpi=150, bbox_inches='tight')
        plt.show()

In [None]:
# Show reconstructions for BA-VQ
if 'ba_k512' in results:
    print("BA-VQ (K=512) Reconstructions:")
    model_ba = load_model('ba_k512', codebook_size=512, quantizer_type='ba_vq')
    if model_ba:
        fig = show_reconstructions(model_ba)
        fig.savefig(exp_dir / 'reconstructions_ba_vq.png', dpi=150, bbox_inches='tight')
        plt.show()

## 5. Statistical Analysis

In [None]:
# If you have multiple runs with different seeds, perform statistical tests
from scipy import stats

# Group by quantizer type
vq_ema_results = [v for k, v in results.items() if 'vq_' in k or 'baseline' in k]
ba_vq_results = [v for k, v in results.items() if 'ba_' in k]

if len(vq_ema_results) >= 2 and len(ba_vq_results) >= 2:
    print("Statistical Comparison (t-test):")
    print("=" * 60)
    
    for metric in ['perplexity', 'usage_rate', 'psnr']:
        vq_vals = [r.get(metric, 0) for r in vq_ema_results]
        ba_vals = [r.get(metric, 0) for r in ba_vq_results]
        
        t_stat, p_value = stats.ttest_ind(ba_vals, vq_vals)
        
        print(f"{metric:20s}: t={t_stat:+.3f}, p={p_value:.4f}", end='')
        if p_value < 0.05:
            print(" ✓ Significant difference")
        else:
            print(" ✗ Not significant")
else:
    print("Need at least 2 runs of each quantizer type for statistical tests")

## 6. Summary Report

In [None]:
# Generate summary report
print("\n" + "="*60)
print("EXPERIMENT SUMMARY")
print("="*60)

print(f"\nTotal experiments: {len(results)}")
print(f"VQ-EMA runs: {len([k for k in results if 'vq_' in k or 'baseline' in k])}")
print(f"BA-VQ runs: {len([k for k in results if 'ba_' in k])}")

if baseline and ba_vq:
    print("\n" + "="*60)
    print("KEY FINDINGS (K=512):")
    print("="*60)
    
    perp_improvement = (ba_vq['perplexity'] - baseline['perplexity']) / baseline['perplexity'] * 100
    usage_improvement = (ba_vq['usage_rate'] - baseline['usage_rate']) / baseline['usage_rate'] * 100
    psnr_change = ba_vq['psnr'] - baseline['psnr']
    
    print(f"\nCodebook Utilization:")
    print(f"  Perplexity improvement: {perp_improvement:+.1f}%")
    print(f"  Usage rate improvement: {usage_improvement:+.1f}%")
    
    print(f"\nReconstruction Quality:")
    print(f"  PSNR change: {psnr_change:+.2f} dB")
    
    print("\n" + "="*60)
    if perp_improvement > 5 and usage_improvement > 5 and psnr_change > -0.5:
        print("✓ BA-VQ shows promising improvement!")
    elif perp_improvement > 0:
        print("○ BA-VQ shows marginal improvement")
    else:
        print("✗ BA-VQ does not improve over baseline")
    print("="*60)

print("\nAll results saved to: experiments/")
print("Plots saved to: experiments/comparison_plots.png")

## 7. Next Steps

In [None]:
print("Recommendations for next steps:\n")

if not baseline or not ba_vq:
    print("1. Run the baseline experiments (K=512):")
    print("   python train.py --quantizer vq_ema --codebook_size 512 --epochs 30 --name vq_k512")
    print("   python train.py --quantizer ba_vq --codebook_size 512 --epochs 30 --name ba_k512")
elif len(results) < 4:
    print("1. Run K=256 experiments for comparison:")
    print("   python train.py --quantizer vq_ema --codebook_size 256 --epochs 30 --name baseline_test")
    print("   python train.py --quantizer ba_vq --codebook_size 256 --epochs 30 --name ba_test")
else:
    print("1. Run full training (100 epochs) for final results")
    print("2. Test larger codebook sizes (K=1024, 2048)")
    print("3. Tune BA-VQ hyperparameters (beta schedule, iterations)")
    print("4. Try different datasets (ImageNet, etc.)")