In [10]:

import os
import torch
import torch.nn as nn
from PIL import Image, ImageTk
import tkinter as tk
from tkinter import filedialog, messagebox
from torchvision import transforms, datasets
import customtkinter as ctk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk


MODEL_PATH = r"D:\mohamed\deeplearning\project\project\cnn_manual_50.pth"
DATA_ROOT_FOR_CLASSES = r"D:\mohamed\deeplearning\project\project\ImageFolder101\train"
IMG_SIZE = 50


class CNN_Manual_50(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Dropout(0.25),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 6 * 6, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


predict_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                        [0.229, 0.224, 0.225]),
])


def load_class_mapping(data_root_train):
    ds = datasets.ImageFolder(data_root_train)
    class_to_idx = ds.class_to_idx

    idx_to_class = {v: k for k, v in class_to_idx.items()}

    readable = {"class0": "Benign", "class1": "Malignant"}
    idx_to_readable = {i: readable.get(name, name) for i, name in idx_to_class.items()}

    return class_to_idx, idx_to_class, idx_to_readable


def load_checkpoint(model_path, device, num_classes):
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model not found: {model_path}")

    ckpt = torch.load(model_path, map_location=device)

    if isinstance(ckpt, dict) and "model_state" in ckpt:
        state = ckpt["model_state"]
    else:
        state = ckpt

    model = CNN_Manual_50(num_classes=num_classes).to(device)
    model.load_state_dict(state, strict=True)
    model.eval()
    return model


@torch.no_grad()
def predict(model, image_path, device, idx_to_readable):
    img = Image.open(image_path).convert("RGB")
    x = predict_transform(img).unsqueeze(0).to(device)

    logits = model(x)
    probs = torch.softmax(logits, dim=1)[0]
    pred_idx = int(torch.argmax(probs))
    conf = float(probs[pred_idx].item())

    return idx_to_readable[pred_idx], conf

ctk.set_appearance_mode("dark")   # "dark" لو حابب
ctk.set_default_color_theme("blue")

class App(ctk.CTk):
    def __init__(self, model, device, idx_to_readable):
        super().__init__()

        self.model = model
        self.device = device
        self.idx_to_readable = idx_to_readable

        self.current_path = None
        self.img_tk = None

        self.title("Breast Image Classifier")
        self.geometry("700x560")
        self.resizable(False, False)
        ctk.CTkLabel(
            self,
            text="Breast Image Classifier",
            font=ctk.CTkFont(size=22, weight="bold")
        ).pack(pady=15)

        self.image_frame = ctk.CTkFrame(self, width=420, height=280)
        self.image_frame.pack(pady=10)

        self.image_frame.pack_propagate(False)

        self.img_label = ctk.CTkLabel(
            self.image_frame,
            text="No image selected",
            fg_color="#eaeaea",
            corner_radius=10,
            width=400,
            height=260
        )
        self.img_label.pack(padx=10, pady=10)

        self.btn_frame = ctk.CTkFrame(self, fg_color="transparent")
        self.btn_frame.pack(pady=10)

        ctk.CTkButton(
            self.btn_frame,
            text="Choose Image",
            width=180,
            command=self.choose_image
        ).grid(row=0, column=0, padx=10)

        ctk.CTkButton(
            self.btn_frame,
            text="Predict",
            width=180,
            command=self.run_predict
        ).grid(row=0, column=1, padx=10)

        self.result_var = ctk.StringVar(value="Prediction: -")
        ctk.CTkLabel(
            self,
            textvariable=self.result_var,
            font=ctk.CTkFont(size=16, weight="bold")
        ).pack(pady=20)

    def choose_image(self):
        path = filedialog.askopenfilename(
            filetypes=[("Images", "*.png *.jpg *.jpeg")]
        )
        if not path:
            return

        self.current_path = path
        img = Image.open(path).convert("RGB")
        img.thumbnail((380, 250))

        self.img_tk = ImageTk.PhotoImage(img)
        self.img_label.configure(image=self.img_tk, text="")
        self.result_var.set("Prediction: -")

    def run_predict(self):
        if not self.current_path:
            messagebox.showwarning("No image", "Please choose an image first.")
            return

        try:
            label, conf = predict(
                self.model,
                self.current_path,
                self.device,
                self.idx_to_readable
            )
            self.result_var.set(f"Prediction: {label}   (conf = {conf:.2f})")
        except Exception as e:
            messagebox.showerror("Error", str(e))


def main_1():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    class_to_idx, idx_to_class, idx_to_readable = load_class_mapping(DATA_ROOT_FOR_CLASSES)
    num_classes = len(class_to_idx)

    model = load_checkpoint(MODEL_PATH, device, num_classes)

    app = App(model, device, idx_to_readable)
    app.mainloop()


if __name__ == "__main__":
    main_1()


Using device: cpu



