# Capter9 EM算法
## EM模型

In [1]:
import numpy as np

class EM_Model():
    def __init__(self, prob_type = 'Bernoulli'):
        self.prob_type = prob_type

    def bnl_prob(self, train_x, params):
        T_num = np.sum(train_x == 1, axis = 1)
        F_num = np.sum(train_x == 0, axis = 1)
        hide_prob = []
        for p in params:
            hide_prob.append(p['alph'] * np.power(p['prob'], T_num) * np.power(1 - p['prob'], F_num))
        hide_sum = np.sum(hide_prob, axis = 0)
        hide_prob = hide_prob/hide_sum
        T_p = np.sum(T_num * hide_prob, axis = 1)
        F_p = np.sum(F_num * hide_prob, axis = 1)
        new_params = []
        for i in range(len(T_p)):
            new_params.append({})
            new_params[i]['alph'] = np.mean(hide_prob[i])
            new_params[i]['prob'] = (T_p[i]/(T_p[i] + F_p[i]))
        return np.array(new_params)

    def gaussian_prob(self, train_x, params):
        hide_prob = []
        def gauss_func(x, param):
            a = param['alph']
            s = param['scale']
            l = param['loc']
            return a/s * np.exp(-1*np.power((x - l),2)/(2*np.power(s, 2)))
        for p in params:
            hide_prob.append(gauss_func(train_x, p))
        hide_sum = np.sum(hide_prob, axis = 0)
        hide_prob = hide_prob/hide_sum
        new_params = []
        for i in range(len(params)):
            new_params.append({})
            new_params[i]['alph'] = np.mean(hide_prob[i])
            new_params[i]['loc'] = np.dot(hide_prob[i].T,train_x)/np.sum(hide_prob[i])
            new_params[i]['scale'] = np.sqrt(np.dot(hide_prob[i].T, np.power(train_x - new_params[i]['loc'], 2))/np.sum(hide_prob[i]))
        return new_params

    def fit(self, init_val, train_x, epoch = 10, eps = 0.000001):
        params = init_val
        for i in range(epoch):
            if self.prob_type == 'Bernoulli':
                new_params = self.bnl_prob(train_x, params)
            elif self.prob_type == 'Gaussian':
                new_params = self.gaussian_prob(train_x, params)
            max_diff = 0
            for j in range(len(new_params)):
                for k in new_params[j]:
                    diff = np.abs(new_params[j][k] - params[j][k])
                    if diff > max_diff:
                        max_diff = diff
            if max_diff < eps:
                break
            else:
                params = new_params
            print(params)
        return

## 算法测试
### 混合伯努利分布测试
初始值说明：alph 选择第i个模型概率的初始值，prob 为1概率的初始值

In [6]:
train_x = []
for i in range(300):
    if np.random.rand() > 0.2:
        train_x.append(np.where(np.random.rand(100) > 0.7, 1, 0))
    else:
        train_x.append(np.where(np.random.rand(100) > 0.4, 1, 0))
train_x = np.array(train_x)
model = EM_Model()
model.fit([{'alph' : 0.7, 'prob' : 0.4}, {'alph':0.3, 'prob' : 0.7}], train_x, 20)

[[0 0 0 ... 0 0 0]
 [1 0 0 ... 0 0 0]
 [0 1 0 ... 0 0 1]
 ...
 [0 0 0 ... 1 0 1]
 [0 0 1 ... 1 0 1]
 [0 1 1 ... 0 0 0]]
[{'alph': 0.870817945092795, 'prob': 0.3146831852005208}
 {'alph': 0.12918205490720497, 'prob': 0.6167128659290555}]
[{'alph': 0.8299463760280815, 'prob': 0.3034571456070492}
 {'alph': 0.17005362397191853, 'prob': 0.5989101516614664}]
[{'alph': 0.8283786178324937, 'prob': 0.30315241729475223}
 {'alph': 0.17162138216750628, 'prob': 0.5976820502976621}]
[{'alph': 0.8282791458887997, 'prob': 0.303133855480086}
 {'alph': 0.17172085411120036, 'prob': 0.5976009706224854}]
[{'alph': 0.8282726992570061, 'prob': 0.3031326545695716}
 {'alph': 0.17172730074299386, 'prob': 0.5975957085606616}]


### 混合高斯分布测试
初始值说明：alph 选择第i个模型概率的初始值，loc 期望，scale 方差

In [7]:
train_x = []
for i in range(300):
    if np.random.rand() > 0.4:
        train_x.append(np.random.normal(3, 1))
    else:
        train_x.append(np.random.normal(50, 7))
train_x = np.array(train_x)

model = EM_Model(prob_type='Gaussian')
model.fit([{'alph': 0.6, 'loc':5, 'scale':9}, {'alph': 0.4, 'loc':6, 'scale':10}], train_x, 20)

[{'alph': 0.40466593248492233, 'loc': 7.969481524597407, 'scale': 13.825940851132833}, {'alph': 0.5953340675150778, 'loc': 33.49811919467401, 'scale': 23.218988230941964}]
[{'alph': 0.4185789663208641, 'loc': 4.122510902638575, 'scale': 6.760269176893133}, {'alph': 0.5814210336791359, 'loc': 36.87852972722482, 'scale': 21.750490750804683}]
[{'alph': 0.5007956870861883, 'loc': 2.9803043521910344, 'scale': 1.074311148666671}, {'alph': 0.49920431291381173, 'loc': 43.41914743193235, 'scale': 17.09662181788495}]
[{'alph': 0.5595145463194462, 'loc': 2.964457364267405, 'scale': 1.0172982274202018}, {'alph': 0.44048545368055375, 'loc': 48.82997163667947, 'scale': 9.075106954688277}]
[{'alph': 0.5666563540849847, 'loc': 2.9847709532786197, 'scale': 1.055382637985247}, {'alph': 0.4333436459150153, 'loc': 49.559304729925366, 'scale': 7.129460945197463}]
[{'alph': 0.5666666614776379, 'loc': 2.9848390743883337, 'scale': 1.055497217845223}, {'alph': 0.4333333385223621, 'loc': 49.56032348386422, 'sca