/
classic_em.jl
152 lines (123 loc) · 5.12 KB
/
classic_em.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
ClassicEM<:AbstractEM
The EM algorithm was introduced by A. P. Dempster, N. M. Laird and D. B. Rubin in 1977 in the reference paper [*Maximum Likelihood from Incomplete Data Via the EM Algorithm*](https://rss.onlinelibrary.wiley.com/doi/abs/10.1111/j.2517-6161.1977.tb01600.x).
"""
struct ClassicEM <: AbstractEM end
"""
fit_mle!(α::AbstractVector, dists::AbstractVector{F} where {F<:Distribution}, y::AbstractVecOrMat, method::ClassicEM; display=:none, maxiter=1000, atol=1e-3, rtol=nothing, robust=false)
Use the EM algorithm to update the Distribution `dists` and weights `α` composing a mixture distribution.
- `robust = true` will prevent the (log)likelihood to overflow to `-∞` or `∞`.
- `atol` criteria determining the convergence of the algorithm. If the Loglikelihood difference between two iteration `i` and `i+1` is smaller than `atol` i.e. `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<atol`, the algorithm stops.
- `rtol` relative tolerance for convergence, `|ℓ⁽ⁱ⁺¹⁾ - ℓ⁽ⁱ⁾|<rtol*(|ℓ⁽ⁱ⁺¹⁾| + |ℓ⁽ⁱ⁾|)/2` (does not check if `rtol` is `nothing`)
- `display` value can be `:none`, `:iter`, `:final` to display Loglikelihood evolution at each iterations `:iter` or just the final one `:final`
"""
function fit_mle!(
α::AbstractVector,
dists::AbstractVector{F} where {F<:Distribution},
y::AbstractVecOrMat,
method::ClassicEM;
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
)
@argcheck display in [:none, :iter, :final]
@argcheck maxiter >= 0
N, K = size_sample(y), length(dists)
history = Dict("converged" => false, "iterations" => 0, "logtots" => zeros(0))
# Allocate memory for in-place updates
LL = zeros(N, K)
γ = similar(LL)
c = zeros(N)
# E-step
E_step!(LL, c, γ, dists, α, y; robust = robust)
# Loglikelihood
logtot = sum(c)
(display == :iter) && println("Method = $(method)\nIteration 0: Loglikelihood = ", logtot)
for it = 1:maxiter
# M-step
# using γ, maximize (update) the parameters
α[:] = mean(γ, dims = 1)
dists[:] = [fit_mle(dists[k], y, γ[:, k]) for k = 1:K]
# E-step
# evaluate likelihood for each type k
E_step!(LL, c, γ, dists, α, y; robust = robust)
# Loglikelihood
logtotp = sum(c)
(display == :iter) && println("Iteration $(it): loglikelihood = ", logtotp)
push!(history["logtots"], logtotp)
history["iterations"] += 1
if abs(logtotp - logtot) < atol || (rtol !== nothing && abs(logtotp - logtot) < rtol * (abs(logtot) + abs(logtotp)) / 2)
(display in [:iter, :final]) &&
println("EM converged in ", it, " iterations, final loglikelihood = ", logtotp)
history["converged"] = true
break
end
logtot = logtotp
end
if !history["converged"]
if display in [:iter, :final]
println(
"EM has not converged after $(history["iterations"]) iterations, final loglikelihood = $logtot",
)
end
end
return history
end
function fit_mle!(
α::AbstractVector,
dists::AbstractVector{F} where {F<:Distribution},
y::AbstractVecOrMat,
w::AbstractVector,
method::ClassicEM;
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
)
@argcheck display in [:none, :iter, :final]
@argcheck maxiter >= 0
N, K = size_sample(y), length(dists)
@argcheck length(w) == N
history = Dict("converged" => false, "iterations" => 0, "logtots" => zeros(0))
# Allocate memory for in-place updates
LL = zeros(N, K)
γ = similar(LL)
c = zeros(N)
# E-step
E_step!(LL, c, γ, dists, α, y; robust = robust)
# Loglikelihood
logtot = sum(w[n] * c[n] for n = 1:N) #dot(w, c)
(display == :iter) && println("Method = $(method)\nIteration 0: Loglikelihood = ", logtot)
for it = 1:maxiter
# M-step
# with γ in hand, maximize (update) the parameters
α[:] = mean(γ, weights(w), dims = 1)
dists[:] = [fit_mle(dists[k], y, w[:] .* γ[:, k]) for k = 1:K]
# E-step
# evaluate likelihood for each type k
E_step!(LL, c, γ, dists, α, y; robust = robust)
# Loglikelihood
logtotp = sum(w[n] * c[n] for n in eachindex(c)) #dot(w, c)
(display == :iter) && println("Iteration $(it): loglikelihood = ", logtotp)
push!(history["logtots"], logtotp)
history["iterations"] += 1
if abs(logtotp - logtot) < atol || (rtol !== nothing && abs(logtotp - logtot) < rtol * (abs(logtot) + abs(logtotp)) / 2)
(display in [:iter, :final]) &&
println("EM converged in ", it, " iterations, final loglikelihood = ", logtotp)
history["converged"] = true
break
end
logtot = logtotp
end
if !history["converged"]
if display in [:iter, :final]
println(
"EM has not converged after $(history["iterations"]) iterations, final loglikelihood = $logtot",
)
end
end
return history
end