In [1]:
import os
import torch
import pandas as pd
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import torchvision.transforms as transforms
from tkinter import ttk
import torch.nn as nn
from efficientnet_pytorch import EfficientNet
from efficientnetB7_model import EfficientNetB7Model  # Import my custom model class
from collections import OrderedDict

print("All required libraries are installed and working!")


All required libraries are installed and working!


# .pth module compatible loading

In [2]:
# Function to remove `module.` prefix from keys
def remove_module_prefix(state_dict):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace("module.", "")  # Remove `module.` prefix
        new_state_dict[new_key] = v
    return new_state_dict

# Define the base directory and model path
BASE_DIR = "/Users/hafeez/Desktop"
MODEL_PATH = os.path.join(
    BASE_DIR, 'Thesis_Hafeez', 'Thesis_Code', 
    'Enhanced-Skin-Lesion-detection-using-Deep-Learning-model', 
    'results', 'output', 'efficientnetb7_trained_model.pth'
)

# Ensure compatibility with CPU-only systems
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the custom model with the correct metadata features count
num_metadata_features = 3  # Number of metadata features: sex, age, and anatomy
model = EfficientNetB7Model(num_metadata_features)
model = model.to(device)

# Load the state_dict from the checkpoint
try:
    # Check if the model path exists
    if not os.path.exists(MODEL_PATH):
        print(f"Error: Model path does not exist at {MODEL_PATH}")
        raise FileNotFoundError(f"{MODEL_PATH} not found.")

    # Attempt to load the checkpoint
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    
    # Clean the state_dict by removing 'module.' prefix if it exists
    if any(key.startswith("module.") for key in checkpoint.keys()):
        print("Detected 'module.' prefix in keys. Cleaning state_dict...")
        checkpoint = remove_module_prefix(checkpoint)

    # Check if checkpoint keys match the model keys
    checkpoint_keys = set(checkpoint.keys())
    model_keys = set(model.state_dict().keys())
    if checkpoint_keys != model_keys:
        print("Warning: Mismatched keys detected between checkpoint and model.")
        print("Mismatched keys:", checkpoint_keys.difference(model_keys))
    
    # Try strict loading with the cleaned state_dict
    try:
        model.load_state_dict(checkpoint, strict=True)
        print("Model loaded successfully with strict=True.")
    except RuntimeError as e:
        print("Strict loading failed. Error:", e)
        print("Attempting to load with strict=False...")
        model.load_state_dict(checkpoint, strict=False)
        print("Model loaded successfully with strict=False. Warning: Verify mismatched keys.")
    
    # Save the cleaned model only if 'module.' was present
    if any(key.startswith("module.") for key in checkpoint.keys()):
        CLEANED_MODEL_PATH = os.path.join(
            BASE_DIR, 'Thesis_Hafeez', 'Thesis_Code', 
            'Enhanced-Skin-Lesion-detection-using-Deep-Learning-model', 
            'results', 'output', 'efficientnetb7_cleaned_model.pth'
        )
        torch.save(model.state_dict(), CLEANED_MODEL_PATH)
        print(f"Cleaned model saved at {CLEANED_MODEL_PATH}")

except FileNotFoundError as fnfe:
    print(f"Checkpoint file not found at {MODEL_PATH}. Error: {fnfe}")
except Exception as ex:
    print(f"An error occurred: {ex}")


Loaded pretrained weights for efficientnet-b7
Detected 'module.' prefix in keys. Cleaning state_dict...
Model loaded successfully with strict=True.


# Define Paths and Load the Model

In [3]:
# Define the base directory and paths
BASE_DIR = "/Users/hafeez/Desktop"
IMAGE_DIR = os.path.join(BASE_DIR, 'Thesis_Hafeez', 'Dataset', 'Predictions', 'ISIC_2020_Test')
PRED_CSV_PATH = os.path.join(BASE_DIR, 'Thesis_Hafeez', 'Dataset', 'Predictions', 'ISIC_2020_Test_Metadata.csv')
CLEANED_MODEL_PATH = os.path.join(BASE_DIR, 'Thesis_Hafeez', 'Thesis_Code', 
                                   'Enhanced-Skin-Lesion-detection-using-Deep-Learning-model', 
                                   'results', 'output', 'efficientnetb7_cleaned_model.pth')


In [4]:
from efficientnetB7_model import EfficientNetB7Model  # Import your custom model

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize your custom model with the correct number of metadata features
num_metadata_features = 3  # Adjust this based on your use case
model = EfficientNetB7Model(num_metadata_features)

# Ensure compatibility with your device
model = model.to(DEVICE)

# Load the state_dict
checkpoint = torch.load(CLEANED_MODEL_PATH, map_location=DEVICE)

# Load the weights into your model
model.load_state_dict(checkpoint, strict=True)

# Set the model to evaluation mode
model.eval()

print("Custom model loaded and ready for inference.")

Loaded pretrained weights for efficientnet-b7
Custom model loaded and ready for inference.


# Load Metadata and Define Image Transform

In [5]:
# Load metadata CSV
try:
    ground_truth_df = pd.read_csv(PRED_CSV_PATH)
    print("Metadata loaded successfully.")
except FileNotFoundError:
    print(f"Error: Metadata CSV not found at {PRED_CSV_PATH}")
    raise

# Define the transform for input images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("Transformations defined for image preprocessing.")


Metadata loaded successfully.
Transformations defined for image preprocessing.


In [6]:
# Define mappings for sex and anatomy categories
sex_mapping = {
    'male': 0,
    'female': 1
}

anatomy_mapping = {
    'head/neck': 0,
    'torso': 1,
    'lower extremity': 2,
    'upper extremity': 3,
    'palms/soles': 4,
    'oral/genital': 5,
    # Add more mappings as needed based on your data
}

# Utility Functions for Prediction and Metadata Handling

In [7]:
# Function to predict image using the model
def predict_image(image_path, metadata):
    try:
        # Load and transform the image
        image = Image.open(image_path).convert('RGB')
        input_tensor = transform(image).unsqueeze(0).to(DEVICE)

        # Prepare metadata tensor
        age, sex, anatomy = metadata
        if age is None or sex is None or anatomy is None:
            raise ValueError(f"Incomplete metadata for image: {os.path.basename(image_path)}")

        # Map metadata to tensor values
        metadata_tensor = torch.tensor(
            [[age, sex_mapping.get(sex, -1), anatomy_mapping.get(anatomy, -1)]]
        ).float().to(DEVICE)

        # Perform prediction
        with torch.no_grad():
            output = model(input_tensor, metadata_tensor).squeeze()
            probability = torch.sigmoid(output).item()
            if not (0.0 <= probability <= 1.0): 
                raise ValueError(f"Invalid probability value: {probability}")
            return probability
    except Exception as e:
        print(f"Error predicting image: {e}")
        return None

# Function to get metadata from CSV
def get_metadata(image_name):
    try:
        # Match the image name with the metadata
        ground_truth_row = ground_truth_df[ground_truth_df['image_name'] == image_name]
        if not ground_truth_row.empty:
            # Extract metadata fields
            age = ground_truth_row.iloc[0]['age_approx']
            sex = ground_truth_row.iloc[0]['sex']
            anatomy = ground_truth_row.iloc[0]['anatom_site_general']

            # Handle missing values or unknown fields
            if pd.isna(age) or pd.isna(sex) or pd.isna(anatomy):
                print(f"Warning: Missing metadata for image: {image_name}")
                return None, None, None

            return age, sex, anatomy
        else:
            print(f"Warning: Image not found in metadata CSV: {image_name}")
            return None, None, None
    except Exception as e:
        print(f"Error accessing metadata: {e}")
        return None, None, None

# GUI Setup for Multi-Image Selection and Prediction

In [8]:
# Function to calculate final prediction based on hybrid approach
def calculate_final_prediction(probabilities):
    """
    Calculates the final prediction based on average probability and high-confidence predictions.
    """
    if not probabilities:
        raise ValueError("Probability list is empty. Cannot calculate final prediction.")
    
    avg_probability = sum(probabilities) / len(probabilities)
    high_confidence_exists = any(prob > 0.9 for prob in probabilities)

    # Final decision logic
    if high_confidence_exists and avg_probability > 0.5:
        return 'malignant'
    elif avg_probability > 0.5:
        return 'malignant'
    else:
        return 'benign'

In [None]:
# old approach
'''''
# GUI for multi-image selection and prediction
def select_images():
    # Ask user to select multiple images
    file_paths = filedialog.askopenfilenames(initialdir=IMAGE_DIR, title="Select Images",
                                             filetypes=[("All files", "*.*")])
    if file_paths:
        try:
            probabilities = []
            result_text = ""

            for file_path in file_paths:
                # Get the image name (with extension)
                image_name = os.path.basename(file_path)

                # Fetch metadata from CSV
                metadata = get_metadata(image_name)

                # Ensure metadata exists
                if None in metadata:
                    result_text += f"\nError: Metadata missing for {image_name}. Skipping."
                    continue

                # Make prediction
                probability = predict_image(file_path, metadata)
                if probability is not None:
                    probabilities.append(probability)
                    age, sex, anatomy = metadata

                    # Map numerical metadata back to categorical values for display
                    sex_str = next((key for key, value in sex_mapping.items() if value == sex_mapping[sex]), "Unknown")
                    anatomy_str = next((key for key, value in anatomy_mapping.items() if value == anatomy_mapping[anatomy]), "Unknown")

                    result_text += (
                        f"\nImage: {image_name} | "
                        f"Probability: {probability:.2f} | "
                        f"Metadata: Age: {age}, Sex: {sex_str}, Anatomy: {anatomy_str}"
                    )

            # Calculate final prediction
            if probabilities:
                final_prediction = calculate_final_prediction(probabilities)
                result_text += f"\n\nFinal Prediction: {final_prediction.capitalize()}"

            # Update the result label
            result_label.config(text=result_text, justify='left', font=('Helvetica', 12), foreground='#003366')
        except Exception as e:
            messagebox.showerror("Error", f"An error occurred: {e}")

# Set up the GUI
root = tk.Tk()
root.title("Skin Lesion Classification Tool")
root.geometry("800x900")
root.configure(bg='#e6f2ff')

# Title label
title_label = ttk.Label(root, text="Skin Lesion Classification Tool", font=('Helvetica', 20, 'bold'), background='#e6f2ff')
title_label.pack(pady=20)

# Image display panel
panel = tk.Label(root, bg='#e6f2ff', borderwidth=2, relief="groove")
panel.pack(pady=20)

# Button to select images
btn = ttk.Button(root, text="Select Images", command=select_images)
btn.pack(pady=10)

# Label to display prediction and ground truth
result_label = ttk.Label(root, text="", wraplength=700)
result_label.pack(pady=20)

# Add university logo
try:
    logo_path = os.path.join(BASE_DIR, 'Thesis_Hafeez', 'Thesis_Code', 'Enhanced-Skin-Lesion-detection-using-Deep-Learning-model', 'results', 'uds.jpg')
    logo_image = Image.open(logo_path).resize((220, 150), Image.LANCZOS)  # Use LANCZOS for resizing
    logo_photo = ImageTk.PhotoImage(logo_image)
    logo_label = tk.Label(root, image=logo_photo, bg='#e6f2ff')
    logo_label.image = logo_photo
    logo_label.pack(side='bottom', pady=20)
except (FileNotFoundError, OSError) as e:
    print(f"Error loading university logo: {e}")


# Run the GUI
root.mainloop()
'''''

In [9]:
# Initialize the main Tkinter window
root = tk.Tk()
root.title("Skin Lesion Classification Tool")
root.geometry("900x1000")
root.configure(bg='#e6f2ff')

# GUI for multi-image selection and prediction
def select_images():
    file_paths = filedialog.askopenfilenames(
        initialdir=IMAGE_DIR,
        title="Select Images",
        filetypes=[("Image files", "*.jpg *.jpeg *.png")]  # Allow common image formats
    )
    if file_paths:
        try:
            # Clear existing widgets in the predictions frame
            for widget in predictions_frame.winfo_children():
                widget.destroy()

            probabilities = []
            images_metadata = []
            grid_entries = []  # To ensure all selected images are displayed, even with missing metadata

            for file_path in file_paths:
                image_name = os.path.basename(file_path)
                metadata = get_metadata(image_name)

                if None in metadata:
                    # Include the image with a missing metadata warning
                    grid_entries.append({
                        'image_name': image_name,
                        'probability': None,
                        'age': None,
                        'sex': None,
                        'anatomy': None,
                        'file_path': file_path,
                        'warning': f"Metadata missing for {image_name}. Skipping prediction."
                    })
                    continue

                probability = predict_image(file_path, metadata)
                if probability is not None:
                    probabilities.append(probability)
                    age, sex, anatomy = metadata

                    sex_str = next((key for key, value in sex_mapping.items() if value == sex_mapping[sex]), "Unknown")
                    anatomy_str = next((key for key, value in anatomy_mapping.items() if value == anatomy_mapping[anatomy]), "Unknown")

                    grid_entries.append({
                        'image_name': image_name,
                        'probability': probability,
                        'age': age,
                        'sex': sex_str,
                        'anatomy': anatomy_str,
                        'file_path': file_path,
                        'warning': None
                    })

            # Display predictions in a grid
            num_columns = 3  # Adjust this for the desired number of columns
            for idx, data in enumerate(grid_entries):
                row = idx // num_columns
                col = idx % num_columns

                frame = tk.Frame(predictions_frame, bg='#e6f2ff', relief='groove', borderwidth=1)
                frame.grid(row=row, column=col, padx=10, pady=10, sticky='nsew')

                # Display image thumbnail
                img = Image.open(data['file_path']).resize((100, 100), Image.LANCZOS)
                img_thumbnail = ImageTk.PhotoImage(img)
                img_label = tk.Label(frame, image=img_thumbnail, bg='#e6f2ff')
                img_label.image = img_thumbnail
                img_label.pack()

                # Display metadata and probability or warning
                if data['warning']:
                    text = f"Image: {data['image_name']}\n{data['warning']}"
                    tk.Label(frame, text=text, font=('Helvetica', 12), bg='#e6f2ff', fg='red', justify='center').pack()
                else:
                    text = (f"Image: {data['image_name']}\n"
                            f"Probability: {data['probability']:.2f}\n"
                            f"Age: {data['age']}, Sex: {data['sex']}, Anatomy: {data['anatomy']}")
                    tk.Label(frame, text=text, font=('Helvetica', 12), bg='#e6f2ff', justify='center').pack()

            # Calculate and display final prediction
            if probabilities:
                final_prediction = calculate_final_prediction(probabilities)
                final_prediction_text = f"Final Prediction: {final_prediction.capitalize()}"
                tk.Label(predictions_frame, text=final_prediction_text,
                         font=('Helvetica', 16, 'bold'), fg='#003366', bg='#e6f2ff').grid(
                    row=(len(grid_entries) // num_columns) + 1, column=0, columnspan=num_columns, pady=10)

        except Exception as e:
            messagebox.showerror("Error", f"An error occurred: {e}")

# Scrollable frame setup
scrollbar_canvas = tk.Canvas(root, bg='#e6f2ff')
scrollbar_frame = tk.Frame(scrollbar_canvas, bg='#e6f2ff')
scrollbar = tk.Scrollbar(root, orient='vertical', command=scrollbar_canvas.yview)
scrollbar_canvas.configure(yscrollcommand=scrollbar.set)

scrollbar_canvas.pack(side='left', fill='both', expand=True)
scrollbar.pack(side='right', fill='y')

scrollbar_window = scrollbar_canvas.create_window((0, 0), window=scrollbar_frame, anchor="nw")

def on_frame_configure(event):
    scrollbar_canvas.configure(scrollregion=scrollbar_canvas.bbox("all"))

scrollbar_frame.bind("<Configure>", on_frame_configure)

# Predictions frame inside the scrollable area
predictions_frame = tk.Frame(scrollbar_frame, bg='#e6f2ff')
predictions_frame.pack(fill='both', expand=True)

# Button to select images
btn = ttk.Button(root, text="Select Images", command=select_images)
btn.pack(pady=10)

# Add university logo
try:
    logo_path = os.path.join(BASE_DIR, 'Thesis_Hafeez', 'Thesis_Code', 'Enhanced-Skin-Lesion-detection-using-Deep-Learning-model', 'results', 'uds.jpg')
    logo_image = Image.open(logo_path).resize((220, 150), Image.LANCZOS)
    logo_photo = ImageTk.PhotoImage(logo_image)
    logo_label = tk.Label(root, image=logo_photo, bg='#e6f2ff')
    logo_label.image = logo_photo
    logo_label.pack(side='bottom', pady=20)
except (FileNotFoundError, OSError) as e:
    print(f"Error loading university logo: {e}")

# Run the GUI
root.mainloop()




: 