In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import expon, multivariate_normal
from sympy import nextprime

# Set the dimensions of the problem
n = nextprime(20)  # size of observations
d = nextprime(3)   # number of latent dimensions
T = nextprime(1000)  # number of trials

# Specify SNR and noise variance
SNR = 0.5
nseVar = d ** 2 / SNR

# Generate data
W = np.random.randn(n, d)  # weight matrix
z = np.random.randn(d, T)  # latents
lambda_ = expon.rvs(scale=1, size=(n,)) / nseVar  # noise precisions
Y = W @ z + np.diag(np.sqrt(1.0 / lambda_)) @ np.random.randn(n, T)

# Do EM (with help from Bishop)
# initialize
I = 1200  # number of iterations
Wn = np.random.randn(n, d)  # Initialize new W
lambNew = nseVar * expon.rvs(scale=1, size=(n,))  # initialize noise variances
S = Y @ Y.T / T  # raw data sample covariance
ll = np.empty(I)  # to store log-likelihoods

for ii in range(I):
    Wo = Wn
    lambOld = lambNew
    
    # E-step
    invPsiOld = np.diag(lambOld)
    invG = np.eye(d) + Wo.T @ np.linalg.inv(invPsiOld) @ Wo
    Ez = Wo.T @ np.linalg.inv(invPsiOld) @ Y
    Ez = np.linalg.solve(invG, Ez)
    sumEzz = T * np.linalg.inv(invG) + Ez @ Ez.T
    
    # M-step
    Wn = Y @ Ez.T
    Wn = np.linalg.solve(sumEzz, Wn.T).T
    lambNew = 1.0 / np.diag(S - Wn @ Ez @ Y.T / T)
    
    # document progress
    cov_matrix = Wn @ Wn.T + np.diag(1.0 / lambNew)
    ll[ii] = np.sum(multivariate_normal.logpdf(Y.T, mean=np.zeros(n), cov=cov_matrix))

# Verify we are getting the correct result
WW = W @ W.T
WnWn = Wn @ Wn.T
pltclr = plt.cm.tab10(0)

plt.figure(figsize=(12, 5))
# Left plot
plt.subplot(1, 2, 1)
plt.loglog([min(lambda_), max(lambda_)], [min(lambda_), max(lambda_)], 'k', linewidth=2)
plt.scatter(lambda_, lambNew, edgecolor=pltclr, facecolor=pltclr, alpha=0.5)
plt.title(r'$\lambda$')
plt.xlabel('True')
plt.ylabel('Estimated')
plt.axis('equal')
plt.grid(False)

# Right plot
plt.subplot(1, 2, 2)
plt.plot([WW.min(), WW.max()], [WW.min(), WW.max()], 'k', linewidth=2)
plt.scatter(WW.flatten(), WnWn.flatten(), edgecolor=pltclr, facecolor=pltclr, alpha=0.2)
plt.title(r'$W W^T$')
plt.xlabel('True')
plt.ylabel('Estimated')
plt.axis('equal')
plt.grid(False)

plt.show()

# Verify marginal log-likelihood is increasing every step
plt.figure()
plt.loglog(range(1, I + 1), ll, '-o', linewidth=2)
plt.ylabel('Marginal Log-likelihood')
plt.xlabel('Iteration')
plt.grid(False)
plt.show()
