Skip to content

Commit

Permalink
adding missing update equations, removing uneeded dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
funnell committed Dec 3, 2018
1 parent 6dd6266 commit 3c92669
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/ILDA.jl
Expand Up @@ -321,13 +321,16 @@ function fit_heldout(Xheldout::Vector{Matrix{Int}}, model::ILDA;

heldout_model = ILDA(model.K, model.α, model.η, model.features, Xheldout)
heldout_model.λ = deepcopy(model.λ)
heldout_model.β = deepcopy(model.β)
heldout_model.Elnβ = deepcopy(model.Elnβ)

ll = Float64[]
for iter in 1:maxiter
update_γ!(heldout_model)
update_ϕ!(heldout_model)

update_θ!(heldout_model)

push!(ll, calculate_loglikelihood(Xheldout, heldout_model))

if verbose
Expand Down
3 changes: 3 additions & 0 deletions src/LDA.jl
Expand Up @@ -267,13 +267,16 @@ function fit_heldout(Xheldout::Vector{Matrix{Int}}, model::LDA;

heldout_model = LDA(model.K, model.α, model.η, Xheldout)
heldout_model.λ = deepcopy(model.λ)
heldout_model.β = deepcopy(model.β)
heldout_model.Elnβ = deepcopy(model.Elnβ)

ll = Float64[]
for iter in 1:maxiter
update_γ!(heldout_model)
update_ϕ!(heldout_model)

update_θ!(heldout_model)

push!(ll, calculate_loglikelihood(Xheldout, heldout_model))

if verbose
Expand Down
15 changes: 9 additions & 6 deletions src/MMCTM.jl
Expand Up @@ -181,11 +181,11 @@ end

function update_θ!(model::MMCTM, d::Int)
offset = 0
for m in 1:model.M
for w in 1:size(model.X[d][m], 1)
@inbounds for m in 1:model.M
@inbounds for w in 1:size(model.X[d][m], 1)
v = model.X[d][m][w, 1]

for k in 1:model.K[m]
@inbounds for k in 1:model.K[m]
model.θ[d][m][k, w] = exp(
model.λ[d][offset + k] + model.Elnϕ[m][k][v]
)
Expand Down Expand Up @@ -453,7 +453,8 @@ function fitdoc!(model::MMCTM, d::Int)
update_λ!(model, d)
end

function fit!(model::MMCTM; maxiter=100, tol=1e-4, verbose=true, autoα=false)
function fit!(model::MMCTM; maxiter=100, tol=1e-4, verbose=true, autoα=false,
updateΣ=true)
ll = Vector{Float64}[]

for iter in 1:maxiter
Expand All @@ -462,7 +463,9 @@ function fit!(model::MMCTM; maxiter=100, tol=1e-4, verbose=true, autoα=false)
end

update_μ!(model)
update_Σ!(model)
if updateΣ
update_Σ!(model)
end
update_γ!(model)
if autoα
update_α!(model)
Expand Down Expand Up @@ -555,7 +558,7 @@ function fit_heldout(Xheldout::Vector{Vector{Matrix{Int}}}, model::MMCTM;
fitdoc!(heldout_model, d)
end

update_props!(model)
update_props!(heldout_model)

push!(ll, calculate_loglikelihoods(heldout_model))

Expand Down
1 change: 0 additions & 1 deletion src/MultiModalMuSig.jl
Expand Up @@ -2,7 +2,6 @@ module MultiModalMuSig

using Distributions
using NLopt
using Clustering
using StatsFuns

export IMMCTM, MMCTM, ILDA, LDA, fit!
Expand Down

0 comments on commit 3c92669

Please sign in to comment.