In [1]:
import numpy as np 
import struct
import matplotlib.pyplot as plt

class DataUtils(object):
    
    def __init__(self, filename=None, outpath=None):
        self._filename = filename
        self._outpath = outpath
        
        self._tag = '>'
        self._twoBytes = 'II'
        self._fourBytes = 'IIII'    
        self._pictureBytes = '784B'
        self._labelByte = '1B'
        self._twoBytes2 = self._tag + self._twoBytes
        self._fourBytes2 = self._tag + self._fourBytes
        self._pictureBytes2 = self._tag + self._pictureBytes
        self._labelByte2 = self._tag + self._labelByte
    
    def getImage(self):
        """
        将MNIST的二进制文件转换成像素特征数据
        """
        binfile = open(self._filename, 'rb') #以二进制方式打开文件
        buf = binfile.read() 
        binfile.close()
        index = 0
        numMagic,numImgs,numRows,numCols=struct.unpack_from(self._fourBytes2,\
                                                                    buf,\
                                                                    index)
        index += struct.calcsize(self._fourBytes)
        images = []
        for i in range(numImgs):
            imgVal = struct.unpack_from(self._pictureBytes2, buf, index)
            index += struct.calcsize(self._pictureBytes2)
            imgVal = list(imgVal)
            avg=sum(imgVal)/(len(imgVal))
            for j in range(len(imgVal)):
                if imgVal[j] >= avg:
                    imgVal[j] = 256
                else: imgVal[j]=0
            images.append(imgVal)
        return np.array(images)
        
    def getLabel(self):
        """
        将MNIST中label二进制文件转换成对应的label数字特征
        """
        binFile = open(self._filename,'rb')
        buf = binFile.read()
        binFile.close()
        index = 0
        magic, numItems= struct.unpack_from(self._twoBytes2, buf,index)
        index += struct.calcsize(self._twoBytes2)
        labels = [];
        for x in range(numItems):
            im = struct.unpack_from(self._labelByte2,buf,index)
            index += struct.calcsize(self._labelByte2)
            labels.append(im[0])
        return np.array(labels)


In [2]:
def main():
        trainfile_X = '/Users/luojp/Desktop/pro1/train-images-idx3-ubyte'
        trainfile_y = '/Users/luojp/Desktop/pro1/train-labels-idx1-ubyte'
        testfile_X = '/Users/luojp/Desktop/pro1/t10k-images-idx3-ubyte'
        testfile_y = '/Users/luojp/Desktop/pro1/t10k-labels-idx1-ubyte'
        
        train_X = DataUtils(filename=trainfile_X).getImage()
        train_y = DataUtils(filename=trainfile_y).getLabel()
        test_X = DataUtils(testfile_X).getImage()
        test_y = DataUtils(testfile_y).getLabel()

        return train_X, train_y, test_X, test_y 

In [3]:
train_X, train_y, test_X, test_y = main()

In [4]:
import random
from sklearn.utils import safe_indexing
whole_index=[]
for majority_class in range(0,10):
    index_arr=np.flatnonzero(train_y == majority_class)
    list=[]
    while(1):
        random_index=np.random.randint(0,len(index_arr))
        if(random_index not in list):
            list.append(index_arr[random_index])
        if(len(list)==5): break;
    whole_index+=list

In [5]:
sample_X=safe_indexing(train_X,whole_index)
sample_y=safe_indexing(train_y,whole_index)

In [6]:
from sklearn.neighbors import KNeighborsClassifier
import numpy
neigh=KNeighborsClassifier(n_neighbors=3,algorithm='kd_tree')
neigh.fit(sample_X,sample_y)

KNeighborsClassifier(algorithm='kd_tree', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=3, p=2,
           weights='uniform')

In [7]:
for index in random.sample(range(0, len(train_y)), 60000):
    target_image=train_X[index]
    prediction=neigh.predict([target_image])
    if(prediction!=train_y[index]):
        whole_index=numpy.append(whole_index,index)
        sample_X=safe_indexing(train_X,whole_index)
        sample_y=safe_indexing(train_y,whole_index)
        neigh.fit(sample_X,sample_y)
    if len(sample_X)==1000: break;

In [9]:
for index in random.sample(range(0, len(train_y)), 10000-4540):
    target_image=train_X[index]
    whole_index=numpy.append(whole_index,index)
sample_X=safe_indexing(train_X,whole_index)
sample_y=safe_indexing(train_y,whole_index)

In [8]:
len(sample_X)

1000

In [9]:
random_list=random.sample(range(0, len(train_y)), 1000)

In [10]:
random_X=safe_indexing(train_X,random_list)
random_y=safe_indexing(train_y,random_list)

In [11]:
neigh.fit(random_X,random_y)
random_prediction=neigh.predict(test_X)
random_correct=[(a==b) for (a,b) in zip(random_prediction,test_y)]
random_acc=sum(random_correct)*1.0/len(random_correct)
print random_acc

0.8911


In [12]:
neigh.fit(sample_X,sample_y)
sample_prediction=neigh.predict(test_X)
sample_correct=[(a==b) for (a,b) in zip(sample_prediction,test_y)]
sample_acc=sum(sample_correct)*1.0/len(sample_correct)
print sample_acc

0.9007
