In [None]:
import numpy as np
import numpy.random as npr
from scipy.special import gammaln, digamma

In [1284]:
def simulate_LDA(K, V, N, M, eta0=0.1, alpha0=0.5, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    beta = rs.dirichlet(np.full(V, eta0), size=K)
    theta = rs.dirichlet(np.full(K, alpha0), size=N)
    X = []
    for i in range(N):
        x_i = np.zeros(M[i], dtype=int)
        for j in range(M[i]):
            z_ij = rs.choice(K, p=theta[i]) + 1
            x_ij = rs.choice(V, p=beta[z_ij-1]) + 1
            x_i[j] = x_ij
        X.append(x_i)
    return X

def get_unique_idxs(X):
    N = len(X)
    Ms = [len(x_i) for x_i in X]
    unique_idxs = []

    for i in range(N):
        unique_idxs_i = np.zeros(Ms[i], dtype=int)
        for j in range(Ms[i]):
            unique_idxs_i[j] = X[i][j] - 1
        unique_idxs_i = np.unique(np.sort(unique_idxs_i))
        unique_idxs.append(unique_idxs_i)
    return unique_idxs

def get_idx_counts(X, unique_idxs):
    N = len(X)
    idx_counts = []

    for i in range(N):
        counts = np.zeros(len(unique_idxs[i]), dtype=int)
        for j, val in enumerate(unique_idxs[i]):
            counts[j] = np.sum(X[i].astype(float) == (val+1))
        idx_counts.append(counts)
    return idx_counts

def init_variational_params(X, K, V, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    N = len(X)
    Ms = np.array([len(x_i) for x_i in X])
    lambd = rs.uniform(low=0.01, high=1.0, size=(K, V))
    # lambd = np.full((K, V), 1.0)
    gamma = np.ones((N, K))
    phi = []
    for M_i in Ms:
        phi_i = np.ones((M_i, K))
        phi_i = phi_i / K
        phi.append(phi_i)
    return lambd, gamma, phi

def log_dir(x, alpha):
    return gammaln(np.sum(alpha)) - np.sum(gammaln(alpha)) + np.sum((alpha-1) * np.log(x + 1e-10))

def compute_ELBO(lambd, gamma, phi, unique_idxs, num_occs):
    N = gamma.shape[0]

    E_log_p_beta = np.sum((eta0-1) * (digamma(lambd) - digamma(np.sum(lambd, axis=1, keepdims=True))))
    E_log_p_theta = np.sum((alpha0-1) * (digamma(gamma) - digamma(np.sum(gamma, axis=1, keepdims=True))))
    E_log_q_beta = np.sum(-gammaln(np.sum(lambd, axis=1)) + np.sum(gammaln(lambd), axis=1) \
            - np.sum((lambd - 1) * (digamma(lambd) - digamma(np.sum(lambd, axis=1, keepdims=True))), axis=1))
    E_log_q_theta = np.sum(-gammaln(np.sum(gamma, axis=1)) + np.sum(gammaln(gamma), axis=1) \
            - np.sum((gamma - 1) * (digamma(gamma) - digamma(np.sum(gamma, axis=1, keepdims=True))), axis=1))
    
    E_log_p_x_z = 0.0
    for i in range(N):
        unique_idx = unique_idxs[i]
        counts = num_occs[i]
        j = 0
        for idx in unique_idx:
            E_log_p_x_z += counts[j] * np.sum(phi[i][j] * (digamma(gamma[i])-digamma(np.sum(gamma[i])))) \
                + counts[j] * np.sum(phi[i][j] * (digamma(lambd[:, idx])-digamma(np.sum(lambd, axis=1))))
            j += 1

    E_log_q_z = 0.0
    for i in range(N):
        unique_idx = unique_idxs[i]
        counts = num_occs[i]
        j = 0
        for idx in unique_idx:
            E_log_q_z += -np.sum(phi[i][j] * np.log(phi[i][j]))
            j += 1
    return E_log_p_beta + E_log_p_theta + E_log_q_beta + E_log_q_theta + E_log_p_x_z + E_log_q_z

In [1250]:
def sample_variational_params(var_params):
    lambd, gamma, phi = var_params
    K, V = lambd.shape
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    
    beta = np.zeros((K, V))
    for k in range(K):
        # beta[k] = npr.dirichlet(lambd[k])
        beta[k] = npr.dirichlet(lambd[k]) + 1e-10
        beta[k] /= beta[k].sum()
    
    theta = np.zeros((N, K))
    for i in range(N):
        # theta[i] = npr.dirichlet(gamma[i])
        theta[i] = npr.dirichlet(gamma[i]) + 1e-10
        theta[i] /= theta[i].sum()

    z = []
    for i in range(N):
        z_i = np.zeros(Ms[i], dtype=int)
        for j in range(Ms[i]):
            z_i[j] = npr.choice(K, p=phi[i][j]) + 1
        z.append(z_i)
    
    return beta, theta, z

def log_variational_dist(latent_params, var_params):
    lambd, gamma, phi = var_params
    beta, theta, z = latent_params
    K = lambd.shape[0]
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])

    log_q = 0.0
    for k in range(K):
        log_q += log_dir(beta[k], lambd[k])

    for i in range(N):
        log_q += log_dir(theta[i], gamma[i])
        for j in range(Ms[i]):
            for k in range(K):
                log_q += float(z[i][j]-1 == k) * np.log(phi[i][j, k] + eps)
    return log_q

def log_joint_prob(latent_params, X):
    beta, theta, z = latent_params
    K, V = beta.shape
    N = theta.shape[0]
    Ms = np.array([len(z[i]) for i in range(len(z))])

    log_p = 0.0
    for k in range(K):
        log_p += log_dir(beta[k], np.full(V, eta0))

    for i in range(N):
        log_p += log_dir(theta[i], np.full(K, alpha0))
        for j in range(Ms[i]):
            z_ij, x_ij = z[i][j], X[i][j]
            log_p += np.log(theta[i, z_ij-1])
            log_p += np.log(beta[z_ij-1, x_ij-1])
    return log_p

def score_variational_dist(latent_params, var_params):
    beta, theta, z = latent_params
    lambd, gamma, phi = var_params
    K, V = lambd.shape
    N = gamma.shape[0]
    Ms = np.array([len(phi[i]) for i in range(len(phi))])
    grad_lambda = np.zeros_like(lambd)
    grad_gamma = np.zeros_like(gamma)
    grad_phi = [np.zeros_like(phi_i) for phi_i in phi]

    for k in range(K):
        grad_lambda_k = np.zeros(V)
        for v in range(V):
            grad_lambda_k[v] = digamma(np.sum(lambd[k])) - digamma(lambd[k, v]) + np.log(beta[k, v] + 1e-10)
        grad_lambda[k] = grad_lambda_k
    
    for i in range(N):
        grad_gamma_i = np.zeros(K)
        for k in range(K):
            grad_gamma_i[k] = digamma(np.sum(gamma[i])) - digamma(gamma[i, k]) + np.log(theta[i, k] + 1e-10)
        grad_gamma[i] = grad_gamma_i
    
    for i in range(N):
        for j in range(Ms[i]):
            grad_phi_ij = np.zeros(K)
            for k in range(K):
                grad_phi_ij[k] = float(z[i][j]-1 == k) / phi[i][j, k]
            grad_phi[i][j] = grad_phi_ij
    return grad_lambda, grad_gamma, grad_phi

In [None]:
rs = npr.RandomState(0)
eta0 = 0.3
alpha0 = 0.5
K = 5
V = 100
N = 10
M = npr.poisson(50, size=N)
X = simulate_LDA(K, V, N, M, eta0, alpha0)
unique_idxs = get_unique_idxs(X)
num_occs = get_idx_counts(X, unique_idxs)

S = 20
rho = 1e-6
eps = 1e-6
lambd, gamma, phi = init_variational_params(X, K, V)
G_lambda = np.zeros_like(lambd)
G_gamma = np.zeros_like(gamma)
G_phi = [np.zeros_like(phi_i) for phi_i in phi]

print(compute_ELBO(lambd, gamma, phi, unique_idxs, num_occs))

for t in range(1000):
    grad_lambda = np.zeros_like(lambd)
    grad_gamma = np.zeros_like(gamma)
    grad_phi = [np.zeros_like(phi_i) for phi_i in phi]
    for s in range(S):
        beta_s, theta_s, z_s = sample_variational_params((lambd, gamma, phi))
        grad_lambda_s, grad_gamma_s, grad_phi_s = score_variational_dist((beta_s, theta_s, z_s), (lambd, gamma, phi))
        log_p, log_q = log_joint_prob((beta_s, theta_s, z_s), X), log_variational_dist((beta_s, theta_s, z_s), (lambd, gamma, phi))
        grad_lambda += grad_lambda_s * (log_p - log_q)
        grad_gamma += grad_gamma_s * (log_p - log_q)
        for i in range(len(grad_phi)):
            grad_phi[i] += grad_phi_s[i] * (log_p - log_q)
    grad_lambda /= S
    grad_gamma /= S
    grad_phi = [g_phi / S for g_phi in grad_phi]

    G_lambda += np.square(grad_lambda)
    G_gamma += np.square(grad_gamma)
    G_phi = [G_p + np.square(g_p) for G_p, g_p in zip(G_phi, grad_phi)]

    rho_lambda = rho / (np.sqrt(G_lambda) + eps)
    rho_gamma = rho / (np.sqrt(G_gamma) + eps)
    rho_phi = [rho / (np.sqrt(G_p) + eps) for G_p in G_phi]

    lambd += rho_lambda * grad_lambda
    lambd = np.maximum(lambd, 1e-3)
    gamma += rho_gamma * grad_gamma
    gamma = np.maximum(gamma, 1e-3)
    for i in range(len(phi)):
        phi[i] += rho_phi[i] * grad_phi[i]
        phi[i] = np.maximum(phi[i], 1e-10)
        phi[i] /= phi[i].sum(axis=1, keepdims=True)
            
        # if t % 100 == 0:
        #     print(t)
    print(compute_ELBO(lambd, gamma, phi, unique_idxs, num_occs))

-4983.819658127381
-4983.905264196464
-4983.966246182873
-4984.014674972568
-4984.057099537765
-4984.095234404971
-4984.1303220457
-4984.162286111251
-4984.192364114485
-4984.220846626944
-4984.247522591442
-4984.272862917039
-4984.297307809569
-4984.320495357874
-4984.343256898017
-4984.364996246092
-4984.386148615327
-4984.406784340388
-4984.426870423872
-4984.446353703554
-4984.465077389978
-4984.483426421914
-4984.501413378982
-4984.5193670972885
-4984.536480717368
-4984.553444428044
-4984.569771104143
-4984.58586665467
-4984.602068564449
-4984.6178513271025
-4984.633326861395
-4984.648659238888
-4984.663519965062
-4984.678264753955
-4984.692983680963
-4984.7073993318545
-4984.721552714937
-4984.735491770767
-4984.749233657117
-4984.763031949853
-4984.7761146502435
-4984.78914840146
-4984.8022983259925
-4984.815128298545
-4984.827903438429
-4984.840589780726
-4984.85304804686
-4984.865530555126
-4984.877682327683
-4984.889785022901
-4984.9018345623035
-4984.913781412323
-4984.92560

KeyboardInterrupt: 

In [1281]:
lambd

array([[1.        , 1.00346299, 1.00336161, 1.00031725, 1.00413804,
        1.00568828, 1.        , 1.        , 1.00005877, 1.0002066 ,
        1.        , 1.0085892 , 1.00428117, 1.01208755, 1.00050942,
        1.01014529, 1.        , 1.00292325, 1.01744589, 1.        ,
        1.0018596 , 1.00618972, 1.00173486, 1.00002366, 1.00044266,
        1.        , 1.00771825, 1.01150733, 1.00256214, 1.00117589,
        1.00008221, 1.        , 1.00875977, 1.        , 1.        ,
        1.01172746, 1.00043438, 1.00234683, 1.        , 1.        ,
        1.0020654 , 1.00128344, 1.00616477, 1.00279991, 1.00153313,
        1.        , 1.        , 1.00179698, 1.00007665, 1.01499926,
        1.        , 1.        , 1.00005689, 1.        , 1.        ,
        1.00139097, 1.0062881 , 1.        , 1.        , 1.        ,
        1.00887612, 1.        , 1.        , 1.00634993, 1.01246871,
        1.        , 1.01116507, 1.01497582, 1.00831365, 1.00022908,
        1.00128752, 1.00428388, 1.        , 1.00

In [1231]:
import scipy
scipy.stats.mode(lambd)[0]

array([0.72772895, 0.68986909, 0.5139746 , 0.81693573, 0.67666415,
       0.72916291, 0.52684395, 0.71954486, 0.78033844, 0.74273193,
       0.5948293 , 0.67831366, 0.60617041, 0.37333081, 0.9905991 ,
       0.34404053, 0.54714105, 0.87253999, 0.3270894 , 0.84705663,
       0.80041732, 0.4112345 , 0.47779821, 0.76437366, 0.98003957,
       0.59814953, 0.66800309, 0.65138105, 0.84521121, 0.69373105,
       0.55705901, 0.56451172, 0.40557593, 0.69829906, 0.64126996,
       0.78974439, 0.42131593, 0.71449213, 0.85733011, 0.71171898,
       0.59919093, 0.70407542, 0.66668948, 0.56677303, 0.52834288,
       0.998471  , 0.87207133, 0.90712203, 0.71797699, 0.37034996,
       0.6840777 , 0.84403024, 0.65780185, 0.93805543, 0.78395794,
       0.61876346, 1.02362356, 0.92187373, 0.95010749, 0.67143199,
       0.92023049, 0.68014147, 0.68260278, 0.96160331, 0.61899003,
       0.62203863, 0.84426588, 0.84446863, 0.75051007, 0.84629898,
       0.63250784, 0.77954966, 0.75507628, 0.71519343, 0.58479