In [72]:
from mnist import MNIST
import numpy as np
from scipy.optimize import linear_sum_assignment
from util import *

In [73]:
mndata = MNIST('data', return_type = 'numpy')
mndata.gz = True
train_images, train_labels = mndata.load_training()
# or
test_images, test_labels = mndata.load_testing()
train_images[train_images < 128] = 0
train_images[train_images >= 128] = 1
#print(train_images[0])

In [74]:
# init lambda
lam = np.random.rand(10)
lam = lam/np.sum(lam)
lam_last = lam
print(lam)
# init P
P = ((np.random.rand(10,784) * 8) + 1) / 10
P_last = P
print(P)
diff = 100
eps = 1
count = 0



[0.08356023 0.19012468 0.16636728 0.05711445 0.03137817 0.05605658
 0.02453316 0.14243347 0.17719872 0.07123327]
[[0.3306048  0.65403366 0.42846179 ... 0.13574808 0.50315622 0.23075331]
 [0.73027799 0.33520706 0.58260496 ... 0.2328892  0.87006403 0.12510935]
 [0.85694551 0.77901818 0.43700837 ... 0.11534167 0.53501739 0.78906704]
 ...
 [0.71908245 0.8418705  0.79432999 ... 0.38480704 0.38779311 0.44944844]
 [0.63733325 0.62262109 0.13190097 ... 0.66969099 0.71380148 0.25551638]
 [0.26265183 0.49146708 0.30570761 ... 0.1740134  0.33949431 0.28865669]]


In [75]:
while diff > eps and count < 15:
    count += 1
    # update W
    # E step
    W = np.zeros((60000, 10))
    W = W.astype(np.float64)
    for i in range(60000):
        for j in range(10):
            W[i, j] = np.prod(train_images[i]*P[j] + (1-train_images[i])*(1-P[j]))
            W[i, j] *= lam[j]
    #print(W)
    sums = np.sum(W, axis=1).reshape(-1, 1)
    #print(sums.shape)
    sums[sums==0] = 1
    W = W/sums
    #print(W)
    
    # M step
    # update lam
    lam = np.sum(W,axis=0)
    lam = lam/60000
    print(np.sum(lam))

    # update P
    sums = np.sum(W,axis=0)
    sums[sums==0] = 1
    W_normalized = W/sums
    P = train_images.T@W_normalized
    P = P.T

    # calc diff
    diff=np.sum(np.abs(lam-lam_last))+np.sum(np.abs(P-P_last))
    lam_last = lam
    P_last = P
    
    #break


0.9999999999999962
0.9999999999999859
0.999999999999996
0.9999999999999946
0.9999999999999957
0.9999999999999961
0.9999999999999983
0.9999999999999964
0.9999999999999979
0.9999999999999984
0.9999999999999972
0.9999999999999976
0.9999999999999976
0.9999999999999976
0.9999999999999981


In [76]:
#predict classes belonging
maxs = np.argmax(W, axis=1)
unique,counts=np.unique(maxs,return_counts=True)
print(dict(zip(unique,counts)))
print('Lambda:',lam.reshape(1,-1))

# get GT
distribution = np.zeros((10, 784))
labels = np.zeros(10)
for i in train_labels:
    labels[i] += 1
for i in range(60000):
        c=train_labels[i]
        for j in range(784):
            if train_images[i, j]==1:
                distribution[c, j]+=1
distribution = distribution / labels.reshape(-1,1)

# plot GT
for c in range(10):
        print('class',c)
        for i in range(28):
            for j in range(28):
                print(1 if distribution[c,i*28+j]>0.5 else 0,end=' ')
            print()
        print()
        print()

Cost=np.zeros((10,10))
for i in range(10):
    for j in range(10):
        Cost[i,j] = np.linalg.norm(distribution[i] - P[j])
    
row_idx,col_idx=linear_sum_assignment(Cost)
class_order = col_idx
print('Class_order:',class_order)



{0: 8756, 1: 2853, 2: 5336, 3: 4413, 4: 6359, 5: 3166, 6: 7216, 7: 7359, 8: 5536, 9: 9006}
Lambda: [[0.14609671 0.04760586 0.08899257 0.07352451 0.10602787 0.0527075
  0.12013831 0.12269749 0.09224192 0.14996726]]
class 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 1 1 1 1 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0

In [77]:
Pattern = np.asarray(P>0.5, dtype='uint8')
for i in range(10):
    print('class {}:'.format(i))
    c = class_order[i]
    pic = Pattern[c]
    # plot each class pattern
    for i in range(28):
        for j in range(28):
            print(pic[i*28+j], end=' ')
        print()
    print()
    print()

# confusion matrix
for k in range(10):
    c = class_order[k]
    TP,FN,FP,TN = 0,0,0,0
    for i in range(60000):
        if train_labels[i] != k and maxs[i] != c:
            TN += 1
        elif train_labels[i] == k and maxs[i] == c:
            TP += 1
        elif train_labels[i] != k and maxs[i] == c:
            FP += 1
        else:
            FN += 1
    plot_confusion_matrix(k, TP, FN, FP, TN)
# print error rate
print('Total iteration to converge: {}'.format(count))
real_transform = np.zeros(60000)
for i in range(60000):
    real_transform[i] = class_order[train_labels[i]]
error = np.count_nonzero(real_transform-maxs)
print('Total error rate: {}'.format(error/60000))



class 0:
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 
0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 
0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 
0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 
0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 
0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 
0 0 0 0 0 0 1 1 1 0 0 