In [3]:
# trial.py

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import sys
import os
import numpy as np

# Import the model architecture
from model_architecture import get_model  # Ensure model_architecture.py is in the same directory

# Define class names
CLASS_NAMES = [
    'Eczema',
    'Warts Molluscum and other Viral Infections',
    'Melanoma',
    'Atopic Dermatitis',
    'Basal Cell Carcinoma (BCC)',
    'Melanocytic Nevi (NV)',
    'Benign Keratosis-like Lesions (BKL)',
    'Psoriasis pictures Lichen Planus and related diseases',
    'Seborrheic Keratoses and other Benign Tumors',
    'Tinea Ringworm Candidiasis and other Fungal Infections'
]

# Path to the model file
MODEL_PATH = 'skin_disease_model.pth'

def load_model(model_path=MODEL_PATH, num_classes=10):
    """
    Load the ML model from a .pth file.
    
    Args:
        model_path (str): Path to the .pth file containing the state_dict.
        num_classes (int): Number of output classes.
        
    Returns:
        model (nn.Module): Loaded machine learning model in evaluation mode.
    """
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"The model file {model_path} does not exist.")
    
    # Initialize the model architecture
    model = get_model(num_classes=num_classes)
    
    # Load the state_dict
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    
    # Handle cases where the state_dict was saved using DataParallel
    if isinstance(state_dict, dict) and 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    
    if isinstance(state_dict, dict):
        # Check if keys are prefixed with "module." (common when using DataParallel)
        if list(state_dict.keys())[0].startswith("module."):
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                new_key = k.replace("module.", "")
                new_state_dict[new_key] = v
            state_dict = new_state_dict
    
    # Load state_dict into the model
    model.load_state_dict(state_dict)
    
    # Set model to evaluation mode
    model.eval()
    
    return model

def preprocess_image(image_path, target_size=(224, 224)):
    """
    Load and preprocess the image for prediction.
    
    Args:
        image_path (str): Path to the input image.
        target_size (tuple): Desired image size (width, height).
        
    Returns:
        torch.Tensor: Preprocessed image tensor suitable for the model.
    """
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"The image file {image_path} does not exist.")
    
    try:
        # Define the transformation pipeline
        preprocess = transforms.Compose([
            transforms.Resize(target_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Adjust if different
                                 std=[0.229, 0.224, 0.225])   # Adjust if different
        ])
        
        image = Image.open(image_path).convert('RGB')
        image = preprocess(image)
        image = image.unsqueeze(0)  # Add batch dimension
        return image
    except Exception as e:
        raise ValueError(f"Error processing image: {e}")

def predict_disease(image_path, model):
    """
    Predict the disease from the input image using the loaded model.
    
    Args:
        image_path (str): Path to the input image.
        model (nn.Module): Loaded machine learning model.
        
    Returns:
        tuple: (Predicted disease, Confidence score)
    """
    preprocessed_image = preprocess_image(image_path)
    
    with torch.no_grad():
        outputs = model(preprocessed_image)
        # Apply softmax to get probabilities
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted_idx = torch.max(probabilities, 1)
        predicted_class = predicted_idx.item()
        confidence_score = confidence.item()
        
        if predicted_class < 0 or predicted_class >= len(CLASS_NAMES):
            disease = "Unknown Disease"
        else:
            disease = CLASS_NAMES[predicted_class]
        
        return disease, confidence_score

# Example usage
if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python trial.py <path_to_image>")
        sys.exit(1)
    
    image_file = sys.argv[1]
    
    try:
        # Load the model
        model = load_model()
        
        # Predict the disease
        disease, confidence = predict_disease(image_file, model)
        print(f"Predicted Disease: {disease} (Confidence: {confidence*100:.2f}%)")
    except Exception as error:
        print(f"Error: {error}")


ModuleNotFoundError: No module named 'model_architecture'