# EM algorithm
參考:https://mlbhanuyerra.github.io/2018-01-28-Handwritten-Digits_Mixture-of-Bernoulli-Distributions/

In [1]:
import numpy as np
import pandas as pd 
from loadMnist import load_images, load_labels

In [2]:
N = 60000
# HW3讀資料
data = load_images("data/train-images.idx3-ubyte")[:N]
true = load_labels("data/train-labels.idx1-ubyte")[:N]

In [3]:
# train_data to 0/1 矩陣
data[data > 0] = 1

In [4]:
# 初始化 L(每個類別出現機率) and P(每個類別 每個pixel的p)
# w:(10,1)
L = np.random.uniform(.25, .75, 10)
tot = np.sum(L)
L = L/tot
print("init Lambda: \n", L)

P = np.random.rand(10, 784)
print("init P: \n", P)

init Lambda: 
 [0.13412131 0.0626771  0.11505688 0.11837364 0.13749633 0.06142893
 0.13529738 0.08342226 0.05665833 0.09546785]
init P: 
 [[0.0136914  0.50442465 0.02623742 ... 0.06754934 0.68955997 0.76303332]
 [0.88901833 0.23731753 0.14747632 ... 0.47673125 0.57263376 0.04668272]
 [0.75840194 0.73155086 0.64148766 ... 0.60799022 0.22246316 0.05544319]
 ...
 [0.53737863 0.93767736 0.13419727 ... 0.28406708 0.71861264 0.95060504]
 [0.35890996 0.08639793 0.57564042 ... 0.72286219 0.01067405 0.2073523 ]
 [0.55691593 0.01495693 0.54079675 ... 0.23901225 0.66898999 0.90917637]]


In [5]:
def responsibility(data, L, P):
    # 課本(9.48) 
    W = np.zeros((N,10))
    for i in range(N):
        for j in range(10):    
            W[i, j] = np.prod(data[i] * P[j] + (1 - data[i]) * (1 - P)[j])
            
    # 課本(9.47)
    W = W * L.reshape(1, -1)
    # 讓每一列(一張圖)機率在區間[0, 1]
    sums = np.sum(W, axis=1).reshape(-1, 1)
    sums[sums == 0] = 1
    W = W / sums

    return W


def get_L_new(W):
    # 課本(9.60)
    # 某次出現class i的機率全部加起來/n(把一行的resp加起來/n)
    L = np.sum(W, axis=0)
    L = L/N
    
    return L


def get_P_new(data, W):
    # 課本(9.59) 
    # (某次出現class i的機率*pixel=1出現次數)/ 某次出現class i的機率全部加起來
    sums = np.sum(W, axis=0)
    sums[sums==0] = 1
    W = W/sums
    P = data.T@W
    return P.T

# EM algo

In [6]:
diff = 1000
iteration = 0
for i in range(1, 100):
    # E-step:利用已知的P計算resposibility Wi
    W = responsibility(data, L, P)
        
    # M-step:利用 Estep算出的Wi，推得Wi下的L和P
    L_new = get_L_new(W)
    P_new = get_P_new(data, W)
    
    # 自定義difference: 就是L和P的差和 
    diff_new = np.sum(np.abs(L - L_new)) + np.sum(np.abs(P - P_new))
    print("No. of Iteration: {}, Difference: {}".format(i, diff_new))
    
    # 自定義收斂條件: difference<1
    if abs(diff_new - diff) < 1:
        break
    else:
        L = L_new
        P = P_new
        diff = diff_new
        iteration += 1

No. of Iteration: 1, Difference: 3210.498170533274
No. of Iteration: 2, Difference: 314.6151423200065
No. of Iteration: 3, Difference: 122.18440569481548
No. of Iteration: 4, Difference: 76.15565478482719
No. of Iteration: 5, Difference: 56.70866719599603
No. of Iteration: 6, Difference: 39.70606600185123
No. of Iteration: 7, Difference: 28.623979455665324
No. of Iteration: 8, Difference: 20.478836041137097
No. of Iteration: 9, Difference: 19.197215501093254
No. of Iteration: 10, Difference: 17.518486855895475
No. of Iteration: 11, Difference: 15.312429163639393
No. of Iteration: 12, Difference: 13.257700507908417
No. of Iteration: 13, Difference: 11.235716207181705
No. of Iteration: 14, Difference: 9.786410689436202
No. of Iteration: 15, Difference: 8.073349116685428
No. of Iteration: 16, Difference: 6.66322019716467
No. of Iteration: 17, Difference: 5.808112458157886


In [7]:
def confusion_matrix(true, preds):
    for c in range(10):
        TP, FP, FN, TN = 0, 0, 0, 0
        for n in range(N):
            if c == true[n]:
                if true[n] == preds[n]:
                    TP += 1
                else:
                    FP += 1 
            else:
                if c == preds[n]:
                    FN += 1
                else:
                    TN += 1
        
        confusion_matrix = pd.DataFrame([[TP,FP], [FN, TN]], index=("Is number {}".format(c),"Isn't number {}".format(c)),
                                        columns=("Predict number {}".format(c), "Predict number {}".format(c)))
        print(confusion_matrix)
        print()
        print("Sensitivity (Successfully predict number {}): {:.5f}".format(c, TP/(TP+FP)))
        print("Specificity (Successfully predict number {}): {:.5f}".format(c, TN/(FN+TN)))
        print("------------------------------------------------------------")
        
def error_rate(true, preds):
    error = 0
    for n in range(N):
        if true[n] != preds[n]:
            error += 1
    print("Total error rate: {}".format(error/N))

# result

![title](./group.PNG)

In [8]:
# groups: 因為是非監督式學習，所以你得到的是群(G0,G1,...,G9) 
groups = np.argmax(W, axis=1)

# 找出每個群對應的"類別"字典
group2class_dict = {}
for c in range(10):
    indexs = np.where(groups==c)[0]       # 首先找出Gi群的index
    if len(indexs)==0:                    # !這邊有個問題!: 常常只有八個群(或更少)，所以可能有群沒有index，直接跳過
        continue          
    
    temp = {}
    for i in indexs:                      # 用index找出Gi群裡面最多的真值當作，預測類別
        if true[i] in temp:
            temp[true[i]] += 1
        else:
            temp[true[i]] = 1
    classs = max(temp, key=temp.get)
    print("貼標class:",classs, ",真值數量:", temp)
    group2class_dict[c] = classs

貼標class: 4 ,真值數量: {4: 3427, 9: 3118, 2: 91, 7: 2646, 3: 280, 8: 291, 5: 326, 0: 9, 1: 7, 6: 13}
貼標class: 2 ,真值數量: {3: 1187, 2: 1951, 1: 39, 5: 537, 8: 619, 6: 77, 0: 102, 9: 26, 7: 19, 4: 5}
貼標class: 1 ,真值數量: {1: 3545, 4: 5, 6: 8, 8: 40, 3: 5, 7: 12, 2: 2, 9: 4, 5: 1}
貼標class: 8 ,真值數量: {4: 1467, 5: 2481, 7: 1335, 8: 2627, 9: 1007, 0: 434, 3: 437, 1: 798, 2: 259, 6: 178}
貼標class: 6 ,真值數量: {2: 3265, 6: 5362, 1: 49, 4: 339, 0: 151, 5: 164, 9: 26, 3: 243, 7: 13, 8: 60}
貼標class: 0 ,真值數量: {0: 4761, 5: 331, 6: 55, 3: 92, 2: 125, 4: 18, 9: 42, 8: 94, 7: 17}
貼標class: 1 ,真值數量: {1: 2295, 9: 1640, 4: 580, 7: 2216, 3: 356, 5: 158, 8: 554, 2: 64, 6: 138, 0: 17}
貼標class: 8 ,真值數量: {8: 1}
貼標class: 3 ,真值數量: {5: 1423, 3: 3531, 8: 1565, 9: 86, 6: 87, 0: 449, 2: 201, 1: 9, 7: 7, 4: 1}


In [9]:
# 將群貼上類別標籤=preds
preds = np.zeros((60000))
for i in range(N):
    preds[i] = group2class_dict[groups[i]]

In [10]:
confusion_matrix(true, preds)     # 會有兩個類別混淆矩陣很怪
print("Total iteration to converge: {}".format(iteration))
error_rate(true, preds)

                Predict number 0  Predict number 0
Is number 0                 4761              1162
Isn't number 0               774             53303

Sensitivity (Successfully predict number 0): 0.80382
Specificity (Successfully predict number 0): 0.98569
------------------------------------------------------------
                Predict number 1  Predict number 1
Is number 1                 5840               902
Isn't number 1              5800             47458

Sensitivity (Successfully predict number 1): 0.86621
Specificity (Successfully predict number 1): 0.89110
------------------------------------------------------------
                Predict number 2  Predict number 2
Is number 2                 1951              4007
Isn't number 2              2611             51431

Sensitivity (Successfully predict number 2): 0.32746
Specificity (Successfully predict number 2): 0.95169
------------------------------------------------------------
                Predict number 3  Pre

In [11]:
def plot(P, threshold):
    Pattern = np.asarray(P > threshold, dtype='uint8')
    for g in group2class_dict:
        print('class {}:'.format(group2class_dict[g]))
        
        temp = Pattern[g]
        for i in range(len(temp)):
            if i != 0 and i % 28 == 0:
                print()
            print(temp[i], end="")
        print()
        
    return

In [15]:
plot(P, threshold=0.65)    
# 改進與感想
# 1. 第二多的類別也許才是最佳代表分類 
# 2. threshold大小 顯示類別的關鍵pixel

class 4:
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000001111100111000000000
0000000011111000111100000000
0000000011100000011100000000
0000000011100000011100000000
0000000011000000011100000000
0000000011000000111100000000
0000000011000001111100000000
0000000001100011111100000000
0000000000000011111000000000
0000000000000001111000000000
0000000000000001111000000000
0000000000000001110000000000
0000000000000001110000000000
0000000000000001100000000000
0000000000000001100000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
class 2:
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000001111111100000000