In [2]:
import numpy as np
from scipy.stats import binom

y = np.array([7, 7, 1, 3, 2, 1, 5, 1, 4, 0, 2, 3, 3, 1, 2, 2, 3,
              1, 4, 0, 8, 3, 5, 3, 0, 0, 7, 1, 1, 3, 1, 3, 2, 4, 1, 6, 2, 2, 4, 1, 3, 1, 1, 2, 7, 3, 3, 2, 2, 2])

In [3]:
def e_step(data, M, mu1_cur, mu2_cur, pi1_cur):
    pi2_cur = 1 - pi1_cur
    r_i1 = binom.pmf(data, M, mu1_cur) * pi1_cur / (binom.pmf(data, M, mu1_cur) * pi1_cur 
                                                    + binom.pmf(data, M, mu2_cur) * pi2_cur)
    r_i2 = binom.pmf(data, M, mu2_cur) * pi2_cur / (binom.pmf(data, M, mu1_cur) * pi1_cur 
                                                    + binom.pmf(data, M, mu2_cur) * pi2_cur)
    return r_i1, r_i2

In [4]:
def m_step(data, M, r1, r2):
    pi1_new = sum(r1) / len(data)
    pi2_new = 1 - pi1_new
    mu1_new = sum(data * r1) / (M * sum(r1))
    mu2_new = sum(data * r2) / (M * sum(r2))

    return pi1_new, pi2_new, mu1_new, mu2_new

In [5]:
def em_2coins(data, M = 10, mu1 = None, mu2 = None, pi1 = None, eps = .0005):
    mu1 = mu1 or np.random.random()
    mu2 = mu2 or np.random.random()
    pi1 = pi1 or np.random.random()
    pi2 = 1 - pi1
    k = 0
    print("Starting parameters: pi1 = %0.3f,  pi2 = %0.3f,  mu1 = %0.3f,  mu2 = %0.3f\n" % (pi1, pi2, mu1, mu2))
    print("№itr\t  pi1     pi2     mu1     mu2")
    while True:
        r1, r2 = e_step(data, M, mu1, mu2, pi1)
        pi1_new, pi2_new, mu1_new, mu2_new = m_step(data, M, r1, r2)
        print("#%d:\t  %0.3f   %0.3f   %0.3f   %0.3f" % (k, pi1_new, pi2_new, mu1_new, mu2_new))
        k = k + 1
        if np.sqrt((mu1 - mu1_new)**2 + (mu2 - mu2_new)**2 + 2 * (pi1 - pi1_new)**2) < eps:
            break
            
        mu1, mu2, pi1 = mu1_new, mu2_new, pi1_new
    print("---------------------------------------")
    print("Successfully finished after %d iterations" % k)

In [6]:
em_2coins(y)

Starting parameters: pi1 = 0.049,  pi2 = 0.951,  mu1 = 0.388,  mu2 = 0.344

№itr	  pi1     pi2     mu1     mu2
#0:	  0.044   0.956   0.355   0.266
#1:	  0.054   0.946   0.469   0.259
#2:	  0.092   0.908   0.620   0.234
#3:	  0.127   0.873   0.662   0.213
#4:	  0.139   0.861   0.658   0.208
#5:	  0.144   0.856   0.651   0.206
#6:	  0.147   0.853   0.647   0.205
#7:	  0.149   0.851   0.644   0.205
#8:	  0.150   0.850   0.642   0.204
#9:	  0.151   0.849   0.641   0.204
#10:	  0.151   0.849   0.640   0.204
#11:	  0.152   0.848   0.640   0.204
#12:	  0.152   0.848   0.640   0.204
---------------------------------------
Successfully finished after 13 iterations
