In [110]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import dirichlet, norm, uniform
import seaborn as sns
import itertools
sns.set_style("darkgrid")
np.random.seed(seed=1)

### Data and Model

Let $\theta_1,\dots,\theta_K\in\Delta^{V-1}$ denote $K$ topics which define the topic simplex $\mathcal{P}=\text{conv}(\theta_1,\dots,\theta_K)$. For each document $\boldsymbol{w}_i = (X_{i1},\dots,X_{iN})$, we first assume that its admixture proportion of topics is generated as 
\begin{equation}
    \beta_i = (\beta_{ik})_{k=1}^{K} \overset{\text{iid}}{\sim} \text{Dir}_K(\gamma) ,\quad i\in [D].
\end{equation}
For each word $X_{ij}$, the latent variable that decided topic for $X_{ij}$ is modeled as \begin{equation}
    z_{ij}|\beta_i \overset{\text{iid}}{\sim}  \text{Multinomial}(\beta_i),\quad i\in [D], j\in [N].
\end{equation}
Finally
\begin{equation}
    X_{ij}  | z_{ij} = k \sim \text{Multinomial}(\theta_{k}),\quad i\in [D], j \in [N], v\in [V], k\in [K].
\end{equation}
The documents are independent and the parameters of interest are $\theta_1,\dots,\theta_K$ (topics) and $\gamma$ (the mixing parameter), we assume $K$ is fixed and known.

We observe $(X_{ij})_{i\in [D], j\in [N]}$ and want to make inference about $(\theta_k)_{k\in [K]}$ and $(\beta_{i})_{i\in [D]}$.

### Index: 

$D$: Number of documents, $i=1,\dots, D$

$N$: Number of words in each document, $j=1,\dots, N$

$K$: Number of topics, $k = 1,\dots, K$

$V$: Number of words in the vocab, $v=1,\dots, V$


### Prior:
We put prior $\theta_k \sim \text{Dir}(\alpha)$ independently for $k = 1,\dots, K$ and $\gamma \sim U(0, 10)$. 

### Inference of posterior using Gibbs MCMC:

Because of the conjugacy of Dirichlet prior with respect to multinomial likelihood, the posterior is also Dirichlet. 
For each $k \in [K]$
\begin{equation}
    \theta_{k}| (X_{ij}, z_{ij})\sim \text{Dir}(\alpha + n_{k1}, \dots, \alpha + n_{kV}),
\end{equation}
for $n_{kv} = \#\{(i,j): (X_{ij} = v) \& (z_{ij} = k) \}\forall v\in [V]$. Similarly, for each $i\in [D],$
\begin{equation}
    \beta_i \sim \text{Dir}(\gamma + m_{i1}, \dots, \gamma + m_{iK})
\end{equation}
for $m_{ik} = \#\{j : z_{ij} = k\}$. Besides, the update rule for $(z_{ij})$ can be obtained by Bayesian rule
\begin{equation}
    P(z_{ij} = k | X_{ij} = v, (\theta_k), (\beta_i)) = \dfrac{\theta_{kv} \beta_{ik}}{\sum_{k'\in [K]} \theta_{k'v} \beta_{ik'}}
\end{equation}
Finally we can update $\gamma$ conditioning on $(\beta_i)_{i\in [D]}$ using Metropolis-Hasting algorithm.

In [20]:
def gen_prior(D, N, V, K, α=1, u=10.0):
    """
    Generate θ_k \sim Dir_K(alpha) independently and γ \sim U(0, u)
    """
    θ = np.random.dirichlet([α]*V, size=K)
    γ = np.random.uniform(low=0, high=u)
    β = np.zeros((D, K))
    z = np.zeros((D, N))
    for i in range(D):
        β[i] = dirichlet.rvs([γ]*K)[0]
        z[i] = np.random.choice(a = K, size = N, p=β[i])
    return θ, β, z, γ

def compare(θ, θprime, K):
    errs = []
    list_permu = list(itertools.permutations(range(K)))
    for τ in list_permu:
        errs.append(np.linalg.norm(θ - θprime[list(τ)]))
    return min(errs)

In [129]:
def sample_theta(X, z, α):
    """
    Gibbs sampling for θ given X (data: D times N array), z (latent: D times N array), and α (parameter of prior)
    Return θ (topics simplex: K times V array)
    """
    θ = np.zeros((K, V))
    n = np.array([[np.sum(np.logical_and(X==v, z==k)) for v in range(V)] for k in range(K)])  ##dim(n) = K x V
    for k in range(K):
        θ[k] = dirichlet.rvs(n[k] + α)[0]
    
    return θ

def sample_beta(z, γ):
    """
    Gibbs sampling for β given z (latent: D times N array), and γ (parameter of prior)
    Return β (admixture proportion of docs: D times K array)
    """
    β = np.zeros((D, K))
    m = np.array([[(np.sum(z[i]==k)) for k in range(K)] for i in range(D)])  ##dim(m) = D x K
    for i in range(D):
        β[i] = dirichlet.rvs(m[i] + γ)[0]
    return β

def sample_z(θ, β, X):
    """
    Gibbs sampling for z given θ (topics simplex: K times V array), β (admixture proportion of docs: D times K array)
    and X (data: D times N array)
    Return z (latent: D times N array)
    """
    z = np.zeros((D, N))
    for i in range(D):
        for j in range(N):
            prob = [(θ[k, int(X[i, j])] * β[i, k]) for k in range(K)]
            z[i, j] = np.random.choice(a = K, size = 1, p=prob/sum(prob))
    return z

def sample_gamma(γ, β, σ=0.1):
    """
    Sampling γ using M-H proposal with Normal random walk proposal with variance σ
    """
    γ_prime = norm(γ, σ).rvs()
    if γ_prime > 0 and γ_prime < 10:
        r = np.exp(sum(np.log(dirichlet.pdf(β[i], [γ_prime] * K)) for i in range(D))
                  - sum(np.log(dirichlet.pdf(β[i], [γ] * K)) for i in range(D)))
        if (uniform().rvs() < r):
            γ = γ_prime
    print('gamma is {}'.format(γ))  
    return γ

### Toy example: 

We will generate 

$D = 100$ documents

each document has $N = 20$ words 

vocabulary of $V = 30$

number of topic $K = 3$

In [142]:
def generate_data(D, N, V, K, α=1):
    X = np.zeros((D, N))
    θ = np.random.dirichlet([α]*V, size=K)
    β = np.random.dirichlet([0.5]*K, size=D)
    for i in range(D):
        η = β[i] @ θ
        X[i] = np.random.choice(a=V, size=N, p=η)
    return X, θ, β

In [143]:
D=100
N=20
V=10
K=3
X, true_θ, true_β = generate_data(D, N, V, K, α=1)
X.shape

(100, 20)

## Inference

In [144]:

α = 1
θ, β, z, γ = gen_prior(D, N, V, K, α=1, u=10.0)

Make some changes

In [145]:
for _ in range(1000):
    θ = sample_theta(X, z, α)
    β = sample_beta(z, γ)
    z = sample_z(θ, β, X)
    γ = sample_gamma(γ, β, σ=0.1)
    print('error of theta is {}'.format(compare(θ, true_θ, K)))

gamma is 3.0653624138439306
error of theta is 0.29912562256663594
gamma is 3.0366519694431293
error of theta is 0.2743835154419978
gamma is 3.0359044350270423
error of theta is 0.29152208072476365
gamma is 3.0359044350270423
error of theta is 0.2963726359369583
gamma is 3.1496637856245444
error of theta is 0.290995716436802
gamma is 3.187654796156647
error of theta is 0.26231537458530635
gamma is 3.3421772193383448
error of theta is 0.25089940462397814
gamma is 3.3955640842709482
error of theta is 0.23284951086151287
gamma is 3.4503660410715025
error of theta is 0.21170070433573626
gamma is 3.3820475543484543
error of theta is 0.23075021470824195
gamma is 3.4077341861255075
error of theta is 0.20812307308130398
gamma is 3.377129306815989
error of theta is 0.1932488392594845
gamma is 3.335509310591176
error of theta is 0.18903822490731736
gamma is 3.179424764120106
error of theta is 0.1781594726607614
gamma is 3.1376679566761654
error of theta is 0.21345554766849792
gamma is 3.121966833

gamma is 0.9221936938390664
error of theta is 0.1683265330525042
gamma is 0.9600355843361975
error of theta is 0.19536379159741052
gamma is 0.8679103111150316
error of theta is 0.18678997395619792
gamma is 0.9007679610550147
error of theta is 0.1979538997752873
gamma is 0.9007679610550147
error of theta is 0.195010231758175
gamma is 0.9007679610550147
error of theta is 0.18959708795425245
gamma is 0.9331072317732909
error of theta is 0.20087859899080152
gamma is 1.0001151467636846
error of theta is 0.17983626297154928
gamma is 1.0001151467636846
error of theta is 0.1680879016895577
gamma is 1.0001151467636846
error of theta is 0.1912178555497352
gamma is 1.0001151467636846
error of theta is 0.15473592346001075
gamma is 1.0598452749362206
error of theta is 0.16357011103100452
gamma is 1.0598452749362206
error of theta is 0.18056170465010873
gamma is 1.0778743184915853
error of theta is 0.17824568479335565
gamma is 1.0778743184915853
error of theta is 0.18377972213837787
gamma is 1.03303

gamma is 0.8454377183530436
error of theta is 0.1680514525698963
gamma is 0.8454377183530436
error of theta is 0.15951439215380958
gamma is 0.8454377183530436
error of theta is 0.17699247746502822
gamma is 0.805602864468445
error of theta is 0.19223249133276296
gamma is 0.805602864468445
error of theta is 0.15991420862219316
gamma is 0.7858397787570784
error of theta is 0.14855475275622748
gamma is 0.7858397787570784
error of theta is 0.16112593573336834
gamma is 0.8583032062683238
error of theta is 0.14953827989325905
gamma is 0.8396727163967908
error of theta is 0.12899765697533652
gamma is 0.8396727163967908
error of theta is 0.15989632697529577
gamma is 0.8396727163967908
error of theta is 0.14757548120174935
gamma is 0.7975422232938357
error of theta is 0.14384450278441868
gamma is 0.9199532512009891
error of theta is 0.1644876433892357
gamma is 0.8942590787801897
error of theta is 0.15894665947030506
gamma is 1.0160359895930444
error of theta is 0.17353946995503164
gamma is 1.016

gamma is 0.8583298274476078
error of theta is 0.1314629076138344
gamma is 0.8723652385325066
error of theta is 0.1279471872918255
gamma is 0.8723652385325066
error of theta is 0.1357157794077419
gamma is 0.8723652385325066
error of theta is 0.16703854580392172
gamma is 0.8723652385325066
error of theta is 0.13948599744088838
gamma is 0.8723652385325066
error of theta is 0.13733732665272516
gamma is 0.8723652385325066
error of theta is 0.1280485521607554
gamma is 0.8723652385325066
error of theta is 0.1566623186144329
gamma is 0.8723652385325066
error of theta is 0.18916966652506503
gamma is 0.8723652385325066
error of theta is 0.18290213130253333
gamma is 0.9455273220613181
error of theta is 0.15180403188461675
gamma is 0.9888785061481568
error of theta is 0.17042490872254776
gamma is 1.0659192553559165
error of theta is 0.18467911057955047
gamma is 0.9916332600570401
error of theta is 0.17164636381929554
gamma is 0.9916332600570401
error of theta is 0.19081342881972674
gamma is 0.9916

gamma is 1.1977385960801081
error of theta is 0.19896486745313008
gamma is 1.1743022529001934
error of theta is 0.192045708884886
gamma is 1.1670604325193394
error of theta is 0.2152420755836619
gamma is 1.12416496718827
error of theta is 0.2112641294274884
gamma is 1.12416496718827
error of theta is 0.18551234320764073
gamma is 1.12416496718827
error of theta is 0.1919987600332786
gamma is 1.2034473737130655
error of theta is 0.21129489026755485
gamma is 1.2034473737130655
error of theta is 0.18269325480514984
gamma is 1.2034473737130655
error of theta is 0.17910555743656878
gamma is 1.2025289979245304
error of theta is 0.18909770770037357
gamma is 1.2025289979245304
error of theta is 0.17947774350152187
gamma is 1.077454505841576
error of theta is 0.19715522670465926
gamma is 0.9596110108370078
error of theta is 0.19445759192091752
gamma is 0.9596110108370078
error of theta is 0.20667655175086155
gamma is 0.9535832890283115
error of theta is 0.18759726259648063
gamma is 1.09183106368

gamma is 0.9811977945765258
error of theta is 0.16442171166424827
gamma is 1.0253782862990661
error of theta is 0.1645292704980112
gamma is 1.0253782862990661
error of theta is 0.19262242568484478
gamma is 1.1455824211311842
error of theta is 0.19936725492256327
gamma is 1.1455824211311842
error of theta is 0.1859570457265499
gamma is 1.1668644867133782
error of theta is 0.1659244087417564
gamma is 1.1668644867133782
error of theta is 0.16454911177680231
gamma is 1.1668644867133782
error of theta is 0.14518727893336802
gamma is 1.1668644867133782
error of theta is 0.149922639570947
gamma is 1.1668644867133782
error of theta is 0.1919065103901328
gamma is 1.1668644867133782
error of theta is 0.1677237737960385
gamma is 1.1499354091586633
error of theta is 0.21555430399694564
gamma is 1.1499354091586633
error of theta is 0.1942599107406621
gamma is 1.1035010952799322
error of theta is 0.18136685903777383
gamma is 1.1035010952799322
error of theta is 0.1906634106688757
gamma is 1.10350109

gamma is 0.4870346772785171
error of theta is 0.17579951639938696
gamma is 0.4870346772785171
error of theta is 0.19245425764094967
gamma is 0.4870346772785171
error of theta is 0.19185859910913672
gamma is 0.5140081371988796
error of theta is 0.15986294187401132
gamma is 0.4754898435619658
error of theta is 0.17803716509790063
gamma is 0.4754898435619658
error of theta is 0.1650592503709547
gamma is 0.4754898435619658
error of theta is 0.16153548205907428
gamma is 0.48341318763387203
error of theta is 0.13006688480363413
gamma is 0.6287159845265297
error of theta is 0.1494768860755209
gamma is 0.6287159845265297
error of theta is 0.16251672060053157
gamma is 0.6287159845265297
error of theta is 0.1694049485337974
gamma is 0.6287159845265297
error of theta is 0.19770353604603802
gamma is 0.6287159845265297
error of theta is 0.17586916834136507
gamma is 0.6287159845265297
error of theta is 0.15488573203564285
gamma is 0.606950107027508
error of theta is 0.15847701025732758
gamma is 0.60

gamma is 0.4397694699996692
error of theta is 0.1642769366714249
gamma is 0.4397694699996692
error of theta is 0.17812713556303517
gamma is 0.4340099611008655
error of theta is 0.16857975391867136
gamma is 0.4340099611008655
error of theta is 0.14346994992982176
gamma is 0.4340099611008655
error of theta is 0.13564528363325923
gamma is 0.45353703361719017
error of theta is 0.1536842170001625
gamma is 0.45353703361719017
error of theta is 0.12267432494640278
gamma is 0.45353703361719017
error of theta is 0.12515311698372017
gamma is 0.45353703361719017
error of theta is 0.13483285545776474
gamma is 0.5106653121923612
error of theta is 0.1229501193071538
gamma is 0.5449741705105053
error of theta is 0.1251360231577069
gamma is 0.4370841277033167
error of theta is 0.15754491550199115
gamma is 0.44027408428997333
error of theta is 0.14113121806457724
gamma is 0.44027408428997333
error of theta is 0.17487265586041292
gamma is 0.44027408428997333
error of theta is 0.13118898848142993
gamma i

We can see that $\gamma$ is converging to its true value $0.5$ and the error of $\theta$ (convex polytope) is decreasing.