In [2]:
using Turing, Turing.RandomMeasures
using Plots, StatsPlots
using Statistics, Random, LinearAlgebra
using MCMCChains

Define the LDA model.

In [10]:
@model function LDA(w, K, D)
    # K: number of topics
    # D: number of words
    # M = number of documents
    # this gets the length of the 1st dimension of array w
    M = size(w, 1)
    # N = number of words per document
    # length of the 2nd dimension of w
    N = size(w, 2)

    # topic distributions
    # A Vector of Vectors, size M, each initialized to undef
    # Each inner vector will have K entries that add up to 1.
    θ = Vector{Vector}(undef, M)
    α = 1.0
    for m = 1:M
        θ[m] ~ Dirichlet(K, α)
    end
    # println("theta:")
    # println(θ)

    # word distributions (for each topic)
    ψ = Vector{Vector}(undef, K)
    η = 0.01
    for k = 1:K
        ψ[k] ~ Dirichlet(D, η)
    end

    # println("ψ (word distributions for each topic):")
    # println(ψ)

    # one entry in outer vec per doc
    # the ints represent the topic assignment of each word
    z = Vector{Vector{Int}}(undef, M)

    for m = 1:M
        # in each doc, initialize each word's topic as 0 (I think)
        z[m] = zeros(Int, N)
        for n = 1:N
            # select topic for word n in document m
            # draw from the topic distribution for that doc
            z[m][n] ~ Categorical(θ[m])
            # select symbol for word n in document m from topic z[m][n]
            # draw from the word distribution for that topic
            w[m,n] ~ Categorical(ψ[z[m][n]])
        end
    end
    # println("z:")
    # println(z)
    # println("w:")
    # println(w)
    return w
end


LDA (generic function with 4 methods)

In [11]:
# number of docs
M = 2
# number words per doc
N = 10
# number of topics
K = 4
# number of words in corpus
D = 20

20

TODO: Import the data. It should be an M x N matrix, where each entry is an int representing a word.

In [12]:
condition_data = [1 2 3 1 2 9 1 2 3 1; 12 3 4 5 12 3 4 15 2 3]
condition_data

2×10 Matrix{Int64}:
  1  2  3  1   2  9  1   2  3  1
 12  3  4  5  12  3  4  15  2  3

Condition the model with the provided documents.

In [13]:
conditioned_LDA = LDA(condition_data, K, D)

DynamicPPL.Model{typeof(LDA), (:w, :K, :D), (), (), Tuple{Matrix{Int64}, Int64, Int64}, Tuple{}, DynamicPPL.DefaultContext}(LDA, (w = [1 2 … 3 1; 12 3 … 2 3], K = 4, D = 20), NamedTuple(), DynamicPPL.DefaultContext())

Sample the model. It currently uses a Sequential Monte Carlo (SMC) sampler, but it can also be configured to use importance sampling (IS), Metropolis Hastings (MH), or Particle Gibbs (PG). It can also combine multiple samplers so one is used for the discrete variables and a different one is used for the continuous variables, such as Hamiltonian Markov Chain (HMC) or the No U-Turn Sampler (NUTS).

In [14]:
cond_chain = sample(conditioned_LDA, SMC(), 1000)

Chains MCMC chain (1000×110×1 Array{Float64, 3}):

Log evidence      = -179.75136502367724
Iterations        = 1:1:1000
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 5.94 seconds
Compute duration  = 5.94 seconds
parameters        = θ[1][1], θ[1][2], θ[1][3], θ[1][4], θ[2][1], θ[2][2], θ[2][3], θ[2][4], ψ[1][1], ψ[1][2], ψ[1][3], ψ[1][4], ψ[1][5], ψ[1][6], ψ[1][7], ψ[1][8], ψ[1][9], ψ[1][10], ψ[1][11], ψ[1][12], ψ[1][13], ψ[1][14], ψ[1][15], ψ[1][16], ψ[1][17], ψ[1][18], ψ[1][19], ψ[1][20], ψ[2][1], ψ[2][2], ψ[2][3], ψ[2][4], ψ[2][5], ψ[2][6], ψ[2][7], ψ[2][8], ψ[2][9], ψ[2][10], ψ[2][11], ψ[2][12], ψ[2][13], ψ[2][14], ψ[2][15], ψ[2][16], ψ[2][17], ψ[2][18], ψ[2][19], ψ[2][20], ψ[3][1], ψ[3][2], ψ[3][3], ψ[3][4], ψ[3][5], ψ[3][6], ψ[3][7], ψ[3][8], ψ[3][9], ψ[3][10], ψ[3][11], ψ[3][12], ψ[3][13], ψ[3][14], ψ[3][15], ψ[3][16], ψ[3][17], ψ[3][18], ψ[3][19], ψ[3][20], ψ[4][1], ψ[4][2], ψ[4][3], ψ[4][4], ψ[4][5], ψ[4][6], ψ[4][7], ψ[4][8], ψ[4][9], ψ[4][10], ψ[4][11], ψ

This represents the word distribution of each topic.

In [15]:
topic_word_dists = Vector{Vector{Float64}}(undef, K)
for j = 1:K
    topic_word_dists[j] = [mean(cond_chain, "ψ[$j][$i]") for i in 1:D]
end
topic_word_dists

4-element Vector{Vector{Float64}}:
 [8.339945352934453e-67, 1.6207337651679205e-6, 1.8417237966967173e-84, 6.839910920634829e-125, 1.147373085156585e-10, 7.0902000101806676e-43, 2.276688720631089e-9, 1.4622788041341917e-52, 0.4931699580797472, 2.409253584478786e-29, 9.758457185140399e-9, 3.4436617200581426e-86, 8.154700304200247e-22, 3.824201656689989e-46, 4.582575090813536e-14, 4.5263650236723275e-57, 1.1776059229366975e-63, 0.06513525946925518, 0.44169131488971214, 1.8346775919086609e-6]
 [0.9544412231845386, 5.697446433965199e-13, 1.0855917988602781e-14, 1.1307077064669073e-16, 2.5624065185083515e-65, 1.378185610179918e-133, 1.2147668870925905e-11, 1.1926722689074882e-18, 2.603744383525184e-17, 3.2618526668577096e-29, 1.2314821454624585e-7, 0.04555865365349639, 0.0, 7.321726490934666e-101, 1.0220123270515363e-12, 6.338887950081083e-34, 5.565970761555031e-35, 3.142418481910138e-36, 1.4309558167761946e-65, 4.302778977929825e-20]
 [1.802149674953276e-21, 3.229906076659316e-10, 0.375912

Query the distribution of topics in each document.

In [21]:
document_topic_distributions = Vector{Vector{Float64}}(undef, M)
for j = 1:M
    document_topic_distributions[j] = [mean(cond_chain, "θ[$j][$i]") for i in 1:K]
end
document_topic_distributions

2-element Vector{Vector{Float64}}:
 [0.09295633367114257, 0.5804187860456147, 0.09753235199927433, 0.22909252828396762]
 [0.1693518229234748, 0.7405616591578033, 0.06317470671035684, 0.026911811208365922]

Get the highest probability topic for each movie.

In [27]:
highest_prob_topic_per_movie = Vector{Tuple{Int, Float64}}(undef, M)
for doc = 1:M
    max_prob = 0.0
    max_ind = 0
    for topic = 1:K
        if document_topic_distributions[doc][topic] > max_prob
            max_prob = document_topic_distributions[doc][topic]
            max_ind = topic
        end
    end
    highest_prob_topic_per_movie[doc] = (max_ind, max_prob) 
end
highest_prob_topic_per_movie

2-element Vector{Tuple{Int64, Float64}}:
 (2, 0.5804187860456147)
 (2, 0.7405616591578033)