# EM algorithm
參考:https://mlbhanuyerra.github.io/2018-01-28-Handwritten-Digits_Mixture-of-Bernoulli-Distributions/

In [1]:
import numpy as np
import pandas as pd 
from loadMnist import load_images, load_labels

In [2]:
N = 60000
# HW3讀資料
data = load_images("data/train-images.idx3-ubyte")[:N]
true = load_labels("data/train-labels.idx1-ubyte")[:N]

In [3]:
# train_data to 0/1 矩陣
data[data > 0] = 1

In [4]:
# 初始化 L(每個類別出現機率) and P(每個類別 每個pixel的p)
# w:(10,1)
L = np.random.uniform(.25, .75, 10)
tot = np.sum(L)
L = L/tot
print("init Lambda: \n", L)

P = np.random.rand(10, 784)
print("init P: \n", P)

init Lambda: 
 [0.10685176 0.12676213 0.0915706  0.11883933 0.07130692 0.0814007
 0.08120382 0.10652006 0.14951053 0.06603413]
init P: 
 [[0.229566   0.12182178 0.09241054 ... 0.75741331 0.81600426 0.28669333]
 [0.29121682 0.67051234 0.42958692 ... 0.68763728 0.42670072 0.42067238]
 [0.72274536 0.61870434 0.14521331 ... 0.17204987 0.73618818 0.58009875]
 ...
 [0.27528253 0.62781116 0.49006376 ... 0.645893   0.82445003 0.98585206]
 [0.27335035 0.68253712 0.75568175 ... 0.87254517 0.38799735 0.07083456]
 [0.98363703 0.24769425 0.97446195 ... 0.6039563  0.32485887 0.7722092 ]]


In [5]:
def responsibility(data, L, P):
    # 課本(9.48) 
    W = np.zeros((N,10))
    for i in range(N):
        for j in range(10):    
            W[i, j] = np.prod(data[i] * P[j] + (1 - data[i]) * (1 - P)[j])
            
    # 課本(9.47)
    W = W * L.reshape(1, -1)
    # 讓每一列(一張圖)機率在區間[0, 1]
    sums = np.sum(W, axis=1).reshape(-1, 1)
    sums[sums == 0] = 1
    W = W / sums

    return W


def get_L_new(W):
    # 課本(9.60)
    # 某次出現class i的機率全部加起來/n(把一行的resp加起來/n)
    L = np.sum(W, axis=0)
    L = L/N
    
    return L


def get_P_new(data, W):
    # 課本(9.59) 
    # (某次出現class i的機率*pixel=1出現次數)/ 某次出現class i的機率全部加起來
    sums = np.sum(W, axis=0)
    sums[sums==0] = 1
    W = W/sums
    P = data.T@W
    return P.T

# EM algo

In [6]:
diff = 1000
iteration = 0
for i in range(1, 100):
    # E-step:利用已知的P計算resposibility Wi
    W = responsibility(data, L, P)
        
    # M-step:利用 Estep算出的Wi，推得Wi下的L和P
    L_new = get_L_new(W)
    P_new = get_P_new(data, W)
    
    # 自定義difference: 就是L和P的差和 
    diff_new = np.sum(np.abs(L - L_new)) + np.sum(np.abs(P - P_new))
    print("No. of Iteration: {}, Difference: {}".format(i, diff_new))
    
    # 自定義收斂條件: difference<1
    if abs(diff_new - diff) < 1:
        break
    else:
        L = L_new
        P = P_new
        diff = diff_new
        iteration += 1

No. of Iteration: 1, Difference: 3194.62177908525
No. of Iteration: 2, Difference: 220.6520511136953
No. of Iteration: 3, Difference: 130.65746408435862
No. of Iteration: 4, Difference: 86.23840756087189
No. of Iteration: 5, Difference: 75.56648463965796
No. of Iteration: 6, Difference: 52.369241213692376
No. of Iteration: 7, Difference: 36.16076425620772
No. of Iteration: 8, Difference: 24.643671220066498
No. of Iteration: 9, Difference: 15.514135406609972
No. of Iteration: 10, Difference: 9.275618501863539
No. of Iteration: 11, Difference: 5.804495350772368
No. of Iteration: 12, Difference: 4.314068533290799
No. of Iteration: 13, Difference: 3.389790388957467


In [7]:
def confusion_matrix(true, preds):
    for c in range(10):
        TP, FP, FN, TN = 0, 0, 0, 0
        for n in range(N):
            if c == true[n]:
                if true[n] == preds[n]:
                    TP += 1
                else:
                    FP += 1 
            else:
                if c == preds[n]:
                    FN += 1
                else:
                    TN += 1
        
        confusion_matrix = pd.DataFrame([[TP,FP], [FN, TN]], index=("Is number {}".format(c),"Isn't number {}".format(c)),
                                        columns=("Predict number {}".format(c), "Predict number {}".format(c)))
        print(confusion_matrix)
        print()
        print("Sensitivity (Successfully predict number {}): {:.5f}".format(c, TP/(TP+FP)))
        print("Specificity (Successfully predict number {}): {:.5f}".format(c, TN/(FN+TN)))
        print("------------------------------------------------------------")
        
def error_rate(true, preds):
    error = 0
    for n in range(N):
        if true[n] != preds[n]:
            error += 1
    print("Total error rate: {}".format(error/N))

# result

![title](./group.PNG)

In [8]:
# groups: 因為是非監督式學習，所以你得到的是群(G0,G1,...,G9) 
groups = np.argmax(W, axis=1)

# 找出每個群對應的"類別"字典
group2class_dict = {}
for c in range(10):
    indexs = np.where(groups==c)[0]       # 首先找出Gi群的index
    if len(indexs)==0:                    # !這邊有個問題!: 常常只有八個群(或更少)，所以可能有群沒有index，直接跳過
        continue          
    
    temp = {}
    for i in indexs:                      # 用index找出Gi群裡面最多的真值當作，預測類別
        if true[i] in temp:
            temp[true[i]] += 1
        else:
            temp[true[i]] = 1
    classs = max(temp, key=temp.get)
    print("貼標class:",classs, ",真值數量:", temp)
    group2class_dict[c] = classs

貼標class: 1 ,真值數量: {1: 6222, 7: 153, 5: 45, 8: 296, 2: 114, 6: 258, 3: 205, 4: 49, 9: 72, 0: 3}
貼標class: 7 ,真值數量: {9: 2323, 4: 1195, 7: 2481, 3: 222, 5: 140, 8: 213, 2: 22, 1: 56, 6: 12, 0: 4}
貼標class: 4 ,真值數量: {4: 2409, 7: 585, 9: 1482, 3: 85, 8: 101, 5: 156, 2: 117, 6: 37, 0: 21, 1: 8}
貼標class: 6 ,真值數量: {2: 3987, 6: 5152, 1: 62, 7: 67, 4: 235, 0: 138, 5: 162, 9: 26, 8: 69, 3: 270}
貼標class: 7 ,真值數量: {4: 1654, 7: 2811, 9: 1808, 2: 11, 3: 32, 8: 235, 5: 155, 1: 6, 6: 1}
貼標class: 8 ,真值數量: {8: 1}
貼標class: 0 ,真值數量: {0: 4691, 5: 127, 3: 57, 6: 53, 9: 28, 8: 67, 7: 13, 2: 98, 4: 8}
貼標class: 8 ,真值數量: {3: 1158, 5: 2768, 8: 3029, 0: 804, 4: 290, 6: 331, 2: 473, 1: 355, 9: 91, 7: 85}
貼標class: 3 ,真值數量: {5: 1868, 3: 4102, 2: 1136, 9: 119, 8: 1840, 0: 262, 1: 33, 7: 70, 6: 74, 4: 2}


In [9]:
# 將群貼上類別標籤=preds
preds = np.zeros((60000))
for i in range(N):
    preds[i] = group2class_dict[groups[i]]

In [10]:
confusion_matrix(true, preds)     # 會有兩個類別混淆矩陣很怪
print("Total iteration to converge: {}".format(iteration))
error_rate(true, preds)

                Predict number 0  Predict number 0
Is number 0                 4691              1232
Isn't number 0               451             53626

Sensitivity (Successfully predict number 0): 0.79200
Specificity (Successfully predict number 0): 0.99166
------------------------------------------------------------
                Predict number 1  Predict number 1
Is number 1                 6222               520
Isn't number 1              1195             52063

Sensitivity (Successfully predict number 1): 0.92287
Specificity (Successfully predict number 1): 0.97756
------------------------------------------------------------
                Predict number 2  Predict number 2
Is number 2                    0              5958
Isn't number 2                 0             54042

Sensitivity (Successfully predict number 2): 0.00000
Specificity (Successfully predict number 2): 1.00000
------------------------------------------------------------
                Predict number 3  Pre

In [11]:
def plot(P, threshold):
    Pattern = np.asarray(P > threshold, dtype='uint8')
    for g in group2class_dict:
        print('class {}:'.format(group2class_dict[g]))
        
        temp = Pattern[g]
        for i in range(len(temp)):
            if i != 0 and i % 28 == 0:
                print()
            print(temp[i], end="")
        print()
        
    return

In [27]:
plot(P, threshold=0.55)    
# 改進與感想
# 1. 第二多的類別也許才是最佳代表分類 
# 2. threshold大小(0.55 => 0.65) 顯示類別的關鍵pixel
# 3. 由2發現其實類別6是一個大群，也許可以再往下分群得到比較正確的類別

class 1:
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000001100000000000
0000000000000011110000000000
0000000000000011110000000000
0000000000000011110000000000
0000000000000011110000000000
0000000000000111100000000000
0000000000000111100000000000
0000000000000111100000000000
0000000000000111100000000000
0000000000001111100000000000
0000000000001111000000000000
0000000000001111000000000000
0000000000001111000000000000
0000000000001111000000000000
0000000000011110000000000000
0000000000011110000000000000
0000000000011110000000000000
0000000000011110000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
class 7:
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000