In [10]:
using DataFrames, Distributions, DataFramesMeta

In [11]:
function gen_data(
    ; 
    n = 500000, 
    mu = [0, 2], 
    sigma = 1, 
    pZ1 = 0.8)
    
    data = DataFrame(
        Obs_ID = 1:n,
        Z = (rand(Uniform(0,1), n) .> pZ1) .+ 1
        )

    @transform!(data, :Y = rand(Normal(0, sigma), n) + mu[:Z])    
    @select!(data, :Obs_ID, :Y)
    @transform!(data, :p_Y_given_Z_1 = pdf.(Normal(mu[1], sigma), :Y))
    @transform!(data, :p_Y_given_Z_2 = pdf.(Normal(mu[2], sigma), :Y))

    return data
  end

gen_data (generic function with 1 method)

In [12]:
data = gen_data(n = 500000, pZ1 = 0.8);
first(data, 6)

Unnamed: 0_level_0,Obs_ID,Y,p_Y_given_Z_1,p_Y_given_Z_2
Unnamed: 0_level_1,Int64,Float64,Float64,Float64
1,1,-0.909342,0.263846,0.00579317
2,2,-0.391804,0.369467,0.0228386
3,3,0.564351,0.340213,0.142348
4,4,0.104665,0.396763,0.0661992
5,5,-0.350631,0.375157,0.0251809
6,6,0.732426,0.305086,0.178653


In [13]:
function fit_model!(
    data; 
    pi_hat_0 = 0.5, 
    tolerance = 0.0001,
    max_iterations = 1000,
    progress = DataFrame(
        iter = 1:(max_iterations+1), 
        pi_hat = Vector{Float64}(undef, max_iterations+1), 
        ll = Vector{Float64}(undef, max_iterations+1), 
        ll_diff = Vector{Float64}(undef, max_iterations+1)
        )
    )

    pi_hat = pi_hat_0
    E_step!(data, pi_hat)
    ll = loglik(data)
    progress[1,:] = (0, pi_hat, ll, NaN)
    
    last_iter = 0
    for i in 1:max_iterations
        pi_hat = M_step(data)
        E_step!(data, pi_hat)
        
        ll_old = ll
        ll = loglik(data)
        ll_diff = ll - ll_old
        progress[i+1,:] = (i, pi_hat, ll, ll_diff)

        if ll_diff < tolerance
            last_iter = i
            break
        end
    end
    return progress[1:(last_iter + 1), :]
end

fit_model! (generic function with 1 method)

In [14]:
function E_step!(data, pi_hat)
    @transform!(data, :pY_Z1 = :p_Y_given_Z_1 .* pi_hat)
    @transform!(data, :pY_Z2 = :p_Y_given_Z_2 .* (1- pi_hat))
    @transform!(data, :pY = :pY_Z1 + :pY_Z2)
    @transform!(data, :pZ1_given_Y = :pY_Z1 ./ :pY)
end

E_step! (generic function with 1 method)

In [15]:
function M_step(data)
    mean(data[!, :pZ1_given_Y])
end

M_step (generic function with 1 method)

In [16]:
function loglik(data)
    sum(log.(data[!, :pY]))
end

loglik (generic function with 1 method)

In [17]:
@time progress = fit_model!(data, tolerance = .00001);


  0.456945 seconds (687.27 k allocations: 398.864 MiB, 46.30% gc time, 29.05% compilation time)


The first run requires compilation; on subsequent calls it will be faster:

In [18]:
@time progress = fit_model!(data, tolerance = .00001)

  0.124853 seconds (12.99 k allocations: 363.058 MiB, 13.93% gc time)


Unnamed: 0_level_0,iter,pi_hat,ll,ll_diff
Unnamed: 0_level_1,Int64,Float64,Float64,Float64
1,0,0.5,-878417.0,
2,1,0.665171,-837972.0,40445.4
3,2,0.73717,-829089.0,8882.77
4,3,0.769413,-827044.0,2045.07
5,4,0.784626,-826549.0,495.184
6,5,0.792058,-826425.0,123.725
7,6,0.795761,-826393.0,31.4526
8,7,0.797624,-826385.0,8.06886
9,8,0.798568,-826383.0,2.07977
10,9,0.799046,-826383.0,0.537366
