# 测试MNIST数据的手写数字 Test handwritten digits on MNIST data

In [None]:
import sys
import time
from PIL import Image
import os
import joblib
import numpy as np
import matplotlib.pyplot as plt
#忽略警告 Ignore warnings
import warnings
warnings.filterwarnings('ignore')

# 获取指定路径下的所有 .png 文件 Get all .png files in the specified path
def get_file_list(path):
    return [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".png")]

# 解析出 .png 图片文件的名称 Parse the name of the .png image file
def get_img_name_str(imgPath):
    return imgPath.split(os.path.sep)[-1]
 
# 将 28px * 28px 的图像数据转换成 1*784 的 numpy 向量 Convert 28px * 28px image data into 1*784 numpy vector
# 参数：imgFile--图像名  如：0_1.png Parameter: imgFile--image name, such as: 0_1.png
# 返回：1*784 的 numpy 向量 Returns: 1*784 numpy vector
def img2vector(imgFile):
    # print("in img2vector func--para:{}".format(imgFile))
    img = Image.open(imgFile).convert('L')
    img_arr = np.array(img, 'i')  # 28px * 28px 灰度图像 28px * 28px grayscale image
    img_normalization = np.round(img_arr / 255)  # 对灰度值进行归一化 Normalize the grayscale value
    img_arr2 = np.reshape(img_normalization, (1, -1))  # 1 * 400 矩阵 1 * 400 matrix
    return img_arr2

# 读取一个类别的所有数据并转换成矩阵 Read all the data of a category and convert it into a matrix
# 参数： parameter:
#    basePath: 图像数据所在的基本路径 basePath: The base path where the image data is located
#       MNIST-data/train/
#       MNIST-data/test/
#    cla：类别名称 cla: Category name
#       0,1,2,...,9
# 返回：某一类别的所有数据----[样本数量*(图像宽x图像高)] 矩阵 Returns: All data of a certain category----[sample number*(image width x image height)] matrix
def read_and_convert(imgFileList):
    dataLabel = []  # 存放类标签 Storage label
    dataNum = len(imgFileList)
    dataMat = np.zeros((dataNum, 784))  # dataNum * 784 的矩阵 Matrix of dataNum * 784
    for i in range(dataNum):
        imgNameStr = imgFileList[i]
        imgName = get_img_name_str(imgNameStr)  # 得到 当前数字的数字编号.png Get the digital number of the current number.png
        # print("imgName: {}".format(imgName))
        classTag = imgNameStr.split(os.path.sep)[-2]
        # classTag = imgName.split(".")[0].split("_")[0]  # 得到 类标签(数字) Get class label (number)
        #print(classTag)
        #print(imgNameStr)
        dataLabel.append(classTag)
        dataMat[i, :] = img2vector(imgNameStr)
    return dataMat, dataLabel

def svmtest(model_path):

    tbasePath = "MNIST_data/test/"
    tcName = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    tst = time.perf_counter()
    allErrCount = 0
    allErrorRate = 0.0
    allScore = 0.0
    ErrCount=np.zeros(10,int)
    TrueCount=np.zeros(10,int)
    #加载模型 Loading the model
    clf = joblib.load(model_path)
    for tcn in tcName:
        testPath = tbasePath + tcn
        # print("class " + tcn + " path is: {}.".format(testPath))
        tflist = get_file_list(testPath)
        # tflist
        tdataMat, tdataLabel = read_and_convert(tflist)
        print("test dataMat shape: {0}, test dataLabel len: {1} ".format(tdataMat.shape, len(tdataLabel)))
        # print("test dataLabel: {}".format(len(tdataLabel)))
        pre_st = time.perf_counter()
        preResult = clf.predict(tdataMat)
        pre_et = time.perf_counter()
        print("Recognition  " + tcn + " spent {:.4f}s.".format((pre_et - pre_st)))
        # print("predict result: {}".format(len(preResult)))
        errCount = len([x for x in preResult if x != tcn])
        ErrCount[int(tcn)]=errCount
        TrueCount[int(tcn)]= len(tdataLabel)-errCount
        print("errorCount: {}.".format(errCount))
        allErrCount += errCount
        score_st = time.perf_counter()
        score = clf.score(tdataMat, tdataLabel)
        score_et = time.perf_counter()
        print("computing score spent {:.6f}s.".format(score_et - score_st))
        allScore += score
        print("score: {:.6f}.".format(score))
        print("error rate is {:.6f}.".format((1 - score)))
 
    #tet = perf_counter()
    tet = time.process_time()
    print("Testing All class total spent {:.6f}s.".format(tet - tst))
    print("All error Count is: {}.".format(allErrCount))
    avgAccuracy = allScore / 10.0
    print("Average accuracy is: {:.6f}.".format(avgAccuracy))
    print("Average error rate is: {:.6f}.".format(1 - avgAccuracy))
    print("number"," TrueCount"," ErrCount")
    for tcn in tcName:
        tcn=int(tcn)
        print(tcn,"     ",TrueCount[tcn],"      ",ErrCount[tcn])
    plt.figure(figsize=(12, 6))
    x=list(range(10))
    plt.plot(x,TrueCount, color='green', label="TrueCount")  # 将正确的数量设置为绿色 Set the correct amount to green
    plt.plot(x,ErrCount, color='red', label="ErrCount")      # 将错误的数量为红色  The wrong number is colored red
    plt.legend(loc='best')  # 显示图例的位置，这里为右下方 Displays the location of the legend, here is the lower right
    plt.title('Projects')
    plt.xlabel('number')    # x轴标签 x-axis label
    plt.ylabel('count')     # y轴标签 y-axis label
    plt.xticks(np.arange(10), ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
    plt.show()
 
 
if __name__ == '__main__':
    model_path='model/svm.model'
    svmtest(model_path)
