# 손글씨 입력 및 예측 시스템

In [None]:
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
import pandas as pd
import numpy as np
import random
import os
import gc
import cv2
import sys

import torch
import torchvision.transforms as T
import timm

import tensorflow as tf

from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QVBoxLayout, QFileDialog
from PyQt5.QtGui import QPainter, QPen, QPixmap
from PyQt5.QtCore import Qt

# 시드를 고정합니다.
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

device = 'mps'

# 예측할 이미지 데이터셋 클래스
class PredDataset(torch.utils.data.Dataset):
    def __init__(self, root_path, transforms):
        self.data = []
        self.transforms = transforms
        images = os.listdir(root_path)
        for image in images:
            self.data.append(os.path.join(root_path, image))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = Image.open(self.data[idx]).convert("L")
        image = self.transforms(image)
        return image

class DrawBoard(QWidget):
    def __init__(self):
        super().__init__()
        self.initUI()

    def initUI(self):
        self.setWindowTitle('Handwriting Input')
        self.setFixedSize(280, 280)
        self.image = QPixmap(self.size())
        self.image.fill(Qt.black)

        # 버튼들 추가
        clear_btn = QPushButton('Clear', self)
        clear_btn.clicked.connect(self.clear)
        
        save_btn = QPushButton('Save', self)
        save_btn.clicked.connect(self.save_image)

        pred_btn = QPushButton('Predict', self)
        pred_btn.clicked.connect(self.image_pred)

        layout = QVBoxLayout()
        layout.addWidget(clear_btn)
        layout.addWidget(save_btn)
        layout.addWidget(pred_btn)
        self.setLayout(layout)

    def paintEvent(self, event):
        canvasPainter = QPainter(self)
        canvasPainter.drawPixmap(self.rect(), self.image)

    def mouseMoveEvent(self, event):
        if event.buttons() == Qt.LeftButton:
            painter = QPainter(self.image)
            pen = QPen(Qt.white, 20, Qt.SolidLine)
            painter.setPen(pen)
            painter.drawPoint(event.pos())
            self.update()

    def clear(self):
        self.image.fill(Qt.black)
        self.update()

    def save_image(self):
        # 파일 저장 대화상자 열기
        options = QFileDialog.Options()
        file_name, _ = QFileDialog.getSaveFileName(self, "Save Image", "", "PNG Files (*.png);;All Files (*)", options=options)
        
        if file_name:
            self.image.save(file_name)
            print(f"이미지가 저장되었습니다: {file_name}")
    
    def predict(self, model, pred_dataloader):
        model.eval()
        total_preds = []
        total_max_probs = []
        with torch.no_grad():
            tbar = tqdm(pred_dataloader)
            for images in tbar:
                images = images.to('mps')
                outputs = model(images)
                probs = torch.nn.functional.softmax(outputs, dim=1)
                max_prob, preds = torch.max(probs, 1)
                total_max_probs.extend(max_prob)
                total_preds.extend(preds)

        del model
        gc.collect()
        
        print(total_preds)
        print(total_max_probs)

        return total_preds, total_max_probs

    def image_pred(self):
        model = timm.create_model('resnet18', in_chans=1, num_classes=10, pretrained=False)
        model.load_state_dict(torch.load('/Volumes/samsung ssd/handwriting_classification/resnet18_mnist_epochs_30_last_model.pth'))
        model.to('mps')

        transforms = T.Compose([T.Resize((28, 28)), T.ToTensor()])
        pred_dataset = PredDataset('/Volumes/samsung ssd/handwriting_classification/saved_images', transforms)
        pred_dataloader = torch.utils.data.DataLoader(pred_dataset, batch_size=32, shuffle=False)
        total_preds, total_max_probs = self.predict(model, pred_dataloader)

        for i in range(len(pred_dataset)):
            image = pred_dataset[i].permute(1, 2, 0).numpy() * 255
            image = image.astype(np.uint8)

            # 이미지 크기 변경 및 채널 변환 (회색 이미지이므로 그대로 사용)
            image = cv2.resize(image, (280, 280), interpolation=cv2.INTER_NEAREST)

            # 예측 결과 표시
            cv2.putText(image, f"Prediction: {total_preds[i].item()}", (10, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 1)
            cv2.putText(image, f"Prob: {total_max_probs[i].item():.2f}", (10, 40),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 1)
            cv2.imshow(f'Prediction {i + 1}', image)

        cv2.waitKey(0)
        cv2.destroyAllWindows()

app = QApplication(sys.argv)
window = DrawBoard()
window.show()
sys.exit(app.exec_())

del app, window
gc.collect()

  0%|          | 0/1 [00:00<?, ?it/s]

[tensor(8, device='mps:0'), tensor(9, device='mps:0'), tensor(4, device='mps:0'), tensor(5, device='mps:0'), tensor(0, device='mps:0'), tensor(7, device='mps:0'), tensor(6, device='mps:0'), tensor(9, device='mps:0'), tensor(2, device='mps:0'), tensor(3, device='mps:0'), tensor(0, device='mps:0'), tensor(1, device='mps:0'), tensor(0, device='mps:0'), tensor(0, device='mps:0')]
[tensor(0.9903, device='mps:0'), tensor(1.0000, device='mps:0'), tensor(0.8176, device='mps:0'), tensor(1., device='mps:0'), tensor(1., device='mps:0'), tensor(1.0000, device='mps:0'), tensor(1.0000, device='mps:0'), tensor(0.9962, device='mps:0'), tensor(1.0000, device='mps:0'), tensor(0.9999, device='mps:0'), tensor(1.0000, device='mps:0'), tensor(1.0000, device='mps:0'), tensor(1., device='mps:0'), tensor(0.9971, device='mps:0')]
