In [1]:
import numpy as np
import scipy.stats as stats

### HW  Problem

Write a python code to accept a sequence of $X_1, X_2, ..., X_n$ of coin flip results and initial guess $\theta_A^0, \theta_b^0$, do EM to find $\hat{\theta_A}, \hat{\theta_b}$

Write more code to assemble these to give $\hat{\theta_A}, \hat{\theta_b}$

In [2]:
def em(x, n, num_iter=100):
    """
    Parameters:
    x: 1-d array of counts of heads in a series of coin flip experiments
    n: number of coin flips in each trial
    """
    param = np.random.random(2)
    if max(x)>n:
        raise ValueError("Values in x should not be greater than n")
    
    for i in range(num_iter):
        #Expectation step
        #get the prob dist over the two biased coins(binomial)
        prob = np.array([stats.binom.pmf(x, n, p) for p in param])
        prob = prob/np.sum(prob, axis=0) #normalize
        
        exp_heads = prob*x #expectation for heads
        exp_tails = prob*(n-x) #expectation of tails
        
        #Maximization step
        param = np.sum(exp_heads, axis=1)/np.sum(exp_heads+exp_tails, axis=1)
    return param

In [3]:
#test simulation
true_param = np.array([0.10, 0.80]) #probabilities of heads in each of the coin

#suppose we which coin was used in the series of 100 experiments
coin_used = np.random.binomial(1, 0.5, size=100)#(bernoulli values) 0 for coin A and 1 for coin B

#generate samples
n = 10
x = np.random.binomial(10, true_param[coin_used], len(coin_used))

In [4]:
#run em for several iterations and see the results:
for i in range(20):
    print(em(x, n))

[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.81856534  0.09230997]
[ 0.09230997  0.81856534]
[ 0.81856534  0.09230997]
[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.81856534  0.09230997]
[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.81856534  0.09230997]
[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.09230997  0.81856534]
[ 0.81856534  0.09230997]
[ 0.09230997  0.81856534]
