In [335]:
import numpy as np
import numpy.random as npr
from scipy.special import gammaln, digamma

In [402]:
def simulate_LDA(K, V, N, M, eta=0.1, alpha=0.5, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    beta = rs.dirichlet(np.full(V, eta), size=K)
    theta = rs.dirichlet(np.full(K, alpha), size=N)
    X = []
    for i in range(N):
        x_i = np.zeros(M[i], dtype=int)
        for j in range(M[i]):
            z_ij = rs.choice(K, p=theta[i])
            x_ij = rs.choice(V, p=beta[z_ij])
            x_i[j] = x_ij
        X.append(x_i)
    return X

def init_variational_params(X, K, V):
    N = len(X)
    Ms = np.array([len(x_i) for x_i in X])
    lambd = npr.uniform(low=0.01, high=1.0, size=(K, V))
    gamma = np.ones((N, K))
    phi = []
    for M_i in Ms:
        phi_i = np.ones((M_i, K))
        phi_i = phi_i / K
        phi.append(phi_i)
    return lambd, gamma, phi

B = 100
rs = npr.RandomState(0)
eta = 0.1
alpha = 0.5
K = 10
V = 500
N = 50
M = npr.poisson(100, size=N)
X = simulate_LDA(K, V, N, M, eta, alpha)

In [444]:
def score_dir(x, alpha):
    return digamma(np.sum(alpha)) - digamma(alpha)

def score_cat(z, phi):
    return z / phi

def sample_variational_params(params, num_samples):
    lambd_params, gamma_params, phi_params = params
    K, V = lambd.shape
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    betas = []
    thetas = []
    zs = []
    for _ in range(num_samples):
        beta = np.zeros((K, V))
        for k in range(K):
            beta[k] = npr.dirichlet(lambd_params[k])
        
        theta = np.zeros((N, K))
        for i in range(N):
            theta[i] = npr.dirichlet(gamma_params[i])

        z = []
        for i in range(N):
            z_i = np.zeros(Ms[i], dtype=int)
            for j in range(Ms[i]):
                z_i[j] = npr.choice(K, p=phi_params[i][j])
            z.append(z_i)
        betas.append(beta)
        thetas.append(theta)
        zs.append(z)
    if num_samples == 1:
        return betas[0], thetas[0], zs[0]
    else:
        return betas, thetas, zs

In [404]:
def log_dir(x, alpha):
    return gammaln(np.sum(alpha)) - np.sum(gammaln(alpha)) + np.sum((alpha-1) * np.log(x))

def log_variational_dist(latent_params, var_params):
    lambd, gamma, phi = var_params
    K = lambd.shape[0]
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    beta, theta, z = latent_params

    log_q_beta = 0.0
    for k in range(K):
        log_q_beta += log_dir(beta[k], lambd[k])
    
    log_q_theta = 0.0
    for i in range(N):
        log_q_theta += log_dir(theta[i], gamma[i])

    log_q_z = 0.0
    for i in range(N):
        for j in range(Ms[i]):
            for k in range(K):
                log_q_z += float(z[i][j] == k) * np.log(phi[i][j, k])
    return log_q_beta + log_q_theta + log_q_z

def log_joint_prob(latent_params, X):
    beta, theta, z = latent_params
    K, V = beta.shape
    N = theta.shape[0]
    Ms = np.array([len(z[i]) for i in range(len(z))])

    log_p_beta = 0.0
    for k in range(K):
        log_p_beta += log_dir(beta[k], np.full(V, eta))

    log_p_theta = 0.0
    log_p_z = 0.0
    log_p_x = 0.0
    for i in range(N):
        log_p_theta += log_dir(theta[i], np.full(K, alpha))
        for j in range(Ms[i]):
            log_p_z += np.log(theta[i, z[i][j]])
            log_p_x += np.log(beta[z[i][j], X[i][j]])
    return log_p_beta + log_p_theta + log_p_z + log_p_x

In [422]:
def score_variational_dist(latent_params, var_params):
    beta, theta, z = latent_params
    lambd, gamma, phi = var_params
    K = lambd.shape[0]
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    grad_lambda = np.zeros_like(lambd)
    grad_gamma = np.zeros_like(gamma)
    grad_phi = [np.zeros_like(phi[i]) for i in range(N)]

    for k in range(K):
        grad_lambda[k] = score_dir(beta[k], lambd[k])
    
    for i in range(N):
        grad_gamma[i] = score_dir(theta[k], gamma[k])
    
    for i in range(N):
        for j in range(Ms[i]):
            for k in range(K):
                grad_phi[i][j, k] = float(z[i][j] == k) / phi[i][j, k]
    return grad_lambda, grad_gamma, grad_phi

In [451]:
lr = 10e-4
eps = 10e-6
S = 10
lambd, gamma, phi = init_variational_params(X, K, V)
beta_samp, theta_samp, z_samp = sample_variational_params((lambd, gamma, phi), S)
accum_grad_lambda, accum_grad_gamma, accum_grad_phi = 0, 0, None

for s in range(S):
    grad_lambda, grad_gamma, grad_phi = score_variational_dist((beta_samp[s], theta_samp[s], z_samp[s]), (lambd, gamma, phi))
    accum_grad_lambda += grad_lambda
    accum_grad_gamma += grad_gamma
    # accum_grad_phi += grad_phi

# lambd + (accum_grad_lambda / S)
(lr / np.sqrt((np.square(accum_grad_lambda) + eps))) * accum_grad_lambda, 10e-4
# grad_lambda * (log_joint_prob((beta, theta, z), X) - log_variational_dist((beta, theta, z), (lambd, gamma, phi)))

(array([[0.001, 0.001, 0.001, ..., 0.001, 0.001, 0.001],
        [0.001, 0.001, 0.001, ..., 0.001, 0.001, 0.001],
        [0.001, 0.001, 0.001, ..., 0.001, 0.001, 0.001],
        ...,
        [0.001, 0.001, 0.001, ..., 0.001, 0.001, 0.001],
        [0.001, 0.001, 0.001, ..., 0.001, 0.001, 0.001],
        [0.001, 0.001, 0.001, ..., 0.001, 0.001, 0.001]]),
 0.001)