In [None]:
import os
import cv2 as cv
import sys

import tensorflow as tf
from tensorflow.keras import Sequential
import numpy as np
from PIL import Image
import imageio.v2 as imageio

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def predict_local_image(image_path):
    loaded_model = tf.keras.models.load_model("mnist_model_optimized_v9.keras")

    # 读取图片
    try:
        img = imageio.imread(image_path, pilmode='L')  # 加载图像并转换为灰度图
    except FileNotFoundError:
        print("Error: Image not found or could not be read.")
        return

    # 调整大小为28x28
    img = Image.fromarray(img).resize((28, 28))
    img = np.array(img)
    # 反转颜色（如果是白底黑字）
    img = 255 - img
    # 归一化处理
    img = img.astype('float32') / 255.0
    # 增加批次维度
    img = np.expand_dims(img, axis=0)
    # 预测
    prediction = loaded_model.predict(img)
    predicted_digit = np.argmax(prediction)
    print(f"== RESULT:Predicted digit for the image: {predicted_digit}")


def draw(event, x, y, flags, param):
    global img, pre_pts

    if event == cv.EVENT_RBUTTONDOWN:
        print("mouse right button pressed")

    if event == cv.EVENT_LBUTTONDOWN:
        print("mouse left button pressed")
        pre_pts = (x, y)
        # print("mouse pressed points: {},{}".format(x, y))

    if event == cv.EVENT_MOUSEMOVE and flags == cv.EVENT_FLAG_LBUTTON:
        pts = (x, y)
        img = cv.line(img, pre_pts, pts, (0, 0, 0), 10, 5, 0)
        pre_pts = pts
        cv.imshow("image", img)

    if event == cv.EVENT_LBUTTONUP:
        print("mouse left button released")

    if event == cv.EVENT_RBUTTONUP:
        print("okok...")

        if cv.imwrite("image.png", img):
            #target_path = utils.main()
            #predict_local_image(target_path)
            predict_local_image("image.png")


if __name__ == '__main__':
    print("hello world")

    img = cv.imread('bc_image.jpg')

    if img is None:
        print("Failed to load image")
        sys.exit()

    pre_pts = -1, -1
    cv.imshow("image", img)
    cv.setMouseCallback("image", draw)
    cv.waitKey(0)
    cv.destroyAllWindows()
