In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [99]:
# generate the data
BASE = [1, 2, 3, 4, 5, 6]
def sample(threshold=0.4, n_samples=10000, seed=0):
    np.random.seed(seed)
    x = []
    for i in range(n_samples):
        if np.random.uniform(0, 1) < threshold:
            x.append(np.random.choice(BASE, p=[0.01, 0.05, 0.05, 0.05, 0.05, 0.79]))
        else:
            x.append(np.random.choice(BASE, p=[1/6, 1/6, 1/6, 1/6, 1/6, 1/6]))
    count = {}
    for data in x:
        if data not in count.keys():
            count[data] = 1
        else:
            count[data] += 1
    return x, count

data, count = sample()
print(count)

{5: 1083, 4: 1043, 6: 4843, 3: 1070, 1: 921, 2: 1040}


In [5]:
init_para = {'w_11': 1/6, 'w_21': 1/6, 'w_31': 1/6, 'w_41': 1/6, 'w_51': 1/6, 'w_61': 1/6,
             'w_12': 1/6, 'w_22': 1/6, 'w_32': 1/6, 'w_42': 1/6, 'w_52': 1/6, 'w_62': 1/6,
             'pi_1': 1/2, 'pi_2': 1/2}

In [14]:
post_para = {'z_11': 1/2, 'z_21': 1/2, 'z_31': 1/2, 'z_41': 1/2, 'z_51': 1/2, 'z_61': 1/2,
             'z_12': 1/2, 'z_22': 1/2, 'z_32': 1/2, 'z_42': 1/2, 'z_52': 1/2, 'z_62': 1/2}

In [48]:
def compute_post(post_, prior_):
    for z_ik in post_.keys():
        i, k = z_ik[-2], z_ik[-1]
        w_ik = 'w_' + i + k     
        pi_k = 'pi_' + k
        
        # z_ik = pi_k * w_ik / (pi_1 * w_i1 + pi_2 * w_i2)
        post_[z_ik] = prior_[pi_k] * prior_[w_ik] / (prior_['pi_1']*prior_['w_' + i + '1'] + prior_['pi_2']*prior_['w_'+i+'2'])
    
    return post_

In [52]:
init_prior = {'w_11': 1/6, 'w_21': 1/6, 'w_31': 1/6, 'w_41': 1/6, 'w_51': 1/6, 'w_61': 1/6,
             'w_12': 0.2, 'w_22': 0.1, 'w_32': 0.2, 'w_42': 0.3, 'w_52': 0.1, 'w_62': 0.1,
             'pi_1': 1/2, 'pi_2': 1/2}
print(compute_post(post_para, init_prior))

{'z_11': 0.4545454545454545, 'z_21': 0.625, 'z_31': 0.4545454545454545, 'z_41': 0.3571428571428571, 'z_51': 0.625, 'z_61': 0.625, 'z_12': 0.5454545454545454, 'z_22': 0.375, 'z_32': 0.5454545454545454, 'z_42': 0.6428571428571428, 'z_52': 0.375, 'z_62': 0.375}


In [31]:
def compute_prior(post_, prior_, count):
    # compute normalize const N_k
    # N_k = sum_i (n_i * z_ik)
    normalize_const = {}
    for k in range(1, 3):
        N_k = 'N_' + str(k)
        normalize_const[N_k] = 0
        for i in range(1, 7):
            z_ik = 'z_' + str(i) + str(k)
            normalize_const[N_k] += count[i] * post_[z_ik] 
    # compute w_ik and pi_k
    for key in prior_.keys():
        # compute w_ik
        if 'w' in key:
            w_ik = key
            i, k = w_ik[-2], w_ik[-1]
            z_ik = 'z_' + i + k
            N_k = 'N_' + k
            # w_ik = n_i * z_ik / N_k
            prior_[w_ik] = count[int(i)] * post_[z_ik] / normalize_const[N_k]
        # compute the pi_k
        else:
            pi_k = key
            k = pi_k[-1]
            N_k = 'N_' + k
            # pi_k = N_k / sum_k(N_k)
            prior_[pi_k] = normalize_const[N_k] / (normalize_const['N_1'] + normalize_const['N_2'])
    
    return prior_

In [53]:
post_ = {'z_11': 0.4545454545454545, 'z_21': 0.625, 'z_31': 0.4545454545454545, 'z_41': 0.3571428571428571, 'z_51': 0.625, 'z_61': 0.625, 'z_12': 0.5454545454545454, 'z_22': 0.375, 'z_32': 0.5454545454545454, 'z_42': 0.6428571428571428, 'z_52': 0.375, 'z_62': 0.375}
print(compute_prior(post_, init_para, count))
del post_

{'w_11': 0.13798715006179327, 'w_21': 0.19351515663260999, 'w_31': 0.13532197769299853, 'w_41': 0.10618931014101154, 'w_51': 0.18642235919952715, 'w_61': 0.24056404627205943, 'w_12': 0.18575525071688753, 'w_22': 0.13025291275930662, 'w_32': 0.1821674545971221, 'w_42': 0.2144245524296675, 'w_52': 0.12547882921284456, 'w_62': 0.16192100028417164, 'pi_1': 0.5287053571428572, 'pi_2': 0.47129464285714284}


In [95]:
def check_prior(prior_):
    w_1_sum, w_2_sum, pi_sum = 0, 0, 0
    for i in range(1, 7):
        w_1_sum += prior_['w_'+ str(i) + '1']
        w_2_sum += prior_['w_'+ str(i) + '2']
    pi_sum = prior_['pi_1'] + prior_['pi_2']
    assert (w_1_sum - 1) < 1e-5, "w_1_sum: " + str(w_1_sum)
    assert (w_2_sum - 1) < 1e-5, "w_2_sum: " + str(w_2_sum)
    assert (pi_sum - 1)  < 1e-5 , "pi_sum: "  + str(pi_sum)

In [96]:
def check_post(post_):
    z1_sum = post_['z_11'] + post_['z_12']
    z2_sum = post_['z_21'] + post_['z_22']
    assert (z1_sum - 1) < 1e-5, "z1_sum: " + str(z1_sum)
    assert (z2_sum - 1) < 1e-5, "z2_sum: " + str(z2_sum)

In [104]:
def mix_multi_model(count, init_prior=None, init_post=None, epochs=100, epsilon=1e-3):
    ### init parameter 
    ### Please do not give symmetric parameter
    if init_prior is None:
        prior_para = {'w_11': 0.1, 'w_21': 0.2, 'w_31': 0.2, 'w_41': 0.1, 'w_51': 0.3, 'w_61': 0.1,
                      'w_12': 0.01, 'w_22': 0.05, 'w_32': 0.05, 'w_42': 0.05, 'w_52': 0.05, 'w_62': 0.79,
                      'pi_1': 0.3,  'pi_2': 0.7 }
    else:
        prior_para = init_prior
    if init_post is None:
        post_para = {'z_11': 0.7, 'z_21': 0.7, 'z_31': 0.6, 'z_41': 0.6, 'z_51': 0.6, 'z_61': 0.1,
                     'z_12': 0.3, 'z_22': 0.3, 'z_32': 0.4, 'z_42': 0.4, 'z_52': 0.4, 'z_62': 0.9}
    else:
        post_para = init_post
    
    for i_epoch in range(epochs):
        ### copy parameter to determine convergency
        post_copy = post_para.copy()
        prior_copy = prior_para.copy()
        
        ### step 1
        ### compute the post parameter
        try:
            post_para = compute_post(post_para, prior_para)
        except Exception as e:
            if i == 0:
                print("*"*40)
                print("输入参数有误!请检查参数是否符合概率定义(for user)")
                print("*"*40)
            else:
                print("*"*40)
                print("模型计算有误!(for developer)")
                print("*"*40)
            print(e)
            return prior_para, post_para
        
        ### step 2
        ### under the post parameter, compute the optimal prior parameter
        try:
            prior_para = compute_prior(post_para, prior_para, count)
        except Exception as e:
            if i == 0:
                print("*"*40)
                print("输入参数有误!请检查参数是否符合概率定义(for user)")
                print("*"*40)
            else:
                print("*"*40)
                print("模型计算有误!(for developer)")
                print("*"*40)
            print(e)
            return prior_para, post_para
        
        ### step 3
        ### check out 
        check_prior(prior_para)
        check_post(post_para)
        
        ### step 4
        ### compute the difference of pi_k
        diff = 0
        for key in prior_para.keys():
            diff += abs(prior_para[key] - prior_copy[key])
        # if i_epoch % 10 == 0:
        print("{}/{} ---> diff: {}".format(i_epoch, epochs, diff))
        
        if False:
            print("{}/{} ---> diff: {}".format(i_epoch, epochs, diff))
            print("训练结束, 参数收敛!")
            return prior_para, post_para
    return prior_para, post_para

In [105]:
prior_para, post_para = mix_multi_model(count)
print(prior_para)
print(post_para)

0/100 ---> diff: 0.5543257784134089
1/100 ---> diff: 7.632783294297951e-17
2/100 ---> diff: 4.85722573273506e-17
3/100 ---> diff: 3.469446951953614e-17
4/100 ---> diff: 0.0
5/100 ---> diff: 0.0
6/100 ---> diff: 0.0
7/100 ---> diff: 0.0
8/100 ---> diff: 0.0
9/100 ---> diff: 0.0
10/100 ---> diff: 0.0
11/100 ---> diff: 0.0
12/100 ---> diff: 0.0
13/100 ---> diff: 0.0
14/100 ---> diff: 0.0
15/100 ---> diff: 0.0
16/100 ---> diff: 0.0
17/100 ---> diff: 0.0
18/100 ---> diff: 0.0
19/100 ---> diff: 0.0
20/100 ---> diff: 0.0
21/100 ---> diff: 0.0
22/100 ---> diff: 0.0
23/100 ---> diff: 0.0
24/100 ---> diff: 0.0
25/100 ---> diff: 0.0
26/100 ---> diff: 0.0
27/100 ---> diff: 0.0
28/100 ---> diff: 0.0
29/100 ---> diff: 0.0
30/100 ---> diff: 0.0
31/100 ---> diff: 0.0
32/100 ---> diff: 0.0
33/100 ---> diff: 0.0
34/100 ---> diff: 0.0
35/100 ---> diff: 0.0
36/100 ---> diff: 0.0
37/100 ---> diff: 0.0
38/100 ---> diff: 0.0
39/100 ---> diff: 0.0
40/100 ---> diff: 0.0
41/100 ---> diff: 0.0
42/100 ---> diff: 