# EM Algorithm

latent variable 이 존재하는 probabilistic model 의 maximum likelihood 혹은 maximum aposterior 문제를 풀기 위한 알고리즘 중 하나이다. 굉장히 많은 probabilistic 모델을 풀기 위해 널리 사용되는 알고리즘 중 하나이며, iterative 한 알고리즘 중 하나이다. Clustering 에서 다뤘던 GMM 은 물론이고, HMM, RBM 등의 문제를 해결하는데 있어서도 사용되는 알고리즘이다.

In [2]:
import numpy as np

In [3]:
xs = np.array([(5,5), (9,1), (8,2), (4,6), (7,3)])
xs

array([[5, 5],
       [9, 1],
       [8, 2],
       [4, 6],
       [7, 3]])

In [13]:
thetas = np.array([[0.6, 0.4], [0.5, 0.5]])
thetas

array([[ 0.6,  0.4],
       [ 0.5,  0.5]])

In [14]:
pis = np.array([0.5,0.5])
pis

array([ 0.5,  0.5])

In [10]:
tol = 0.01
max_iter = 100

ll_old = 0

In [12]:
for i in range(max_iter):
    
    # 업데이트 하기 
    ws_A = []
    ws_B = []

    vs_A = []
    vs_B = []

    ll_new = 0

    for x in xs:

        # multinomial (binomial) log likelihood
        ll_A = np.sum([x*np.log(thetas[0])])
        ll_B = np.sum([x*np.log(thetas[1])])

        # [EQN 1]
        denom = np.exp(ll_A) + np.exp(ll_B)
        w_A = np.exp(ll_A)/denom
        w_B = np.exp(ll_B)/denom

        ws_A.append(w_A)
        ws_B.append(w_B)

        # used for calculating theta
        vs_A.append(np.dot(w_A, x))
        vs_B.append(np.dot(w_B, x))

        # update complete log likelihood
        ll_new += w_A * ll_A + w_B * ll_B

    # M-step: update values for parameters given current distribution
    # [EQN 2]
    thetas[0] = np.sum(vs_A, 0)/np.sum(vs_A)
    thetas[1] = np.sum(vs_B, 0)/np.sum(vs_B)
    # print distribution of z for each x and current parameter estimate

    print("Iteration: %d" % (i+1))
    print("theta_A = %.2f, theta_B = %.2f, ll = %.2f" % (thetas[0,0], thetas[1,0], ll_new))

    if np.abs(ll_new - ll_old) < tol:
        break
    ll_old = ll_new

Iteration: 1
theta_A = 0.71, theta_B = 0.58, ll = -32.69
Iteration: 2
theta_A = 0.75, theta_B = 0.57, ll = -31.26
Iteration: 3
theta_A = 0.77, theta_B = 0.55, ll = -30.76
Iteration: 4
theta_A = 0.78, theta_B = 0.53, ll = -30.33
Iteration: 5
theta_A = 0.79, theta_B = 0.53, ll = -30.07
Iteration: 6
theta_A = 0.79, theta_B = 0.52, ll = -29.95
Iteration: 7
theta_A = 0.80, theta_B = 0.52, ll = -29.90
Iteration: 8
theta_A = 0.80, theta_B = 0.52, ll = -29.88
Iteration: 9
theta_A = 0.80, theta_B = 0.52, ll = -29.87


- https://mk-minchul.github.io/EM/
- http://issactoast.com/130
