In [2]:
import sys
import os
import numpy as np
from PIL import Image, ImageOps
import tensorflow as tf
from tensorflow.keras.models import load_model
from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QVBoxLayout, QHBoxLayout,
    QPushButton, QLabel, QFileDialog, QWidget, QTextEdit,
    QProgressBar, QMessageBox, QGroupBox
)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtGui import QPixmap, QFont

# Predefined plant class names (based on dataset)
CLASS_NAMES = [
    'aloevera', 'banana', 'bilimbi', 'cantaloupe', 'cassava',
    'coconut', 'corn', 'cucumber', 'curcuma', 'eggplant',
    'galangal', 'ginger', 'guava', 'kale', 'longbeans',
    'mango', 'melon', 'orange', 'paddy', 'papaya',
    'peper chili', 'pineapple', 'pomelo', 'shallot', 'soybeans',
    'spinach', 'sweet potatoes', 'tobacco', 'waterapple', 'watermelon'
]

# Background thread for model prediction
class PredictionThread(QThread):
    prediction_done = pyqtSignal(str, float, list)
    
    def __init__(self, model, image_path, actual_classes):
        super().__init__()
        self.model = model
        self.image_path = image_path
        self.actual_classes = actual_classes
    
    def run(self):
        """Run prediction in background thread"""
        try:
            image = Image.open(self.image_path).convert('RGB')
            image = ImageOps.fit(image, (224, 224))
            image_array = np.array(image) / 255.0
            image_array = np.expand_dims(image_array, axis=0)

            predictions = self.model.predict(image_array, verbose=0)
            predicted_class_idx = np.argmax(predictions[0])
            confidence = float(predictions[0][predicted_class_idx])

            if predicted_class_idx < len(self.actual_classes):
                class_name = self.actual_classes[predicted_class_idx]
            else:
                class_name = f"Class_{predicted_class_idx}"

            top_3_indices = np.argsort(predictions[0])[-3:][::-1]
            top_3_predictions = []

            for idx in top_3_indices:
                class_name_item = (
                    self.actual_classes[idx]
                    if idx < len(self.actual_classes)
                    else f"Class_{idx}"
                )
                top_3_predictions.append({
                    'class': class_name_item,
                    'confidence': float(predictions[0][idx])
                })

            self.prediction_done.emit(class_name, confidence, top_3_predictions)

        except Exception as e:
            self.prediction_done.emit(f"Error: {str(e)}", 0.0, [])


class PlantDetectionGUI(QMainWindow):
    """Main GUI application for plant species detection"""
    def __init__(self):
        super().__init__()
        self.model = None
        self.current_image_path = None
        self.model_classes = []
        self.initUI()
        self.load_model()
    
    def initUI(self):
        self.setWindowTitle("Plant Species Detection System")
        self.setGeometry(100, 100, 1000, 800)
        self.setStyleSheet("""
            QMainWindow { background-color: #f0f8f0; }
            QGroupBox {
                font-weight: bold;
                font-size: 14px;
                border: 2px solid #2e7d32;
                border-radius: 10px;
                margin-top: 10px;
                padding-top: 15px;
                background-color: white;
            }
            QPushButton {
                background-color: #2e7d32;
                color: white;
                border: none;
                padding: 10px 20px;
                font-size: 13px;
                border-radius: 6px;
                font-weight: bold;
            }
            QPushButton:hover { background-color: #388e3c; }
            QTextEdit {
                border: 2px solid #81c784;
                border-radius: 8px;
                padding: 10px;
                background-color: white;
                font-family: Arial;
                font-size: 13px;
            }
            QLabel {
                font-size: 14px;
                color: #1b5e20;
            }
            QProgressBar {
                border: 2px solid #81c784;
                border-radius: 8px;
                text-align: center;
                font-weight: bold;
                height: 20px;
            }
            QProgressBar::chunk {
                background-color: #4caf50;
                border-radius: 6px;
            }
        """)

        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        layout = QVBoxLayout(central_widget)
        layout.setSpacing(15)
        layout.setContentsMargins(20, 20, 20, 20)

        title = QLabel("Plant Species Detection System")
        title.setAlignment(Qt.AlignCenter)
        title.setFont(QFont("Arial", 20, QFont.Bold))
        title.setStyleSheet("color: #2e7d32; margin: 10px;")
        layout.addWidget(title)

        self.model_info = QLabel("Model information will be displayed here.")
        self.model_info.setAlignment(Qt.AlignCenter)
        layout.addWidget(self.model_info)

        # Image display
        image_group = QGroupBox("Image Preview")
        image_layout = QVBoxLayout()

        self.image_label = QLabel()
        self.image_label.setAlignment(Qt.AlignCenter)
        self.image_label.setMinimumSize(450, 350)
        self.image_label.setText("No image selected.\nSelect an image to begin detection.")
        self.image_label.setStyleSheet(
            "border: 2px dashed #4caf50; background-color: #f1f8e9; border-radius: 10px;"
        )
        image_layout.addWidget(self.image_label)
        image_group.setLayout(image_layout)
        layout.addWidget(image_group)

        # Buttons
        button_layout = QHBoxLayout()
        self.select_btn = QPushButton("Select Image")
        self.select_btn.clicked.connect(self.select_image)

        self.detect_btn = QPushButton("Detect Plant Species")
        self.detect_btn.clicked.connect(self.detect_plant)
        self.detect_btn.setEnabled(False)

        self.clear_btn = QPushButton("Clear All")
        self.clear_btn.clicked.connect(self.clear_all)

        button_layout.addWidget(self.select_btn)
        button_layout.addWidget(self.detect_btn)
        button_layout.addWidget(self.clear_btn)
        layout.addLayout(button_layout)

        # Progress bar
        self.progress_bar = QProgressBar()
        self.progress_bar.setVisible(False)
        self.progress_bar.setFormat("Detecting... %p%")
        layout.addWidget(self.progress_bar)

        # Results
        results_group = QGroupBox("Detection Results")
        results_layout = QVBoxLayout()

        self.results_text = QTextEdit()
        self.results_text.setReadOnly(True)
        self.results_text.setPlaceholderText("Results will appear here...")
        results_layout.addWidget(self.results_text)
        results_group.setLayout(results_layout)
        layout.addWidget(results_group)

        # Status
        self.status_label = QLabel("Ready to detect plant species.")
        self.status_label.setStyleSheet(
            "color: #2e7d32; font-weight: bold; padding: 5px; background-color: #e8f5e9;"
        )
        layout.addWidget(self.status_label)

    def load_model(self):
        """Load trained deep learning model"""
        try:
            possible_paths = [
                "plant_model.h5",
                "model.h5",
                "plant_detection_model.h5",
                "plant_classifier.h5",
                r"C:\Users\ACER\OneDrive\Desktop\Aushadhe\resnet50_aushadhee.h5",
                "resnet50_aushadhee.h5"
            ]

            model_path = None
            for path in possible_paths:
                if os.path.exists(path):
                    model_path = path
                    break

            if not model_path:
                self.show_error("Model file not found. Please ensure the model file exists.")
                self.status_label.setText("Model not found.")
                return

            self.status_label.setText("Loading model...")
            self.model = load_model(model_path)

            dummy_input = np.random.random((1, 224, 224, 3))
            prediction = self.model.predict(dummy_input, verbose=0)
            output_classes = prediction.shape[1]

            if output_classes == len(CLASS_NAMES):
                self.model_classes = CLASS_NAMES
            else:
                self.model_classes = [f"Class_{i}" for i in range(output_classes)]

            self.model_info.setText(f"Model Loaded Successfully | Classes: {len(self.model_classes)}")
            self.status_label.setText("Model loaded successfully.")

        except Exception as e:
            self.show_error(f"Error loading model: {str(e)}")
            self.status_label.setText("Error loading model.")

    def select_image(self):
        """Select an image from file explorer"""
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Select Plant Image", "", "Image Files (*.png *.jpg *.jpeg *.bmp *.tiff)"
        )
        if file_path:
            self.current_image_path = file_path
            pixmap = QPixmap(file_path)
            pixmap = pixmap.scaled(450, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation)
            self.image_label.setPixmap(pixmap)
            self.detect_btn.setEnabled(True)
            self.status_label.setText(f"Image selected: {os.path.basename(file_path)}")

    def detect_plant(self):
        """Run prediction using loaded model"""
        if not self.current_image_path:
            self.show_error("Please select an image first.")
            return
        if not self.model:
            self.show_error("Model not loaded.")
            return

        self.progress_bar.setVisible(True)
        self.progress_bar.setRange(0, 0)
        self.detect_btn.setEnabled(False)
        self.status_label.setText("Detecting plant species...")

        self.prediction_thread = PredictionThread(self.model, self.current_image_path, self.model_classes)
        self.prediction_thread.prediction_done.connect(self.on_prediction_done)
        self.prediction_thread.start()

    def on_prediction_done(self, class_name, confidence, top_3_predictions):
        """Display prediction results"""
        self.progress_bar.setVisible(False)
        self.detect_btn.setEnabled(True)

        if class_name.startswith("Error:"):
            self.show_error(class_name)
            self.status_label.setText("Detection failed.")
            return

        confidence_percent = confidence * 100
        result_text = f"Detected Plant: {class_name}\nConfidence: {confidence_percent:.2f}%\n\nTop 3 Predictions:\n"
        for i, pred in enumerate(top_3_predictions, 1):
            pred_conf = pred['confidence'] * 100
            result_text += f"{i}. {pred['class']} - {pred_conf:.2f}%\n"

        self.results_text.setText(result_text)
        self.status_label.setText("Detection completed successfully.")

    def clear_all(self):
        """Clear all data and reset GUI"""
        self.current_image_path = None
        self.image_label.clear()
        self.image_label.setText("No image selected.\nSelect an image to begin detection.")
        self.results_text.clear()
        self.detect_btn.setEnabled(False)
        self.status_label.setText("Ready to detect plant species.")

    def show_error(self, message):
        """Display error dialog"""
        QMessageBox.critical(self, "Error", message)

    def closeEvent(self, event):
        """Confirm before closing the application"""
        reply = QMessageBox.question(
            self,
            "Exit",
            "Are you sure you want to exit?",
            QMessageBox.Yes | QMessageBox.No,
            QMessageBox.No
        )
        if reply == QMessageBox.Yes:
            event.accept()
        else:
            event.ignore()


def main():
    app = QApplication(sys.argv)
    app.setStyle('Fusion')
    app.setFont(QFont("Arial", 10))
    window = PlantDetectionGUI()
    window.show()
    sys.exit(app.exec_())


if __name__ == '__main__':
    main()


SystemExit: 0