# EM algorithm
參考:https://mlbhanuyerra.github.io/2018-01-28-Handwritten-Digits_Mixture-of-Bernoulli-Distributions/

In [1]:
import numpy as np
import pandas as pd 
from loadMnist import load_images, load_labels

In [2]:
N = 60000
# HW3讀資料
data = load_images("data/train-images.idx3-ubyte")[:N]
true = load_labels("data/train-labels.idx1-ubyte")[:N]

In [3]:
# train_data to 0/1 矩陣
data[data > 0] = 1

In [4]:
# 初始化 L(每個類別出現機率) and P(每個類別 每個pixel的p)
# w:(10,1)
L = np.random.uniform(.25, .75, 10)
tot = np.sum(L)
L = L/tot
print("init Lambda: \n", L)

P = np.random.rand(10, 784)
print("init P: \n", P)

init Lambda: 
 [0.11661669 0.07766198 0.13782261 0.06922699 0.11677242 0.13518452
 0.1388937  0.05803875 0.07299346 0.07678888]
init P: 
 [[0.29732752 0.67754102 0.28525515 ... 0.07686414 0.3474416  0.83525107]
 [0.19415164 0.6434689  0.72547214 ... 0.59426784 0.84458183 0.17463937]
 [0.95378532 0.16283732 0.93589146 ... 0.0527249  0.0743095  0.85733169]
 ...
 [0.58159857 0.1526059  0.90070812 ... 0.54416737 0.47123231 0.07618872]
 [0.74448118 0.93090253 0.97314533 ... 0.12684871 0.12586899 0.56400256]
 [0.13427634 0.66463372 0.67513633 ... 0.75457829 0.34368036 0.89873484]]


In [5]:
def responsibility(data, L, P):
    # 課本(9.48) 
    W = np.zeros((N,10))
    for i in range(N):
        for j in range(10):    
            W[i, j] = np.prod(data[i] * P[j] + (1 - data[i]) * (1 - P)[j])
            
    # 課本(9.47)
    W = W * L.reshape(1, -1)
    # 讓每一列(一張圖)機率在區間[0, 1]
    sums = np.sum(W, axis=1).reshape(-1, 1)
    sums[sums == 0] = 1
    W = W / sums

    return W


def get_L_new(W):
    # 課本(9.60)
    # 某次出現class i的機率全部加起來/n(把一行的resp加起來/n)
    L = np.sum(W, axis=0)
    L = L/N
    
    return L


def get_P_new(data, W):
    # 課本(9.59) 
    # (某次出現class i的機率*pixel=1出現次數)/ 某次出現class i的機率全部加起來
    sums = np.sum(W, axis=0)
    sums[sums==0] = 1
    W = W/sums
    P = data.T@W
    return P.T

# EM algo

In [6]:
diff = 1000
iteration = 0
for i in range(1, 15):
    # E-step:利用已知的P計算resposibility Wi
    W = responsibility(data, L, P)
        
    # M-step:利用 Estep算出的Wi，推得Wi下的L和P
    L_new = get_L_new(W)
    P_new = get_P_new(data, W)
    
    # 自定義difference: 就是L和P的差和 
    diff_new = np.sum(np.abs(L - L_new)) + np.sum(np.abs(P - P_new))
    print("No. of Iteration: {}, Difference: {}".format(i, diff_new))
    
    # 自定義收斂條件: difference<1
    if abs(diff_new - diff) < 1:
        break
    else:
        L = L_new
        P = P_new
        diff = diff_new
        iteration += 1

No. of Iteration: 1, Difference: 3348.439020165073
No. of Iteration: 2, Difference: 180.72094597322092
No. of Iteration: 3, Difference: 80.9287877762104
No. of Iteration: 4, Difference: 50.24223880110068
No. of Iteration: 5, Difference: 37.56651305390023
No. of Iteration: 6, Difference: 31.25741629042092
No. of Iteration: 7, Difference: 26.26484348016646
No. of Iteration: 8, Difference: 21.50960707270319
No. of Iteration: 9, Difference: 15.956377859801549
No. of Iteration: 10, Difference: 12.895455396656214
No. of Iteration: 11, Difference: 11.328263088291559
No. of Iteration: 12, Difference: 9.390195981358332
No. of Iteration: 13, Difference: 7.294508894776285
No. of Iteration: 14, Difference: 5.934155658122049


In [7]:
def confusion_matrix(true, preds):
    for c in range(10):
        TP, FP, FN, TN = 0, 0, 0, 0
        for n in range(N):
            if c == true[n]:
                if true[n] == preds[n]:
                    TP += 1
                else:
                    FP += 1 
            else:
                if c == preds[n]:
                    FN += 1
                else:
                    TN += 1
        
        confusion_matrix = pd.DataFrame([[TP,FP], [FN, TN]], index=("Is number {}".format(c),"Isn't number {}".format(c)),
                                        columns=("Predict number {}".format(c), "Predict number {}".format(c)))
        print(confusion_matrix)
        print()
        print("Sensitivity (Successfully predict number {}): {:.5f}".format(c, TP/(TP+FP)))
        print("Specificity (Successfully predict number {}): {:.5f}".format(c, TN/(FN+TN)))
        print("------------------------------------------------------------")
        
def error_rate(true, preds):
    error = 0
    for n in range(N):
        if true[n] != preds[n]:
            error += 1
    print("Total error rate: {}".format(error/N))

# result

![title](./group.PNG)

In [8]:
# groups: 因為是非監督式學習，所以你得到的是群(G0,G1,...,G9) 
groups = np.argmax(W, axis=1)

# 找出每個群對應的"類別"字典
group2class_dict = {}
for c in range(10):
    indexs = np.where(groups==c)[0]       # 首先找出Gi群的index
    if len(indexs)==0:                    # !這邊有個問題!: 常常只有八個群(或更少)，所以可能有群沒有index，直接跳過
        continue          
    
    temp = {}
    for i in indexs:                      # 用index找出Gi群裡面最多的真值當作，預測類別
        if true[i] in temp:
            temp[true[i]] += 1
        else:
            temp[true[i]] = 1
    classs = max(temp, key=temp.get)
    print("貼標class:",classs, ",真值數量:", temp)
    group2class_dict[c] = classs

貼標class: 2 ,真值數量: {2: 4051, 3: 799, 6: 3166, 4: 1069, 0: 556, 1: 49, 5: 583, 9: 426, 8: 714, 7: 106}
貼標class: 0 ,真值數量: {0: 4806, 3: 750, 4: 166, 8: 1102, 2: 819, 5: 1716, 1: 3, 7: 40, 6: 314, 9: 93}
貼標class: 3 ,真值數量: {5: 779, 3: 3281, 9: 60, 8: 502, 0: 62, 2: 283, 7: 14, 6: 2, 1: 15, 4: 1}
貼標class: 7 ,真值數量: {4: 4051, 9: 4907, 7: 5633, 3: 491, 8: 702, 5: 606, 2: 72, 1: 29, 0: 6, 6: 4}
貼標class: 6 ,真值數量: {6: 2224, 1: 137, 0: 480, 8: 1915, 7: 16, 3: 535, 2: 649, 5: 576, 9: 69, 4: 269}
貼標class: 1 ,真值數量: {1: 6509, 5: 1161, 8: 915, 9: 394, 7: 456, 4: 285, 2: 83, 6: 206, 3: 275, 0: 13}
貼標class: 8 ,真值數量: {8: 1}
貼標class: 6 ,真值數量: {6: 2, 2: 1, 4: 1}


In [9]:
# 將群貼上類別標籤=preds
preds = np.zeros((60000))
for i in range(N):
    preds[i] = group2class_dict[groups[i]]

In [10]:
confusion_matrix(true, preds)     # 會有兩個類別混淆矩陣很怪
print("Total iteration to converge: {}".format(iteration))
error_rate(true, preds)

                Predict number 0  Predict number 0
Is number 0                 4806              1117
Isn't number 0              5003             49074

Sensitivity (Successfully predict number 0): 0.81141
Specificity (Successfully predict number 0): 0.90748
------------------------------------------------------------
                Predict number 1  Predict number 1
Is number 1                 6509               233
Isn't number 1              3788             49470

Sensitivity (Successfully predict number 1): 0.96544
Specificity (Successfully predict number 1): 0.92887
------------------------------------------------------------
                Predict number 2  Predict number 2
Is number 2                 4051              1907
Isn't number 2              7468             46574

Sensitivity (Successfully predict number 2): 0.67993
Specificity (Successfully predict number 2): 0.86181
------------------------------------------------------------
                Predict number 3  Pre

In [11]:
def plot(P, threshold):
    Pattern = np.asarray(P > threshold, dtype='uint8')
    for g in group2class_dict:
        print('class {}:'.format(group2class_dict[g]))
        
        temp = Pattern[g]
        for i in range(len(temp)):
            if i != 0 and i % 28 == 0:
                print()
            print(temp[i], end="")
        print()
        
    return

In [12]:
plot(P, threshold=0.5)

class 2:
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000110000000000000
0000000000111111100000000000
0000000001111111110000000000
0000000001111111110000000000
0000000011111000010000000000
0000000001100000011000000000
0000000001000000011000000000
0000000000000000111100000000
0000000000000001111100000000
0000000001000011111110000000
0000000011111111111110000000
0000000011111111111111000000
0000000011111111111111000000
0000000111111111111111000000
0000000111111111111111000000
0000000011111111111111000000
0000000011111111111111000000
0000000001111111111110000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
class 0:
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000111111110000