In [1]:
import re
import numpy as np
import numpy.random as npr
import scipy.special as spp
from scipy.stats import mode
from scipy.special import digamma, logsumexp, loggamma
from scipy.stats import dirichlet as dir
from tqdm import tqdm
import matplotlib.pyplot as plt

rs = npr.RandomState(0)
K, V, N = 10, 300, 30
eta0, alpha0 = 0.8, 1 / K
Ms = rs.poisson(60, size=N)

In [44]:
def generate_lda(K, V, N, Ms, eta0=0.5, alpha0=0.5, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    eta0_vec = np.ones(V) * eta0
    alpha0_vec = np.ones(K) * alpha0

    beta = dir.rvs(eta0_vec, size=K, random_state=rs_int)
    theta = dir.rvs(alpha0_vec, size=N, random_state=rs_int)
    X = []
    for i in range(N):
        x_i = np.zeros(Ms[i], dtype=int)
        for j in range(Ms[i]):
            z_ij = rs.choice(K, p=theta[i])
            x_ij = rs.choice(V, p=beta[z_ij])
            x_i[j] = x_ij
        X.append(x_i)
    return X

def init_var_params(X, K, V, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    N = len(X)
    log_lambd = np.log(rs.uniform(low=0.3, high=1.0, size=(K, V)))
    log_gamma = np.log(rs.uniform(low=0.3, high=1.0, size=(N, K)))
    return log_lambd, log_gamma

def sample_params(var_params, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    log_lambd, log_gamma = var_params
    beta = np.zeros_like(log_lambd)
    theta = np.zeros_like(log_gamma)
    # for k in range(K):
        # beta[k] = dir.rvs(np.exp(log_lambd[k]), random_state=rs_int)[0]
        # beta[k] = rs.dirichlet(np.exp(log_lambd[k]))
    # for i in range(N):
        # theta[i] = dir.rvs(np.exp(log_gamma[i]), random_state=rs_int)[0]
        # theta[i] = rs.dirichlet(np.exp(log_gamma[i]))
    log_gamm_beta = np.log(npr.gamma(np.exp(log_lambd), 1))
    log_probs_beta = log_gamm_beta - logsumexp(log_gamm_beta, axis=1)[:, None]
    beta = np.exp(log_probs_beta)
    log_gamm_gamma = np.log(npr.gamma(np.exp(log_gamma), 1))
    log_probs_gamma = log_gamm_gamma - logsumexp(log_gamm_gamma, axis=1)[:, None]
    theta = np.exp(log_probs_gamma)
    
    return beta, theta

def log_var_approx(var_params, latent_params):
    log_lambd, log_gamma = var_params
    beta, theta = latent_params
    K = log_lambd.shape[0]
    N = log_gamma.shape[0]

    log_q_beta = sum(dir.logpdf(beta[k], np.exp(log_lambd[k])) for k in range(K))
    log_q_theta = sum(dir.logpdf(theta[i], np.exp(log_gamma[i])) for i in range(N))
    return float(log_q_beta + log_q_theta)

def log_joint_prob(latent_params, X):
    beta, theta = latent_params
    K, V = beta.shape
    N = theta.shape[0]
    eta0_vec = np.ones(V) * eta0
    alpha0_vec = np.ones(K) * alpha0

    log_p_beta = sum(dir.logpdf(beta[k], eta0_vec) for k in range(K))
    log_p_theta = sum(dir.logpdf(theta[i], alpha0_vec) for i in range(N))
    log_p_x = 0.0
    for _, (theta_i, x_i) in enumerate(zip(theta, X)):
        beta_xi = beta[:, x_i]
        log_word_probs = np.log(np.sum(theta_i[:, None] * beta_xi, axis=0))
        log_p_x += np.sum(log_word_probs)
    return float(log_p_beta + log_p_theta + log_p_x)

def score_dir(x, alpha):
    return digamma(np.sum(alpha)) - digamma(alpha) + np.log(x)

def score_var_dist(var_params, latent_params):
    log_lambd, log_gamma = var_params
    beta, theta = latent_params
    grad_lambd = (digamma(np.sum(np.exp(log_lambd), axis=0)) - digamma(np.exp(log_lambd)) + np.log(beta)) * np.exp(log_lambd)
    grad_gamma = (digamma(np.sum(np.exp(log_gamma), axis=0)) - digamma(np.exp(log_gamma)) + np.log(theta)) * np.exp(log_gamma)
    return grad_lambd, grad_gamma

def estimate_ELBO(var_params, X, S):
    log_lambd, log_gamma = var_params
    ELBO = 0.0
    for _ in range(S):
        beta_s, theta_s = sample_params((log_lambd, log_gamma))
        log_p = log_joint_prob((beta_s, theta_s), X)
        log_q = log_var_approx((log_lambd, log_gamma), (beta_s, theta_s))
        ELBO += (log_p - log_q)
    return ELBO / S

In [51]:
V = 10
K = 2
N = 1
eta0 = 1 / K
alpha0 = 1 / K
lr_init = 1e-3
eps = 1e-8
x = generate_lda(K=K, V=V, N=N, Ms=[50], eta0=(1/K), alpha0=(1/K), rs_int=8)
lambd, gamma = init_var_params(X=x, K=K, V=V)
G_lambd = np.zeros_like(lambd)
G_gamma = np.zeros_like(gamma)

beta, theta = sample_params((lambd, gamma))

log_p_beta = 0.0
for k in range(K):
    beta_k = beta[k]
    log_p_beta += loggamma(V * eta0) - V * loggamma(eta0) + (eta0 - 1) * np.sum(np.log(beta_k))
log_p_theta = 0.0
for i in range(N):
    theta_i = theta[i]
    log_p_theta += loggamma(K * alpha0) - K * loggamma(alpha0) + (alpha0 - 1) * np.sum(np.log(theta_i))
log_p_x = 0.0
for i, x_i in enumerate(x):
    for x_ij in x_i:
        log_p_x += np.log(np.sum(theta[i] * beta[:, x_ij]))

log_q_beta = 0.0
for k in range(K):
    lambd_k = np.exp(lambd[k])
    beta_k = beta[k]
    log_q_beta += loggamma(np.sum(lambd_k)) - np.sum(loggamma(lambd_k)) + np.sum((lambd_k - 1) * np.log(beta_k))
log_q_theta = 0.0
for i in range(N):
    gamma_i = np.exp(gamma[i])
    theta_i = theta[i]
    log_q_theta += loggamma(np.sum(gamma_i)) - np.sum(loggamma(gamma_i)) + np.sum((gamma_i - 1) * np.log(theta_i))
log_p = (log_p_beta + log_p_theta + log_p_x)
log_q = (log_q_beta + log_q_theta)

score_lambd = np.zeros_like(lambd)
for k in range(K):
    lambd_k = np.exp(lambd[k])
    beta_k = beta[k]
    digam_sum_lambd_k = digamma(np.sum(lambd_k))
    for v in range(V):
        score_lambd[k, v] = lambd_k[v] * ( digam_sum_lambd_k - digamma(lambd_k[v]) + np.log(beta_k[v]) )
score_gamma = np.zeros_like(gamma)
for i in range(N):
    gamma_i = np.exp(gamma[i])
    theta_i = theta[i]
    digam_sum_gamma_i = digamma(np.sum(gamma_i))
    for k in range(K):
        score_gamma[i, k] = gamma_i[k] * ( digam_sum_gamma_i - digamma(gamma_i[k]) + np.log(gamma_i[k]) )

stoch_score_lambd = score_lambd * (log_p - log_q)
stoch_score_gamma = score_gamma * (log_p - log_q)
G_lambd += np.square(stoch_score_lambd)
rho_lambd = lr_init / (G_lambd + eps)**(-0.5)
lambd = lambd + rho_lambd * stoch_score_lambd
np.exp(lambd), stoch_score_lambd, rho_lambd

(array([[1.07748094e+24, 1.27587924e+01, 1.15010154e-20, 5.61388630e-05,
         1.27228984e+00, 5.71844231e-01, 9.55263217e-06, 2.80483595e-04,
         1.63312099e-07, 8.86346757e-06],
        [5.53298420e+01, 6.08643884e-01, 1.99974169e-28, 4.84376568e-02,
         1.22059160e+01, 4.52764274e-01, 2.01187704e-06, 4.71117697e-02,
         3.19019664e+30, 8.93364658e-07]]),
 array([[ 236.8992711 ,   54.33566478, -213.38751544,  -93.19240876,
           28.88081906,   15.03430508, -106.36226672,  -85.43174572,
         -120.72932856, -107.66337979],
        [  66.16467866,   23.24345222, -252.11803762,  -46.23972901,
           56.97321086,  -10.49094126, -113.48000884,  -50.84490704,
          266.52380005, -115.92462659]]),
 array([[0.23689927, 0.05433566, 0.21338752, 0.09319241, 0.02888082,
         0.01503431, 0.10636227, 0.08543175, 0.12072933, 0.10766338],
        [0.06616468, 0.02324345, 0.25211804, 0.04623973, 0.05697321,
         0.01049094, 0.11348001, 0.05084491, 0.2665238 ,

In [43]:
lambd = np.log(rs.gamma(1., 1./5., size=(K, V)))
gamma = np.log(rs.gamma(1., 1./5., size=(N, K)))
G_lambd = np.zeros_like(lambd)
G_gamma = np.zeros_like(gamma)

beta, theta = sample_params((lambd, gamma))
log_p_beta = 0.0
for k in range(K):
    beta_k = beta[k]
    log_p_beta += loggamma(V * eta0) - V * loggamma(eta0) + (eta0 - 1) * np.sum(np.log(beta_k))
log_p_theta = 0.0
for i in range(N):
    theta_i = theta[i]
    log_p_theta += loggamma(K * alpha0) - K * loggamma(alpha0) + (alpha0 - 1) * np.sum(np.log(theta_i))
log_p_x = 0.0
for i, x_i in enumerate(x):
    for x_ij in x_i:
        log_p_x += np.log(np.sum(theta[i] * beta[:, x_ij]))
log_q_beta = 0.0
for k in range(K):
    lambd_k = np.exp(lambd[k])
    beta_k = beta[k]
    log_q_beta += loggamma(np.sum(lambd_k)) - np.sum(loggamma(lambd_k)) + np.sum((lambd_k - 1) * np.log(beta_k))
log_q_theta = 0.0
for i in range(N):
    gamma_i = np.exp(gamma[i])
    theta_i = theta[i]
    log_q_theta += loggamma(np.sum(gamma_i)) - np.sum(loggamma(gamma_i)) + np.sum((gamma_i - 1) * np.log(theta_i))

log_p = (log_p_beta + log_p_theta + log_p_x)
log_q = (log_q_beta + log_q_theta)
score_lambd = np.zeros_like(lambd)
for k in range(K):
    lambd_k = np.exp(lambd[k])
    beta_k = beta[k]
    digam_sum_lambd_k = digamma(np.sum(lambd_k))
    for v in range(V):
        score_lambd[k, v] = lambd_k[v] * ( digam_sum_lambd_k - digamma(lambd_k[v]) + np.log(beta_k[v]) )
score_gamma = np.zeros_like(gamma)
for i in range(N):
    gamma_i = np.exp(gamma[i])
    theta_i = theta[i]
    digam_sum_gamma_i = digamma(np.sum(gamma_i))
    for k in range(K):
        score_gamma[i, k] = gamma_i[k] * ( digam_sum_gamma_i - digamma(gamma_i[k]) + np.log(gamma_i[k]) )
stoch_score_lambd = score_lambd * (log_p - log_q)
stoch_score_gamma = score_gamma * (log_p - log_q)
G_lambd += np.square(stoch_score_lambd)
rho_lambd = lr_init / (G_lambd + eps)**(-0.5)
lambd = lambd + rho_lambd * stoch_score_lambd
lambd, stoch_score_lambd, rho_lambd

(array([[ 1.11206666e+04, -1.44586752e+03, -7.48153502e+02,
         -1.41260245e+02, -5.34881039e+01,  1.76485466e+00,
          5.85850063e+02,  6.49442167e+02, -1.80389301e+02,
          2.71341985e-01],
        [ 4.04755944e+02, -4.28045336e+02, -1.85844975e+02,
         -8.84703726e+02, -6.50546851e+02, -1.01032641e+03,
          4.00166533e+01,  2.91441018e+03,  5.13105989e+03,
          5.01019778e+02]]),
 array([[ 3335.20755651, -1202.22741002,  -864.63095286,  -374.33169354,
          -227.75890606,    51.42886826,   767.1496872 ,   806.48309001,
          -422.77787901,    78.73293313],
        [  638.13591099,  -651.61752759,  -430.84082526,  -939.50238497,
          -805.33431856, -1004.35897705,   205.92006648,  1708.19653726,
          2265.41897963,   710.06149571]]),
 array([[3.33520756, 1.20222741, 0.86463095, 0.37433169, 0.22775891,
         0.05142887, 0.76714969, 0.80648309, 0.42277788, 0.07873293],
        [0.63813591, 0.65161753, 0.43084083, 0.93950238, 0.80533432

In [52]:
V = 100
K = 3
N = 10
Ms = np.array([55, 50, 65])
eta0 = 1 / K
alpha0 = 1 / K
lr_init = 1e-3
eps = 1e-8

def generate_lda(N, V, K, Ms, eta0=(1/K), alpha0=(1/K), rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    eta0_vec = np.ones(V) * eta0
    alpha0_vec = np.ones(K) * alpha0
    beta = rs.dirichlet(eta0_vec, size=K)
    theta = rs.dirichlet(alpha0_vec, size=N)
    x = []
    for i, M_i in enumerate(Ms):
        z_i = rs.multinomial(1, theta[i], size=M_i).argmax(axis=1)
        x_i = np.array([rs.multinomial(1, beta[z_ij]).argmax() for z_ij in z_i])
        x.append(x_i)
    return x
x = generate_lda(N=N, V=V, K=K, Ms=Ms, eta0=eta0, alpha0=alpha0, rs_int=0)

def init_var_params(K, V, N, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    log_lambd = np.log(rs.uniform(low=0.25, high=0.75, size=(K, V)))
    log_gamma = np.log(np.ones(shape=(N, K)))
    return log_lambd, log_gamma
log_lambd, log_gamma = init_var_params(K=K, V=V, N=N, rs_int=0)
G_lambd = np.zeros_like(log_lambd)
G_gamma = np.zeros_like(log_gamma)

def sample_params(var_params, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    log_lambd, log_gamma = var_params
    beta = np.zeros_like(log_lambd)
    theta = np.zeros_like(log_gamma)
    for k in range(K):
        beta[k] = rs.dirichlet(np.exp(log_lambd[k]))
    for i in range(N):
        theta[i] = rs.dirichlet(np.exp(log_gamma[i]))
    return beta, theta

def log_joint_prob(latent_params, x):
    beta, theta = latent_params
    log_p_beta = 0.0
    for k in range(K):
        beta_k = beta[k]
        log_p_beta += loggamma(V * eta0) - V * loggamma(eta0) + (eta0 - 1) * np.sum(np.log(beta_k))
    log_p_theta = 0.0
    for i in range(N):
        theta_i = theta[i]
        log_p_theta += loggamma(K * alpha0) - K * loggamma(alpha0) + (alpha0 - 1) * np.sum(np.log(theta_i))
    log_p_x = 0.0
    for i, x_i in enumerate(x):
        for x_ij in x_i:
            log_p_x += np.log(np.sum(theta[i] * beta[:, x_ij]))
    return log_p_beta + log_p_theta + log_p_x

def log_var_approx(var_params, latent_params):
    log_lambd, log_gamma = var_params
    beta, theta = latent_params
    log_q_beta = 0.0
    for k in range(K):
        lambd_k = np.exp(log_lambd[k])
        beta_k = beta[k]
        log_q_beta += loggamma(np.sum(lambd_k)) - np.sum(loggamma(lambd_k)) + np.sum((lambd_k - 1) * np.log(beta_k))
    log_q_theta = 0.0
    for i in range(N):
        gamma_i = np.exp(log_gamma[i])
        theta_i = theta[i]
        log_q_theta += loggamma(np.sum(gamma_i)) - np.sum(loggamma(gamma_i)) + np.sum((gamma_i - 1) * np.log(theta_i))
    return log_q_beta + log_q_theta

S = 10
for t in range(100000):
    betas, thetas = [], []
    for _ in range(S):
        beta_s, theta_s = sample_params((log_lambd, log_gamma))
        betas.append(beta_s)
        thetas.append(theta_s)

    stoch_score_lambd = np.zeros_like(log_lambd)
    stoch_score_gamma = np.zeros_like(log_gamma)
    for s in tqdm(range(S)):
        beta, theta = betas[s], thetas[s]
        log_p = log_joint_prob((beta, theta), x)
        log_q = log_var_approx((log_lambd, log_gamma), (beta, theta))

        score_lambd = np.zeros_like(log_lambd)
        for k in range(K):
            lambd_k = np.exp(log_lambd[k])
            beta_k = beta[k]
            digam_sum_lambd_k = digamma(np.sum(lambd_k))
            for v in range(V):
                score_lambd[k, v] = lambd_k[v] * ( digam_sum_lambd_k - digamma(lambd_k[v]) + np.log(beta_k[v]) )

        score_gamma = np.zeros_like(log_gamma)
        for i in range(N):
            gamma_i = np.exp(log_gamma[i])
            theta_i = theta[i]
            digam_sum_gamma_i = digamma(np.sum(gamma_i))
            for k in range(K):
                score_gamma[i, k] = gamma_i[k] * ( digam_sum_gamma_i - digamma(gamma_i[k]) + np.log(theta_i[k]) )
        stoch_score_lambd += score_lambd * (log_p - log_q)
        stoch_score_gamma += score_gamma * (log_p - log_q)
    stoch_score_lambd /= S
    stoch_score_gamma /= S

    G_lambd += stoch_score_lambd**2
    rho_lambd = lr_init / (np.sqrt(G_lambd) + eps)
    log_lambd = log_lambd + rho_lambd * stoch_score_lambd

    G_gamma += stoch_score_gamma**2
    rho_gamma = lr_init / (np.sqrt(G_gamma) + eps)
    log_gamma = log_gamma + rho_gamma * stoch_score_gamma

100%|██████████| 10/10 [00:00<00:00, 392.40it/s]
100%|██████████| 10/10 [00:00<00:00, 480.01it/s]
100%|██████████| 10/10 [00:00<00:00, 492.97it/s]
100%|██████████| 10/10 [00:00<00:00, 477.10it/s]
100%|██████████| 10/10 [00:00<00:00, 541.53it/s]
100%|██████████| 10/10 [00:00<00:00, 460.88it/s]
100%|██████████| 10/10 [00:00<00:00, 430.48it/s]
100%|██████████| 10/10 [00:00<00:00, 510.03it/s]
100%|██████████| 10/10 [00:00<00:00, 539.04it/s]
100%|██████████| 10/10 [00:00<00:00, 518.73it/s]
100%|██████████| 10/10 [00:00<00:00, 527.13it/s]
100%|██████████| 10/10 [00:00<00:00, 537.15it/s]
100%|██████████| 10/10 [00:00<00:00, 489.10it/s]
100%|██████████| 10/10 [00:00<00:00, 443.28it/s]
100%|██████████| 10/10 [00:00<00:00, 476.25it/s]
100%|██████████| 10/10 [00:00<00:00, 430.10it/s]
100%|██████████| 10/10 [00:00<00:00, 466.54it/s]
100%|██████████| 10/10 [00:00<00:00, 445.40it/s]
100%|██████████| 10/10 [00:00<00:00, 464.98it/s]
100%|██████████| 10/10 [00:00<00:00, 443.57it/s]
100%|██████████| 10/

In [628]:
mode(np.exp(log_lambd), axis=1)[0], mode(np.exp(log_gamma), axis=1)[0]

(array([0.06729349, 0.08094639, 0.09667891]),
 array([0.93326689, 0.62713365, 0.45124123, 0.60129675, 0.4641892 ,
        0.39540843, 0.57982581, 0.46793802, 0.64269007, 0.86167202]))

In [122]:
def init_var_params(X, K, V, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    N = len(X)
    log_lambd = np.log(rs.uniform(low=0.3, high=1.0, size=(K, V)))
    log_gamma = np.log(rs.uniform(low=0.3, high=1.0, size=(N, K)))
    return log_lambd, log_gamma

rs = npr.RandomState(0)
K, V, N = 10, 300, 30
eta0, alpha0 = 1 / K, 1 / K
Ms = rs.poisson(60, size=N)
S = 100
lr = 1e-4
eps = 1e-6

X = generate_lda(K, V, N, Ms, eta0=eta0, alpha0=alpha0)
lambd, gamma = init_var_params(X, K, V)
G_lambd, G_gamma = np.zeros_like(lambd), np.zeros_like(gamma)
max_iters = 21

for t in range(max_iters):
    betas = []
    thetas = []
    for _ in range(S):
        beta_s, theta_s = sample_params((lambd, gamma))
        betas.append(beta_s)
        thetas.append(theta_s)

    stoch_score_grad_lambd = np.zeros_like(lambd)
    stoch_score_grad_gamma = np.zeros_like(gamma)
    for s in tqdm(range(S), desc=f"Iteration {t} | Calculating stochastic score gradient"):
        score_lambd, score_gamma = score_var_dist((lambd, gamma), (betas[s], thetas[s]))
        log_p = log_joint_prob((betas[s], thetas[s]), X)
        log_q = log_var_approx((lambd, gamma), (betas[s], thetas[s]))
        stoch_score_grad_lambd += score_lambd * (log_p - log_q)
        stoch_score_grad_gamma += score_gamma * (log_p - log_q)
    stoch_score_grad_lambd /= S
    stoch_score_grad_gamma /= S

    stoch_score_grad_lambd = np.clip(stoch_score_grad_lambd, -100, 100)
    stoch_score_grad_gamma = np.clip(stoch_score_grad_gamma, -100, 100)

    G_lambd += np.power(stoch_score_grad_lambd, 2)
    lambd = lambd + np.multiply((lr / np.power(G_lambd, 0.5)), stoch_score_grad_lambd)

    G_gamma += np.power(stoch_score_grad_gamma, 2)
    gamma = gamma + np.multiply((lr / np.power(G_gamma, 0.5)), stoch_score_grad_gamma)

    if t % 5 == 0:
        print(f"\nMC ELBO: {estimate_ELBO((lambd, gamma), X, S)}\n")

Iteration 0 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 186.61it/s]



MC ELBO: -13229.900704645395



Iteration 1 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 226.77it/s]
Iteration 2 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 246.59it/s]
Iteration 3 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 244.05it/s]
Iteration 4 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 233.75it/s]
Iteration 5 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 216.58it/s]



MC ELBO: -13238.28590415385



Iteration 6 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 236.33it/s]
Iteration 7 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 237.80it/s]
Iteration 8 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 236.84it/s]
Iteration 9 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 216.20it/s]
Iteration 10 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 244.26it/s]



MC ELBO: -13227.803020040255



Iteration 11 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 239.71it/s]
Iteration 12 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 247.90it/s]
Iteration 13 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 236.14it/s]
Iteration 14 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 221.52it/s]
Iteration 15 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 242.85it/s]



MC ELBO: -13227.100151368619



Iteration 16 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 241.89it/s]
Iteration 17 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 242.66it/s]
Iteration 18 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 235.29it/s]
Iteration 19 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 243.60it/s]
Iteration 20 | Calculating stochastic score gradient: 100%|██████████| 100/100 [00:00<00:00, 246.78it/s]



MC ELBO: -13230.373878964008



In [3]:
def digamma_vec(x):
    return digamma(x)

def log_joint_prob(X, zeta, theta, K, V, eta0, alpha0):
    N = len(X)
    log_prob = 0.0

    for k in range(K):
        log_prob += dir.logpdf(zeta[k], np.ones(V)*eta0)
    
    for i in range(N):
        log_prob += dir.logpdf(theta[i], np.ones(K)*alpha0)
        for j in range(len(X[i])):
            word_idx = X[i][j]
            word_probs = np.zeros(K)
            for k in range(K):
                word_probs[k] = theta[i][k] * zeta[k][word_idx]
            log_prob += np.log(np.sum(word_probs))
    return log_prob

def log_q_density(zeta, theta, lambd, gamma):
    K, V = lambd.shape
    N = len(gamma)
    log_q = 0.0

    for k in range(K):
        log_q += dir.logpdf(zeta[k], lambd[k])
    
    for i in range(N):
        log_q += dir.logpdf(theta[i], gamma[i])
    return log_q