In [1]:
import numpy as np


def em(data, thetas, max_iter=50, eps=1e-3):
    """
    输入
    :param data: 观测数据
    :param thetas: 初始化的估计参数值
    :param max_iter: 最大迭代次数
    :param eps: 收敛阈值
    :return: thetas:估计值
    """
    # 初始化似然函数值
    ll_old = 0
    for i in range(max_iter):
        """
        E步：求隐变量分布
        """
        # 对数似然
        log_like = np.array([np.sum(data * np.log(theta), axis=1) for theta in thetas])
        # 似然
        like = np.exp(log_like)
        # 求隐变量分布
        ws = like/like.sum(0)
        # 概率加权
        vs = np.array([w[:, None] * data for w in ws])
        """
        ### M步：更新参数值
        """
        thetas = np.array([v.sum(0)/v.sum() for v in vs])
        thetas = np.array([v.sum(0)/v.sum() for v in vs])
        # 更新似然函数
        ll_new = np.sum([w*l for w, l in zip(ws, log_like)])
        print("Iteration:%d" % (i+1))
        #print("theta_B = %.2f, theta_C = %.2f, ll = %.2f" % (thetas[0, 0], thetas[1, 0], ll_new))
        print("theta_B = %.2f, theta_C = %.2f, ll = %.2f"
              % (thetas[0,0], thetas[1,0], ll_new))
        if np.abs(ll_new - ll_old) < eps:
            break
        ll_old = ll_new

    return thetas



In [2]:
# 观测数据，5次独立实验， 每次实验10次抛掷的正反面次数
# 比如第一次实验为5次正面，5次反面
observed_data = np.array([(5, 5), (9, 1), (8, 2), (4, 6), (7, 3)])
# 初始化参数值，即硬币B出现正面的概率为0.6，硬币C出现正面的概率为0.5
thetas = np.array([[0.6, 0.4], [0.5, 0.5]])
# EM算寻优
thetas = em(observed_data, thetas, max_iter=30, eps=1e-3)
print(thetas)

Iteration:1
theta_B = 0.71, theta_C = 0.58, ll = -32.69
Iteration:2
theta_B = 0.75, theta_C = 0.57, ll = -31.26
Iteration:3
theta_B = 0.77, theta_C = 0.55, ll = -30.76
Iteration:4
theta_B = 0.78, theta_C = 0.53, ll = -30.33
Iteration:5
theta_B = 0.79, theta_C = 0.53, ll = -30.07
Iteration:6
theta_B = 0.79, theta_C = 0.52, ll = -29.95
Iteration:7
theta_B = 0.80, theta_C = 0.52, ll = -29.90
Iteration:8
theta_B = 0.80, theta_C = 0.52, ll = -29.88
Iteration:9
theta_B = 0.80, theta_C = 0.52, ll = -29.87
Iteration:10
theta_B = 0.80, theta_C = 0.52, ll = -29.87
Iteration:11
theta_B = 0.80, theta_C = 0.52, ll = -29.87
Iteration:12
theta_B = 0.80, theta_C = 0.52, ll = -29.87
[[0.7967829  0.2032171 ]
 [0.51959543 0.48040457]]
