In [1]:

import math, urllib.request, tarfile
from pathlib import Path
import torch, numpy as np, matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from tqdm import trange

# ---------- settings ----------
torch.set_default_dtype(torch.float64)
DEV = "cuda" if torch.cuda.is_available() else "cpu"

LR_X, LR_HYP = 5e-3, 1e-4
BATCH, T_TOTAL, INNER0, INNER = 128, 100, 20, 5
JITTER, MAX_EXP, CLIP_QUAD = 1e-6, 60.0, 1e6
GRAD_CLIP_X = 10.0
BOUNDS = {"log_sf2": (-5., 6.), "log_alpha": (-5., 5.),
          "log_beta": (-5., 2.), "log_s2x": (-5., 5.)}

# ---------- helpers ----------
def rho(t, t0=200., k=0.7): return (t0 + t) ** (-k)
def safe_exp(x): return torch.exp(torch.clamp(x, max=MAX_EXP))
def chol_or_eig(M, eps=1e-6):
    I = torch.eye(M.size(-1), device=M.device, dtype=M.dtype)
    try: return torch.linalg.cholesky(M + eps * I)
    except RuntimeError:
        e, v = torch.linalg.eigh(M)
        e = torch.clamp(e, min=eps)
        return v @ torch.diag(torch.sqrt(e))

def check_nan(name, T, outer, inner):
    if not torch.isfinite(T).all():
        print(f"[NaN] {name} step={outer:03d}/{inner:02d} "
              f"min={T.min().item():+.3e} "
              f"max={T.max().item():+.3e} "
              f"mean={T.mean().item():+.3e}", flush=True)

# ---------- data ----------
root = Path("oil_data")
root.mkdir(exist_ok=True)
URL = "http://staffwww.dcs.shef.ac.uk/people/N.Lawrence/resources/3PhData.tar.gz"
arc = root / "3PhData.tar.gz"
if not arc.exists():
    urllib.request.urlretrieve(URL, arc)
with tarfile.open(arc) as tar:
    tar.extract("DataTrn.txt", path=root)
    tar.extract("DataTrnLbls.txt", path=root)
Y_np = np.loadtxt(root / "DataTrn.txt")
lbl_np = np.loadtxt(root / "DataTrnLbls.txt").astype(int)
Y = torch.tensor(Y_np, device=DEV)
lbl = torch.tensor(lbl_np, device=DEV)
N, D = Y.shape
Q = 2

# ---------- latent q(x_n) ----------
mu_x = torch.tensor(PCA(Q).fit_transform(Y_np), device=DEV, requires_grad=True)
log_s2x = torch.full_like(mu_x, -2.0, requires_grad=True)

# ---------- kernel / inducing ----------
sqrtM = 8
grid = torch.linspace(-1.5, 1.5, sqrtM, device=DEV)
Z = torch.stack(torch.meshgrid(grid, grid, indexing="ij"), -1).reshape(-1, Q)
M = Z.size(0)
log_sf2 = torch.tensor(0.0, device=DEV, requires_grad=True)
log_alpha = torch.zeros(Q, device=DEV, requires_grad=True)
log_beta_inv = torch.tensor(-3.2, device=DEV, requires_grad=True)

def k_se(x, z, lsf, la):
    sf2, a = safe_exp(lsf), safe_exp(la)
    return sf2 * safe_exp(-0.5 * ((x.unsqueeze(-2) - z)**2 * a).sum(-1))

def noise_var(): return safe_exp(log_beta_inv)

def whitening():
    K = k_se(Z, Z,
             log_sf2.clamp(*BOUNDS["log_sf2"]),
             log_alpha.clamp(*BOUNDS["log_alpha"])) \
        + JITTER * torch.eye(M, device=DEV)
    L = chol_or_eig(K)
    return L, torch.linalg.inv(L)

L, Linv = whitening()

# ---------- global q(U) ----------
m_u = torch.zeros(D, M, device=DEV)
C_u = torch.eye(M, device=DEV).expand(D, M, M).clone()

def Sigma(): return C_u @ C_u.transpose(-1, -2)
def sample_U():
    eps = torch.randn(D, M, device=DEV)
    return m_u + (C_u @ eps.unsqueeze(-1)).squeeze(-1)

def nat_from_moments():
    S = Sigma()
    Lmb = -0.5 * torch.linalg.inv(S)
    h = (-2 * Lmb @ m_u.unsqueeze(-1)).squeeze(-1)
    return h, Lmb

def set_from_natural(h_new, Lmb_new, eps=1e-8):
    global m_u, C_u
    for d in range(D):
        Ld = 0.5 * (Lmb_new[d] + Lmb_new[d].T)
        eigv, eigvec = torch.linalg.eigh(Ld)
        eigv = torch.minimum(eigv, torch.tensor(-eps, device=DEV))
        Sd = torch.linalg.inv((eigvec * (-2 * eigv)) @ eigvec.T)
        C_u[d] = chol_or_eig(Sd, eps)
        m_u[d] = Sd @ h_new[d]

Lmb_prior = -0.5 * torch.eye(M, device=DEV).expand(D, M, M)

# ---------- psi statistics ----------
def psi_whiten(mu, s2, outer, inner):
    sf2 = safe_exp(log_sf2.clamp(*BOUNDS["log_sf2"]))
    a = safe_exp(log_alpha.clamp(*BOUNDS["log_alpha"]))
    d1 = a * s2 + 1.0
    c1 = d1.rsqrt().prod(-1, keepdim=True)
    diff = mu.unsqueeze(1) - Z
    psi1 = sf2 * c1 * safe_exp(-0.5 * ((a * diff**2) / d1.unsqueeze(1)).sum(-1))
    d2 = a * s2 + 2.0
    c2 = d2.rsqrt().prod(-1, keepdim=True)
    ZZ = Z.unsqueeze(1) - Z.unsqueeze(0)
    dist = (a * ZZ**2).sum(-1)
    mid = (Z.unsqueeze(1) + Z.unsqueeze(0)) / 2.0
    mu_c = (mu[:, None, None, :] - mid)**2
    expo = -0.25 * dist - 0.5 * ((a * mu_c) / d2[:, None, None, :]).sum(-1)
    psi2 = sf2**2 * c2.unsqueeze(-1) * safe_exp(expo)
    psi1_w = (Linv @ psi1.T).T
    tmp = psi2 @ Linv.T
    psi2_w = (Linv @ tmp.transpose(1, 2)).transpose(1, 2)
    check_nan("psi1_w", psi1_w, outer, inner)
    check_nan("psi2_w", psi2_w, outer, inner)
    return sf2.expand(mu.size(0)), psi1_w, psi2_w

# ---------- local ELBO ----------
def local(idx, U_s, S_det, train_beta, outer, inner):
    mu, s2 = mu_x[idx], log_s2x[idx].exp()
    psi0, psi1_w, psi2_w = psi_whiten(mu, s2, outer, inner)
    fmu = psi1_w @ U_s.T
    var_f = torch.stack([((psi1_w @ S_det[d]) * psi1_w).sum(-1) for d in range(D)], 1)
    base = noise_var() if train_beta else noise_var().detach()
    trace_psi2 = psi2_w.diagonal(dim1=-2, dim2=-1).sum(-1)
    sigma2 = torch.clamp(base + psi0 - trace_psi2, 1e-6, 1e3)
    check_nan("sigma2", sigma2, outer, inner)
    quad = ((Y[idx] - fmu)**2 + var_f) / sigma2[:, None]
    quad = torch.clamp(quad, max=CLIP_QUAD)
    ll = (-0.5 * math.log(2 * math.pi) - 0.5 * sigma2.log()[:, None] - 0.5 * quad).sum(-1)
    klx = 0.5 * ((s2 + mu**2) - s2.log() - 1).sum(-1)
    elbo = (ll - klx).mean()
    check_nan("local_elbo", elbo, outer, inner)
    r = (Y[idx] / sigma2[:, None])[:, :, None] * psi1_w[:, None, :]
    Q = -0.5 / sigma2[:, None, None, None] * psi2_w[:, None, :, :]
    return elbo, r.detach(), Q.detach()

# ---------- optimizers ----------
opt_x = torch.optim.Adam([mu_x, log_s2x], lr=LR_X)
opt_hyp = torch.optim.Adam([log_sf2, log_alpha, log_beta_inv], lr=LR_HYP)

# ---------- training ----------
its, elbos = [], []
bar = trange(1, T_TOTAL + 1, ncols=110)
for t in bar:
    S_det = Sigma().detach()
    idx = torch.randint(0, N, (BATCH,), device=DEV)

    inner_iters = INNER0 if t <= 50 else INNER
    for k in range(inner_iters):
        opt_x.zero_grad(set_to_none=True)
        elbx, _, _ = local(idx, sample_U().detach(), S_det, train_beta=False, outer=t, inner=k)
        (-elbx).backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_([mu_x, log_s2x], GRAD_CLIP_X)
        opt_x.step()
        with torch.no_grad():
            log_s2x.clamp_(*BOUNDS["log_s2x"])

    U_s = sample_U()
    elbo, r_b, Q_b = local(idx, U_s, Sigma(), train_beta=True, outer=t, inner=-1)
    opt_hyp.zero_grad(set_to_none=True)
    (-elbo).backward()
    opt_hyp.step()
    with torch.no_grad():
        log_sf2.clamp_(*BOUNDS["log_sf2"])
        log_alpha.clamp_(*BOUNDS["log_alpha"])
        log_beta_inv.clamp_(*BOUNDS["log_beta"])
        L, Linv = whitening()

    with torch.no_grad():
        h_nat, Lmb_nat = nat_from_moments()
        r_sum, Q_sum = r_b.sum(0), Q_b.sum(0)
        r_tilde = r_sum + 2 * (Q_sum @ (U_s - m_u).unsqueeze(-1)).squeeze(-1)
        scale, lr = N / len(idx), rho(t)
        h_new = (1 - lr) * h_nat + lr * scale * r_tilde
        Lmb_new = (1 - lr) * Lmb_nat + lr * (Lmb_prior + scale * Q_sum)
        set_from_natural(h_new, Lmb_new)

    if t % 25 == 0 or t == 1:
        full_elbo, _, _ = local(torch.arange(N, device=DEV), sample_U(), Sigma(),
                                train_beta=False, outer=t, inner=99)
        its.append(t)
        elbos.append(full_elbo.item())
        bar.set_postfix(ELBO=f"{full_elbo.item():.2e}")

# ---------- plots ----------
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(its, elbos, "-o")
plt.grid(ls=":")
plt.xlabel("iteration")
plt.ylabel("ELBO")
plt.title("ELBO trajectory")

plt.subplot(1, 2, 2)
plt.scatter(mu_x.detach().cpu()[:, 0], mu_x.detach().cpu()[:, 1], c=lbl.cpu(), cmap="brg", s=12)
plt.gca().set_aspect("equal")
plt.title("latent space")
plt.xlabel("mu_1")
plt.ylabel("mu_2")
plt.tight_layout()
plt.show()


  2%|█▏                                                       | 2/100 [00:06<05:13,  3.20s/it, ELBO=-6.00e+06]


KeyboardInterrupt: 