In [26]:
# knn实现手写数字识别
import numpy as np
import operator
from os import listdir


# 将一张图片转换为向量
# 手写数字图片为32*32，转换为1*1024
def img2vector(filename):
    return_vector = np.zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        line_str = fr.readline()
        for j in range(32):
            return_vector[0, 32*i+j] = int(line_str[j])
    return return_vector


# k近邻分类算法
# input_x为用于待分类的输入向量
# dataset为训练样本数据集，矩阵，numpy二维数组
# labels为训练数据集的标签
# k为选择最近邻居的数量
def knn(input_x, dataset, labels, k):
    dataset_size = dataset.shape[0]
    # 计算已知类别数据集中的点与当前点的距离
    distance_mat = np.tile(input_x, (dataset_size, 1)) - dataset
    distance_mat = distance_mat**2
    sq_distance_list = distance_mat.sum(axis=1)
    distance_list = sq_distance_list**0.5
    sorted_dist_index_list = distance_list.argsort()  # 递增排序，获取到的是排序后的索引 
    # 选择距离最小的k个点
    class_count_dict = {}  # 前k个距离最近的点钟每个类别及出现的次数
    for i in range(k):
        vote_i_label = labels[sorted_dist_index_list[i]]
        class_count_dict[vote_i_label] = class_count_dict.get(vote_i_label, 0) + 1
    sorted_class_count_list = sorted(class_count_dict.items(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count_list[0][0]  # [0][0]选取次数出现最多的类别的类名，sorted_class_count为list格式


# 手写数字识别测试代码
def hand_writing_digits_test():
    training_data_labels = []
    training_data_file_list = listdir('./datasets/digits/trainingDigits')
    num_of_training_data = len(training_data_file_list)
    training_data_mat = np.zeros((num_of_training_data, 1024))
    for i in range(num_of_training_data):
        file_name_str = training_data_file_list[i]
        file_str = file_name_str.split('.')[0]
        class_num = int(file_str.split('_')[0])
        training_data_labels.append(class_num)
        training_data_mat[i,:] = img2vector('./datasets/digits/trainingDigits/%s' % file_name_str)
    test_file_list = listdir('./datasets/digits/testDigits')
    error_count = 0
    num_of_test_data = len(test_file_list)
    for i in range(num_of_test_data):
        file_name_str = test_file_list[i]
        file_str = file_name_str.split('.')[0]
        class_num = int(file_str.split('_')[0])  # ground truth
        test_data_vector = img2vector('./datasets/digits/testDigits/%s' % file_name_str)
        classifier_result = knn(test_data_vector, training_data_mat, training_data_labels, 3)
        #print('The classifer came back with: %d, the real answer is: %d' % (classifier_result, class_num))
        if classifier_result != class_num:
            error_count += 1
    print('The total error rate is: %f' % (error_count/float(num_of_test_data)))


# main函数
def main():
    hand_writing_digits_test()


if __name__ == '__main__':
    main()

The total error rate is: 0.010571
