In [None]:
# brain_tumor_gui_notebook.py
import sys, os
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms, models
from PIL import Image

from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QLabel, QPushButton, QFileDialog,
    QVBoxLayout, QWidget, QHBoxLayout, QListWidget, QListWidgetItem, QMessageBox
)
from PyQt5.QtGui import QPixmap, QImage
from PyQt5.QtCore import Qt, QTimer

# USER CONFIG 
ARCHIVE_DEFAULT = Path(r"C:\Users\ACER\Desktop\archive (3)")
DEFAULT_MODEL_NAMES = [
    ARCHIVE_DEFAULT / "trained_brain_tumor_model.h5",
    Path("outputs") / "best_brain_tumor_model.pth"
]
IMAGE_SIZE = 224
MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]
DEFAULT_CLASS_NAMES = ['glioma', 'meningioma', 'pituitary', 'notumor']


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

val_transform = transforms.Compose([
    transforms.Resize(int(IMAGE_SIZE * 1.14)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

def pil_image_to_qpixmap(pil_image, max_size=(520,520)):
    """Convert PIL->QPixmap robustly using numpy -> QImage."""
    if pil_image is None:
        return None
    # ensure PIL mode
    try:
        mode = pil_image.mode
    except Exception:
        pil_image = pil_image.convert('RGB')
        mode = 'RGB'
    if mode not in ('RGB', 'L'):
        try:
            pil_image = pil_image.convert('RGB')
            mode = 'RGB'
        except Exception:
            pil_image = pil_image.convert('L')
            mode = 'L'

    # resize
    w,h = pil_image.size
    max_w, max_h = max_size
    scale = min(max_w/w, max_h/h, 1.0)
    new_w, new_h = int(w*scale), int(h*scale)
    if (new_w,new_h) != (w,h):
        try:
            pil_image = pil_image.resize((new_w,new_h), Image.LANCZOS)
        except Exception:
            pil_image = pil_image.resize((new_w,new_h), Image.ANTIALIAS)

    arr = np.ascontiguousarray(np.asarray(pil_image))
    if arr.dtype != np.uint8:
        arr = (255 * (arr / np.max(arr))).astype(np.uint8)

    try:
        if arr.ndim == 2:
            height, width = arr.shape
            bytes_per_line = width
            qimg = QImage(arr.data, width, height, bytes_per_line, QImage.Format_Grayscale8)
        elif arr.shape[2] == 3:
            height, width, _ = arr.shape
            bytes_per_line = 3 * width
            qimg = QImage(arr.data, width, height, bytes_per_line, QImage.Format_RGB888)
        elif arr.shape[2] == 4:
            arr = arr[:, :, :3]
            height, width, _ = arr.shape
            bytes_per_line = 3 * width
            qimg = QImage(arr.data, width, height, bytes_per_line, QImage.Format_RGB888)
        else:
            return None
        return QPixmap.fromImage(qimg)
    except Exception as e:
        # fallback conversion
        try:
            pil2 = pil_image.convert('RGB')
            arr2 = np.ascontiguousarray(np.asarray(pil2))
            h2,w2,_ = arr2.shape
            qimg = QImage(arr2.data, w2, h2, 3*w2, QImage.Format_RGB888)
            return QPixmap.fromImage(qimg)
        except Exception:
            print("pil->qpixmap failed:", e)
            return None

class ModelLoader:
    """Load full saved model objects or checkpoint dicts with state_dict + classes."""
    def __init__(self, device):
        self.device = device
        self.model = None
        self.classes = DEFAULT_CLASS_NAMES

    def infer_backbone_from_state_dict(self, state_dict):
        keys = list(state_dict.keys())
        if any('layer1' in k or 'layer2' in k for k in keys) or any('fc.' in k for k in keys):
            return 'resnet'
        if any('features' in k or 'classifier' in k for k in keys):
            return 'efficientnet'
        return 'resnet'

    def build_model_for_state_dict(self, state_dict, num_classes):
        backbone = self.infer_backbone_from_state_dict(state_dict)
        if backbone == 'resnet':
            from torchvision.models import ResNet50_Weights
            model = models.resnet50(weights=None)
            in_features = model.fc.in_features
            model.fc = torch.nn.Linear(in_features, num_classes)
            return model
        elif backbone == 'efficientnet':
            from torchvision.models import EfficientNet_B0_Weights
            model = models.efficientnet_b0(weights=None)
            in_features = model.classifier[1].in_features
            model.classifier[1] = torch.nn.Linear(in_features, num_classes)
            return model
        else:
            from torchvision.models import ResNet50_Weights
            model = models.resnet50(weights=None)
            in_features = model.fc.in_features
            model.fc = torch.nn.Linear(in_features, num_classes)
            return model

    def load(self, path):
        path = Path(path)
        if not path.exists():
            raise FileNotFoundError(f"Model file not found: {path}")
        loaded = torch.load(str(path), map_location=self.device)
        # full model object
        if isinstance(loaded, torch.nn.Module):
            self.model = loaded.to(self.device)
            self.model.eval()
            return
        # checkpoint-like dict
        if isinstance(loaded, dict):
            if 'classes' in loaded and isinstance(loaded['classes'], (list,tuple)):
                self.classes = list(loaded['classes'])
            state = None
            for key in ('model_state', 'model_state_dict', 'state_dict', 'model'):
                if key in loaded:
                    state = loaded[key]; break
            if state is None:
                state = loaded
            if isinstance(state, dict):
                num_classes = len(self.classes) if self.classes else len(DEFAULT_CLASS_NAMES)
                model = self.build_model_for_state_dict(state, num_classes)
                model.load_state_dict(state)
                self.model = model.to(self.device)
                self.model.eval()
                return
        raise RuntimeError("Unsupported model file format. Save with torch.save(model) or torch.save({'model_state': model.state_dict(), 'classes': [...]})")

    def predict(self, pil_image):
        if self.model is None:
            raise RuntimeError("Model not loaded")
        img = val_transform(pil_image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            outputs = self.model(img)
            probs = F.softmax(outputs, dim=1).cpu().numpy()[0]
        return probs

class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Brain Tumor Classifier (Notebook)")
        self.resize(920,620)

        self.model_loader = ModelLoader(device)

        # widgets
        self.image_label = QLabel("No image loaded")
        self.image_label.setAlignment(Qt.AlignCenter)
        self.image_label.setFixedSize(520,520)
        self.image_label.setStyleSheet("border:1px solid gray;")

        self.pred_label = QLabel("Model: not loaded")
        self.pred_label.setWordWrap(True)
        self.pred_label.setAlignment(Qt.AlignLeft)

        self.load_model_btn = QPushButton("Load Model")
        self.load_model_btn.clicked.connect(self.load_model)

        self.load_image_btn = QPushButton("Load Image")
        self.load_image_btn.clicked.connect(self.load_image)

        self.predict_btn = QPushButton("Predict")
        self.predict_btn.clicked.connect(self.run_prediction)
        self.predict_btn.setEnabled(False)

        self.prob_list = QListWidget()

        # layouts
        left = QVBoxLayout()
        left.addWidget(self.image_label)
        left.addSpacing(6)
        left.addWidget(self.pred_label)

        right = QVBoxLayout()
        right.addWidget(self.load_model_btn)
        right.addWidget(self.load_image_btn)
        right.addWidget(self.predict_btn)
        right.addSpacing(12)
        right.addWidget(QLabel("Class probabilities:"))
        right.addWidget(self.prob_list)

        top = QHBoxLayout()
        top.addLayout(left)
        top.addLayout(right)

        central = QWidget()
        central.setLayout(top)
        self.setCentralWidget(central)

        self.current_image = None
        self.current_image_path = None

        # schedule auto-load after event loop starts
        QTimer.singleShot(0, self.try_autoload_default_model)

    def safe_set_label_text(self, label, text):
        try:
            label.setText(text)
        except RuntimeError as e:
            print("Warning: widget update failed:", e)

    def try_autoload_default_model(self):
        for p in DEFAULT_MODEL_NAMES:
            try:
                if p.exists():
                    self.model_loader.load(p)
                    self.safe_set_label_text(self.pred_label, f"Model loaded from: {p}\nDevice: {device}")
                    self.predict_btn.setEnabled(True)
                    print("Auto-loaded model from", p)
                    return
            except Exception as e:
                print("Auto-load failed for", p, ":", e)
        self.safe_set_label_text(self.pred_label, 'Model not loaded. Click "Load Model" to pick model file')

    def load_model(self):
        path, _ = QFileDialog.getOpenFileName(self, "Select model file", str(ARCHIVE_DEFAULT),
                                              "PyTorch model (*.pth *.pt *.h5);;All files (*)")
        if not path:
            return
        try:
            self.model_loader.load(path)
            self.safe_set_label_text(self.pred_label, f"Model loaded from: {path}\nDevice: {device}")
            self.predict_btn.setEnabled(True)
            QMessageBox.information(self, "Model loaded", "Model loaded successfully.")
        except Exception as e:
            QMessageBox.critical(self, "Load error", f"Failed to load model:\n{e}")

    def load_image(self):
        path, _ = QFileDialog.getOpenFileName(self, "Select image", os.getcwd(),
                                              "Images (*.png *.jpg *.jpeg *.bmp);;All files (*)")
        if not path:
            return
        try:
            # when opening the image file
            pil_img = Image.open(path)
            # force RGB immediately
            try:
                pil_img = pil_img.convert('RGB')
            except Exception:
                pil_img = Image.fromarray(np.asarray(pil_img)).convert('RGB')
            self.current_image = pil_img
            self.current_image_path = path

            print("Opened image:", path, "size:", pil_img.size, "mode:", pil_img.mode)
        except Exception as e:
            QMessageBox.critical(self, "Image error", f"Unable to open image:\n{e}")
            print("Image open error:", e)
            return
        self.current_image = pil_img
        self.current_image_path = path

        pix = pil_image_to_qpixmap(pil_img, max_size=(520,520))
        if pix is None:
            QMessageBox.critical(self, "Display error", "Failed to create displayable image. See notebook output.")
            return
        self.image_label.setPixmap(pix)
        self.safe_set_label_text(self.pred_label, "Image loaded. Click Predict to run inference.")
        self.predict_btn.setEnabled(self.model_loader.model is not None)

    def run_prediction(self):
        if self.current_image is None:
            QMessageBox.warning(self, "No image", "Please load an image first.")
            return
        if self.model_loader.model is None:
            QMessageBox.warning(self, "No model", "Please load a model first.")
            return
        try:
            probs = self.model_loader.predict(self.current_image)
        except Exception as e:
            QMessageBox.critical(self, "Inference error", f"Model inference failed:\n{e}")
            print("Inference error:", e)
            return

        classes = list(self.model_loader.classes) if self.model_loader.classes else DEFAULT_CLASS_NAMES
        if len(classes) != len(probs):
            classes = [f"class_{i}" for i in range(len(probs))]

        top_idx = int(np.argmax(probs))
        top_name = classes[top_idx]
        top_prob = float(probs[top_idx])
        self.safe_set_label_text(self.pred_label, f"Prediction: {top_name} ({top_prob*100:.2f}%)\nModel device: {device}")

        self.prob_list.clear()
        for i,p in enumerate(probs):
            item = QListWidgetItem(f"{classes[i]}: {p*100:.2f}%")
            self.prob_list.addItem(item)


# run_app

def run_app():
    # If running inside Jupyter, enable Qt event loop and avoid exec_()
    try:
        from IPython import get_ipython
        ip = get_ipython()
    except Exception:
        ip = None

    if ip is not None:
        # Enable qt event loop for notebook
        try:
            ip.run_line_magic('gui', 'qt')
        except Exception:
            pass
        app = QApplication.instance()
        if app is None:
            app = QApplication([])
        win = MainWindow()
        win.show()
        # return handles so user can keep reference in notebook
        return win, app
    else:
        # standard script mode
        app = QApplication(sys.argv)
        win = MainWindow()
        win.show()
        sys.exit(app.exec_())

win_app = run_app()



Auto-loaded model from outputs\best_brain_tumor_model.pth
Opened image: C:/Users/ACER/OneDrive/Desktop/archive (3)/Testing/meningioma/Te-me_0011.jpg size: (200, 223) mode: RGB
