# BOM-VAE vs β-VAE Comparison

**Hypothesis**: BOM achieves comparable or better results than β-VAE without requiring hyperparameter tuning.

**Adaptive squeeze rule**:
```
squeeze_amount = (s_min - 0.5) * k
```
- When s_min = 0.9: squeeze aggressively
- When s_min = 0.55: squeeze gently  
- When s_min ≤ 0.5: stop squeezing

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

In [None]:
# Shared VAE architecture
class VAE(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        self.latent_dim = latent_dim
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 3, 2, 1), nn.GroupNorm(8, 32), nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 3, 2, 1), nn.GroupNorm(8, 64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 2, 1), nn.GroupNorm(8, 128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, 2, 1), nn.GroupNorm(8, 256), nn.LeakyReLU(0.2),
        )
        self.fc_mu = nn.Linear(256*4*4, latent_dim)
        self.fc_logvar = nn.Linear(256*4*4, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, 256*4*4)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.GroupNorm(8, 128), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 3, 2, 1, 1), nn.GroupNorm(8, 64), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, 3, 2, 1, 1), nn.GroupNorm(8, 32), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 1, 3, 2, 1, 1), nn.Sigmoid(),
        )
    
    def forward(self, x):
        h = self.enc(x).view(x.size(0), -1)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        return self.dec(self.fc_dec(z).view(-1, 256, 4, 4)), mu, logvar

In [None]:
# Data
transform = transforms.Compose([transforms.Resize(64), transforms.ToTensor()])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=2)
print(f"Train: {len(train_loader)} batches, Test: {len(test_loader)} batches")

## Shared Metrics

In [None]:
def ssim_index(x, y, c1=0.01**2, c2=0.03**2):
    """Compute per-image SSIM. Expects x,y in [0,1]."""
    mu_x = x.mean([2, 3], keepdim=True)
    mu_y = y.mean([2, 3], keepdim=True)
    sigma_x = ((x - mu_x) ** 2).mean([2, 3], keepdim=True)
    sigma_y = ((y - mu_y) ** 2).mean([2, 3], keepdim=True)
    sigma_xy = ((x - mu_x) * (y - mu_y)).mean([2, 3], keepdim=True)

    numerator = (2 * mu_x * mu_y + c1) * (2 * sigma_xy + c2)
    denominator = (mu_x ** 2 + mu_y ** 2 + c1) * (sigma_x + sigma_y + c2)
    ssim_map = numerator / (denominator + 1e-8)
    return ssim_map.mean([1, 2, 3])


def compute_metrics(x, x_recon, mu, logvar):
    """Compute MSE, KL, sharpness, and SSIM."""
    B = x.size(0)
    mse = F.mse_loss(x_recon, x, reduction='none').view(B, -1).mean(1)
    kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1)
    dx = torch.abs(x_recon[:, :, :, 1:] - x_recon[:, :, :, :-1])
    dy = torch.abs(x_recon[:, :, 1:, :] - x_recon[:, :, :-1, :])
    sharp = (dx.mean([1, 2, 3]) + dy.mean([1, 2, 3])) / 2
    ssim = ssim_index(x, x_recon)
    return mse, kl, sharp, ssim


def evaluate(model, loader, device):
    """Evaluate model on test set."""
    model.eval()
    all_mse, all_kl, all_sharp, all_ssim = [], [], [], []
    
    with torch.no_grad():
        for batch in loader:
            x = batch[0].to(device)
            x_recon, mu, logvar = model(x)
            mse, kl, sharp, ssim = compute_metrics(x, x_recon, mu, logvar)
            all_mse.extend(mse.cpu().numpy())
            all_kl.extend(kl.cpu().numpy())
            all_sharp.extend(sharp.cpu().numpy())
            all_ssim.extend(ssim.cpu().numpy())
    
    return {
        'mse': np.mean(all_mse),
        'kl': np.mean(all_kl),
        'sharp': np.mean(all_sharp),
        'ssim': np.mean(all_ssim),
    }


---
## β-VAE Training

In [None]:
def train_beta_vae(model, loader, device, beta, n_epochs=20):
    """
    Standard β-VAE training.
    Loss = MSE + β * KL
    """
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    history = []
    
    for epoch in range(1, n_epochs + 1):
        model.train()
        epoch_loss, epoch_mse, epoch_kl = [], [], []
        
        pbar = tqdm(loader, desc=f"β-VAE (β={beta}) Epoch {epoch}")
        for batch in pbar:
            x = batch[0].to(device)
            
            optimizer.zero_grad()
            x_recon, mu, logvar = model(x)
            
            mse, kl, sharp, ssim = compute_metrics(x, x_recon, mu, logvar)
            loss = mse.mean() + beta * kl.mean()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_loss.append(loss.item())
            epoch_mse.append(mse.mean().item())
            epoch_kl.append(kl.mean().item())
            
            history.append({
                'mse': mse.mean().item(),
                'kl': kl.mean().item(),
                'sharp': sharp.mean().item(),
                'ssim': ssim.mean().item(),
            })
            
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'mse': f"{mse.mean().item():.4f}", 'kl': f"{kl.mean().item():.0f}"})
        
        print(f"  Epoch {epoch}: loss={np.mean(epoch_loss):.4f}, mse={np.mean(epoch_mse):.4f}, kl={np.mean(epoch_kl):.0f}")
    
    return history


---
## BOM-VAE Training with Adaptive Squeeze

In [None]:
def regular_constraint_lower_better(value, floor):
    """Score for objectives where lower is better (MSE)."""
    return (floor - value) / floor


def regular_constraint_higher_better(value, ceiling):
    """Score for objectives where higher is better (SSIM)."""
    return value / ceiling


def box_constraint(value, floor_low, optimum, floor_high):
    """Score for objectives that need to stay in a range (KL)."""
    left = (value - floor_low) / (optimum - floor_low)
    right = (floor_high - value) / (floor_high - optimum)
    return torch.minimum(left, right)


def compute_bom_loss(
    x, x_recon, mu, logvar,
    mse_floor, kl_floor_low, kl_optimum, kl_floor_high, ssim_ceiling
): 
    """Compute BOM loss using MSE, KL, and SSIM."""
    mse, kl, sharp, ssim = compute_metrics(x, x_recon, mu, logvar)
    
    mse_score = regular_constraint_lower_better(mse, mse_floor)
    kl_score = box_constraint(kl, kl_floor_low, kl_optimum, kl_floor_high)
    ssim_score = regular_constraint_higher_better(ssim, ssim_ceiling)
    
    scores = torch.stack([mse_score, kl_score, ssim_score], dim=1)
    s_min, min_idx = torch.min(scores, dim=1)
    
    violations = (s_min <= 0).sum().item()
    
    metrics = {
        'mse': mse.mean().item(),
        'kl': kl.mean().item(),
        'sharp': sharp.mean().item(),
        'ssim': ssim.mean().item(),
        'mse_score': mse_score.mean().item(),
        'kl_score': kl_score.mean().item(),
        'ssim_score': ssim_score.mean().item(),
        's_min': s_min.mean().item(),
        'violations': violations,
    }
    
    if violations > 0:
        return None, metrics
    
    loss = -torch.log(s_min).mean()
    names = ['mse', 'kl', 'ssim']
    metrics['bottleneck'] = names[torch.bincount(min_idx, minlength=3).argmax().item()]
    metrics['loss'] = loss.item()
    
    return loss, metrics


In [None]:
def calibrate_bom(model, loader, device, n_batches=50):
    """Calibrate BOM constraints based on model's current outputs."""
    model.train()  # Important: use train mode for BatchNorm
    all_mse, all_kl, all_sharp, all_ssim = [], [], [], []
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= n_batches: break
            x = batch[0].to(device)
            x_recon, mu, logvar = model(x)
            mse, kl, sharp, ssim = compute_metrics(x, x_recon, mu, logvar)
            all_mse.extend(mse.cpu().numpy())
            all_kl.extend(kl.cpu().numpy())
            all_sharp.extend(sharp.cpu().numpy())
            all_ssim.extend(ssim.cpu().numpy())
    
    mse_arr = np.array(all_mse)
    kl_arr = np.array(all_kl)
    sharp_arr = np.array(all_sharp)
    ssim_arr = np.array(all_ssim)
    
    ssim_ceiling = min(0.95, ssim_arr.mean() + 0.05)
    
    params = {
        'mse_floor': mse_arr.max() * 2.0,
        'kl_floor_low': kl_arr.min() * 0.1,
        'kl_optimum': kl_arr.mean(),
        'kl_floor_high': kl_arr.max() * 50.0,  # Very loose initially
        'ssim_ceiling': ssim_ceiling,
    }
    
    print(f"Calibration: MSE={mse_arr.mean():.4f}, KL={kl_arr.mean():.1f}, Sharp={sharp_arr.mean():.4f}, SSIM={ssim_arr.mean():.4f}")
    print(
        f"Initial constraints: mse_floor={params['mse_floor']:.4f}, "
        f"kl_box=[{params['kl_floor_low']:.1f}, {params['kl_optimum']:.1f}, {params['kl_floor_high']:.1f}], "
        f"ssim_ceiling={params['ssim_ceiling']:.3f}"
    )
    
    return params


In [None]:
def train_bom_vae(model, loader, device, n_epochs=20):
    """
    BOM-VAE training with adaptive squeeze.
    
    Squeeze rule: squeeze_amount = (s_min - 0.5) * k
    - s_min > 0.5: squeeze proportionally
    - s_min <= 0.5: don't squeeze
    """
    # Calibrate
    params = calibrate_bom(model, loader, device)
    
    mse_floor = params['mse_floor']
    kl_floor_low = params['kl_floor_low']
    kl_optimum = params['kl_optimum']
    kl_floor_high = params['kl_floor_high']
    ssim_ceiling = params['ssim_ceiling']
    
    # Targets
    target_kl_floor_low = 50
    target_kl_optimum = 80
    target_kl_floor_high = 150
    
    # Adaptive squeeze settings
    squeeze_k = 0.5  # Gain factor
    min_s_min_for_squeeze = 0.5
    squeeze_start_epoch = 3
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    history = []
    
    for epoch in range(1, n_epochs + 1):
        model.train()
        epoch_loss, epoch_s_min = [], []
        epoch_violations = 0
        
        pbar = tqdm(loader, desc=f"BOM-VAE Epoch {epoch}")
        for batch in pbar:
            x = batch[0].to(device)
            
            optimizer.zero_grad()
            x_recon, mu, logvar = model(x)
            
            loss, metrics = compute_bom_loss(
                x, x_recon, mu, logvar,
                mse_floor, kl_floor_low, kl_optimum, kl_floor_high, ssim_ceiling
            )
            
            if loss is not None:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                epoch_loss.append(metrics['loss'])
                epoch_s_min.append(metrics['s_min'])
            else:
                epoch_violations += metrics['violations']
            
            history.append(metrics)
            pbar.set_postfix({'s_min': f"{metrics['s_min']:.3f}", 'kl': f"{metrics['kl']:.0f}"})
        
        avg_s_min = np.mean(epoch_s_min) if epoch_s_min else 0
        print(
            f"  Epoch {epoch}: s_min={avg_s_min:.3f}, violations={epoch_violations}, "
            f"mse={metrics['mse']:.4f}, kl={metrics['kl']:.0f}, ssim={metrics['ssim']:.3f}"
        )
        print(f"    KL box: [{kl_floor_low:.1f}, {kl_optimum:.1f}, {kl_floor_high:.1f}]")
        print(f"    SSIM ceiling: {ssim_ceiling:.3f}")
        
        # Adaptive squeeze
        if epoch >= squeeze_start_epoch and avg_s_min > min_s_min_for_squeeze:
            squeeze_amount = (avg_s_min - min_s_min_for_squeeze) * squeeze_k
            squeeze_factor = 1.0 - squeeze_amount  # e.g., s_min=0.9 -> factor=0.8
            squeeze_factor = max(0.5, squeeze_factor)  # Don't squeeze more than 50%
            
            print(f"    🔧 Squeeze: s_min={avg_s_min:.3f} -> factor={squeeze_factor:.2f}")
            
            # Squeeze MSE floor
            mse_floor *= squeeze_factor
            
            # Squeeze KL box toward targets
            if kl_floor_low < target_kl_floor_low:
                kl_floor_low += (target_kl_floor_low - kl_floor_low) * (1 - squeeze_factor)
            if kl_optimum < target_kl_optimum:
                kl_optimum += (target_kl_optimum - kl_optimum) * (1 - squeeze_factor)
            if kl_floor_high > target_kl_floor_high:
                kl_floor_high -= (kl_floor_high - target_kl_floor_high) * (1 - squeeze_factor)
    
    return history


---
## Run Comparison

In [None]:
N_EPOCHS = 20
results = {}

In [None]:
# β-VAE with different β values
betas = [0.0001, 0.001, 0.01, 0.1]

for beta in betas:
    print(f"\n{'='*60}")
    print(f"Training β-VAE with β={beta}")
    print('='*60)
    
    model = VAE(latent_dim=128).to(device)
    history = train_beta_vae(model, train_loader, device, beta=beta, n_epochs=N_EPOCHS)
    test_metrics = evaluate(model, test_loader, device)
    
    results[f'beta_{beta}'] = {
        'model': model,
        'history': history,
        'test': test_metrics,
    }
    
    print(f"\nTest results: MSE={test_metrics['mse']:.4f}, KL={test_metrics['kl']:.1f}, SSIM={test_metrics['ssim']:.4f}")


In [None]:
# BOM-VAE
print(f"\n{'='*60}")
print(f"Training BOM-VAE (no β tuning required)")
print('='*60)

model_bom = VAE(latent_dim=128).to(device)
history_bom = train_bom_vae(model_bom, train_loader, device, n_epochs=N_EPOCHS)
test_metrics_bom = evaluate(model_bom, test_loader, device)

results['bom'] = {
    'model': model_bom,
    'history': history_bom,
    'test': test_metrics_bom,
}

print(f"\nTest results: MSE={test_metrics_bom['mse']:.4f}, KL={test_metrics_bom['kl']:.1f}, SSIM={test_metrics_bom['ssim']:.4f}")


---
## Results Comparison

In [None]:
# Summary table
print("\n" + "="*70)
print("FINAL RESULTS")
print("="*70)
print(f"{'Method':<20} {'MSE':>10} {'KL':>10} {'SSIM':>10}")
print("-"*70)

for name, data in results.items():
    t = data['test']
    print(f"{name:<20} {t['mse']:>10.4f} {t['kl']:>10.1f} {t['ssim']:>10.4f}")

print("-"*70)


In [None]:
# Training curves comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for name, data in results.items():
    h = data['history']
    label = name.replace('_', '=')
    
    axes[0].plot([x['mse'] for x in h], label=label, alpha=0.8)
    axes[1].plot([x['kl'] for x in h], label=label, alpha=0.8)
    axes[2].plot([x['ssim'] for x in h], label=label, alpha=0.8)

axes[0].set_title('MSE (↓ better)')
axes[0].set_xlabel('Step')
axes[0].legend()
axes[0].set_yscale('log')

axes[1].set_title('KL Divergence')
axes[1].set_xlabel('Step')
axes[1].legend()

axes[2].set_title('SSIM (↑ better)')
axes[2].set_xlabel('Step')
axes[2].legend()

plt.tight_layout()
plt.savefig('training_comparison.png', dpi=150)
plt.show()


In [None]:
# Pareto plot: MSE vs KL
plt.figure(figsize=(10, 6))

colors = plt.cm.viridis(np.linspace(0, 1, len(results)))

for (name, data), color in zip(results.items(), colors):
    t = data['test']
    marker = 's' if 'beta' in name else 'o'
    size = 100 if 'bom' in name else 60
    plt.scatter(t['mse'], t['kl'], s=size, c=[color], marker=marker, label=name.replace('_', '='), edgecolors='black')

plt.xlabel('MSE (↓ better)')
plt.ylabel('KL (moderate is better)')
plt.title('Pareto Front: MSE vs KL')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('pareto_comparison.png', dpi=150)
plt.show()

In [None]:
# Reconstructions comparison
test_batch = next(iter(test_loader))[0][:8].to(device)

n_models = len(results)
fig, axes = plt.subplots(n_models + 1, 8, figsize=(16, 2*(n_models+1)))

# Original
for i in range(8):
    axes[0, i].imshow(test_batch[i].cpu().squeeze(0))
    axes[0, i].axis('off')
axes[0, 0].set_ylabel('Original', fontsize=12)

# Each model's reconstruction
for row, (name, data) in enumerate(results.items(), 1):
    model = data['model']
    model.eval()
    with torch.no_grad():
        recon, _, _ = model(test_batch)
    
    for i in range(8):
        axes[row, i].imshow(recon[i].cpu().squeeze(0))
        axes[row, i].axis('off')
    axes[row, 0].set_ylabel(name.replace('_', '='), fontsize=12)

plt.tight_layout()
plt.savefig('reconstructions_comparison.png', dpi=150)
plt.show()

In [None]:
# Samples from prior comparison
z = torch.randn(8, 128, device=device)

n_models = len(results)
fig, axes = plt.subplots(n_models, 8, figsize=(16, 2*n_models))

for row, (name, data) in enumerate(results.items()):
    model = data['model']
    model.eval()
    with torch.no_grad():
        samples = model.dec(model.fc_dec(z).view(-1, 256, 4, 4))
    
    for i in range(8):
        axes[row, i].imshow(samples[i].cpu().squeeze(0))
        axes[row, i].axis('off')
    axes[row, 0].set_ylabel(name.replace('_', '='), fontsize=12)

plt.suptitle('Samples from Prior (same z for all models)', fontsize=14)
plt.tight_layout()
plt.savefig('samples_comparison.png', dpi=150)
plt.show()

---
## Analysis

In [None]:
# Find best β-VAE
beta_results = {k: v for k, v in results.items() if 'beta' in k}
bom_result = results['bom']

# Best by MSE
best_mse_beta = min(beta_results.items(), key=lambda x: x[1]['test']['mse'])
print(f"Best β-VAE by MSE: {best_mse_beta[0]} (MSE={best_mse_beta[1]['test']['mse']:.4f})")
print(f"BOM-VAE MSE: {bom_result['test']['mse']:.4f}")
print()

# Best by balanced score (low MSE, moderate KL, high SSIM)
def balanced_score(t):
    # Lower MSE is better (invert)
    # KL around 50-150 is good (penalty for too low or too high)
    # Higher SSIM is better
    mse_score = 1.0 / (t['mse'] + 0.001)
    kl_score = 1.0 / (abs(t['kl'] - 100) + 10)  # Peak at KL=100
    ssim_score = t['ssim']
    return mse_score * kl_score * ssim_score

best_balanced_beta = max(beta_results.items(), key=lambda x: balanced_score(x[1]['test']))
print(f"Best β-VAE by balanced score: {best_balanced_beta[0]}")
print(f"  Score: {balanced_score(best_balanced_beta[1]['test']):.4f}")
print(f"BOM-VAE balanced score: {balanced_score(bom_result['test']):.4f}")


In [None]:
# Conclusion
print("\n" + "="*70)
print("CONCLUSION")
print("="*70)
print("""
β-VAE requires tuning β to balance reconstruction vs regularization.
Different β values give different tradeoffs:
- Low β (0.0001): Good MSE, but KL may collapse or explode
- High β (0.1): Controlled KL, but poor reconstruction

BOM-VAE automatically finds a balanced solution:
- No β hyperparameter to tune
- Adaptive squeeze finds the Pareto frontier
- All objectives are explicitly constrained

Using SSIM alongside MSE makes it easier to line up scores with
visual similarity, since MSE can look good numerically while
still missing fine details.

The key insight: BOM optimizes the WORST objective at each step,
preventing any single objective from being sacrificed.
""")


---
## Experimental: BHiLBO-VAE (BHiVAE + LBO)

This section adds a minimal BHiLBO-VAE implementation and a quick smoke-test training loop.
It is intended for rapid sanity checks (not a full benchmark).


In [None]:
import math
from torch.distributions import MultivariateNormal, kl_divergence

class BHiLBO_VAE(nn.Module):
    def __init__(self, input_dim=64 * 64, hidden_dim=400, core_latent_dim=16, mid_latent_dim=32, detail_latent_dim=64, block_size=4):
        super().__init__()
        self.input_dim = input_dim
        self.core_latent_dim = core_latent_dim
        self.mid_latent_dim = mid_latent_dim
        self.detail_latent_dim = detail_latent_dim

        # Hierarchical encoder
        self.fc_shared = nn.Linear(input_dim, hidden_dim)
        self.fc_mid_hidden = nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc_detail_hidden = nn.Linear(hidden_dim // 2, hidden_dim // 4)

        self.fc_mu_core = nn.Linear(hidden_dim, core_latent_dim)
        self.fc_logvar_core = nn.Linear(hidden_dim, core_latent_dim)

        self.fc_mu_mid = nn.Linear(hidden_dim // 2, mid_latent_dim)
        self.fc_logvar_mid = nn.Linear(hidden_dim // 2, mid_latent_dim)

        self.fc_mu_detail = nn.Linear(hidden_dim // 4, detail_latent_dim)
        self.fc_logvar_detail = nn.Linear(hidden_dim // 4, detail_latent_dim)

        total_latent_dim = core_latent_dim + mid_latent_dim + detail_latent_dim
        self.total_latent_dim = total_latent_dim

        # Decoder
        self.fc_decode_hidden = nn.Linear(total_latent_dim, hidden_dim)
        self.fc_decode_out = nn.Linear(hidden_dim, input_dim)

        # Discriminator (for recon realism)
        self.disc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

        # TC separator (density ratio)
        self.d_sep = nn.Sequential(
            nn.Linear(total_latent_dim + hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1), nn.Sigmoid()
        )

        # Block-diagonal prior
        prior_cov = torch.eye(total_latent_dim)
        for start in range(0, total_latent_dim, block_size):
            end = min(start + block_size, total_latent_dim)
            block = torch.full((end - start, end - start), 0.5)
            block.diagonal().fill_(1.0)
            prior_cov[start:end, start:end] = block
        self.register_buffer('prior_cov', prior_cov)
        self.register_buffer('prior_mean', torch.zeros(total_latent_dim))

        # Constraint configs (minimal set for smoke test)
        self.constraints = {
            'recon': {'type': 'BOX', 'target': 0.0, 'failure': 0.12},
            'kl_core': {'type': 'BOX', 'target': 100.0, 'failure': 4000.0},
            'kl_mid': {'type': 'BOX', 'target': 100.0, 'failure': 4000.0},
            'kl_detail': {'type': 'BOX', 'target': 100.0, 'failure': 4000.0},
            'disc': {'type': 'BOX', 'target': 0.5, 'failure': 0.0},
            'sep': {'type': 'BOX', 'target': 0.0, 'failure': 50.0},
            'prior': {'type': 'BOX', 'target': 1.0, 'failure': 100.0},
        }

    def encode(self, x):
        h1 = F.relu(self.fc_shared(x))
        mu_core = self.fc_mu_core(h1)
        logvar_core = self.fc_logvar_core(h1)

        h2 = F.relu(self.fc_mid_hidden(h1))
        mu_mid = self.fc_mu_mid(h2)
        logvar_mid = self.fc_logvar_mid(h2)

        h3 = F.relu(self.fc_detail_hidden(h2))
        mu_detail = self.fc_mu_detail(h3)
        logvar_detail = self.fc_logvar_detail(h3)

        return (mu_core, mu_mid, mu_detail), (logvar_core, logvar_mid, logvar_detail), (h1, h2, h3)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc_decode_hidden(z))
        return torch.sigmoid(self.fc_decode_out(h))

    def estimate_tc(self, z, h):
        z_i = torch.cat([z, h], dim=1)
        perm = torch.randperm(z_i.size(0))
        z_perm = torch.cat([z[perm], h], dim=1)
        d_true = self.d_sep(z_i)
        d_perm = self.d_sep(z_perm)
        return torch.log(d_true / (1 - d_true + 1e-6)).mean()

    def forward(self, x):
        x_flat = x.view(x.size(0), -1)
        mus, logvars, hs = self.encode(x_flat)
        mu_core, mu_mid, mu_detail = mus
        logvar_core, logvar_mid, logvar_detail = logvars
        h1, h2, h3 = hs

        z_core = self.reparameterize(mu_core, logvar_core)
        z_mid = self.reparameterize(mu_mid, logvar_mid)
        z_detail = self.reparameterize(mu_detail, logvar_detail)

        z = torch.cat([z_core, z_mid, z_detail], dim=1)
        recon_x = self.decode(z)

        raw_recon = F.binary_cross_entropy(recon_x, x_flat, reduction='none').mean(dim=1)
        raw_kl_core = -0.5 * torch.sum(1 + logvar_core - mu_core.pow(2) - logvar_core.exp(), dim=1)
        raw_kl_mid = -0.5 * torch.sum(1 + logvar_mid - mu_mid.pow(2) - logvar_mid.exp(), dim=1)
        raw_kl_detail = -0.5 * torch.sum(1 + logvar_detail - mu_detail.pow(2) - logvar_detail.exp(), dim=1)
        raw_disc = torch.sigmoid(self.disc(recon_x)).squeeze(-1)

        raw_sep = self.estimate_tc(z, h1)

        q_dist = MultivariateNormal(torch.cat(mus, dim=1), torch.diag_embed(torch.cat(logvars, dim=1).exp()))
        p_dist = MultivariateNormal(self.prior_mean, self.prior_cov)
        raw_prior = kl_divergence(q_dist, p_dist)

        metrics = {
            'recon': raw_recon,
            'kl_core': raw_kl_core,
            'kl_mid': raw_kl_mid,
            'kl_detail': raw_kl_detail,
            'disc': raw_disc,
            'sep': raw_sep,
            'prior': raw_prior
        }
        return recon_x, metrics

def normalize_scores(metrics, cfg):
    S = {}
    S['recon'] = (cfg['recon']['failure'] - metrics['recon']) / (cfg['recon']['failure'] - cfg['recon']['target'])
    for k in ['kl_core', 'kl_mid', 'kl_detail']:
        S[k] = (cfg[k]['failure'] - metrics[k]) / (cfg[k]['failure'] - cfg[k]['target'])
    S['disc'] = (metrics['disc'] - cfg['disc']['failure']) / (cfg['disc']['target'] - cfg['disc']['failure'])
    S['sep'] = (cfg['sep']['failure'] - metrics['sep']) / (cfg['sep']['failure'] - cfg['sep']['target'])
    S['prior'] = (cfg['prior']['failure'] - metrics['prior']) / (cfg['prior']['failure'] - cfg['prior']['target'])
    return S

def lbo_loss_from_scores(S):
    stacked = torch.stack(list(S.values()), dim=1)
    return stacked.min(dim=1).values


In [None]:
# Smoke-test: small subset + few batches to confirm plumbing
small_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=2)
model_bhilbo = BHiLBO_VAE().to(device)
optim_main = optim.Adam(model_bhilbo.parameters(), lr=2e-3)
optim_disc = optim.Adam(list(model_bhilbo.disc.parameters()) + list(model_bhilbo.d_sep.parameters()), lr=2e-4)

model_bhilbo.train()
for i, (x, _) in enumerate(small_loader):
    if i >= 5:
        break
    x = x.to(device)

    # Train discriminators
    recon_x, _ = model_bhilbo(x)
    optim_disc.zero_grad()
    d_real = model_bhilbo.disc(x.view(x.size(0), -1))
    d_fake = model_bhilbo.disc(recon_x.detach())
    loss_disc = (
        F.binary_cross_entropy_with_logits(d_real, torch.ones_like(d_real))
        + F.binary_cross_entropy_with_logits(d_fake, torch.zeros_like(d_fake))
    )
    loss_disc.backward()
    optim_disc.step()

    # Main LBO step
    optim_main.zero_grad()
    recon_x, metrics = model_bhilbo(x)
    S = normalize_scores(metrics, model_bhilbo.constraints)
    min_s = lbo_loss_from_scores(S)
    if (min_s <= 0).any() or torch.isnan(min_s).any():
        continue
    loss = -torch.log(min_s).mean()
    loss.backward()
    optim_main.step()

print('BHiLBO-VAE smoke test complete.')


---
## Experimental: Log-Sum Goal VAE (log goals gradient scaling)

This variant uses the same metrics but optimizes the sum of log(goal) values,
so each goal receives a gradient scaled by 1/goal.


In [None]:
def logsum_goal_loss(S):
    """Return per-sample loss = -sum(log(goal)).
    Each goal's gradient is scaled by 1/goal."""
    stacked = torch.stack(list(S.values()), dim=1)
    return -(torch.log(stacked).sum(dim=1))

# Smoke-test: log-sum goal loss variant
model_logsum = BHiLBO_VAE().to(device)
optim_logsum = optim.Adam(model_logsum.parameters(), lr=2e-3)

model_logsum.train()
for i, (x, _) in enumerate(small_loader):
    if i >= 5:
        break
    x = x.to(device)
    optim_logsum.zero_grad()
    recon_x, metrics = model_logsum(x)
    S = normalize_scores(metrics, model_logsum.constraints)
    min_s = lbo_loss_from_scores(S)
    if (min_s <= 0).any() or torch.isnan(min_s).any():
        continue
    loss = logsum_goal_loss(S).mean()
    loss.backward()
    optim_logsum.step()

print('Log-sum goal VAE smoke test complete.')
