In [15]:
import numpy as np, warnings, pandas as pd
from numpy.linalg import solve, svd, LinAlgError
from scipy.stats  import ttest_rel, wilcoxon
warnings.filterwarnings("ignore", category=RuntimeWarning)

# ---------- helpers ----------
def safe_inv(A, lam=1e-6):
    ridge = max(lam*np.linalg.norm(A,'fro')/A.shape[0], 1e-8)
    try:    return solve(A+ridge*np.eye(A.shape[0]), np.eye(A.shape[0]))
    except LinAlgError: return np.linalg.pinv(A+ridge*np.eye(A.shape[0]))

def safe_solve(A,b,lam=1e-6):
    ridge = max(lam*np.linalg.norm(A,'fro')/A.shape[0], 1e-8)
    try:    return solve(A+ridge*np.eye(A.shape[0]), b)
    except LinAlgError: return np.linalg.lstsq(A+ridge*np.eye(A.shape[0]), b, rcond=None)[0]

def pca_init(R,k):
    _,s,Vt = svd(R, full_matrices=False)
    r = max(1, min(k, (s>1e-8).sum()))
    F = Vt.T[:,:r]*np.sqrt(s[:r])
    return np.pad(F, ((0,0),(0,k-r)))        # (K,k)

# ---------- synthetic data (harder) ----------
def simulate(N,K,p,k,seed):
    rng  = np.random.default_rng(seed)
    base = rng.standard_normal((N,p))
    Xs   = [base + 0.50*rng.standard_normal((N,p)) for _ in range(K)]
    B    = rng.standard_normal((p,K))
    F0   = 1.2*rng.standard_normal((K,k))
    D0   = 0.02 + 0.15*rng.random(K)          # smaller idio noise
    U    = rng.standard_normal((N,k))
    Y    = np.column_stack([Xs[i]@B[:,[i]] for i in range(K)]) \
           + U@F0.T + rng.standard_normal((N,K))*np.sqrt(D0)
    return Xs,Y,B

def gram(Xs):
    p,K = Xs[0].shape[1], len(Xs)
    out = np.empty((K,K,p,p))
    for j in range(K):
        Xt = Xs[j].T
        for l in range(K):
            out[j,l] = Xt @ Xs[l]
    return out

# ---------- EM (dense) ----------
def em_sur(Xs,Y,k,G,lam=1e-4,eps=1e-6,iters=45):
    K,p,N = len(Xs), Xs[0].shape[1], Y.shape[0]
    β = np.column_stack([safe_solve(G[j,j], Xs[j].T@Y[:,j]) for j in range(K)])
    XB=lambda B: np.column_stack([Xs[j]@B[:,[j]] for j in range(K)])
    R=Y-XB(β); F=pca_init(R,k); D=np.var(R,0)+eps
    for _ in range(iters):
        Dinv=np.diag(1/D)
        Cf=safe_inv(np.eye(k)+F.T@Dinv@F)
        EZ=(Y-XB(β))@Dinv@F@Cf
        F=(Y-XB(β)).T@EZ@safe_inv(EZ.T@EZ+N*Cf+lam*np.eye(k))
        R=Y-XB(β)-EZ@F.T; D=np.mean(R**2,0)+eps
        Dinv=np.diag(1/D); Cf=safe_inv(np.eye(k)+F.T@Dinv@F)
        Σinv=Dinv-Dinv@F@Cf@F.T@Dinv                  # biggest array
        A  = np.zeros((p*K,p*K)); rhs=np.zeros(p*K)
        for j in range(K):
            Sj=Σinv[:,j]
            for l in range(K):
                A[j*p:(j+1)*p,l*p:(l+1)*p]=Sj[l]*G[j,l]
            rhs[j*p:(j+1)*p]=Xs[j].T@(Y@Sj)
        β=safe_solve(A+lam*np.eye(p*K),rhs).reshape(K,p).T
    Σ = F@F.T + np.diag(D)
    mem_mb = Σinv.nbytes/1e6          # deterministic “peak”
    return β,Σ,mem_mb

# ---------- ALS (memory-light) ----------
def als_sur(Xs,Y,k,G,lam=1e-4,eps=1e-6,sweeps=6):
    K,p,N = len(Xs), Xs[0].shape[1], Y.shape[0]
    β = np.column_stack([safe_solve(G[j,j], Xs[j].T@Y[:,j]) for j in range(K)])
    XB=lambda B: np.column_stack([Xs[j]@B[:,[j]] for j in range(K)])
    R=Y-XB(β); F=pca_init(R,k); D=np.var(R,0)+eps
    for _ in range(sweeps):
        Dinv=np.diag(1/D)
        Cf=safe_inv(np.eye(k)+F.T@Dinv@F)
        Σinv=Dinv-Dinv@F@Cf@F.T@Dinv
        A  = np.zeros((p*K,p*K)); rhs=np.zeros(p*K)
        for j in range(K):
            Sj=Σinv[:,j]
            for l in range(K):
                A[j*p:(j+1)*p,l*p:(l+1)*p]=Sj[l]*G[j,l]
            rhs[j*p:(j+1)*p]=Xs[j].T@(Y@Sj)
        β=safe_solve(A+lam*np.eye(p*K),rhs).reshape(K,p).T
        R=Y-XB(β)
        U=R@F@safe_inv(F.T@F+lam*np.eye(k))
        F=R.T@U@safe_inv(U.T@U+lam*np.eye(k))
        D=np.mean((R-U@F.T)**2,0)+eps
    Σ = F@F.T + np.diag(D)
    mem_mb = (K*k + K)*8/1e6          # size of F and D (upper-bound)
    return β,Σ,mem_mb

# ---------- benchmark ----------
def run(Ks=(50,80,120), reps=6, N=220, p=3, k=3):
    rows=[]
    for K in Ks:
        for r in range(reps):
            X,Y,B = simulate(N,K,p,k,seed=33+K*7+r); G=gram(X)
            βe,Σe,mem_em   = em_sur(X,Y,k,G)
            βa,Σa,mem_als  = als_sur(X,Y,k,G)
            rows.append(dict(K=K,rep=r,
                             RMSE_EM =np.sqrt(np.mean((βe-B)**2)),
                             RMSE_ALS=np.sqrt(np.mean((βa-B)**2)),
                             MEM_EM  =mem_em,
                             MEM_ALS =mem_als))
    return pd.DataFrame(rows)

# ---------------- run ----------------
if __name__=="__main__":
    df=run()
    tbl=df.groupby("K")[['RMSE_EM','RMSE_ALS','MEM_EM','MEM_ALS']].mean().round(3)
    print("\n=== mean RMSE & estimated peak MB ===\n", tbl)
    for K,g in df.groupby("K"):
        t,p=ttest_rel(g.RMSE_EM,g.RMSE_ALS); w,pw=wilcoxon(g.RMSE_EM,g.RMSE_ALS)
        print(f"K={K}: ΔRMSE={g.RMSE_EM.mean()-g.RMSE_ALS.mean():+.3f} "
              f"| p_t={p:.3g}, p_w={pw:.3g}, mem ratio ≈ {(g.MEM_EM.mean()/g.MEM_ALS.mean()):.1f}×")



=== mean RMSE & estimated peak MB ===
      RMSE_EM  RMSE_ALS  MEM_EM  MEM_ALS
K                                      
50     0.021     0.021   0.020    0.002
80     0.020     0.020   0.051    0.003
120    0.020     0.020   0.115    0.004
K=50: ΔRMSE=+0.000 | p_t=0.862, p_w=1, mem ratio ≈ 12.5×
K=80: ΔRMSE=-0.000 | p_t=0.145, p_w=0.156, mem ratio ≈ 20.0×
K=120: ΔRMSE=+0.000 | p_t=0.665, p_w=0.688, mem ratio ≈ 30.0×
