# 识别自己手写数字 Recognize your own handwritten numbers

In [1]:
from PIL import Image
import sys
import time
import os
import joblib
import numpy as np
import matplotlib.pyplot as plt

# 获取指定路径下的所有 .png 文件 Get all .png files in the specified path
def get_file_list(path):
    return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".png")]
 
# 解析出 .png 图片文件的名称 Parse the name of the .png image file
def get_img_name_str(imgPath):
    return imgPath.split(os.path.sep)[-1]
 
 
# 将 28px * 28px 的图像数据转换成 1*784 的 numpy 向量 Convert 28px * 28px image data into 1*784 numpy vector
# 参数：imgFile--图像名  如：0_1.png Parameter: imgFile--image name, such as: 0_1.png
# 返回：1*784 的 numpy 向量 Returns: 1*784 numpy vector
def img2vector(imgFile):
    # print("in img2vector func--para:{}".format(imgFile))
    img = Image.open(imgFile).convert('L')
    img_arr = np.array(img, 'i')  # 28px * 28px 灰度图像 28px * 28px grayscale image
    img_normalization = np.round(img_arr / 255)  # 对灰度值进行归一化 Normalize the grayscale value
    img_arr2 = np.reshape(img_normalization, (1, -1))  # 1 * 400 矩阵 1 * 400 matrix
    return img_arr2
 
 
# 读取一个类别的所有数据并转换成矩阵 Read all the data of a category and convert it into a matrix
# 参数： parameter:
#    basePath: 图像数据所在的基本路径 basePath: The base path where the image data is located
#       MNIST-data/train/
#       MNIST-data/test/
#    cla：类别名称 cla: Category name
#       0,1,2,...,9
# 返回：某一类别的所有数据----[样本数量*(图像宽x图像高)] 矩阵 Returns: All data of a certain category----[sample number*(image width x image height)] matrix
def read_and_convert(imgFileList):
    dataLabel = []  # 存放类标签 Storage label
    dataNum = len(imgFileList)
    dataMat = np.zeros((dataNum, 784))  # dataNum * 784 的矩阵 Matrix of dataNum * 784
    for i in range(dataNum):
        imgNameStr = imgFileList[i]
        imgName = get_img_name_str(imgNameStr)  # 得到 当前数字的数字编号.png Get the digital number of the current number.png
        # print("imgName: {}".format(imgName))
        classTag = imgNameStr.split(os.path.sep)[-2]
        # classTag = imgName.split(".")[0].split("_")[0]  # 得到 类标签(数字) Get class label (number)
        #print(classTag)
        #print(imgNameStr)
        dataLabel.append(classTag)
        dataMat[i, :] = img2vector(imgNameStr)
    return dataMat, dataLabel
 
def svmtest(model_path):
    
    #图片路径 Image Path
    tbasePath = "image/"

    #加载模型 Loading the model
    clf = joblib.load(model_path)
    #获取文件列表 Get a list of files
    tflist = get_file_list(tbasePath)
    # tflist
    tdataMat, tdataLabel = read_and_convert(tflist)
    print("test dataMat shape: {0}, test dataLabel len: {1} ".format(tdataMat.shape, len(tdataLabel)))
    pre_st = time.perf_counter()
    #预测结果 forecast result
    preResult = clf.predict(tdataMat) 
    pre_et = time.perf_counter()
    print("Recognition  1 spent {:.4f}s.".format((pre_et - pre_st)))
    print("predict result: {}".format(len(preResult)))
    score = clf.score(tdataMat, tdataLabel)

 
if __name__ == '__main__':
    model_path='model/svm.model'
    svmtest(model_path)




test dataMat shape: (1, 784), test dataLabel len: 1 
Recognition  1 spent 0.0157s.
predict result: 1
