In [547]:
import numpy as np
import numpy.random as npr
from scipy.special import gammaln, digamma
eps = 1e-100

In [584]:
K = 10
V = 10
N = 50
M = npr.poisson(50, size=N)
eta0 = 0.1
alpha0 = 0.1

rs = npr.RandomState(npr.randint(0, 100))
beta = rs.dirichlet(np.full(V, eta0), size=K)
theta = rs.dirichlet(np.full(K, alpha0), size=N)
X = []
i = 0
x_i = np.zeros(M[i], dtype=int)
for j in range(M[i]):
    z_ij = rs.choice(K, p=theta[i])
    x_ij = rs.choice(V, p=beta[z_ij]) + 1
    x_i[j] = x_ij
# for i in range(N):
#     x_i = np.zeros(M[i], dtype=int)
#     for j in range(M[i]):
#         z_ij = rs.choice(K, p=theta[i])
#         x_ij = rs.choice(V, p=beta[z_ij])
#         x_i[j] = x_ij
x_i

array([ 8,  9,  6,  5,  8, 10,  9,  6,  5,  6, 10,  6,  5,  6,  5, 10,  5,
        5,  9,  6,  6,  9,  6,  5,  6,  8,  6, 10,  6,  2,  5,  3,  9, 10,
        5,  5,  5,  7,  5, 10,  2,  9,  9,  6,  6,  9,  5,  5,  6,  6,  5,
        9,  5,  8,  5,  6,  6])

In [945]:
def simulate_LDA(K, V, N, M, eta0=0.1, alpha0=0.5, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    beta = rs.dirichlet(np.full(V, eta0), size=K)
    theta = rs.dirichlet(np.full(K, alpha0), size=N)
    X = []
    for i in range(N):
        x_i = np.zeros(M[i], dtype=int)
        for j in range(M[i]):
            z_ij = rs.choice(K, p=theta[i])
            x_ij = rs.choice(V, p=beta[z_ij])
            x_i[j] = x_ij
        X.append(x_i)
    return X

def init_variational_params(X, K, V):
    N = len(X)
    Ms = np.array([len(x_i) for x_i in X])
    lambd = npr.uniform(low=0.01, high=1.0, size=(K, V))
    gamma = np.ones((N, K))
    phi = []
    for M_i in Ms:
        phi_i = np.ones((M_i, K))
        phi_i = phi_i / K
        phi.append(phi_i)
    return lambd, gamma, phi

rs = npr.RandomState(0)
eta0 = 0.1
alpha0 = 0.5
K = 5
V = 500
N = 10
M = npr.poisson(10, size=N)
X = simulate_LDA(K, V, N, M, eta0, alpha0)

In [None]:
def score_dir(x, alpha):
    return digamma(np.sum(alpha)) - digamma(alpha) + np.log(x + 1e-10)

def score_cat(z, phi):
    return z / phi

def sample_variational_params(var_params):
    lambd, gamma, phi = var_params
    K, V = lambd.shape
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    
    beta = np.zeros((K, V))
    for k in range(K):
        beta[k] = npr.dirichlet(lambd[k])
    
    theta = np.zeros((N, K))
    for i in range(N):
        theta[i] = npr.dirichlet(gamma[i])

    z = []
    for i in range(N):
        z_i = np.zeros(Ms[i], dtype=int)
        for j in range(Ms[i]):
            z_i[j] = npr.choice(K, p=phi[i][j])
        z.append(z_i)
    
    return beta, theta, z

In [869]:
def log_dir(x, alpha):
    return gammaln(np.sum(alpha)) - np.sum(gammaln(alpha)) + np.sum((alpha-1) * np.log(x + 1e-10))

def log_variational_dist(latent_params, var_params):
    lambd, gamma, phi = var_params
    beta, theta, z = latent_params
    K = lambd.shape[0]
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])

    log_q = 0.0
    for k in range(K):
        log_q += log_dir(beta[k], lambd[k])

    for i in range(N):
        log_q += log_dir(theta[i], gamma[i])
        for j in range(Ms[i]):
            for k in range(K):
                log_q += float(z[i][j] == k) * np.log(phi[i][j, k] + eps)
    return log_q

def log_joint_prob(latent_params, X):
    beta, theta, z = latent_params
    K, V = beta.shape
    N = theta.shape[0]
    Ms = np.array([len(z[i]) for i in range(len(z))])

    log_p_beta = 0.0
    for k in range(K):
        log_p_beta += log_dir(beta[k], np.full(V, eta))

    log_p_theta = 0.0
    log_p_z = 0.0
    log_p_x = 0.0
    for i in range(N):
        log_p_theta += log_dir(theta[i], np.full(K, alpha))
        for j in range(Ms[i]):
            log_p_z += np.log(theta[i, z[i][j]])
            log_p_x += np.log(beta[z[i][j], X[i][j]])
    return log_p_beta + log_p_theta + log_p_z + log_p_x

In [974]:
def score_variational_dist(latent_params, var_params):
    beta, theta, z = latent_params
    lambd, gamma, phi = var_params
    K = lambd.shape[0]
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    grad_lambda = np.zeros(V)
    grad_gamma = np.zeros(K)
    grad_phi = np.zeros(K)

    for k in range(K):
        grad_lambda_k = np.zeros(V)
        for v in range(V):
            assert beta[k, v] >= 0
            grad_lambda_k[v] = digamma(np.sum(lambd[k])) - digamma(lambd[k, v]) + np.log(beta[k, v] + 1e-10)
        grad_lambda += grad_lambda_k
    
    for i in range(N):
        grad_gamma_i = np.zeros(K)
        for k in range(K):
            assert theta[i, k] > 0
            grad_gamma_i[k] = digamma(np.sum(gamma[i])) - digamma(gamma[i, k]) + np.log(theta[i, k] + 1e-10)
        grad_gamma += grad_gamma_i
    
    for i in range(N):
        for j in range(Ms[i]):
            grad_phi_ij = np.zeros(K)
            for k in range(K):
                grad_phi_ij[k] = float(z[i][j] == k) / phi[i][j, k]
            grad_phi += grad_phi_ij
    return grad_lambda, grad_gamma, grad_phi

In [992]:
S = 20
eta = 0.0000001
eps = 1e-6
lr_lambda = lr_gamma = lr_phi = eta
G_lambda = np.zeros(V)
G_gamma = np.zeros(K)
G_phi = np.zeros(K)

lambd, gamma, phi = init_variational_params(X, K, V)
grad_lambda = np.zeros(V)
grad_gamma = np.zeros(K)
grad_phi = np.zeros(K)

for t in range(1000):
    for s in range(S):
        beta_s, theta_s, z_s = sample_variational_params((lambd, gamma, phi))
        grad_lambda_s, grad_gamma_s, grad_phi_s = score_variational_dist((beta_s, theta_s, z_s), (lambd, gamma, phi))
        log_p, log_q = log_joint_prob((beta_s, theta_s, z_s), X), log_variational_dist((beta_s, theta_s, z_s), (lambd, gamma, phi))
        grad_lambda += grad_lambda_s * (log_p - log_q)
        grad_gamma += grad_gamma_s * (log_p - log_q)
        grad_phi += grad_phi_s * (log_p - log_q)
    grad_lambda /= S
    grad_gamma /= S
    grad_phi /= S

    G_lambda += np.square(grad_lambda)
    G_gamma += np.square(grad_gamma)
    G_phi += np.square(grad_phi)

    lr_lambda = eta / np.sqrt(G_lambda + eps)
    lr_gamma = eta / np.sqrt(G_gamma + eps)
    lr_phi = eta / np.sqrt(G_phi + eta)

    lambd += lr_lambda * grad_lambda
    lambd = np.maximum(lambd, 1e-5)
    gamma += lr_gamma * grad_gamma
    gamma = np.maximum(gamma, 1e-5)
    for i in range(N):
        phi[i] += lr_phi * grad_phi
        phi[i] = phi[i] / phi[i].sum(axis=1, keepdims=True)
    if t % 100 == 0:
        print(t)

0
100
200
300
400
500
600
700
800
900


In [993]:
import scipy
scipy.stats.mode(lambd)

ModeResult(mode=array([0.02894835, 0.58892686, 0.13111334, 0.25840583, 0.24143926,
       0.16777638, 0.09798775, 0.29752736, 0.01899156, 0.25634053,
       0.38678771, 0.15177852, 0.06888556, 0.36558001, 0.23014562,
       0.01411915, 0.07472059, 0.14255709, 0.35967489, 0.17878896,
       0.0626343 , 0.01070142, 0.03588282, 0.05943561, 0.13892424,
       0.30845921, 0.16994307, 0.13341858, 0.03092116, 0.33621839,
       0.40971375, 0.01507124, 0.7200196 , 0.1445776 , 0.10218153,
       0.59500313, 0.18863063, 0.39598547, 0.21616192, 0.21537945,
       0.03679043, 0.19236899, 0.09426296, 0.23378499, 0.04341779,
       0.03446724, 0.02516611, 0.59384448, 0.3696326 , 0.14679525,
       0.07819982, 0.38616314, 0.1200598 , 0.11410448, 0.29588956,
       0.27239536, 0.01062785, 0.1029726 , 0.21183283, 0.09150654,
       0.15005094, 0.02184367, 0.08930682, 0.22849604, 0.2144991 ,
       0.32954936, 0.03748791, 0.01320275, 0.19205902, 0.2155242 ,
       0.1572662 , 0.56371885, 0.26965703, 0.2

In [None]:
beta, _, _ = sample_variational_params((lambd, gamma, phi))
K, V = beta.shape

array([[2.00659721e-04, 0.00000000e+00, 6.71866779e-05, ...,
        0.00000000e+00, 1.87893042e-03, 9.49965702e-03],
       [1.35615162e-04, 9.00542717e-04, 2.41716938e-08, ...,
        6.95212744e-08, 5.82136775e-03, 1.80706597e-03],
       [2.95655303e-05, 3.24707332e-05, 7.62776683e-08, ...,
        1.96422889e-07, 0.00000000e+00, 4.56455141e-03],
       [5.32103402e-03, 7.99700914e-05, 3.41885343e-03, ...,
        7.40578658e-04, 7.35425060e-03, 1.35452073e-03],
       [7.43319979e-03, 3.58847573e-04, 0.00000000e+00, ...,
        6.72027375e-04, 4.59674536e-03, 1.95999394e-04]])