In [6]:
import numpy as np
import math
import pandas as pd
import cv2
import matplotlib.pyplot as plt

In [7]:
from struct import unpack
import gzip
import numpy as np
from numpy import zeros, uint8, float32

def get_labeled_data(imagefile, labelfile):
    """Read input-vector (image) and target class (label, 0-9) and return
       it as list of tuples.
    """
    # Open the images with gzip in read binary mode
    images = gzip.open(imagefile, 'rb')
    labels = gzip.open(labelfile, 'rb')

    # Read the binary data

    # We have to get big endian unsigned int. So we need '>I'

    # Get metadata for images
    images.read(4)  # skip the magic_number
    number_of_images = images.read(4)
    number_of_images = unpack('>I', number_of_images)[0]
    rows = images.read(4)
    rows = unpack('>I', rows)[0]
    cols = images.read(4)
    cols = unpack('>I', cols)[0]

    # Get metadata for labels
    labels.read(4)  # skip the magic_number
    N = labels.read(4)
    N = unpack('>I', N)[0]

    if number_of_images != N:
        raise Exception('number of labels did not match the number of images')

    # Get the data
    x = zeros((N, rows, cols), dtype=float32)  # Initialize numpy array
    y = zeros((N, 1), dtype=uint8)  # Initialize numpy array
    for i in range(N):
        #if i % 1000 == 0:
           # print("i: %i" % i)
        for row in range(rows):
            for col in range(cols):
                tmp_pixel = images.read(1)  # Just a single byte
                tmp_pixel = unpack('>B', tmp_pixel)[0]
                x[i][row][col] = tmp_pixel
        tmp_label = labels.read(1)
        y[i] = unpack('>B', tmp_label)[0]
    return (x, y)

In [8]:
def resize(imagefile):
    num, row, col=imagefile.shape
    imagefile_resized = np.zeros((num,20,20))
    for i in range (num):
        temp=np.zeros((28,28))
        temp = np.clip(imagefile[i], 0, 255)
        temp = np.clip(imagefile[i], 0, 255).astype('uint8')
        contours=cv2.findContours(temp,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
        cnt = contours[0]
        x,y,w,h=cv2.boundingRect(cnt)    
        num_only=cnt[y:y+h,x:x+w]
        stretched =cv2.resize(num_only,(20,20))#     plt.imshow(stretched,cmap='Greys',interpolation='nearest')
#     plt.show()
#     print(stretched.shape)
        imagefile_resized[i]=stretched
    return imagefile_resized

In [9]:
def normalize(data):  
    m,n,x = np.array(data).shape
    data1 = np.zeros((m,n,x))
    for i in range(m):
        for j in range(n):
            for k in range(x):
                if data[i,j,k] > 127:
                    data1[i,j,k] = 1
                else:
                    data1[i,j,k] = 0
    return data1

In [10]:
def CalProb(train_data,train_label):
    num,row,col = train_data.shape
    #labelnum = len(set(train_label)) 
    pyj = np.zeros(10)
    pyjk = np.zeros((10,row,col))
    for i in range(num):
        label = train_label[i]
        pyj[label] = pyj[label] + 1 
        if label == 0:
            for j in range(row):
                for k in range(col):
                    pyjk[0][j][k] = pyjk[0][j][k] + train_data[i][j][k]
        if label == 1:
            for j in range(row):
                for k in range(col):
                    pyjk[1][j][k] = pyjk[1][j][k] + train_data[i][j][k]
        if label == 2:
            for j in range(row):
                for k in range(col):
                    pyjk[2][j][k] = pyjk[2][j][k] + train_data[i][j][k]
        if label == 3:
            for j in range(row):
                for k in range(col):
                    pyjk[3][j][k] = pyjk[3][j][k] + train_data[i][j][k]
        if label == 4:
            for j in range(row):
                for k in range(col):
                    pyjk[4][j][k] = pyjk[4][j][k] + train_data[i][j][k]
        if label == 5:
            for j in range(row):
                for k in range(col):
                    pyjk[5][j][k] = pyjk[5][j][k] + train_data[i][j][k]
        if label == 6:
            for j in range(row):
                for k in range(col):
                    pyjk[6][j][k] = pyjk[6][j][k] + train_data[i][j][k]
        if label == 7:
            for j in range(row):
                for k in range(col):
                    pyjk[7][j][k] = pyjk[7][j][k] + train_data[i][j][k]
        if label == 8:
            for j in range(row):
                for k in range(col):
                    pyjk[8][j][k] = pyjk[8][j][k] + train_data[i][j][k]
        if label == 9:
            for j in range(row):
                for k in range(col):
                    pyjk[9][j][k] = pyjk[9][j][k] + train_data[i][j][k] 
    #print('pyj个数：',pyj)
    pyjk = (pyjk.T + 1) / (pyj + 2) 
    pyj = (pyj + 1) / (num + 10) 
    return  pyj, pyjk

In [11]:
def CalBerProb(xk,pyjk): 
    return xk * np.log(pyjk) + (1-xk) * np.log(1-pyjk)

In [19]:
def predictBer(test_data,test_label,pyjk,pyj):  
    num,row,col = test_data.shape
    acc = 0
    for i in range(num):
        p = np.log(pyj)
        testdata = test_data[i]
        for j in range(row): 
            for k in range(col):
                xk = testdata[j][k]
                p[0] = p[0] + CalBerProb(xk,pyjk.T[0][j][k])
                p[1] = p[1] + CalBerProb(xk,pyjk.T[1][j][k])
                p[2] = p[2] + CalBerProb(xk,pyjk.T[2][j][k])
                p[3] = p[3] + CalBerProb(xk,pyjk.T[3][j][k])
                p[4] = p[4] + CalBerProb(xk,pyjk.T[4][j][k])
                p[5] = p[5] + CalBerProb(xk,pyjk.T[5][j][k])
                p[6] = p[6] + CalBerProb(xk,pyjk.T[6][j][k])
                p[7] = p[7] + CalBerProb(xk,pyjk.T[7][j][k])
                p[8] = p[8] + CalBerProb(xk,pyjk.T[8][j][k])
                p[9] = p[9] + CalBerProb(xk,pyjk.T[9][j][k])
        p = np.argmax(p)
        acc = acc + (p == test_label[i])
        #print('real is: ',test_label[i],'  predict is: ',p)
    print('Test accuracy is: ', acc/num)

In [21]:
def BerNB():
    imagefile = 'train-images-idx3-ubyte.gz'
    labelfile = 'train-labels-idx1-ubyte.gz'
    imagefile2 = 't10k-images-idx3-ubyte.gz'
    labelfile2 = 't10k-labels-idx1-ubyte.gz'
    train_data, train_label = get_labeled_data(imagefile, labelfile)
    test_data, test_label = get_labeled_data(imagefile2, labelfile2)
    train1 = normalize(train_data)
    test1 = normalize(test_data)
    pyj, pyjk = CalProb(train1,train_label)
    predictBer(test1, test_label, pyjk, pyj)

In [22]:
BerNB()

Test accuracy is:  [0.8427]


In [23]:
def BerNB_scratch():
    imagefile = 'train-images-idx3-ubyte.gz'
    labelfile = 'train-labels-idx1-ubyte.gz'
    imagefile2 = 't10k-images-idx3-ubyte.gz'
    labelfile2 = 't10k-labels-idx1-ubyte.gz'
    train_data, train_label = get_labeled_data(imagefile, labelfile)
    test_data, test_label = get_labeled_data(imagefile2, labelfile2)
    train1 = normalize(train_data)
    test1 = normalize(test_data)
    train = resize(train1)
    test = resize(test1)
    pyj, pyjk = CalProb(train,train_label)
    predictBer(test, test_label, pyjk, pyj)

In [24]:
BerNB_scratch()

Test accuracy is:  [0.8324]


In [18]:
def BerNB_traintest():
    imagefile = 'train-images-idx3-ubyte.gz'
    labelfile = 'train-labels-idx1-ubyte.gz'
    train_data, train_label = get_labeled_data(imagefile, labelfile)
    test_label = train_label
    train1 = normalize(train_data)
    test1 = train1
    pyj, pyjk = CalProb(train1,train_label)
    predictBer(test1, test_label, pyjk, pyj)

In [15]:
BerNB_traintest()

Test accuracy is:  [0.83576667]


In [17]:
def BerNB_traintest_stretched():
    imagefile = 'train-images-idx3-ubyte.gz'
    labelfile = 'train-labels-idx1-ubyte.gz'
    train_data, train_label = get_labeled_data(imagefile, labelfile)
    test_label = train_label
    train_data1 = resize(train_data)
    train1 = normalize(train_data1)
    test1 = train1
    pyj, pyjk = CalProb(train1,train_label)
    predictBer(test1, test_label, pyjk, pyj)

In [20]:
BerNB_traintest_stretched()

Test accuracy is:  [0.81345]
