## Running the EM Algorithm

Now that we have figured out how to update our weights in each iteration algorithmically, let's code up the formulas.

Remember that $\hat{\theta}_A = \frac{\sum Pr(Z_i = A | X_i, \theta_i) x_i}{10\sum Pr(Z_i = A | X_i, \theta_i)}$ and $\hat{\theta}_B = \frac{\sum Pr(Z_i = B | X_i, \theta_i) x_i}{10\sum Pr(Z_i = B | X_i, \theta_i)}$.

Further recall that $Pr(Z_i = A | X_i, \theta_i) = \frac{Pr(X_i, \theta_i | Z_i = A) Pr(Z_i = A)} {Pr(X_i, \theta_i)}$. A similar formula applies for $Pr(Z_i = B | X_i, \theta_i)$.

We code this up in the function new_theta. You will notice it matches the results on the page here: http://karlrosaen.com/ml/notebooks/em-coin-flips/

In [10]:
import numpy as np
from scipy.stats import binom
from math import comb

In [11]:
def new_theta (old_theta, obs_heads):
    prob_x_theta_given_Z_A_xi = np.array([binom.pmf(k = k, n = 10, p = old_theta['A']) for k in obs_heads])
    prob_x_theta_given_Z_B_xi = np.array([binom.pmf(k = k, n = 10, p = old_theta['B']) for k in obs_heads])
    prob_zA_given_x_theta = prob_x_theta_given_Z_A_xi / (prob_x_theta_given_Z_A_xi + prob_x_theta_given_Z_B_xi)
    prob_zB_given_x_theta = prob_x_theta_given_Z_B_xi / (prob_x_theta_given_Z_A_xi + prob_x_theta_given_Z_B_xi)
    
    new_thetaA = np.sum(prob_zA_given_x_theta*obs_heads)/ (10 * np.sum(prob_zA_given_x_theta))
    new_thetaB = np.sum(prob_zB_given_x_theta*obs_heads)/ (10 * np.sum(prob_zB_given_x_theta))
    

    return{ 'A': new_thetaA, 'B' : new_thetaB}
    
    

In [15]:
# theta_A is index 0 and theta_B is index 1
theta = {'A' : 0.5, 'B' : 0.5}
obs_heads = np.array([5,9,8,4,7])

print(str(0) + ": thetaA/thetaB: " + f"{theta['A']:.3f}" + " / " + f"{theta['B']:.3f}") # print original
for i in range(6):
    theta = new_theta(theta, obs_heads)
    print(str(i+1) + ": thetaA/thetaB: " + f"{theta['A']:.3f}" + " / " + f"{theta['B']:.3f}")

0: thetaA/thetaB: 0.500 / 0.500
1: thetaA/thetaB: nan / nan
2: thetaA/thetaB: nan / nan
3: thetaA/thetaB: nan / nan
4: thetaA/thetaB: nan / nan
5: thetaA/thetaB: nan / nan
6: thetaA/thetaB: nan / nan


  prob_zA_given_x_theta = prob_x_theta_given_Z_A_xi / (prob_x_theta_given_Z_A_xi + prob_x_theta_given_Z_B_xi)
  prob_zB_given_x_theta = prob_x_theta_given_Z_B_xi / (prob_x_theta_given_Z_A_xi + prob_x_theta_given_Z_B_xi)
