In [63]:
import math, tarfile, urllib.request, numpy as np, matplotlib.pyplot as plt, torch
from pathlib import Path; from tqdm import trange; from sklearn.decomposition import PCA
torch.set_default_dtype(torch.float64)

DEV, DEBUG = ("cuda" if torch.cuda.is_available() else "cpu"), True
LR_X, LR_HYP  = 5e-3, 1e-4
BATCH, T_TOTAL, INNER0, INNER = 128, 100, 20, 5
JITTER, MAX_EXP, CLIPQ, GR_CLIP = 1e-6, 60., 1e6, 10.
BOUNDS = {"log_sf2": (-5., 6.), "log_alpha": (-5., 5.),
          "log_beta": (-5., 2.), "log_s2x": (-5., 5.)}
rho = lambda t, t0=200., k=.7: (t0+t)**(-k)
safe_exp = lambda x: torch.exp(torch.clamp(x,max=MAX_EXP))
def chol_or_eig(M,eps=1e-6):
    I=torch.eye(M.size(-1),device=M.device,dtype=M.dtype)
    try: return torch.linalg.cholesky(M+eps*I)
    except RuntimeError:
        e,v=torch.linalg.eigh(M); e=torch.clamp(e,min=eps)
        return v@torch.diag(torch.sqrt(e))

# ------------------------------ data --------------------------------
root=Path("oil_data"); root.mkdir(exist_ok=True)
url ="http://staffwww.dcs.shef.ac.uk/people/N.Lawrence/resources/3PhData.tar.gz"
arc =root/"3PhData.tar.gz"
if not arc.exists(): urllib.request.urlretrieve(url, arc)
with tarfile.open(arc) as tar:
    tar.extract("DataTrn.txt",    path=root)
    tar.extract("DataTrnLbls.txt",path=root)
Y_np = np.loadtxt(root/"DataTrn.txt")
lbl_np=np.loadtxt(root/"DataTrnLbls.txt").astype(int)
Y   = torch.tensor(Y_np,   device=DEV)
lbl = torch.tensor(lbl_np, device=DEV)
N,D = Y.shape;  Q=2

# --------------------------- latent X -------------------------------
mu_x    = torch.tensor(PCA(Q).fit_transform(Y_np), device=DEV, requires_grad=True)
log_s2x = torch.full_like(mu_x,-2.0,requires_grad=True)

# ---------------------- kernel & inducing ---------------------------
sqrtM=8; grid=torch.linspace(-1.5,1.5,sqrtM,device=DEV)
Z=torch.stack(torch.meshgrid(grid,grid,indexing="ij"),-1).reshape(-1,Q); M=Z.size(0)
log_sf2      = torch.tensor(0.,device=DEV,requires_grad=True)
log_alpha    = torch.zeros(Q,device=DEV,requires_grad=True)
log_beta_inv = torch.tensor(-3.2,device=DEV,requires_grad=True)

def k_se(x,z,lsf,la):
    sf2,a=safe_exp(lsf),safe_exp(la)
    return sf2*safe_exp(-.5*((x[...,None,:]-z)**2*a).sum(-1))
noise_var=lambda : safe_exp(log_beta_inv)

def upd_Kinv():
    K=k_se(Z,Z,
           log_sf2.clamp(*BOUNDS["log_sf2"]),
           log_alpha.clamp(*BOUNDS["log_alpha"])) \
      +JITTER*torch.eye(M,device=DEV)
    L=chol_or_eig(K); return K,torch.cholesky_inverse(L)
K_MM,Kinv=upd_Kinv()

# ------------------------- q(U)  (block) ----------------------------
m_u = torch.zeros(D,M,device=DEV)
C_u = torch.eye(M,device=DEV).expand(D,M,M).clone()
Sigma = lambda : C_u@C_u.transpose(-1,-2)
sample_U = lambda : m_u + (C_u@torch.randn(D,M,device=DEV)[...,None]).squeeze(-1)

def nat_from_mom():
    Lam=-.5*torch.linalg.inv(Sigma()); h=(-2*Lam@m_u[...,None]).squeeze(-1)
    return h,Lam
def set_from_nat(h_new,Lam_new,eps=1e-8):
    global m_u,C_u
    for d in range(D):
        Lam=.5*(Lam_new[d]+Lam_new[d].T)
        eigv,eigvec=torch.linalg.eigh(Lam); eigv=torch.minimum(eigv,-eps)
        Sd=torch.linalg.inv((eigvec*(-2*eigv))@eigvec.T)
        C_u[d]=chol_or_eig(Sd,eps); m_u[d]=Sd@h_new[d]
Lmb_prior=(-.5*Kinv).expand(D,M,M).clone()

# ------------------------- psi-statistics ---------------------------
def psi_stats(mu,s2):
    sf2=safe_exp(log_sf2.clamp(*BOUNDS["log_sf2"]))
    a  =safe_exp(log_alpha.clamp(*BOUNDS["log_alpha"]))
    psi0=torch.full((mu.size(0),),sf2.item(),device=DEV)
    d1=a*s2+1.; c1=d1.rsqrt().prod(-1,keepdim=True)
    diff=mu[:,None,:]-Z; psi1=sf2*c1*safe_exp(-.5*((a*diff**2)/d1[:,None,:]).sum(-1))
    d2=a*s2+2.; c2=d2.rsqrt().prod(-1,keepdim=True)
    ZZ=Z[:,None,:]-Z[None,:,:]; dist=(a*ZZ**2).sum(-1)
    mid=.5*(Z[:,None,:]+Z[None,:,:])
    mc=(mu[:,None,None,:]-mid)**2
    expo=-.25*dist -.5*((a*mc)/d2[:,None,None,:]).sum(-1)
    psi2=sf2**2*c2[:,None,None]*safe_exp(expo)
    return psi0,psi1,psi2

# --------------------------- local ----------------------------------
def local(idx,U_s,Sigma_det,train_beta,dbg=False):
    mu,s2=mu_x[idx],log_s2x[idx].exp(); B=mu.size(0)
    psi0,psi1,psi2=psi_stats(mu,s2)
    A = psi1 @ Kinv                               # (B,M)

    if dbg and DEBUG:
        print("A",A.shape,"psi1",psi1.shape)

    fmu   = A @ U_s.T                             # (B,D)
    var_f = torch.stack([(A@Sigma_det[d]*A).sum(-1) for d in range(D)],1)

    base   = noise_var() if train_beta else noise_var().detach()
    trace  = (psi2*Kinv).sum((-2,-1))             # (B,)
    sigma2 = torch.clamp(base+psi0-trace,1e-6,1e3)# (B,)
    sigma2_unsq = sigma2[:,None]                  # (B,1)

    # ---------- r ----------------------------------------------------
    Yscaled = Y[idx]/sigma2_unsq                  # (B,D)
    r = (Yscaled[:,:,None]) * (A[:,None,:])       # (B,D,M)

    # ---------- Q ----------------------------------------------------
    Aexp   = A[:,None,:]                          # (B,1,M)
    outer  = Aexp.transpose(-1,-2)*Aexp           # (B,1,M,M)
    outer  = outer.expand(B,D,-1,-1)              # (B,D,M,M)
    Q = (-.5/sigma2_unsq)[:,:,None,None]*outer    # (B,D,M,M)

    quad=((Y[idx]-fmu)**2+var_f)/sigma2_unsq; quad=torch.clamp(quad,max=CLIPQ)
    ll  = (-.5*math.log(2*math.pi)-.5*sigma2.log()[:,None] -.5*quad).sum(-1)
    klx = .5*((s2+mu**2)-s2.log()-1.).sum(-1)
    return (ll-klx).mean(), r.detach(), Q.detach()

# --------------------------- optim ----------------------------------
opt_x   = torch.optim.Adam([mu_x,log_s2x],lr=LR_X)
opt_hyp = torch.optim.Adam([log_sf2,log_alpha,log_beta_inv],lr=LR_HYP)

its,elbos=[],[]
for t in trange(1,T_TOTAL+1,ncols=100):
    Sigma_det=Sigma().detach(); idx=torch.randint(0,N,(BATCH,),device=DEV)

    for _ in range(INNER0 if t<=50 else INNER):
        opt_x.zero_grad(set_to_none=True)
        elbx,_,_=local(idx,sample_U().detach(),Sigma_det,False,dbg=(t==1))
        (-elbx).backward(retain_graph=True)
        torch.nn.utils.clip_grad_norm_([mu_x,log_s2x],GR_CLIP); opt_x.step()
        with torch.no_grad(): log_s2x.clamp_(*BOUNDS["log_s2x"])

    U_s=sample_U()
    elbo,r_b,Q_b=local(idx,U_s,Sigma(),True)
    opt_hyp.zero_grad(set_to_none=True); (-elbo).backward(); opt_hyp.step()

    with torch.no_grad():
        for p,b in ((log_sf2,"log_sf2"),(log_alpha,"log_alpha"),(log_beta_inv,"log_beta")):
            p.clamp_(*BOUNDS[b])
        K_MM,Kinv=upd_Kinv(); Lmb_prior.copy_((-0.5*Kinv).expand_as(Lmb_prior))
        h_nat,Lam_nat=nat_from_mom()
        r_sum,Q_sum = r_b.sum(0),Q_b.sum(0)
        r_tilde = r_sum + 2*(Q_sum @ (U_s-m_u)[...,None]).squeeze(-1)
        lr,scale=rho(t),N/idx.size(0)
        h_new =(1-lr)*h_nat  + lr*scale*r_tilde
        Lam_new=(1-lr)*Lam_nat+ lr*(Lmb_prior+scale*Q_sum)
        set_from_nat(h_new,Lam_new)

    if t%25==0 or t==1:
        full_elbo,_,_=local(torch.arange(N,device=DEV),sample_U(),Sigma(),False)
        its.append(t); elbos.append(full_elbo.item())
        print(f"\nELBO @ {t:3d}: {full_elbo.item():.4e}")

# --------------------------- plots ----------------------------------
plt.figure(figsize=(12,5))
plt.subplot(1,2,1); plt.plot(its,elbos,'-o'); plt.grid(ls=':')
plt.xlabel('iteration'); plt.ylabel('ELBO'); plt.title('ELBO trajectory')
plt.subplot(1,2,2)
plt.scatter(mu_x.detach().cpu()[:,0],mu_x.detach().cpu()[:,1],
            c=lbl.cpu(),cmap='brg',s=14)
plt.gca().set_aspect('equal'); plt.title('latent space')
plt.xlabel('mu_1'); plt.ylabel('mu_2'); plt.tight_layout(); plt.show()


  0%|                                                                       | 0/100 [00:00<?, ?it/s]

A torch.Size([128, 64]) psi1 torch.Size([128, 64])





RuntimeError: The size of tensor a (12) must match the size of tensor b (128) at non-singleton dimension 2