# BHiVAE with LBO Bottleneck (MNIST)

This notebook remakes the training flow from scratch using a BHiVAE architecture with an LBO-style bottleneck and the **Muon** optimizer instead of Adam.

## Colab quick start
If you're in Colab, run the next cell to clone the repo and install dependencies.

In [None]:
#@title (Colab) Clone repo and install dependencies
import os
import sys
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    repo_url = 'https://github.com/<your-org>/<your-repo>.git'  # TODO: replace
    repo_dir = Path('/content/exparamental_vae')
    if not repo_dir.exists():
        !git clone {repo_url} {repo_dir}
    %cd {repo_dir}
    !pip -q install torch torchvision tqdm matplotlib
else:
    print('Not running in Colab; skipping clone/install.')


In [None]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


## Configuration

In [None]:
config = {
    'image_size': 28,
    'latent_dim': 32,
    'batch_size': 128,
    'epochs': 30,
    'lr': 2e-4,
    'weight_decay': 1e-4,
    'decoder_likelihood': 'bernoulli',  # 'bernoulli' -> BCE, 'gaussian' -> MSE
    # LBO bottleneck targets
    'D_ok': 0.12,
    'K_min': 6.0,
    'K_max': 18.0,
    'tau_D': 0.03,
    'tau_K_low': 2.0,
    'tau_K_high': 2.0,
    'tau_M': 40.0,
    'B_max': 0.08,
    'tau_B': 0.02,
    'log_every': 100,
    'sample_every': 1,
    'sample_dir': 'samples',
    'num_samples': 8,
}
config


## Model (BHiVAE)

In [None]:
class BHiVAE(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1, 32, 3, 2, 1), nn.GroupNorm(8, 32), nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 3, 2, 1), nn.GroupNorm(8, 64), nn.LeakyReLU(0.2),
        )
        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, 64 * 7 * 7)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.GroupNorm(8, 32), nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 1, 4, 2, 1), nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.enc(x).view(x.size(0), -1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_dec(z).view(-1, 64, 7, 7)
        return self.dec(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


## Muon Optimizer

In [None]:
class Muon(torch.optim.Optimizer):
    """Minimal Muon optimizer (momentum + normalized update direction).

    This is a light-weight placeholder implementation. Adjust per your Muon spec.
    """
    def __init__(self, params, lr=2e-4, momentum=0.95, weight_decay=1e-4):
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            momentum = group['momentum']
            weight_decay = group['weight_decay']

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                if weight_decay != 0:
                    grad = grad.add(p, alpha=weight_decay)

                state = self.state[p]
                if 'velocity' not in state:
                    state['velocity'] = torch.zeros_like(p)

                v = state['velocity']
                v.mul_(momentum).add_(grad)

                denom = v.norm().clamp_min(1e-8)
                step = v / denom
                p.add_(step, alpha=-lr)

        return loss


## Data

In [None]:
transform = transforms.Compose([
    transforms.Resize(config['image_size']),
    transforms.ToTensor(),
])

train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=config['batch_size'], shuffle=False, num_workers=2, pin_memory=True)
print(f'Train batches: {len(train_loader)}, Test batches: {len(test_loader)}')


## LBO Bottleneck

In [None]:
def kl_divergence(mu, logvar):
    return -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1)

def recon_loss(x, x_hat, likelihood='bernoulli'):
    if likelihood == 'bernoulli':
        return F.binary_cross_entropy(x_hat, x, reduction='none').view(x.size(0), -1).mean(1)
    return F.mse_loss(x_hat, x, reduction='none').view(x.size(0), -1).mean(1)

def bottleneck_gates(x, x_hat, mu, logvar, cfg):
    D = recon_loss(x, x_hat, cfg['decoder_likelihood'])
    K = kl_divergence(mu, logvar)

    g_recon = torch.exp(-(torch.relu(D - cfg['D_ok']) / cfg['tau_D']))
    g_kl_low = torch.exp(-(torch.relu(cfg['K_min'] - K) / cfg['tau_K_low']))
    g_kl_high = torch.exp(-(torch.relu(K - cfg['K_max']) / cfg['tau_K_high']))

    ink_x = x.view(x.size(0), -1).sum(1)
    ink_hat = x_hat.view(x_hat.size(0), -1).sum(1)
    g_ink = torch.exp(-(torch.abs(ink_hat - ink_x) / cfg['tau_M']))

    bg = x_hat.mean(dim=[1, 2, 3])
    g_bg = torch.exp(-(torch.relu(bg - cfg['B_max']) / cfg['tau_B']))

    gates = {
        'g_recon': g_recon,
        'g_kl_low': g_kl_low,
        'g_kl_high': g_kl_high,
        'g_ink': g_ink,
        'g_bg': g_bg,
        'D': D,
        'K': K,
    }
    return gates

def lbo_loss(gates):
    gate_stack = torch.stack([
        gates['g_recon'],
        gates['g_kl_low'],
        gates['g_kl_high'],
        gates['g_ink'],
        gates['g_bg'],
    ], dim=0)
    m = gate_stack.min(dim=0).values
    return -torch.log(m).mean()


## Training

In [None]:
def train_epoch(model, loader, optimizer, cfg):
    model.train()
    totals = []
    for step, (x, _) in enumerate(tqdm(loader, desc='train'), start=1):
        x = x.to(device)
        optimizer.zero_grad()
        x_hat, mu, logvar = model(x)
        gates = bottleneck_gates(x, x_hat, mu, logvar, cfg)
        loss = lbo_loss(gates)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        totals.append(loss.item())
        if step % cfg['log_every'] == 0:
            print(
                f"step={step} loss={loss.item():.4f} "
                f"D={gates['D'].mean().item():.4f} K={gates['K'].mean().item():.4f}"
            )
    return float(np.mean(totals))

@torch.no_grad()
def evaluate(model, loader, cfg):
    model.eval()
    totals = []
    for x, _ in loader:
        x = x.to(device)
        x_hat, mu, logvar = model(x)
        gates = bottleneck_gates(x, x_hat, mu, logvar, cfg)
        totals.append(lbo_loss(gates).item())
    return float(np.mean(totals))

@torch.no_grad()
def save_samples(model, loader, cfg, epoch):
    model.eval()
    os.makedirs(cfg['sample_dir'], exist_ok=True)
    x, _ = next(iter(loader))
    x = x.to(device)
    x_hat, _, _ = model(x)
    num = cfg['num_samples']
    x = x[:num]
    x_hat = x_hat[:num]
    grid = make_grid(torch.cat([x, x_hat], dim=0), nrow=num, pad_value=1.0)
    save_image(grid, os.path.join(cfg['sample_dir'], f"epoch_{epoch:03d}.png"))


## Run

In [None]:
model = BHiVAE(latent_dim=config['latent_dim']).to(device)
optimizer = Muon(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])

for epoch in range(1, config['epochs'] + 1):
    train_loss = train_epoch(model, train_loader, optimizer, config)
    test_loss = evaluate(model, test_loader, config)
    if epoch % config['sample_every'] == 0:
        save_samples(model, test_loader, config, epoch)
    print(f'Epoch {epoch}: train LBO {train_loss:.4f} | test LBO {test_loss:.4f}')
