# BOM-VAE vs β-VAE Comparison

**Hypothesis**: BOM achieves comparable or better results than β-VAE without requiring hyperparameter tuning.

**Adaptive squeeze rule**:
```
squeeze_amount = (s_min - 0.5) * k
```
- When s_min = 0.9: squeeze aggressively
- When s_min = 0.55: squeeze gently  
- When s_min ≤ 0.5: stop squeezing

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

In [None]:
# Shared VAE architecture
class VAE(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        self.latent_dim = latent_dim
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 3, 2, 1), nn.GroupNorm(8, 32), nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 3, 2, 1), nn.GroupNorm(8, 64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 2, 1), nn.GroupNorm(8, 128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, 2, 1), nn.GroupNorm(8, 256), nn.LeakyReLU(0.2),
        )
        self.fc_mu = nn.Linear(256*4*4, latent_dim)
        self.fc_logvar = nn.Linear(256*4*4, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, 256*4*4)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.GroupNorm(8, 128), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), nn.GroupNorm(8, 64), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), nn.GroupNorm(8, 32), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 1, 3, 2, 1, 1), nn.Sigmoid(),
        )
    
    def forward(self, x):
        h = self.enc(x).view(x.size(0), -1)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        return self.dec(self.fc_dec(z).view(-1, 256, 4, 4)), mu, logvar

In [None]:
# Data
transform = transforms.Compose([transforms.Resize(64), transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=2)
print(f"Train: {len(train_loader)} batches, Test: {len(test_loader)} batches")

## Shared Metrics

In [None]:
def compute_metrics(x, x_recon, mu, logvar):
    """Compute MSE, KL, and sharpness."""
    B = x.size(0)
    mse = F.mse_loss(x_recon, x, reduction='none').view(B, -1).mean(1)
    kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1)
    dx = torch.abs(x_recon[:,:,:,1:] - x_recon[:,:,:,:-1])
    dy = torch.abs(x_recon[:,:,1:,:] - x_recon[:,:,:-1,:])
    sharp = (dx.mean([1,2,3]) + dy.mean([1,2,3])) / 2
    return mse, kl, sharp


def evaluate(model, loader, device):
    """Evaluate model on test set."""
    model.eval()
    all_mse, all_kl, all_sharp = [], [], []
    
    with torch.no_grad():
        for batch in loader:
            x = batch[0].to(device)
            x_recon, mu, logvar = model(x)
            mse, kl, sharp = compute_metrics(x, x_recon, mu, logvar)
            all_mse.extend(mse.cpu().numpy())
            all_kl.extend(kl.cpu().numpy())
            all_sharp.extend(sharp.cpu().numpy())
    
    return {
        'mse': np.mean(all_mse),
        'kl': np.mean(all_kl),
        'sharp': np.mean(all_sharp),
    }

---
## β-VAE Training

In [None]:
def train_beta_vae(model, loader, device, beta, n_epochs=20):
    """
    Standard β-VAE training.
    Loss = MSE + β * KL
    """
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    history = []
    
    for epoch in range(1, n_epochs + 1):
        model.train()
        epoch_loss, epoch_mse, epoch_kl = [], [], []
        
        pbar = tqdm(loader, desc=f"β-VAE (β={beta}) Epoch {epoch}")
        for batch in pbar:
            x = batch[0].to(device)
            
            optimizer.zero_grad()
            x_recon, mu, logvar = model(x)
            
            mse, kl, sharp = compute_metrics(x, x_recon, mu, logvar)
            loss = mse.mean() + beta * kl.mean()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss.append(loss.item())
            epoch_mse.append(mse.mean().item())
            epoch_kl.append(kl.mean().item())
            
            history.append({
                'mse': mse.mean().item(),
                'kl': kl.mean().item(),
                'sharp': sharp.mean().item(),
            })
            
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'mse': f"{mse.mean().item():.4f}", 'kl': f"{kl.mean().item():.0f}"})
        
        print(f"  Epoch {epoch}: loss={np.mean(epoch_loss):.4f}, mse={np.mean(epoch_mse):.4f}, kl={np.mean(epoch_kl):.0f}")
    
    return history

---
## BOM-VAE Training with Adaptive Squeeze

In [None]:
def regular_constraint_lower_better(value, floor):
    """Score for objectives where lower is better (MSE)."""
    return (floor - value) / floor


def regular_constraint_higher_better(value, ceiling):
    """Score for objectives where higher is better (sharpness)."""
    return value / ceiling


def box_constraint(value, floor_low, optimum, floor_high):
    """Score for objectives that need to stay in a range (KL)."""
    left = (value - floor_low) / (optimum - floor_low)
    right = (floor_high - value) / (floor_high - optimum)
    return torch.minimum(left, right)


def compute_bom_loss(x, x_recon, mu, logvar, mse_floor, kl_floor_low, kl_optimum, kl_floor_high, sharp_ceiling):
    """Compute BOM loss."""
    mse, kl, sharp = compute_metrics(x, x_recon, mu, logvar)
    
    mse_score = regular_constraint_lower_better(mse, mse_floor)
    kl_score = box_constraint(kl, kl_floor_low, kl_optimum, kl_floor_high)
    sharp_score = regular_constraint_higher_better(sharp, sharp_ceiling)
    
    scores = torch.stack([mse_score, kl_score, sharp_score], dim=1)
    s_min, min_idx = torch.min(scores, dim=1)
    
    violations = (s_min <= 0).sum().item()
    
    metrics = {
        'mse': mse.mean().item(),
        'kl': kl.mean().item(),
        'sharp': sharp.mean().item(),
        'mse_score': mse_score.mean().item(),
        'kl_score': kl_score.mean().item(),
        'sharp_score': sharp_score.mean().item(),
        's_min': s_min.mean().item(),
        'violations': violations,
    }
    
    if violations > 0:
        return None, metrics
    
    loss = -torch.log(s_min).mean()
    names = ['mse', 'kl', 'sharp']
    metrics['bottleneck'] = names[torch.bincount(min_idx, minlength=3).argmax().item()]
    metrics['loss'] = loss.item()
    
    return loss, metrics

In [None]:
def calibrate_bom(model, loader, device, n_batches=50):
    """Calibrate BOM constraints based on model's current outputs."""
    model.train()
    all_mse, all_kl, all_sharp = [], [], []
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= n_batches: break
            x = batch[0].to(device)
            x_recon, mu, logvar = model(x)
            mse, kl, sharp = compute_metrics(x, x_recon, mu, logvar)
            all_mse.extend(mse.cpu().numpy())
            all_kl.extend(kl.cpu().numpy())
            all_sharp.extend(sharp.cpu().numpy())
    
    mse_arr = np.array(all_mse)
    kl_arr = np.array(all_kl)
    sharp_arr = np.array(all_sharp)
    
    params = {
        'mse_floor': mse_arr.max() * 2.0,
        'kl_floor_low': kl_arr.min() * 0.1,
        'kl_optimum': kl_arr.mean(),
        'kl_floor_high': kl_arr.max() * 50.0,  # Very loose initially
        'sharp_ceiling': sharp_arr.mean(),
    }
    
    print(f"Calibration: MSE={mse_arr.mean():.4f}, KL={kl_arr.mean():.1f}, Sharp={sharp_arr.mean():.4f}")
    print(f"Initial constraints: mse_floor={params['mse_floor']:.4f}, kl_box=[{params['kl_floor_low']:.1f}, {params['kl_optimum']:.1f}, {params['kl_floor_high']:.1f}]")
    
    return params

In [None]:
def train_bom_vae(model, loader, device, n_epochs=20):
    """
    BOM-VAE training with adaptive squeeze.
    
    Squeeze rule: squeeze_amount = (s_min - 0.5) * k
    - s_min > 0.5: squeeze proportionally
    - s_min <= 0.5: don't squeeze
    """
    # Calibrate
    params = calibrate_bom(model, loader, device)
    print('Recalibrated BOM constraints for current architecture.')
    
    mse_floor = params['mse_floor']
    kl_floor_low = params['kl_floor_low']
    kl_optimum = params['kl_optimum']
    kl_floor_high = params['kl_floor_high']
    sharp_ceiling = params['sharp_ceiling']
    
    # Targets
    # Per-dimension KL targets (tuned for latent_dim=128 => totals 50/80/150)
    kl_floor_low_per_dim = 0.390625
    kl_optimum_per_dim = 0.625
    kl_floor_high_per_dim = 1.171875
    target_kl_floor_low = kl_floor_low_per_dim * model.latent_dim
    target_kl_optimum = kl_optimum_per_dim * model.latent_dim
    target_kl_floor_high = kl_floor_high_per_dim * model.latent_dim
    if kl_floor_high < target_kl_floor_high:
        print('⚠️ KL upper bound below target; consider loosening constraints.')
    else:
        print('✅ KL targets appear attainable with recalibrated bounds.')
    
    # Adaptive squeeze settings
    squeeze_k = 0.5  # Gain factor
    min_s_min_for_squeeze = 0.5
    squeeze_start_epoch = 3
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    history = []
    
    for epoch in range(1, n_epochs + 1):
        model.train()
        epoch_loss, epoch_s_min = [], []
        epoch_violations = 0
        
        pbar = tqdm(loader, desc=f"BOM-VAE Epoch {epoch}")
        for batch in pbar:
            x = batch[0].to(device)
            
            optimizer.zero_grad()
            x_recon, mu, logvar = model(x)
            
            loss, metrics = compute_bom_loss(
                x, x_recon, mu, logvar,
                mse_floor, kl_floor_low, kl_optimum, kl_floor_high, sharp_ceiling
            )
            
            if loss is not None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                epoch_loss.append(metrics['loss'])
                epoch_s_min.append(metrics['s_min'])
            else:
                epoch_violations += metrics['violations']
            
            history.append(metrics)
            pbar.set_postfix({'s_min': f"{metrics['s_min']:.3f}", 'kl': f"{metrics['kl']:.0f}"})
        
        avg_s_min = np.mean(epoch_s_min) if epoch_s_min else 0
        print(f"  Epoch {epoch}: s_min={avg_s_min:.3f}, violations={epoch_violations}, mse={metrics['mse']:.4f}, kl={metrics['kl']:.0f}")
        print(f"    KL box: [{kl_floor_low:.1f}, {kl_optimum:.1f}, {kl_floor_high:.1f}]")
        
        # Adaptive squeeze
        if epoch >= squeeze_start_epoch and avg_s_min > min_s_min_for_squeeze:
            squeeze_amount = (avg_s_min - min_s_min_for_squeeze) * squeeze_k
            squeeze_factor = 1.0 - squeeze_amount  # e.g., s_min=0.9 -> factor=0.8
            squeeze_factor = max(0.5, squeeze_factor)  # Don't squeeze more than 50%
            
            print(f"    🔧 Squeeze: s_min={avg_s_min:.3f} -> factor={squeeze_factor:.2f}")
            
            # Squeeze MSE floor
            mse_floor *= squeeze_factor
            
            # Squeeze KL box toward targets
            if kl_floor_low < target_kl_floor_low:
                kl_floor_low += (target_kl_floor_low - kl_floor_low) * (1 - squeeze_factor)
            if kl_optimum < target_kl_optimum:
                kl_optimum += (target_kl_optimum - kl_optimum) * (1 - squeeze_factor)
            if kl_floor_high > target_kl_floor_high:
                kl_floor_high -= (kl_floor_high - target_kl_floor_high) * (1 - squeeze_factor)
    
    return history


---
## Run Comparison

In [None]:
N_EPOCHS = 20
results = {}

In [None]:
# β-VAE with different β values
betas = [0.0001, 0.001, 0.01, 0.1]

for beta in betas:
    print(f"\n{'='*60}")
    print(f"Training β-VAE with β={beta}")
    print('='*60)
    
    model = VAE(latent_dim=128).to(device)
    history = train_beta_vae(model, train_loader, device, beta=beta, n_epochs=N_EPOCHS)
    test_metrics = evaluate(model, test_loader, device)
    
    results[f'beta_{beta}'] = {
        'model': model,
        'history': history,
        'test': test_metrics,
    }
    
    print(f"\nTest results: MSE={test_metrics['mse']:.4f}, KL={test_metrics['kl']:.1f}, Sharp={test_metrics['sharp']:.4f}")

In [None]:
# BOM-VAE
print(f"\n{'='*60}")
print(f"Training BOM-VAE (no β tuning required)")
print('='*60)

model_bom = VAE(latent_dim=128).to(device)
history_bom = train_bom_vae(model_bom, train_loader, device, n_epochs=N_EPOCHS)
test_metrics_bom = evaluate(model_bom, test_loader, device)

results['bom'] = {
    'model': model_bom,
    'history': history_bom,
    'test': test_metrics_bom,
}

print(f"\nTest results: MSE={test_metrics_bom['mse']:.4f}, KL={test_metrics_bom['kl']:.1f}, Sharp={test_metrics_bom['sharp']:.4f}")

---
## Results Comparison

In [None]:
# Summary table
print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
print(f"{'Method':<20} {'MSE':>10} {'KL':>10} {'Sharpness':>12}")
print("-"*70)

for name, data in results.items():
    t = data['test']
    print(f"{name:<20} {t['mse']:>10.4f} {t['kl']:>10.1f} {t['sharp']:>12.4f}")

print("-"*70)

In [None]:
# Training curves comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for name, data in results.items():
    h = data['history']
    label = name.replace('_', '=')
    
    axes[0].plot([x['mse'] for x in h], label=label, alpha=0.8)
    axes[1].plot([x['kl'] for x in h], label=label, alpha=0.8)
    axes[2].plot([x['sharp'] for x in h], label=label, alpha=0.8)

axes[0].set_title('MSE (↓ better)')
axes[0].set_xlabel('Step')
axes[0].legend()
axes[0].set_yscale('log')

axes[1].set_title('KL Divergence')
axes[1].set_xlabel('Step')
axes[1].legend()

axes[2].set_title('Sharpness (↑ better)')
axes[2].set_xlabel('Step')
axes[2].legend()

plt.tight_layout()
plt.savefig('training_comparison.png', dpi=150)
plt.show()

In [None]:
# Pareto plot: MSE vs KL
plt.figure(figsize=(10, 6))

colors = plt.cm.viridis(np.linspace(0, 1, len(results)))

for (name, data), color in zip(results.items(), colors):
    t = data['test']
    marker = 's' if 'beta' in name else 'o'
    size = 100 if 'bom' in name else 60
    plt.scatter(t['mse'], t['kl'], s=size, c=[color], marker=marker, label=name.replace('_', '='), edgecolors='black')

plt.xlabel('MSE (↓ better)')
plt.ylabel('KL (moderate is better)')
plt.title('Pareto Front: MSE vs KL')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('pareto_comparison.png', dpi=150)
plt.show()

In [None]:
# Reconstructions comparison
test_batch = next(iter(test_loader))[0][:8].to(device)

n_models = len(results)
fig, axes = plt.subplots(n_models + 1, 8, figsize=(16, 2*(n_models+1)))

# Original
for i in range(8):
    axes[0, i].imshow(test_batch[i].cpu().squeeze(0))
    axes[0, i].axis('off')
axes[0, 0].set_ylabel('Original', fontsize=12)

# Each model's reconstruction
for row, (name, data) in enumerate(results.items(), 1):
    model = data['model']
    model.eval()
    with torch.no_grad():
        recon, _, _ = model(test_batch)
    
    for i in range(8):
        axes[row, i].imshow(recon[i].cpu().squeeze(0))
        axes[row, i].axis('off')
    axes[row, 0].set_ylabel(name.replace('_', '='), fontsize=12)

plt.tight_layout()
plt.savefig('reconstructions_comparison.png', dpi=150)
plt.show()

In [None]:
# Samples from prior comparison
z = torch.randn(8, 128, device=device)

n_models = len(results)
fig, axes = plt.subplots(n_models, 8, figsize=(16, 2*n_models))

for row, (name, data) in enumerate(results.items()):
    model = data['model']
    model.eval()
    with torch.no_grad():
        samples = model.dec(model.fc_dec(z).view(-1, 256, 4, 4))
    
    for i in range(8):
        axes[row, i].imshow(samples[i].cpu().squeeze(0))
        axes[row, i].axis('off')
    axes[row, 0].set_ylabel(name.replace('_', '='), fontsize=12)

plt.suptitle('Samples from Prior (same z for all models)', fontsize=14)
plt.tight_layout()
plt.savefig('samples_comparison.png', dpi=150)
plt.show()

---
## Analysis

In [None]:
# Find best β-VAE
beta_results = {k: v for k, v in results.items() if 'beta' in k}
bom_result = results['bom']

# Best by MSE
best_mse_beta = min(beta_results.items(), key=lambda x: x[1]['test']['mse'])
print(f"Best β-VAE by MSE: {best_mse_beta[0]} (MSE={best_mse_beta[1]['test']['mse']:.4f})")
print(f"BOM-VAE MSE: {bom_result['test']['mse']:.4f}")
print()

# Best by balanced score (low MSE, moderate KL, high sharp)
def balanced_score(t):
    # Lower MSE is better (invert)
    # KL around 50-150 is good (penalty for too low or too high)
    # Higher sharpness is better
    mse_score = 1.0 / (t['mse'] + 0.001)
    kl_score = 1.0 / (abs(t['kl'] - 100) + 10)  # Peak at KL=100
    sharp_score = t['sharp']
    return mse_score * kl_score * sharp_score

best_balanced_beta = max(beta_results.items(), key=lambda x: balanced_score(x[1]['test']))
print(f"Best β-VAE by balanced score: {best_balanced_beta[0]}")
print(f"  Score: {balanced_score(best_balanced_beta[1]['test']):.4f}")
print(f"BOM-VAE balanced score: {balanced_score(bom_result['test']):.4f}")

In [None]:
# Conclusion
print("\n" + "="*70)
print("CONCLUSION")
print("="*70)
print("""
β-VAE requires tuning β to balance reconstruction vs regularization.
Different β values give different tradeoffs:
- Low β (0.0001): Good MSE, but KL may collapse or explode
- High β (0.1): Controlled KL, but poor reconstruction

BOM-VAE automatically finds a balanced solution:
- No β hyperparameter to tune
- Adaptive squeeze finds the Pareto frontier
- All objectives are explicitly constrained

The key insight: BOM optimizes the WORST objective at each step,
preventing any single objective from being sacrificed.
""")