In [116]:
import autograd.numpy as np
import autograd.numpy.random as npr
import autograd.scipy.stats.dirichlet as dir
from scipy.stats import multinomial as multinom

In [78]:
def simulate_LDA(K, V, N, M, ETA=0.1, ALPHA=0.5, rs_int=npr.randint(low=0, high=100)):
    rs = npr.RandomState(rs_int)
    BETA = rs.dirichlet(np.full(V, ETA), size=K)
    THETA = rs.dirichlet(np.full(K, ALPHA), size=N)
    X = []
    for i in range(N):
        x_i = np.zeros(M[i], dtype=int)
        for j in range(M[i]):
            z_ij = rs.choice(K, p=THETA[i])
            x_ij = rs.choice(V, p=BETA[z_ij])
            x_i[j] = x_ij
        X.append(x_i)
    return X

def init_variational_params(X, K, V):
    N = len(X)
    M = np.array([len(x_i) for x_i in X])
    LAMBDA = npr.uniform(low=0.01, high=1.0, size=(K, V))
    GAMMA = np.ones((N, K))
    PHI = []
    for M_i in M:
        PHI_i = np.ones((M_i, K))
        PHI_i = PHI_i / K
        PHI.append(PHI_i)
    return (LAMBDA, GAMMA, PHI)

def unpack_params(params):
    LAMBDA, GAMMA, PHI = params[0], params[1], params[2]
    return LAMBDA, GAMMA, PHI

def sample_variational_params(variational_params, num_samples=100):
    LAMBDA, GAMMA, PHI = unpack_params(variational_params)
    K, V = LAMBDA.shape
    N = GAMMA.shape[0]
    BETAs = []
    THETAs = []
    Zs = []

    for _ in range(num_samples):
        BETA = np.zeros((K, V))
        for k in range(K):
            BETA[k] = npr.dirichlet(LAMBDA[k])

        THETA = np.zeros((N, K))
        for i in range(N):
            THETA[i] = npr.dirichlet(GAMMA[i])

        Z = []
        for i in range(N):
            M_i = len(PHI[i])
            z_i = np.zeros(M_i, dtype=int)
            for j in range(M_i):
                z_i[j] = npr.choice(K, p=PHI[i][j])
            Z.append(z_i)
        BETAs.append(BETA)
        THETAs.append(THETA)
        Zs.append(Z)
    return BETAs, THETAs, Zs

rs = npr.RandomState(0)
K = 10
V = 500
N = 50
M = rs.poisson(75, size=N)
X = simulate_LDA(K, V, N, M)
init_var_params = init_variational_params(X, K, V)

np.int64(7)

In [None]:
LAMBDA, GAMMA, PHI = unpack_params(init_var_params)
BETA, THETA, Z = sample_variational_params(init_var_params, 1)
BETA = BETA[0]
THETA = THETA[0]
Z = Z[0]

np.sum([dir.logpdf(BETA[k], LAMBDA[k]) for k in range(K)]) + np.sum([dir.logpdf(THETA[i], GAMMA[i]) for i in range(N)])
multinom.logpmf(Z[0], np.sum(Z[0]), p=PHI[0, 0])

TypeError: list indices must be integers or slices, not tuple