In [50]:
import numpy as np
import matplotlib.pyplot as plt
import math
import random
train_img_path = "./train-images.idx3-ubyte"
test_img_path = './t10k-images.idx3-ubyte'
train_label_path = './train-labels.idx1-ubyte'
test_label_path = './t10k-labels.idx1-ubyte'


# head_length为数据头文件的大小（头文件包含了该数据集的一些基本信息，如每条数据的维度多少，每个维度的大小等）
# sample_count记录了该文件包含了多少条数据--训练集中有6万条数据
def get_head_info(filename):
    dimension = []
    with open(filename,'rb') as pf:
        data = pf.read(4)#获取magic number
        magic_num = int.from_bytes(data,byteorder='big')#bytes格式大尾端模式转换为int型
        dimension_cnt = magic_num & 0xff #获取dimension的长度,magic number最后一个字节
        for i in range(dimension_cnt):
            data = pf.read(4)  #获取dimension数据，dimension[0]表示图片的个数,图片文件中dimension[1][2]分别表示其行/列数值
            dms = int.from_bytes(data,byteorder='big')
            dimension.append(dms)
            
    sample_count = dimension[0]
    head_length = 4*len(dimension)+4
    return head_length ,sample_count


#mnist单个图片的大小
IMAGE_ROW = 28
IMAGE_COL = 28 

# 根据偏移量读取一张图片
# head_len为上个函数返回的head_length值
def read_image_p(pf,head_len,offset):
    image = np.zeros((IMAGE_ROW*IMAGE_COL),dtype=np.uint8)#创建空白数组存放图片，图片被拉成一维来存储
    pf.seek(head_len+IMAGE_ROW*IMAGE_COL*offset) #指向offset个图片的位置  
    for loc in range(IMAGE_ROW*IMAGE_COL):
        data = pf.read(1)#单个字节读
        pix = int.from_bytes(data,byteorder='big')#byte转为int
        image[loc] = pix
    return image


# 一次性读取全部的标签
def load_labels(filename):
    labels = []
    with open(filename,'rb') as l:
        data = l.read(4)
        magic = int.from_bytes(data,byteorder='big')
        data = l.read(4)
        cnt = int.from_bytes(data,byteorder='big')
        for offset in range(cnt):
            data = l.read(1)
            label = int.from_bytes(data,byteorder='big')
            labels.append(label)
    return labels

#  一次性读取进来全部的图片
def load_pics(filename):
    pics = []
    head_len,sample_cnt = get_head_info(filename)
    p = open(filename,'rb')
    for offset in range(sample_cnt):
        pic = read_image_p(p,head_len,offset)
        pics.append(pic)
    return np.array(pics)
        

def get_database():
    pics = load_pics(train_img_path)
    labels = load_labels(train_label_path)
    return pics,labels

# with open(filename,'rb') as f:
#     img = read_image_p(f,head_len,1)
#     print(img)
#     plt.imshow(img)
#     plt.title('num')
#     plt.show()




In [55]:
# 计算两个向量的距离
def get_distance(vec_a,vec_b,type='euler'):
    if type == 'euler':
        diff = vec_a - vec_b
        distance = math.sqrt(np.sum(diff**2))
        return distance

# 对图像进行均一化
def get_mean(pics):
    # 将图像展开
    mean_pics = np.mean(pics,axis=0)
    pics = pics.astype(np.float)
    pics -= mean_pics
    return pics


# 返回给定向量周围K个最相邻的向量的标签
def search(vec,K,samples,labels):
    # 存储vec到所有其他样本的距离 
    distance = []
    res = []
    for sample in samples:
        dis = get_distance(vec,sample)
        distance.append(dis)
    script = np.argsort(distance)
    for i in range(len(script)):
        res.append(labels[script[i]])
    return res[:K]

# 在K个向量中统计各个标签的数量
def cnt_labels(labels):
    stat = {}
    # 临时保存当前出现最多的标签
    temp = labels[0]
    for label in labels:
        if label not in stat:
            stat[label] = 1
        else:
            stat[label] += 1
    for key in stat:
        if stat[key]>stat[temp]:
            temp = key
    return temp

#给定一张图片，用knn来确定它的标签
def knn(input,data,labels,K):
    k_neibors = search(input,K,data,labels)
    res = cnt_labels(k_neibors)
    return res

# 对测试集中的图片进行分类，返回每个图片的预测标签
def predict(test_path,database,labels,K):
    head_len_test,sample_cnt = get_head_info(test_path)
    t = open(test_path,'rb')
    results = []
    test = load_pics(test_path)
    test = get_mean(test)
    for offset in range(1000):
        pic = test[offset]
        res = knn(pic,database,labels,K)
        results.append(res)
    t.close()
    return results

def acc(pred_labels,origin_labels):
    correct = 0
    for i in range(len(pred_labels)):
        if pred_labels[i] == origin_labels[i]:
            correct += 1
    return correct/len(pred_labels)


In [None]:
pics,labels = get_database()
pics = get_mean(pics)
res = predict(test_img_path,pics,labels,3)
test_labels = load_labels(test_label_path)
test_labels = test_labels[:1000]
accuracy = acc(res,test_labels)
print(res)
print(test_labels)
print(accuracy)