## K近邻法手写识别

## 处理数据

In [1]:
from numpy import *
import os
import operator
def img2vector(filename):
    imgMat = loadtxt(filename, converters = {0: lambda x: array(list(x.decode()), dtype=int)})
    imgVector = imgMat.ravel()
    return imgVector

img2vector('./input/testDigits/0_2.txt')

array([ 0.,  0.,  0., ...,  0.,  0.,  0.])

## 手写识别系统

In [2]:
def classify(inputX, dataSet, label, k):
    '''
    Input:
        inputX:   vector to compare to existing dataset (1xN);
        dataSet:  known data set (MxN);
        label:    data set label (1xM);
        k:        number of neighbors to use
    '''
    # caculate O-distance
    rows = dataSet.shape[0];
    diffMat = tile(inputX,(rows,1)) - dataSet;
    distance = sqrt(sum(diffMat ** 2,1));
    # select kth points with smallest distance
    index = argsort(distance);
    count = {};
    for i in range(k):
        ithlabel = label[index[i]]
        count[ithlabel] = count.get(ithlabel,0) + 1
    sortedClassCount = sorted(count.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

def handwritingClassTest():
    # read test data
    trainingFileList = os.listdir('./input/trainingDigits')
    fileNumber = len(trainingFileList)
    trainingMat = zeros((fileNumber,1024))
    label = []
    for n,files in enumerate(trainingFileList):
        filename = files.split('.')[0]
        filelabel = int(filename.split('_')[0])
        label.append(filelabel)
        trainingMat[n,:] = img2vector('./input/trainingDigits/%s' % files)
    
    testFileList = os.listdir('./input/testDigits')
    testFileNumber = len(testFileList)
    errorCount = 0
    summary = {}
    for n,files in enumerate(testFileList):
        testvector = img2vector('./input/testDigits/%s' % files)
        filename = files.split('.')[0]
        givenLabel = int(filename.split('_')[0])
        classified = classify(testvector, trainingMat, label, 3)
        summary[filename] = classified
        if classified != givenLabel:
            errorCount += 1
            print("The classifier come back with %d, the real answer is %d" %(givenLabel, classified))
    print("total error is %f" %(errorCount/testFileNumber))

In [3]:
handwritingClassTest()

The classifier come back with 1, the real answer is 7
The classifier come back with 3, the real answer is 9
The classifier come back with 5, the real answer is 3
The classifier come back with 5, the real answer is 6
The classifier come back with 8, the real answer is 6
The classifier come back with 8, the real answer is 3
The classifier come back with 8, the real answer is 1
The classifier come back with 8, the real answer is 1
The classifier come back with 9, the real answer is 1
The classifier come back with 9, the real answer is 7
total error is 0.010571
