In [1]:
import sys
import random
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import numpy as np
import cv2
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QLabel, QVBoxLayout, QHBoxLayout, QFrame, QGridLayout
from PyQt5.QtGui import QPixmap, QPainter, QPen, QFont
from PyQt5.QtCore import Qt, QPoint

import torch.nn.functional as F
# モデルクラスの定義
class CNN_LSTM(nn.Module):
    def __init__(self):
        super(CNN_LSTM, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(1, 2), stride=2)
        )
        self.layer4 = torch.nn.Flatten()

        self.lstm = nn.LSTM(input_size=64, hidden_size=256, num_layers=3, batch_first=True)

        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = out.view(out.size(0), -1, 64)
        lstm_out, _ = self.lstm(out)
        lstm_out = lstm_out[:, -1, :]

        out = self.fc1(lstm_out)
        out = self.fc2(out)
        return out


# モデル読み込み
model = CNN_LSTM()
model.load_state_dict(torch.load("cnn_lstm_mnist.pth", map_location=torch.device("cpu")))

<All keys matched successfully>

In [2]:
def preprocess_image(image):
    image = cv2.resize(image, (28, 28))  
    image = transforms.ToTensor()(image).unsqueeze(0)
    return image

class PaintWidget(QFrame):
    def __init__(self, parent=None):
        super(PaintWidget, self).__init__(parent)
        self.setFixedSize(300, 300)
        self.pixmap = QPixmap(self.size())
        self.pixmap.fill(Qt.black)
        self.last_point = QPoint()

    def paintEvent(self, event):
        painter = QPainter(self)
        painter.drawPixmap(self.rect(), self.pixmap)

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            self.last_point = event.pos()

    def mouseMoveEvent(self, event):
        if event.buttons() == Qt.LeftButton:
            painter = QPainter(self.pixmap)
            pen = QPen(Qt.white, 20, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin)
            painter.setPen(pen)
            painter.drawLine(self.last_point, event.pos())
            self.last_point = event.pos()
            self.update()

    def clear(self):
        self.pixmap.fill(Qt.black)
        self.update()
        self.parent().result_label.setText("結果: ")

    def get_image(self):
        return self.pixmap.toImage()

In [None]:
class DigitRecognizerApp(QWidget):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("手書き数字認識アプリ")
        self.setFixedSize(1024, 768)

        self.current_question = random.randint(0, 9)
        self.title_label = QLabel("回答ボタンで開始")
        self.title_label.setFont(QFont("Arial", 24, QFont.Bold))
        self.title_label.setAlignment(Qt.AlignCenter)

        self.subtitle_label = QLabel("AIによる手書き数字認識")
        self.subtitle_label.setFont(QFont("Arial", 18))
        self.subtitle_label.setAlignment(Qt.AlignCenter)

        self.question_label = QLabel(f"{self.current_question}")
        self.question_label.setFont(QFont("Arial", 48, QFont.Bold))
        self.question_label.setAlignment(Qt.AlignCenter)
        self.question_label.setFixedSize(150, 150)
        
        self.paint_widget = PaintWidget(self)
        self.result_label = QLabel("結果: 　")
        self.result_label.setFont(QFont("Arial", 24, QFont.Bold))
        self.result_label.setAlignment(Qt.AlignCenter)

        self.recognize_button = QPushButton("回答")
        self.clear_button = QPushButton("クリア")
        self.next_button = QPushButton("次へ")
        
        self.recognize_button.clicked.connect(self.recognize_digit)
        self.clear_button.clicked.connect(self.paint_widget.clear)
        self.next_button.clicked.connect(self.generate_new_question)

        instructions_label = QLabel("使い方:\n・回答を始める場合は「回答」を押してください。\n・文字は黒い部分に書いてください。")
        instructions_label.setFont(QFont("Arial", 14))
        instructions_label.setAlignment(Qt.AlignCenter)

        grid_layout = QGridLayout()
        grid_layout.addWidget(self.title_label, 0, 0, 1, 3, Qt.AlignCenter)
        grid_layout.addWidget(self.question_label, 1, 0, Qt.AlignCenter)
        grid_layout.addWidget(self.paint_widget, 2, 0, Qt.AlignCenter)
        grid_layout.addWidget(self.recognize_button, 3, 0)
        grid_layout.addWidget(self.clear_button, 3, 1)
        grid_layout.addWidget(self.next_button, 3, 2)
        grid_layout.addWidget(self.subtitle_label, 1, 1, 1, 2, Qt.AlignCenter)
        grid_layout.addWidget(self.result_label, 2, 1, 1, 2, Qt.AlignCenter)
        grid_layout.addWidget(instructions_label, 4, 0, 1, 3, Qt.AlignCenter)
        
        self.setLayout(grid_layout)

    
    def recognize_digit(self):
        image = self.paint_widget.get_image()
        buffer = image.bits().asarray(image.byteCount())
        img_array = np.array(buffer, dtype=np.uint8).reshape(image.height(), image.width(), 4)
        img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGBA2GRAY)  # ここは必要
        img_resized = cv2.resize(img_gray, (28, 28))
        img_tensor = transforms.ToTensor()(img_resized).unsqueeze(0)

        with torch.no_grad():
            output = model(img_tensor)
            predicted = torch.argmax(output, 1).item()

        if predicted == self.current_question:
            self.result_label.setText(f"推論結果: {predicted} (正解！)")
        else:
            self.result_label.setText(f"推論結果: {predicted} (失敗)")
    
    def generate_new_question(self):
        self.current_question = random.randint(0, 9)
        self.question_label.setText(f"{self.current_question}")
        self.paint_widget.clear()






: 

In [None]:
if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = DigitRecognizerApp()
    window.show()
    
    try:
        sys.exit(app.exec_())  
    except SystemExit:
        print("アプリを終了しました")

