In [None]:
import numpy as np

### Setup

In [None]:
K: int = 2  # Topics are sport and astronomy

vocabulary: dict = {
    'ball': 0,
    'planet': 1,
    'star': 2,
    'tennis': 3,
    'basketball': 4,
    'football': 5,
    'sun': 6,
    'moon': 7,
    'earth': 8,
    'run': 9
}

V: int = len(vocabulary)  # Number of words in the vocabulary

documents: list = [
    ['ball', 'tennis', 'basketball', 'football', 'star'],
    ['planet', 'star', 'sun', 'moon', 'earth', 'ball'],
    ['ball', 'tennis', 'basketball', 'planet', 'football'],
    ['star', 'sun', 'moon', 'earth', 'ball', 'run'],
    ['ball', 'tennis', 'basketball', 'planet', 'earth']
]

M: int = len(documents)  # Number of documents

N: int = sum([len(doc) for doc in documents]) # Total number of words in all documents

alpha = np.array([1, 1])  # Dirichlet prior for document-topic distribution

beta = np.array([0.01 for i in range(len(vocabulary))])  # Dirichlet prior for topic-word distribution

### Algorithm

In [None]:
'''


Algorithm

	1. Choose θ_i  ~ Dir(α), where i∈{1,…,M} and Dir(α) is a Dirichlet distribution with a symetric parameter α which typically is sparse (α<1)

	2. Choose ϕ_k  ~ Dir(β), where k∈{1,…,K} and β typically is sparse

	3. For each of the word positions i,j, where i∈{1,…,M}, and j∈{1,…,N_i }
	
		a. Choose a topic z_(i,j)  ~ Multinomial(θ_i)
		
		b. Choose a topic w_(i,j)  ~ Multinomial(ϕ_(z_(i,j) ))

Note that multinomial distribution here refers to the multinomial with only one trial, which is also known as the categorical distribution.

'''

def lda():
    theta = np.random.dirichlet(alpha, M)
    phi = np.random.dirichlet(beta, K)

    z = np.zeros((M, max([len(doc) for doc in documents])), dtype=int)
    w = np.zeros((M, max([len(doc) for doc in documents])), dtype=int)

    for i, doc in enumerate(documents):
        for j, word in enumerate(doc):
            z[i, j] = np.random.choice(K, p=theta[i])
            w[i, j] = np.random.choice(V, p=phi[z[i, j]])

    return theta, phi, z, w

theta, phi, z, w = lda()

# go through the documents and print the words and their topics
for i, doc in enumerate(documents):
    print(f"Document {i + 1}:")
    for j, word in enumerate(doc):
        print(f"Word: {word}, Topic: {z[i, j]}")
    print()