In [1]:
# brain_tumor_gui.py
import sys
import os
import json
from pathlib import Path
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, 
                             QHBoxLayout, QPushButton, QLabel, QFileDialog, 
                             QTextEdit, QGroupBox, QProgressBar, QMessageBox)
from PyQt5.QtCore import Qt, QThread, pyqtSignal
from PyQt5.QtGui import QPixmap, QFont

class BrainTumorModel:
    def __init__(self, model_path, class_names):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.class_names = class_names
        self.model = self.load_model(model_path)
        self.transform = self.get_transform()
        
    def load_model(self, model_path):
        """Load the trained model with correct architecture"""
        try:
            # Create the same model architecture as during training
            model = models.resnet50(pretrained=False)
            
            # Use the EXACT same classifier structure as training
            in_features = model.fc.in_features
            model.fc = nn.Sequential(
                nn.Dropout(0.2),
                nn.Linear(in_features, 512),
                nn.ReLU(),
                nn.BatchNorm1d(512),
                nn.Dropout(0.2),
                nn.Linear(512, len(self.class_names))
            )
            
            # Load weights - handle different file formats
            if torch.cuda.is_available():
                checkpoint = torch.load(model_path)
            else:
                checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
            
            print("üîç Checkpoint type:", type(checkpoint))
            print("üîç Checkpoint keys:", checkpoint.keys() if isinstance(checkpoint, dict) else "Not a dict")
            
            # Handle different checkpoint formats
            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                else:
                    model.load_state_dict(checkpoint)
            else:
                # Direct state dict
                model.load_state_dict(checkpoint)
                
            model = model.to(self.device)
            model.eval()
            print("‚úÖ Model loaded successfully from", model_path)
            return model
            
        except Exception as e:
            print(f"‚ùå Error loading model: {e}")
            # Try alternative loading method
            return self.load_model_alternative(model_path)
    
    def load_model_alternative(self, model_path):
        """Alternative loading method for different formats"""
        try:
            model = models.resnet50(pretrained=False)
            in_features = model.fc.in_features
            model.fc = nn.Sequential(
                nn.Dropout(0.2),
                nn.Linear(in_features, 512),
                nn.ReLU(),
                nn.BatchNorm1d(512),
                nn.Dropout(0.2),
                nn.Linear(512, len(self.class_names))
            )
            
            # Load state dict with strict=False to ignore minor mismatches
            if torch.cuda.is_available():
                checkpoint = torch.load(model_path)
            else:
                checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
            
            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                else:
                    state_dict = checkpoint
            else:
                state_dict = checkpoint
                
            # Filter out unexpected keys
            model_state_dict = model.state_dict()
            filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
            
            model.load_state_dict(filtered_state_dict, strict=False)
            model = model.to(self.device)
            model.eval()
            print("‚úÖ Model loaded with alternative method")
            return model
            
        except Exception as e:
            print(f"‚ùå Alternative loading failed: {e}")
            return None
    
    def get_transform(self):
        """Get the same transform used during training"""
        return transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    def predict(self, image_path):
        """Predict the class of the brain MRI image"""
        try:
            # Load and preprocess image
            image = Image.open(image_path).convert('RGB')
            input_tensor = self.transform(image)
            input_batch = input_tensor.unsqueeze(0).to(self.device)
            
            # Predict
            with torch.no_grad():
                output = self.model(input_batch)
                probabilities = torch.nn.functional.softmax(output[0], dim=0)
                confidence, predicted_idx = torch.max(probabilities, 0)
                
            predicted_class = self.class_names[predicted_idx.item()]
            confidence_percent = confidence.item() * 100
            
            return predicted_class, confidence_percent, probabilities.cpu().numpy()
            
        except Exception as e:
            print(f"‚ùå Prediction error: {e}")
            return None, 0, None

class PredictionThread(QThread):
    prediction_finished = pyqtSignal(str, float, list, str)
    prediction_error = pyqtSignal(str)
    
    def __init__(self, model, image_path):
        super().__init__()
        self.model = model
        self.image_path = image_path
    
    def run(self):
        try:
            predicted_class, confidence, probabilities = self.model.predict(self.image_path)
            if predicted_class:
                self.prediction_finished.emit(
                    predicted_class, confidence, probabilities.tolist() if probabilities is not None else [], 
                    self.image_path
                )
            else:
                self.prediction_error.emit("Prediction failed")
        except Exception as e:
            self.prediction_error.emit(str(e))

class BrainTumorGUI(QMainWindow):
    def __init__(self):
        super().__init__()
        self.model = None
        self.current_image_path = None
        self.init_ui()
        self.load_model()
        
    def init_ui(self):
        """Initialize the user interface"""
        self.setWindowTitle("üß† Brain Tumor Classification System")
        self.setFixedSize(1000, 700)
        
        # Set dark theme
        self.set_dark_theme()
        
        # Central widget
        central_widget = QWidget()
        self.setCentralWidget(central_widget)
        
        # Main layout
        layout = QHBoxLayout()
        central_widget.setLayout(layout)
        
        # Left panel - Image display
        left_panel = self.create_image_panel()
        layout.addWidget(left_panel, 2)
        
        # Right panel - Controls and results
        right_panel = self.create_control_panel()
        layout.addWidget(right_panel, 1)
        
    def set_dark_theme(self):
        """Set dark theme for the application"""
        self.setStyleSheet("""
            QMainWindow {
                background-color: #2b2b2b;
                color: #ffffff;
            }
            QGroupBox {
                color: #ffffff;
                font-weight: bold;
                border: 2px solid #555555;
                border-radius: 5px;
                margin-top: 1ex;
                padding-top: 10px;
            }
            QGroupBox::title {
                subcontrol-origin: margin;
                left: 10px;
                padding: 0 5px 0 5px;
            }
            QPushButton {
                background-color: #4CAF50;
                border: none;
                color: white;
                padding: 10px;
                text-align: center;
                text-decoration: none;
                font-size: 14px;
                margin: 4px 2px;
                border-radius: 5px;
                font-weight: bold;
            }
            QPushButton:hover {
                background-color: #45a049;
            }
            QPushButton:pressed {
                background-color: #3d8b40;
            }
            QPushButton:disabled {
                background-color: #666666;
                color: #999999;
            }
            QLabel {
                color: #ffffff;
            }
            QTextEdit {
                background-color: #1e1e1e;
                color: #ffffff;
                border: 1px solid #555555;
                border-radius: 3px;
                padding: 5px;
            }
            QProgressBar {
                border: 2px solid #555555;
                border-radius: 5px;
                text-align: center;
                color: white;
                background-color: #1e1e1e;
            }
            QProgressBar::chunk {
                background-color: #4CAF50;
                width: 20px;
            }
        """)
    
    def create_image_panel(self):
        """Create the image display panel"""
        panel = QGroupBox("Brain MRI Scan")
        layout = QVBoxLayout()
        
        # Image display label
        self.image_label = QLabel()
        self.image_label.setAlignment(Qt.AlignCenter)
        self.image_label.setMinimumSize(600, 500)
        self.image_label.setStyleSheet("background-color: #1e1e1e; border: 2px dashed #555555;")
        self.image_label.setText("No image selected\n\nClick 'Load MRI Scan' to begin")
        self.image_label.setFont(QFont("Arial", 12))
        
        layout.addWidget(self.image_label)
        panel.setLayout(layout)
        return panel
    
    def create_control_panel(self):
        """Create the control and results panel"""
        panel = QGroupBox("Analysis Controls")
        layout = QVBoxLayout()
        
        # Load image button
        self.load_btn = QPushButton("üìÅ Load MRI Scan")
        self.load_btn.clicked.connect(self.load_image)
        layout.addWidget(self.load_btn)
        
        # Analyze button
        self.analyze_btn = QPushButton("üîç Analyze Tumor")
        self.analyze_btn.clicked.connect(self.analyze_image)
        self.analyze_btn.setEnabled(False)
        layout.addWidget(self.analyze_btn)
        
        # Progress bar
        self.progress_bar = QProgressBar()
        self.progress_bar.setVisible(False)
        layout.addWidget(self.progress_bar)
        
        # Results section
        results_group = QGroupBox("Analysis Results")
        results_layout = QVBoxLayout()
        
        self.results_text = QTextEdit()
        self.results_text.setReadOnly(True)
        self.results_text.setMaximumHeight(200)
        self.results_text.setHtml("""
            <center>
                <h3>üß† Brain Tumor Classifier</h3>
                <p>Load a brain MRI scan to analyze for tumor detection</p>
                <p><b>Model Status:</b> Loading...</p>
            </center>
        """)
        results_layout.addWidget(self.results_text)
        
        # Confidence bars
        self.confidence_group = QGroupBox("Confidence Levels")
        confidence_layout = QVBoxLayout()
        
        self.confidence_labels = {}
        for i in range(4):
            label = QLabel()
            label.setVisible(False)
            confidence_layout.addWidget(label)
            self.confidence_labels[i] = label
        
        self.confidence_group.setLayout(confidence_layout)
        results_layout.addWidget(self.confidence_group)
        
        results_group.setLayout(results_layout)
        layout.addWidget(results_group)
        
        # Status label
        self.status_label = QLabel("Loading model...")
        self.status_label.setStyleSheet("color: #888888; font-style: italic;")
        layout.addWidget(self.status_label)
        
        panel.setLayout(layout)
        return panel
    
    def load_model(self):
        """Load the trained model"""
        try:
            # Try different model file formats
            model_paths = [
                "outputs/best_model.h5",
                "outputs/best_model.keras", 
                "outputs/best_model.legacy",
                "outputs/final_model.h5",
                "outputs/final_model.keras",
                "outputs/final_model.legacy"
            ]
            
            model_path = None
            for path in model_paths:
                if os.path.exists(path):
                    model_path = path
                    break
            
            if not model_path:
                QMessageBox.critical(self, "Error", 
                    "No model file found! Please ensure model files are in the 'outputs' folder.")
                return
            
            # Load class names
            class_info_path = "outputs/class_info.json"
            if os.path.exists(class_info_path):
                with open(class_info_path, 'r') as f:
                    class_info = json.load(f)
                class_names = class_info.get('classes', ['Glioma', 'Meningioma', 'Pituitary', 'No Tumor'])
            else:
                # Try training_info.json
                training_info_path = "outputs/training_info.json"
                if os.path.exists(training_info_path):
                    with open(training_info_path, 'r') as f:
                        training_info = json.load(f)
                    class_names = training_info.get('classes', ['Glioma', 'Meningioma', 'Pituitary', 'No Tumor'])
                else:
                    class_names = ['Glioma', 'Meningioma', 'Pituitary', 'No Tumor']
            
            print(f"üéØ Loading model from: {model_path}")
            print(f"üéØ Class names: {class_names}")
            
            self.model = BrainTumorModel(model_path, class_names)
            if self.model.model is None:
                QMessageBox.critical(self, "Error", "Failed to load the model!")
                self.status_label.setText("‚ùå Model loading failed")
                return
                
            self.status_label.setText("‚úÖ Model loaded successfully")
            self.results_text.setHtml(f"""
                <center>
                    <h3>üß† Brain Tumor Classifier</h3>
                    <p>Model loaded successfully!</p>
                    <p><b>Supported classes:</b><br>
                    - {class_names[0]}<br>
                    - {class_names[1]}<br>
                    - {class_names[2]}<br>
                    - {class_names[3]}</p>
                </center>
            """)
            print("‚úÖ Model and classes loaded successfully")
            
        except Exception as e:
            QMessageBox.critical(self, "Error", f"Failed to load model: {str(e)}")
            self.status_label.setText("‚ùå Model loading failed")
    
    def load_image(self):
        """Load an image for analysis"""
        file_path, _ = QFileDialog.getOpenFileName(
            self, "Select Brain MRI Scan", "",
            "Image Files (*.png *.jpg *.jpeg *.bmp *.tiff)"
        )
        
        if file_path:
            self.current_image_path = file_path
            pixmap = QPixmap(file_path)
            
            # Scale image to fit label while maintaining aspect ratio
            scaled_pixmap = pixmap.scaled(
                self.image_label.width() - 20, 
                self.image_label.height() - 20,
                Qt.KeepAspectRatio, 
                Qt.SmoothTransformation
            )
            
            self.image_label.setPixmap(scaled_pixmap)
            self.analyze_btn.setEnabled(True)
            self.status_label.setText(f"Loaded: {Path(file_path).name}")
            
            # Clear previous results
            self.results_text.clear()
            for label in self.confidence_labels.values():
                label.setVisible(False)
    
    def analyze_image(self):
        """Analyze the loaded image"""
        if not self.current_image_path or not self.model:
            return
        
        self.analyze_btn.setEnabled(False)
        self.progress_bar.setVisible(True)
        self.progress_bar.setRange(0, 0)  # Indeterminate progress
        
        # Start prediction in separate thread
        self.prediction_thread = PredictionThread(self.model, self.current_image_path)
        self.prediction_thread.prediction_finished.connect(self.on_prediction_finished)
        self.prediction_thread.prediction_error.connect(self.on_prediction_error)
        self.prediction_thread.start()
    
    def on_prediction_finished(self, predicted_class, confidence, probabilities, image_path):
        """Handle prediction results"""
        self.progress_bar.setVisible(False)
        self.analyze_btn.setEnabled(True)
        
        # Determine color based on prediction
        if "No Tumor" in predicted_class:
            color = "green"
            emoji = "‚úÖ"
            status = "Healthy"
        else:
            color = "red" 
            emoji = "‚ö†Ô∏è"
            status = "Medical Attention Required"
        
        # Display results
        result_html = f"""
            <center>
                <h2 style="color: {color};">{emoji} {predicted_class}</h2>
                <h3 style="color: {color};">Confidence: {confidence:.2f}%</h3>
                <p><b>Status:</b> {status}</p>
            </center>
        """
        
        self.results_text.setHtml(result_html)
        self.status_label.setText(f"Analysis complete - {predicted_class}")
        
        # Show confidence bars for all classes
        class_names = self.model.class_names
        for i, (class_name, prob) in enumerate(zip(class_names, probabilities)):
            confidence_percent = prob * 100
            bar_width = min(int(confidence_percent * 2), 200)  # Scale for visual bar
            
            # Color coding
            if class_name == predicted_class:
                bar_color = "#4CAF50" if "No Tumor" in class_name else "#ff4444"
            else:
                bar_color = "#666666"
            
            bar_html = f"""
                <div style="margin: 5px 0;">
                    <div style="display: flex; justify-content: space-between;">
                        <span>{class_name}</span>
                        <span>{confidence_percent:.1f}%</span>
                    </div>
                    <div style="background: #333; border-radius: 3px; height: 20px;">
                        <div style="background: {bar_color}; width: {bar_width}px; height: 100%; border-radius: 3px;"></div>
                    </div>
                </div>
            """
            
            self.confidence_labels[i].setText(bar_html)
            self.confidence_labels[i].setVisible(True)
        
        # Show medical advice
        if "No Tumor" not in predicted_class:
            self.show_medical_advice(predicted_class)
    
    def on_prediction_error(self, error_message):
        """Handle prediction errors"""
        self.progress_bar.setVisible(False)
        self.analyze_btn.setEnabled(True)
        QMessageBox.critical(self, "Analysis Error", f"Failed to analyze image:\n{error_message}")
        self.status_label.setText("Analysis failed")
    
    def show_medical_advice(self, tumor_type):
        """Show medical advice for detected tumors"""
        advice_map = {
            "Glioma": "Gliomas are tumors that occur in the brain and spinal cord. Consult a neurologist immediately for further evaluation and treatment planning.",
            "Meningioma": "Meningiomas are tumors that arise from the meninges. While often benign, they should be evaluated by a neurosurgeon for potential treatment.",
            "Pituitary": "Pituitary tumors affect the pituitary gland at the base of the brain. Endocrine evaluation and neurosurgical consultation are recommended."
        }
        
        advice = advice_map.get(tumor_type, "Please consult with a medical professional for proper diagnosis and treatment.")
        
        self.results_text.append(f"""
            <br>
            <div style="background: #ff4444; padding: 10px; border-radius: 5px; color: white;">
                <b>‚ö†Ô∏è Medical Recommendation:</b><br>
                {advice}
            </div>
        """)

def main():
    app = QApplication(sys.argv)
    
    # Set application properties
    app.setApplicationName("Brain Tumor Classification System")
    app.setApplicationVersion("1.0")
    
    # Create and show main window
    window = BrainTumorGUI()
    window.show()
    
    sys.exit(app.exec_())

if __name__ == "__main__":
    main()

üéØ Loading model from: outputs/best_model.h5
üéØ Class names: ['glioma', 'meningioma', 'notumor', 'pituitary']


  checkpoint = torch.load(model_path)


üîç Checkpoint type: <class 'collections.OrderedDict'>
üîç Checkpoint keys: odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.0.downsample.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'lay

SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
