In [1]:
import sys
import numpy as np
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QVBoxLayout, QLabel, QFileDialog
from PyQt5.QtGui import QImage, QPixmap
from PyQt5.QtCore import Qt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping

# Load MNIST data
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Preprocess MNIST data
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255.0
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255.0
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

# Split into training and validation sets
train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.2, random_state=42)

# Model definition
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(train_images, train_labels, epochs=10, validation_data=(val_images, val_labels), batch_size = 32)

# Save the model
# model.save('/Users/behnam/python-projects/Neural Network/Computer Vision/Image Classification/Model_Fitting')


2024-01-25 19:29:27.797642: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x169f75d30>

In [2]:
# Image Uploader Widget for inference
class ImageUploader(QWidget):
    def __init__(self):
        super().__init__()
        self.model = tf.keras.models.load_model('/Users/behnam/python-projects/Neural Network/Computer Vision/Image Classification/Model_Fitting')

        # GUI elements
        self.upload_button = QPushButton("Upload Image", clicked=self.upload_image)
        self.img_label = QLabel()
        self.result_label = QLabel("Prediction: ")
        
        # Layout setup
        layout = QVBoxLayout()
        layout.addWidget(self.upload_button)
        layout.addWidget(self.img_label)
        layout.addWidget(self.result_label)
        self.setLayout(layout)

    def preprocess_image(self, file_path):
        # Image loading and preprocessing to match MNIST data
        img = Image.open(file_path).convert('L')  # Convert to grayscale
        img_resized = img.resize((28, 28))  # Resize to match MNIST
        img_array = np.array(img_resized).astype('float32') / 255.0  # Normalize
        img_array = img_array.reshape((1, 28, 28, 1))  # Reshape for the model
        return img_array

    def predict_image(self, img_array):
        prediction = self.model.predict(img_array)
        return np.argmax(prediction)

    def upload_image(self):
        # Image upload and display
        options = QFileDialog.Options()
        file_path, _ = QFileDialog.getOpenFileName(self, "Upload Image", "", "Images (*.png *.jpg *.bmp *.gif)", options=options)
        if file_path:
            img_array = self.preprocess_image(file_path)
            class_name = self.predict_image(img_array)
            self.result_label.setText(f"Prediction: {class_name}")
            pixmap = QPixmap(file_path).scaled(100, 100, Qt.KeepAspectRatio)
            self.img_label.setPixmap(pixmap)

if __name__ == "__main__":
    app = QApplication(sys.argv)
    window = ImageUploader()
    window.setWindowTitle("MNIST Handwriting Prediction")
    window.show()
    sys.exit(app.exec_())



SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
