In [19]:
from mnist.loader import MNIST
import numpy as np
from scipy.optimize import linear_sum_assignment

In [20]:
def plot_confusion_matrix(c,TP,FN,FP,TN):
    print('------------------------------------------------------------')
    print()
    print('Confusion Matrix {}:'.format(c))
    print('\t\t\t  Predict number {} Predict not number {}'.format(c, c))
    print('Is number  \t{}\t\t{}\t\t\t\t{}'.format(c,TP,FN))
    print('Isn\'t number {}\t\t{}\t\t\t\t{}'.format(c,FP,TN))
    print()
    print('Sensitivity (Successfully predict number {}    ): {:.5f}'.format(c,TP/(TP+FN)))
    print('Specificity (Successfully predict not number {}): {:.5f}'.format(c,TN/(TN+FP)))
    print()

In [21]:
mndata = MNIST('data', return_type = 'numpy')
mndata.gz = True
train_images, train_labels = mndata.load_training()
# or
test_images, test_labels = mndata.load_testing()
train_images[train_images < 128] = 0
train_images[train_images >= 128] = 1
#print(train_images[0])

In [22]:
# init lambda
lam = np.random.rand(10)
lam = lam/np.sum(lam)
lam_last = lam
print(lam)
# init P
P = ((np.random.rand(10,784) * 8) + 1) / 10
P_last = P
print(P)
diff = 100
eps = 1
count = 0

[0.11517464 0.12861536 0.09865398 0.10539199 0.02737169 0.01987019
 0.15344624 0.07755176 0.1713451  0.10257906]
[[0.34483083 0.53205507 0.7094896  ... 0.69129812 0.67078551 0.12317658]
 [0.26212159 0.7570986  0.37157588 ... 0.6667813  0.62782587 0.79907219]
 [0.49628176 0.60303248 0.54922699 ... 0.27720541 0.81639982 0.52007273]
 ...
 [0.65420237 0.41608837 0.37170125 ... 0.21467118 0.38939088 0.10144206]
 [0.20421033 0.46021482 0.23285594 ... 0.6606782  0.77146902 0.60541426]
 [0.42920245 0.35204243 0.87580001 ... 0.87252023 0.44734205 0.52740091]]


In [23]:
while diff > eps and count < 15:
    count += 1
    # update W
    # E step
    W = np.zeros((60000, 10))
    W = W.astype(np.float64)
    for i in range(60000):
        for j in range(10):
            W[i, j] = np.prod(train_images[i]*P[j] + (1-train_images[i])*(1-P[j]))
            W[i, j] *= lam[j]
    #print(W)
    sums = np.sum(W, axis=1).reshape(-1, 1)
    #print(sums.shape)
    sums[sums==0] = 1
    W = W/sums
    #print(W)
    
    # M step
    # update lam
    lam = np.sum(W,axis=0)
    lam = lam/60000
    #print(np.sum(lam))

    # update P
    sums = np.sum(W,axis=0)
    sums[sums==0] = 1
    W_normalized = W/sums
    P = train_images.T@W_normalized
    P = P.T

    # calc diff
    diff=np.sum(np.abs(lam-lam_last))+np.sum(np.abs(P-P_last))
    lam_last = lam
    P_last = P
    

In [24]:
#predict classes belonging
maxs = np.argmax(W, axis=1)
unique,counts=np.unique(maxs,return_counts=True)
print(dict(zip(unique,counts)))
print('Lambda:',lam.reshape(1,-1))

# get GT
distribution = np.zeros((10, 784))
labels = np.zeros(10)
for i in train_labels:
    labels[i] += 1
for i in range(60000):
        c=train_labels[i]
        for j in range(784):
            if train_images[i, j]==1:
                distribution[c, j]+=1
distribution = distribution / labels.reshape(-1,1)

# plot GT
for c in range(10):
        print('class',c)
        for i in range(28):
            for j in range(28):
                print(1 if distribution[c,i*28+j]>0.5 else 0,end=' ')
            print()
        print()
        print()

Cost=np.zeros((10,10))
for i in range(10):
    for j in range(10):
        Cost[i,j] = np.linalg.norm(distribution[i] - P[j])
    
row_idx,col_idx=linear_sum_assignment(Cost)
class_order = col_idx
print('Class_order:',class_order)

{0: 4945, 1: 4230, 2: 6575, 3: 8191, 4: 5657, 5: 4933, 6: 7121, 7: 4357, 8: 7410, 9: 6581}
Lambda: [[0.08242335 0.07033943 0.10953044 0.13671932 0.09441596 0.08209499
  0.11873764 0.0725744  0.123474   0.10969046]]
class 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 1 1 1 1 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 

In [25]:
Pattern = np.asarray(P>0.5, dtype='uint8')
for i in range(10):
    print('class {}:'.format(i))
    c = class_order[i]
    pic = Pattern[c]
    # plot each class pattern
    for i in range(28):
        for j in range(28):
            print(pic[i*28+j], end=' ')
        print()
    print()
    print()

# confusion matrix
for k in range(10):
    c = class_order[k]
    TP,FN,FP,TN = 0,0,0,0
    for i in range(60000):
        if train_labels[i] != k and maxs[i] != c:
            TN += 1
        elif train_labels[i] == k and maxs[i] == c:
            TP += 1
        elif train_labels[i] != k and maxs[i] == c:
            FP += 1
        else:
            FN += 1
    plot_confusion_matrix(k, TP, FN, FP, TN)
# print error rate
print('Total iteration to converge: {}'.format(count))
real_transform = np.zeros(60000)
for i in range(60000):
    real_transform[i] = class_order[train_labels[i]]
error = np.count_nonzero(real_transform-maxs)
print('Total error rate: {}'.format(error/60000))

class 0:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 1 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 
0 0 0 0 0 0 1 1 1 0 0 

------------------------------------------------------------

Confusion Matrix 3:
			  Predict number 3 Predict not number 3
Is number  	3		1615				4516
Isn't number 3		2615				51254

Sensitivity (Successfully predict number 3    ): 0.26342
Specificity (Successfully predict not number 3): 0.95146

------------------------------------------------------------

Confusion Matrix 4:
			  Predict number 4 Predict not number 4
Is number  	4		2073				3769
Isn't number 4		2872				51286

Sensitivity (Successfully predict number 4    ): 0.35484
Specificity (Successfully predict not number 4): 0.94697

------------------------------------------------------------

Confusion Matrix 5:
			  Predict number 5 Predict not number 5
Is number  	5		2212				3209
Isn't number 5		5198				49381

Sensitivity (Successfully predict number 5    ): 0.40804
Specificity (Successfully predict not number 5): 0.90476

------------------------------------------------------------

Confusion Matrix 6:
			  Predict number 6