In [1]:
import numpy as np
from numba import jit

def GetImageData(f):
    # magic number
    f.read(4) 
    
    # number of images
    num = f.read(4)
    num = int.from_bytes(num, byteorder='big') #60000

    row = f.read(4)
    row = int.from_bytes(row, byteorder='big') #28

    column = f.read(4)
    column = int.from_bytes(column, byteorder='big') #28

    buf = f.read(row * column * num)
    data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
    data = data.reshape(num, row * column)
    
    return [data, num, row, column]

def GetLabelData(f):
    # magic number
    f.read(4) 

    # number of items
    num = f.read(4)
    num = int.from_bytes(num, byteorder='big') #10000

    buf = f.read(num)
    data = np.frombuffer(buf, dtype=np.uint8)
    data = data.reshape(num)
    
    return [data, num]

def GetIndexOfEachLabel(label_data):
    index = 0
    index_m = []
    for num in range(10):
        index_m.append([])
        for label in label_data:
            if label == num:
                index_m[num].append(index)
            index +=1
        index = 0
        
    return index_m

In [2]:
import numpy as np
import gzip

train_image = 'train-images-idx3-ubyte.gz'
train_label = 'train-labels-idx1-ubyte.gz'

test_image = 't10k-images-idx3-ubyte.gz'
test_label = 't10k-labels-idx1-ubyte.gz'

# Train
f = gzip.open(train_image,'rb')
data = GetImageData(f)
train_image_data, train_image_num, row, column = data

f = gzip.open(train_label,'rb')
data = GetLabelData(f)
train_label_data, train_label_num = data

# Test
f = gzip.open(test_image,'rb')
data = GetImageData(f)
test_image_data, test_image_num, row, column = data

f = gzip.open(test_label,'rb')
data = GetLabelData(f)
test_label_data, test_label_num = data

train_grey_data = np.where(train_image_data < 128, 0, 1)

c = 60000

w = np.zeros((c, 10)) # w[60000][10]
lam = np.zeros((10, )) # lam[10]
prob = np.zeros((row*column, 10)) # prob[784][10]

train_index = GetIndexOfEachLabel(train_label_data)

for i in range(len(train_index)):
    lam[i] = len(train_index[i])
lam = lam/train_image_num

peudoP = 0.5
prob += peudoP

@jit
def EM(lam, w, prob, c, train_grey_data):
    for n in range(c):
        for label in range(10):
            w[n][label] = lam[label]
            
    for t in range(15):
        # E-step
        for n in range(c):
            totalP = 0
            for pixel in range(row*column):
                for label in range(10):
                    w[n][label] *= (prob[pixel][label] ** train_grey_data[n][pixel]) * ((1 - prob[pixel][label]) ** (1 - train_grey_data[n][pixel]))

                if pixel % 10 == 0:
                    w[n] /= w[n].sum()

            w[n] /= w[n].sum()

        # M-step
        lamMLE = np.zeros((10, ))

        for label in range(10):
            totalW = w[:, label].sum()
            lamMLE[label] = totalW/c

        lamMLE = lamMLE/lamMLE.sum()    

        for label in range(10):
            for pixel in range(784):
                P = 0
                totalP = 0
                for i in range(c):
                    P += w[i][label]*train_grey_data[i][pixel]
                    totalP += w[i][label]
                prob[pixel][label] = P / totalP

        lam = lamMLE
        
    return prob

In [3]:
tmp_prob = EM(lam, w, prob, c, train_grey_data)

In [4]:
tmp_prob = np.reshape(tmp_prob, (28, 28, 10))

for label in range(10):
    print("\n\n")
    for i in range(28):
        for j in range(28):
            if tmp_prob[i][j][label] >= 0.45:
                print(1, end='')
            else:
                print(0, end='')
#             print(tmp_prob[label][i][j], end=" ")
        print()




0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000110000000000
0000000000000011111000000000
0000000000000111111000000000
0000000000000111110000000000
0000000000001111000000000000
0000000000001110000000000000
0000000000011100000000000000
0000000000011100000000000000
0000000000111100000000000000
0000000000111101111000000000
0000000000111011111100000000
0000000001110001111100000000
0000000001110000011100000000
0000000001100000011100000000
0000000001100000111100000000
0000000001110011111000000000
0000000001111111110000000000
0000000011111111100000000000
0000000001111110000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000



0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000000000000000000
0000000000000011111000000000
00000000

In [5]:
for i in range(10):
    print(w[:, i].max())
    print(list(w[:, i]).count(0))

1.0
13624
1.0
16362
1.0
3141
1.0
13693
1.0
24473
1.0
20458
1.0
8629
1.0
21528
1.0
6279
1.0
5865


In [6]:
train_label_data.shape

(60000,)

In [7]:
@jit
def ReloadP(lam, w, prob, c, train_grey_data):
    for n in range(c):
        for label in range(10):
            w[n][label] = lam[label]
    for n in range(c):
        totalP = 0
        for pixel in range(row*column):
            for label in range(10):
                w[n][label] *= (prob[pixel][label] ** train_grey_data[n][pixel]) * ((1 - prob[pixel][label]) ** (1 - train_grey_data[n][pixel]))

            if pixel % 10 == 0:
                w[n] /= w[n].sum()

        w[n] /= w[n].sum()

In [8]:
ReloadP(lam, w, prob, c, train_grey_data)

In [16]:
totalTP = 0
for label in range(10):
    TP, FN, FP, TN = 0, 0, 0, 0
    for n in range(60000):
        if train_label_data[n] == label and w[n][label] >= 0.01:
            TP += 1
            totalTP += 1
        elif train_label_data[n] == label and w[n][label] < 0.01:
            FN += 1
        elif train_label_data[n] != label and w[n][label] >= 0.01:
            FP += 1
        elif train_label_data[n] != label and w[n][label] < 0.01:
            TN += 1
    print(f'Confusion Matrix {label} :')
    print(f'                 Predict number {label}   Predict not number {label}')
    print(f'Is number {label}        {TP}                        {FN}')
    print(f'Is not number {label}    {FP}                       {TN}')
    print(f'Sensitivity (Successfully predict number {label}) : {TP/(TP + FN)}')
    print(f'Specificity (Successfully predict not number {label}) : {TN/(TN + FP)}\n')
print(f'error rate: {1-totalTP/60000}')

Confusion Matrix 0 :
                 Predict number 0   Predict not number 0
Is number 0        879                        5044
Is not number 0    6600                       47477
Sensitivity (Successfully predict number 0) : 0.14840452473408747
Specificity (Successfully predict not number 0) : 0.8779518094568856

Confusion Matrix 1 :
                 Predict number 1   Predict not number 1
Is number 1        0                        6742
Is not number 1    5026                       48232
Sensitivity (Successfully predict number 1) : 0.0
Specificity (Successfully predict not number 1) : 0.905629201246761

Confusion Matrix 2 :
                 Predict number 2   Predict not number 2
Is number 2        4170                        1788
Is not number 2    3808                       50234
Sensitivity (Successfully predict number 2) : 0.6998992950654582
Specificity (Successfully predict not number 2) : 0.92953628659191

Confusion Matrix 3 :
                 Predict number 3   Predict not n

In [10]:
w[1][0]

2.737856059607368e-35