In [6]:
#调用的库
import tkinter as tk
from tkinter import Canvas, Button, Entry
from PIL import Image, ImageDraw, ImageOps
import os
import torch
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

In [8]:
# 创建神经网络模型
class neural_network_model_try(nn.Module):
    def __init__(self):
        super().__init__()
        # 定义卷积层和全连接层
        self.cv1 = nn.Conv2d(1, 10, 5)  # 输入通道为1，输出通道为10，卷积核大小为5x5
        self.cv2 = nn.Conv2d(10, 20, 3)  # 输入通道为10，输出通道为20，卷积核大小为3x3
        self.fc1 = nn.Linear(20 * 10 * 10, 500)  # 输入特征数为20*10*10，输出特征数为500
        self.fc2 = nn.Linear(500, 10)  # 输入特征数为500，输出特征数为10（代表0到9的数字）

    def forward(self, x):
        input_size = x.size(0)
        x = self.cv1(x)
        x = F.relu(x)  # 使用ReLU激活函数
        x = F.max_pool2d(x, 2, 2)  # 最大池化
        x = self.cv2(x)
        x = F.relu(x)
        x = x.view(input_size, -1)  # 将特征图展平成向量
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        outputs = F.log_softmax(x, dim=1)  # 使用log-softmax作为输出层激活函数
        return outputs

# 创建神经网络模型实例并加载已训练的权重
device = torch.device("cuda")
model = neural_network_model_try().to(device)
model.load_state_dict(torch.load('digit_recognizer.pth'))
model.eval()  # 将模型设置为评估模式

# 创建图形用户界面
class DigitRecognizerApp:
    def __init__(self, root, model):
        self.root = root
        self.model = model
        # 创建画布和按钮
        self.canvas = Canvas(self.root, width=280, height=280, background="white")
        self.canvas.pack()
        self.clear_button = Button(self.root, text="Clear", command=self.clear_canvas)
        self.clear_button.pack()
        self.recognize_button = Button(self.root, text="Recognize", command=self.recognize_digit)
        self.recognize_button.pack()
        self.save_button = Button(self.root, text="Save", command=self.save_digit)
        self.save_button.pack()
        self.result_label = tk.Label(self.root, text="Draw a digit and click 'Recognize'", font=("Helvetica", 16))
        self.result_label.pack()
        self.label_entry = Entry(self.root)
        self.label_entry.pack()
        self.last_x = None
        self.last_y = None
        self.drawing = False
        self.image = Image.new("L", (280, 280), "white")
        self.drawer = ImageDraw.Draw(self.image)
        self.canvas.bind("<Button-1>", self.start_draw)
        self.canvas.bind("<B1-Motion>", self.draw)

    def clear_canvas(self):
        self.canvas.delete("all")
        self.result_label.config(text="Draw a digit and click 'Recognize'")
        self.image = Image.new("L", (280, 280), "white")
        self.drawer = ImageDraw.Draw(self.image)

    def recognize_digit(self):
        tensor_image = transforms.ToTensor()(self.image)
        with torch.no_grad():
            output = self.model(tensor_image.unsqueeze(0).to(device))
            predicted_digit = output.argmax().item()
        self.result_label.config(text=f"Predicted Digit: {predicted_digit}")

    def start_draw(self, event):
        self.last_x = event.x
        self.last_y = event.y
        self.drawing = True

    def draw(self, event):
        if self.drawing:
            x, y = event.x, event.y
            if self.last_x is not None and self.last_y is not None:
                self.canvas.create_line((self.last_x, self.last_y, x, y), fill="black", width=8)
                self.drawer.line((self.last_x, self.last_y, x, y), fill="black", width=8)
            self.last_x = x
            self.last_y = y

    def save_digit(self):
        label = self.label_entry.get()  # 获取标签输入
        if label:
            label = int(label)
            save_dir = 'digit_images'  # 替换为你想要保存图像的目录
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            filenames = os.listdir(save_dir)
            existing_numbers = set()
            for filename in filenames:
                if filename.endswith('.png'):
                    existing_numbers.add(int(filename.split('_')[0]))
            number = 0
            while number in existing_numbers:
                number += 1
            filename = os.path.join(save_dir, f'{number}_{label}.png')
            self.image.save(filename)  # 保存图像

# 主应用程序入口
if __name__ == "__main__":
    root = tk.Tk()
    app = DigitRecognizerApp(root, model)
    root.mainloop()