In [48]:
import numpy as  np
import matplotlib.pyplot as plt

from genimages import genimages

In [59]:
def FreeEnergy(X, sigma_sq, pie, alpha, lamda, mu, precisions, eps=1e-8):

    '''
    Free Energy wrt the variational distribution and the current estimate of 
    the parameters.

    Parameters
    ----------
    X: np.ndarray
        D x N data matrix
    sigma_sq: float
        noise variance
    pie: np.ndarray
        K x 1 vector of priors on the latents
    alpha: np.array
        K dimensional parameter vector of precisions of latents
    lamda: np.ndarray
        K x N matrix of parameter values characterizing the variational
        distribution
    mu: np.ndarray
        D x K matrix of means of latent features
    precisions: np.array
        K dimensional vector of precisions of posterior on latents
    eps: float
        correction for numerical stability of log

    Returns
    -------
    F: int
        variational free energy wrt the factored distribution and the current 
        estimate of the parameters 
    '''

    N, D = X.shape[1], X.shape[0]
    lamda_sum_n = np.sum(lamda, axis=1)
    ESS_lamda = lamda@lamda.T; np.fill_diagonal(ESS_lamda, lamda_sum_n)
    ESS_mu =  mu.T@mu; np.fill_diagonal(ESS_mu, np.diag(ESS_mu) + D/precisions)
    
    expectation_log_joint = -N*D/2 * np.log(2*np.pi*sigma_sq) + \
        -1/(2*sigma_sq) * (np.vdot(X, X) + np.vdot(ESS_lamda, ESS_mu) - 2*np.vdot(lamda, mu.T@X)) + \
        np.dot(lamda_sum_n, np.log(pie+eps)) + np.dot(N-lamda_sum_n, np.log(1-pie+eps)) + \
        D/2 * np.sum(np.log(2*np.pi*alpha)) - alpha@np.diag(ESS_mu)/2
    
    H_mBernoulli = - np.sum(lamda*np.log(lamda+eps) + (1-lamda)*np.log(1-lamda+eps)) 
    H_mGaussian = D/2 * np.sum(np.log(2*np.pi*np.e/precisions))
    
    print(expectation_log_joint, H_mBernoulli, H_mGaussian)
    F = expectation_log_joint +  H_mBernoulli + H_mGaussian

    return F

In [60]:
def MeanField(X, sigma_sq, pie, alpha, lamda, mu, precisions, maxsteps, eps=1e-1):

    '''
    Factored Variational E-step for learning with Binary Latent Factor Model.

    Parameters
    ----------
    X: np.ndarray
        D x N data matrix
    sigma_sq: float
        noise variance
    pie: np.ndarray
        K x 1 vector of priors on the latents
    alpha: np.array
        K dimensional parameter vector of precisions of latents
    lamda: np.ndarray
        K x N matrix of initial parameter values characterizing the variational
        distribution
    mu: np.ndarray
        D x K matrix of means of latent features
    precisions: np.array
        K dimensional vector of precisions of posterior on latents
    maxsteps: int
        maximum number of steps of the fixed point interations
    eps: float
        minimum improvement in free energy to continue iterations 

    Returns
    -------
    lamda: np.ndarray
        K x N matrix of updated parameter values characterizing the variational
        distribution
    mu: np.ndarray
        D x K matrix of means of latent features
    precisions: np.array
        precisions of distributions on latent features
    F: float
        Variational Free Energy
    '''

    prior_odds = (1-pie) / (pie+1e-8)
    F = float('-inf')

    for _ in range(maxsteps):
        for i in range(pie.size):
            mu_i, lamda_i = mu[:, [i]], lamda[[i]]
            exponent = mu_i.T @ (-2*X + mu_i + 2*(mu@lamda-mu_i@lamda_i)) / (2*sigma_sq)
            lamda[[i]] = 1 / (1 + prior_odds[i]*np.exp(exponent)); lamda_i = lamda[[i]]
            precisions[i] = precisions[i] + np.sum(lamda_i)/sigma_sq
            mu[:, [i]] = (X - (mu@lamda-mu_i@lamda_i)) @ lamda_i.T / (precisions[i]*sigma_sq)
        old_F, F = F, FreeEnergy(X, sigma_sq, pie, alpha, lamda, mu, precisions)
        if F-old_F < eps:
            break

    return lamda, mu, precisions, F

In [61]:
def MStep(X, lamda, mu, precisions):

    '''
    Maximisation step for learning with Binary Latent Factor Model.
    Parameters
    ----------
    X: np.ndarray
        D x N data matrix
    lamda: np.ndarray
        K x N matrix of parameter values characterizing the variational 
        distribution
    mu: np.ndarray
        D x K matrix of means of latent features
    precisions: np.array
        K dimensional vector of precisions of posterior on latents

    Returns
    -------
    sigma_sq: float
        updated estimate of noise variance
    pie: np.ndarray
        K x 1 vector of parameters characterising the distribution on 
        the latents
    alpha: np.array
        K dimensional parameter vector of precisions of latents
    '''

    D, N = X.shape
    ESS_lamda = lamda@lamda.T; np.fill_diagonal(ESS_lamda, np.sum(lamda, axis=1))
    ESS_mu =  mu.T@mu; np.fill_diagonal(ESS_mu, np.diag(ESS_mu) + D/precisions)
    
    sigma_sq = (np.vdot(X, X) + np.vdot(ESS_lamda, ESS_mu) - 2*np.vdot(lamda, mu.T@X)) / (N*D)
    pie = np.mean(lamda, axis=1)
    alpha = D / (D*precisions + np.diag(ESS_mu))

    return sigma_sq, pie, alpha

In [62]:
def LearnBinFactors(X, K, iterations):

    '''
    Factored Variational EM for learning with Binary Latent Factor Model.

    Parameters
    ----------
    X: np.ndarray
        D x N data matrix
    K: int
        number of latent binary factors
    iterations: int
        maximum number of iterations of EM

    Returns
    -------
    mu: np.ndarray
        D x K matrix of means: mu @ S = X
    sigma_sq: float
        estimate of noise variance
    pie: np.ndarray
        K x 1 vector of parameters characterising the distribution on the
        latents
    Fs: np.ndarray
        Free energy values after each iteration
    '''

    D, N = X.shape
    sigma_sq = np.random.rand()*1.5-0.5; pie = np.random.rand(K); alpha = np.random.rand(K)
    lamda = np.random.rand(K, N); mu = np.random.rand(D, K); precisions = np.random.rand(K)

    Fs = np.zeros(iterations); Fs[0] = FreeEnergy(X, sigma_sq, pie, alpha, lamda, mu, precisions)
    for i in range(iterations):
        lamda, mu, precisions, _ = MeanField(X, sigma_sq, pie, alpha, lamda, mu, precisions, maxsteps=50, eps=0.)
        sigma_sq, pie, alpha = MStep(X, lamda, mu, precisions)
        Fs[i] = FreeEnergy(X, sigma_sq, pie, alpha, lamda, mu, precisions)
        if i != 0:
            fe_increment = Fs[i] - Fs[i-1]
            assert fe_increment >= 0., fe_increment
            if fe_increment < 1e-2:
                break
            
    return mu, alpha, sigma_sq, pie, Fs[:i+1]

In [63]:
N = 400
X = genimages(N).T
D = X.shape[0]

K = 8
mu, alpha, sigma_sq, pie, Fs = LearnBinFactors(X, K, 100)

plt.plot(Fs, color='blue')
plt.xlabel('EM Iterations', font='serif', size=12)
plt.ylabel('Free Eenrgy', font='serif', size=12)
plt.title('Factored Variational EM', font='serif', size=14)
plt.tight_layout()
plt.grid()
plt.show()

fig = plt.figure(figsize=(8, 4.5))
fig.patch.set_facecolor('lightblue')
for k in range(K):
    plt.subplot(2, 4, k+1)
    plt.imshow(np.reshape(mu[:,k], (4, 4)), cmap=plt.gray(), interpolation='none')
    plt.axis('off')
plt.tight_layout()
plt.show()

-925106.3854318497 1593.7720751934407 292.14841255348574
-9166.23231517059 141.4672753360619 -166.65715661371075
-7694.6724940002105 189.75102284589013 -245.74956757420114
-6816.454371981954 462.8899022821786 -305.42501113203775
-6794.801599539534 776.57355589918 -344.0812437965272
-7296.428422207424 844.1978398118825 -370.9618147130809
-6937.977537985441 844.1978398118825 -370.9618147130809
-8214.176287532075 1072.6167549759016 -383.67586017070914
-8516.038054239003 1183.6479053697492 -394.5178501637479
-8035.568803198284 1183.6479053697492 -394.5178501637479


AssertionError: -781.6972351056429