# EM algorithm
參考:https://mlbhanuyerra.github.io/2018-01-28-Handwritten-Digits_Mixture-of-Bernoulli-Distributions/

In [26]:
import numpy as np
import pandas as pd 
from loadMnist import load_images, load_labels

In [27]:
# HW3讀資料
train_data = load_images("data/train-images.idx3-ubyte")
true       = load_labels("data/train-labels.idx1-ubyte")

In [28]:
# train_data to 0/1 矩陣
train_data[train_data > 0] = 1

In [16]:
# 每個類別出現機率lambda(L)
# L:(10,1)
L = np.random.rand(10)
L = L/L.sum()

# 每個類別 每個pixel的機率
# P:(10, 784)
P = np.random.rand(10, 784)


In [17]:
# E step
def Estep(train_data, P, L):
    # 更新計算每張圖片 每個類別的Bernoulli機率分布
    W = np.zeros((60000, 10))
    for i in range(60000):
        for j in range(10):
            # 一張圖片(i) 一個類別(j) 每個pixel機率相乘 
            W[i, j] = np.prod( train_data[i] * P[j] + 
                             (1 - train_data[i]) * (1 - P[j])
                            )


    # 計算每個類別的responsibility
    W = W * L.reshape(1, -1)   # L變(1, 10)
    row_sums = np.sum(W,axis=1).reshape(-1,1)
    row_sums[row_sums==0] = 1
    W = W / row_sums
    
    return W

In [6]:
# M-step
def Mstep(W, train_data):
    # 更新 每個類別出現機率lambda(L)
    L = (W.sum(axis=0) / 60000).T

    # 更新 P
    col_sums = np.sum(W, axis=0)
    col_sums[col_sums==0] = 1
    W_normalized = W/col_sums
    P = (train_data.T@W_normalized).T
    return L, P

In [18]:
# 迭代開始:
new_diff = 100
diff = 10
iter_converge_no = 1

for i in range(1, 1000):
    W            = Estep(train_data, P, L)
    new_L, new_P = Mstep(W, train_data)
    new_diff     = np.sum(np.abs(L - new_L)) + np.sum(np.abs(P - new_P))
    
    if abs(new_diff-diff) < 1:  # 定義收斂
        iter_converge_no = i    # 收斂次數
        print('No. of Iteration: {}'.format(i), ', Difference: {}'.format(new_diff))
        L    = new_L
        P    = new_P
        break
    else:
        print('No. of Iteration: {}'.format(i), ', Difference: {}'.format(new_diff))
        L    = new_L
        P    = new_P
        diff = new_diff

No. of Iteration: 1 , Difference: 3252.598972091834
No. of Iteration: 2 , Difference: 260.47783653265014
No. of Iteration: 3 , Difference: 104.97206466702296
No. of Iteration: 4 , Difference: 64.73169533664179
No. of Iteration: 5 , Difference: 50.03857371683139
No. of Iteration: 6 , Difference: 35.870180224859546
No. of Iteration: 7 , Difference: 23.878947309638924
No. of Iteration: 8 , Difference: 16.84308972233303
No. of Iteration: 9 , Difference: 12.411290419187356
No. of Iteration: 10 , Difference: 9.058199349615675
No. of Iteration: 11 , Difference: 7.412604512808731
No. of Iteration: 12 , Difference: 6.321506961075096
No. of Iteration: 13 , Difference: 5.697720986481477


In [8]:
# 預測
predict = np.argmax(W, axis=1)

In [9]:
TP, FP, FN, TN = 0, 0, 0, 0
# 計算confusion matrix
for i in range(10):
    for j in range(60000):
        if predict[j] == i:  # i類別
            if predict[j] == true[j]:
                TP += 1
            else:            # predict[j] != true[j]
                FN += 1
        else:
            if true[j] == i:
                FP += 1
            else:
                TN += 1
                    
    # 
    print("Confusion Matrix: ", i)
    confusion_matrix = pd.DataFrame([[TP, FP], [FN, TN]], 
                                    index=("Is number {}".format(i), "Isn't number {}".format(i)),
                                    columns=("Predict number {}".format(i), "Predict not number {}".format(i)))
    print(confusion_matrix)
    print()
    print("Sensitivity (Successfully predict number {}): {}".format(i, TP/(TP+FP)))
    print("Specificity (Successfully predict number {}): {}".format(i, TN/(FN+TN)))
    print()
    print("---------------------------------------------------")

print("Total iteration to converge: ", iter_converge_no)
print("Total error rate: ", (predict!=true).sum()/60000)

Confusion Matrix:  0
                Predict number 0  Predict not number 0
Is number 0                    5                  5918
Isn't number 0              6049                 48028

Sensitivity (Successfully predict number 0): 0.0008441668073611346
Specificity (Successfully predict number 0): 0.8881409841522274

---------------------------------------------------
Confusion Matrix:  1
                Predict number 1  Predict not number 1
Is number 1                   36                 12629
Isn't number 1             16179                 91156

Sensitivity (Successfully predict number 1): 0.0028424792735886302
Specificity (Successfully predict number 1): 0.8492663157404388

---------------------------------------------------
Confusion Matrix:  2
                Predict number 2  Predict not number 2
Is number 2                  428                 18195
Isn't number 2             24967                136410

Sensitivity (Successfully predict number 2): 0.022982333673414596
Speci