In [547]:
import numpy as np
import numpy.random as npr
from scipy.special import gammaln, digamma
eps = 1e-100

In [1129]:
def simulate_LDA(K, V, N, M, eta0=0.1, alpha0=0.5, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    beta = rs.dirichlet(np.full(V, eta0), size=K)
    theta = rs.dirichlet(np.full(K, alpha0), size=N)
    X = []
    for i in range(N):
        x_i = np.zeros(M[i], dtype=int)
        for j in range(M[i]):
            z_ij = rs.choice(K, p=theta[i]) + 1
            x_ij = rs.choice(V, p=beta[z_ij-1]) + 1
            x_i[j] = x_ij
        X.append(x_i)
    return X

def init_variational_params(X, K, V):
    N = len(X)
    Ms = np.array([len(x_i) for x_i in X])
    lambd = npr.uniform(low=0.01, high=1.0, size=(K, V))
    # lambd = np.full((K, V), 1.0)
    gamma = np.ones((N, K))
    phi = []
    for M_i in Ms:
        phi_i = np.ones((M_i, K))
        phi_i = phi_i / K
        phi.append(phi_i)
    return lambd, gamma, phi

In [1130]:
def sample_variational_params(var_params):
    lambd, gamma, phi = var_params
    K, V = lambd.shape
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    
    beta = np.zeros((K, V))
    for k in range(K):
        beta[k] = npr.dirichlet(lambd[k]) + 1e-10
        beta[k] /= beta[k].sum()
    
    theta = np.zeros((N, K))
    for i in range(N):
        theta[i] = npr.dirichlet(gamma[i]) + 1e-10
        theta[i] /= theta[i].sum()

    z = []
    for i in range(N):
        z_i = np.zeros(Ms[i], dtype=int)
        for j in range(Ms[i]):
            z_i[j] = npr.choice(K, p=phi[i][j]) + 1
        z.append(z_i)
    
    return beta, theta, z

def log_dir(x, alpha):
    return gammaln(np.sum(alpha)) - np.sum(gammaln(alpha)) + np.sum((alpha-1) * np.log(x + 1e-10))

def log_variational_dist(latent_params, var_params):
    lambd, gamma, phi = var_params
    beta, theta, z = latent_params
    K = lambd.shape[0]
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])

    log_q = 0.0
    for k in range(K):
        log_q += log_dir(beta[k], lambd[k])

    for i in range(N):
        log_q += log_dir(theta[i], gamma[i])
        for j in range(Ms[i]):
            for k in range(K):
                log_q += float(z[i][j]-1 == k) * np.log(phi[i][j, k] + eps)
    return log_q

def log_joint_prob(latent_params, X):
    beta, theta, z = latent_params
    K, V = beta.shape
    N = theta.shape[0]
    Ms = np.array([len(z[i]) for i in range(len(z))])

    log_p = 0.0
    for k in range(K):
        log_p += log_dir(beta[k], np.full(V, eta0))

    for i in range(N):
        log_p += log_dir(theta[i], np.full(K, alpha0))
        for j in range(Ms[i]):
            z_ij, x_ij = z[i][j], X[i][j]
            log_p += np.log(theta[i, z_ij-1])
            log_p += np.log(beta[z_ij-1, x_ij-1])
    return log_p

def score_variational_dist(latent_params, var_params):
    beta, theta, z = latent_params
    lambd, gamma, phi = var_params
    K = lambd.shape[0]
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    grad_lambda = np.zeros_like(lambd)
    grad_gamma = np.zeros_like(gamma)
    grad_phi = [np.zeros_like(phi_i) for phi_i in phi]

    for k in range(K):
        grad_lambda_k = np.zeros(V)
        for v in range(V):
            grad_lambda_k[v] = digamma(np.sum(lambd[k])) - digamma(lambd[k, v]) + np.log(beta[k, v] + 1e-10)
        grad_lambda[k] = grad_lambda_k
    
    for i in range(N):
        grad_gamma_i = np.zeros(K)
        for k in range(K):
            grad_gamma_i[k] = digamma(np.sum(gamma[i])) - digamma(gamma[i, k]) + np.log(theta[i, k] + 1e-10)
        grad_gamma[i] = grad_gamma_i
    
    for i in range(N):
        for j in range(Ms[i]):
            grad_phi_ij = np.zeros(K)
            for k in range(K):
                grad_phi_ij[k] = float(z[i][j]-1 == k) / phi[i][j, k]
            grad_phi[i][j] = grad_phi_ij
    return grad_lambda, grad_gamma, grad_phi

In [1140]:
rs = npr.RandomState(0)
eta0 = 0.3
alpha0 = 0.5
K = 5
V = 100
N = 10
M = npr.poisson(50, size=N)
X = simulate_LDA(K, V, N, M, eta0, alpha0)

S = 20
eta = 0.04
eps = 1e-6
lambd, gamma, phi = init_variational_params(X, K, V)
G_lambda = np.zeros_like(lambd)
G_gamma = np.zeros_like(gamma)
G_phi = [np.zeros_like(phi_i) for phi_i in phi]

for t in range(300):
    grad_lambda = np.zeros_like(lambd)
    grad_gamma = np.zeros_like(gamma)
    grad_phi = [np.zeros_like(phi_i) for phi_i in phi]
    for s in range(S):
        beta_s, theta_s, z_s = sample_variational_params((lambd, gamma, phi))
        grad_lambda_s, grad_gamma_s, grad_phi_s = score_variational_dist((beta_s, theta_s, z_s), (lambd, gamma, phi))
        log_p, log_q = log_joint_prob((beta_s, theta_s, z_s), X), log_variational_dist((beta_s, theta_s, z_s), (lambd, gamma, phi))
        grad_lambda += grad_lambda_s * (log_p - log_q)
        grad_gamma += grad_gamma_s * (log_p - log_q)
        for i in range(len(grad_phi)):
            grad_phi[i] += grad_phi_s[i] * (log_p - log_q)
    grad_lambda /= S
    grad_gamma /= S
    grad_phi = [g_phi / S for g_phi in grad_phi]

    G_lambda += np.square(grad_lambda)
    G_gamma += np.square(grad_gamma)
    G_phi = [G_p + np.square(g_p) for G_p, g_p in zip(G_phi, grad_phi)]

    rho_lambda = eta / np.sqrt(G_lambda + eps)
    rho_gamma = eta / np.sqrt(G_gamma + eps)
    rho_phi = [eta / np.sqrt(G_p + eps) for G_p in G_phi]

    lambd += rho_lambda * grad_lambda
    lambd = np.maximum(lambd, 1e-3)
    gamma += rho_gamma * grad_gamma
    gamma = np.maximum(gamma, 1e-3)
    for i in range(len(phi)):
        phi[i] += rho_phi[i] * grad_phi[i]
        phi[i] = np.maximum(phi[i], 1e-10)
        phi[i] /= phi[i].sum(axis=1, keepdims=True)
        
    if t % 100 == 0:
        print(t)

0
100
200


In [1141]:
import scipy
scipy.stats.mode(lambd)[0][scipy.stats.mode(lambd)[0] > 1e-3], scipy.stats.mode(gamma)[0]

(array([0.31964538, 0.33947036, 0.26728506, 0.32274846, 0.4740857 ,
        0.41559834, 0.30780697, 0.5641044 , 0.27805669, 0.4614689 ,
        0.26140196, 0.26279354, 0.28955787, 0.33479368, 0.35474976,
        0.42149211, 0.42233743]),
 array([0.87462824, 0.88907063, 0.85539712, 0.76451928, 0.76181388]))