In [1]:
using DataFrames, Distributions, DataFramesMeta

In [2]:
function gen_data(
    ; 
    n = 500000, 
    mu = [0, 2], 
    sigma = 1, 
    pZ1 = 0.8)
    
    data = DataFrame(
        Obs_ID = 1:n,
        Z = (rand(Uniform(0,1), n) .> pZ1) .+ 1
        )

    @transform!(data, :Y = rand(Normal(0, sigma), n) + mu[:Z])    
    @select!(data, :Obs_ID, :Y)
    @transform!(data, :p_Y_given_Z_1 = pdf.(Normal(mu[1], sigma), :Y))
    @transform!(data, :p_Y_given_Z_2 = pdf.(Normal(mu[2], sigma), :Y))

    return data
  end

gen_data (generic function with 1 method)

In [3]:
data = gen_data(n = 500000, pZ1 = 0.8);
first(data, 6)

Unnamed: 0_level_0,Obs_ID,Y,p_Y_given_Z_1,p_Y_given_Z_2
Unnamed: 0_level_1,Int64,Float64,Float64,Float64
1,1,1.64912,0.102413,0.375125
2,2,-0.71507,0.308942,0.0100043
3,3,1.50139,0.129247,0.35231
4,4,-0.227602,0.388742,0.0333718
5,5,-0.0106089,0.39892,0.0528545
6,6,0.470564,0.357131,0.12387


In [4]:
function fit_model!(
    data; 
    pi_hat_0 = 0.5, 
    tolerance = 0.00001,
    max_iterations = 1000,
    verbose = false
    )

    # pre-allocate a table of results by iteration:
    progress = DataFrame(
        Iteration = 0:max_iterations, 
        pi_hat = Vector{Float64}(undef, max_iterations+1), 
        loglik = Vector{Float64}(undef, max_iterations+1), 
        diff_loglik = Vector{Float64}(undef, max_iterations+1)
        )

    # initial E step, to perform needed calculations for initial likelihood:
    E_step!(data, pi_hat_0) 
    ll = loglik(data)
    progress[1, :] = (0, pi_hat_0, ll, NaN)
    
    last_iter = 0
    for i in 1:max_iterations

        # M step: re-estimate parameters
        pi_hat = M_step(data)

        # E step: re-compute distribution of missing variables, using parameters
        E_step!(data, pi_hat)
        
        # Assess convergence

        ## save the previous log-likelihood so we can test for convergence
        ll_old = ll
        
        ## here's the new log-likelihood
        ll = loglik(data)

        ll_diff = ll - ll_old
        
        progress[i+1,:] = (i, pi_hat, ll, ll_diff)
        
        if(verbose)
            print(progress[i+1,:])
        end
        
        if ll_diff < tolerance
            last_iter = i
            break
        end
    end
    
    return progress[1:(last_iter + 1), :]
end

fit_model! (generic function with 1 method)

In [5]:
function E_step!(data, pi_hat)
    @transform!(data, :pY_Z1 = :p_Y_given_Z_1 .* pi_hat)
    @transform!(data, :pY_Z2 = :p_Y_given_Z_2 .* (1- pi_hat))
    @transform!(data, :pY = :pY_Z1 + :pY_Z2)
    @transform!(data, :pZ1_given_Y = :pY_Z1 ./ :pY)
end

E_step! (generic function with 1 method)

In [6]:
function M_step(data)
    data[:, :pZ1_given_Y] |> mean
end

M_step (generic function with 1 method)

In [7]:
function loglik(data)
    data[:, :pY] .|> log |> sum
end

loglik (generic function with 1 method)

In [8]:
@time progress = fit_model!(data, tolerance = 0.00001);


  1.440720 seconds (6.98 M allocations: 881.487 MiB, 9.69% gc time, 87.68% compilation time)


The first run requires compilation; on subsequent calls it will be faster:

In [9]:
@time progress = fit_model!(data, tolerance = 0.00001);

  0.159684 seconds (13.08 k allocations: 504.204 MiB, 15.17% gc time)


Note that the memory allocation is also lower this time, because the first run expands the `DataFrame` `data` with additional columns which are reused in the second run:

In [10]:
first(data, 6)

Unnamed: 0_level_0,Obs_ID,Y,p_Y_given_Z_1,p_Y_given_Z_2,pY_Z1,pY_Z2,pY
Unnamed: 0_level_1,Int64,Float64,Float64,Float64,Float64,Float64,Float64
1,1,1.64912,0.102413,0.375125,0.0820079,0.0747413,0.156749
2,2,-0.71507,0.308942,0.0100043,0.247387,0.00199329,0.249381
3,3,1.50139,0.129247,0.35231,0.103495,0.0701956,0.173691
4,4,-0.227602,0.388742,0.0333718,0.311287,0.00664912,0.317937
5,5,-0.0106089,0.39892,0.0528545,0.319438,0.0105309,0.329968
6,6,0.470564,0.357131,0.12387,0.285975,0.0246803,0.310655


Here are the iterative steps and final estimate:

In [11]:
progress

Unnamed: 0_level_0,Iteration,pi_hat,loglik,diff_loglik
Unnamed: 0_level_1,Int64,Float64,Float64,Float64
1,0,0.5,-877696.0,
2,1,0.665478,-837059.0,40637.5
3,2,0.73781,-828079.0,8979.53
4,3,0.770275,-826000.0,2078.93
5,4,0.785627,-825494.0,506.171
6,5,0.793145,-825367.0,127.183
7,6,0.796901,-825334.0,32.5178
8,7,0.798796,-825326.0,8.39098
9,8,0.799758,-825324.0,2.17561
10,9,0.800248,-825323.0,0.565479
