# IEEE Access Paper: A Deep Learning Approach Based on Explainable Artificial Intelligence for Skin Lesion Classification
**By: University of Engineering and Technology,Lahore,pakistan**

# Data Collection

In [2]:
import pandas as pd
import os

dataset_path = "/kaggle/input/all-isic-data-20240629"
metadata = pd.read_csv(os.path.join(dataset_path, "/kaggle/input/all-isic-data-20240629/metadata.csv"), low_memory=False)
image_dir = os.path.join(dataset_path, "images")

# Filter out NaN in diagnosis
metadata = metadata.dropna(subset=['diagnosis'])
print(f"Filtered metadata size: {len(metadata)}")
print("Unique diagnosis values:", metadata['diagnosis'].unique())

Filtered metadata size: 53826
Unique diagnosis values: ['nevus' 'melanoma' 'atypical melanocytic proliferation' 'scar'
 'solar lentigo' 'seborrheic keratosis' 'actinic keratosis'
 'basal cell carcinoma' 'squamous cell carcinoma' 'dermatofibroma'
 'vascular lesion' 'lichenoid keratosis' 'lentigo NOS' 'verruca'
 'clear cell acanthoma' 'angiofibroma or fibrous papule' 'angioma'
 'atypical spitz tumor' 'AIMP' 'neurofibroma' 'lentigo simplex'
 'acrochordon' 'angiokeratoma' 'other' 'cafe-au-lait macule'
 'pigmented benign keratosis' 'melanoma metastasis' 'pyogenic granuloma'
 'sebaceous adenoma' 'sebaceous hyperplasia' 'nevus spilus'
 'mucosal melanosis']


In [3]:
# Inspect diagnosis column
print("Diagnosis column values (first 10):")
print(metadata['diagnosis'].head(10))
print("\nUnique values in diagnosis:")
print(metadata['diagnosis'].unique())
print("\nValue counts:")
print(metadata['diagnosis'].value_counts(dropna=False))
print("\nData type:")
print(metadata['diagnosis'].dtype)

Diagnosis column values (first 10):
0        nevus
2        nevus
5        nevus
8        nevus
9        nevus
10       nevus
13    melanoma
15       nevus
16       nevus
17       nevus
Name: diagnosis, dtype: object

Unique values in diagnosis:
['nevus' 'melanoma' 'atypical melanocytic proliferation' 'scar'
 'solar lentigo' 'seborrheic keratosis' 'actinic keratosis'
 'basal cell carcinoma' 'squamous cell carcinoma' 'dermatofibroma'
 'vascular lesion' 'lichenoid keratosis' 'lentigo NOS' 'verruca'
 'clear cell acanthoma' 'angiofibroma or fibrous papule' 'angioma'
 'atypical spitz tumor' 'AIMP' 'neurofibroma' 'lentigo simplex'
 'acrochordon' 'angiokeratoma' 'other' 'cafe-au-lait macule'
 'pigmented benign keratosis' 'melanoma metastasis' 'pyogenic granuloma'
 'sebaceous adenoma' 'sebaceous hyperplasia' 'nevus spilus'
 'mucosal melanosis']

Value counts:
diagnosis
nevus                                 32697
melanoma                               7349
basal cell carcinoma                  

# Step-1: Image Preprocessing    
To enhance image quality and focus on the lesion areas, the following preprocessing steps are applied:​

1.  **Objective:** Enhance image quality, standardize dimensions, and focus on regions of interest (ROIs).

    I’ll preprocess images by:
    
    * Applying noise reduction (Gaussian blur).
    * Resizing to 224x224 (standard for models like ResNet).
    * Normalizing pixel values to [0, 1].
    * Cropping ROIs (simplified to central cropping for now).
1. **Implementation:**

Since the dataset is large, we’ll preprocess images on-the-fly during training to save memory, using a generator.
We’ll assume metadata.csv has columns like isic_id (image filename without extension) and diagnosis (label).

In [4]:
import pandas as pd
import os
from PIL import Image
import torch
from torchvision import transforms
from sklearn.preprocessing import LabelEncoder

class ISICDataset(torch.utils.data.Dataset):
    def __init__(self, metadata, image_dir, transform=None):
        self.metadata = metadata.reset_index(drop=True)  # Reset index after filtering
        self.image_dir = image_dir
        self.transform = transform
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(metadata['diagnosis'])

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_id = self.metadata.iloc[idx]['isic_id']  # Adjust if different
        img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
        try:
            image = Image.open(img_path).convert('RGB')
        except:
            return None  # Skip invalid images
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Preprocessing transforms
preprocess_transforms = transforms.Compose([
    transforms.Resize((256, 256)),  # Ensure images are large enough
    transforms.CenterCrop(224),  # Crop to 224x224
    transforms.GaussianBlur(kernel_size=5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Step-2: Data Augmentation
To increase data diversity and prevent overfitting, the following augmentation techniques are applied:​

* **Rotation:** Random rotations at various angles.​

* **Flipping:** Horizontal and vertical flips.​

* **Cropping:** Random crops to simulate zoom.​

* **Brightness and Contrast Adjustment:** Randomly altering brightness and contrast levels.​

* **Noise Addition:** Introducing random noise to images.

In [5]:
from torchvision import transforms

# Training transforms with augmentation
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),  # Ensure images are large enough
    transforms.RandomCrop(224),  # Now safe
    transforms.RandomRotation(40),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.GaussianBlur(kernel_size=5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation transforms
val_transforms = preprocess_transforms

# Step-3: Model Training with ResNet-18
Utilizing transfer learning, a pre-trained ResNet-18 model is fine-tuned for skin lesion classification. The final fully connected layer is modified to output predictions for the nine classes present in the ISIC 2019 dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split
import numpy as np

# Create datasets
train_dataset = ISICDataset(metadata, image_dir, transform=train_transforms)
val_dataset = ISICDataset(metadata, image_dir, transform=val_transforms)

# Split indices
indices = np.arange(len(metadata))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)

train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

# Create DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=2)

# Define model
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
num_classes = len(np.unique(metadata['diagnosis']))
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Move to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 50
best_val_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for data in train_loader:
        if data is None:
            continue
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}%")

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            if data is None:
                continue
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = 100 * correct / total
    print(f"Validation Acc: {val_acc:.2f}%")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
    if val_acc > 95:
        break

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 196MB/s] 


Epoch 1, Loss: 1.1809, Train Acc: 65.09%
Validation Acc: 67.71%
Epoch 2, Loss: 1.0468, Train Acc: 67.88%
Validation Acc: 69.07%
Epoch 3, Loss: 1.0009, Train Acc: 68.70%
Validation Acc: 69.30%
Epoch 4, Loss: 0.9675, Train Acc: 69.29%
Validation Acc: 70.42%


# Model Explainability with the LIME
To interpret the model's predictions, LIME (Local Interpretable Model-Agnostic Explanations) is employed. LIME provides visual explanations by highlighting regions in the image that most influenced the model's decision.

In [None]:
from lime import lime_image
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt

def explain_prediction(model, image_tensor):
    model.eval()
    image = image_tensor.numpy().transpose(1, 2, 0)

    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(
        image, 
        classifier_fn=lambda x: model(torch.tensor(x).permute(0, 3, 1, 2).float()).detach().numpy(),
        top_labels=1,
        hide_color=0,
        num_samples=1000
    )

    temp, mask = explanation.get_image_and_mask(
        explanation.top_labels[0],
        positive_only=True,
        num_features=5,
        hide_rest=False
    )

    plt.imshow(mark_boundaries(temp / 255.0, mask))
    plt.title('LIME Explanation')
    plt.axis('off')
    plt.show()


In [None]:
"""
Skin Lesion Classification with Explainable AI
Based on "A Deep Learning Approach Based on Explainable Artificial Intelligence for Skin Lesion Classification"

This implementation combines the best elements from the IEEE Access paper notebook
with a complete training, evaluation, and explainability pipeline.
"""

# ===== Step 1: Import Required Libraries =====
import os
import numpy as np
import pandas as pd
import random
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from scipy.ndimage import gaussian_filter

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Scikit-learn imports for evaluation
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, precision_score, recall_score, f1_score

# LIME for explainability
import lime
from lime import lime_image
from skimage.segmentation import mark_boundaries

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ===== Step 2: Define Constants =====
IMG_SIZE = 224
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 0.001
NUM_CLASSES = 8  # ISIC 2019 has 8 classes as mentioned in the paper

# Class names for ISIC 2019 dataset
CLASS_NAMES = [
    'Melanoma', 
    'Nevus',
    'Basal Cell Carcinoma',
    'Actinic Keratosis',
    'Benign Keratosis',
    'Dermatofibroma',
    'Vascular Lesion',
    'Squamous Cell Carcinoma'
]

# Define paths - update these according to your directory structure
BASE_PATH = "/kaggle/input/all-isic-data-20240629"
IMAGES_PATH = os.path.join(BASE_PATH, "/kaggle/input/all-isic-data-20240629/images")
METADATA_PATH = os.path.join(BASE_PATH, "/kaggle/input/all-isic-data-20240629/metadata.csv")
GROUND_TRUTH_PATH = os.path.join(BASE_PATH, "/kaggle/input/all-isic-data-20240629/metadata.csv")

# ===== Step 3: Image Preprocessing Functions =====
def preprocess_image(image_path):
    """
    Preprocess an image according to the IEEE Access paper methodology:
    - Center crop (ROI extraction)
    - Resize to 224x224
    - Zero padding
    - Gaussian noise reduction
    - Normalization
    """
    # Load image
    img = Image.open(image_path).convert('RGB')
    img_np = np.array(img)

    # Step 1: Center Crop (ROI Extraction)
    h, w = img_np.shape[:2]
    side = min(h, w)
    startx = w//2 - side//2
    starty = h//2 - side//2
    cropped = img_np[starty:starty+side, startx:startx+side]

    # Step 2: Resize to 224x224
    resized = cv2.resize(cropped, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)

    # Step 3: Zero Padding (if needed)
    top, bottom, left, right = (0, 0, 0, 0)
    max_side = max(resized.shape[:2])
    delta_w = max_side - resized.shape[1]
    delta_h = max_side - resized.shape[0]
    if delta_w > 0 or delta_h > 0:
        top, bottom = delta_h // 2, delta_h - (delta_h // 2)
        left, right = delta_w // 2, delta_w - (delta_w // 2)
        padded = cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    else:
        padded = resized

    # Step 4: Gaussian Noise Reduction
    denoised = gaussian_filter(padded, sigma=1)

    # Step 5: Normalize to [0,1]
    normalized = denoised / 255.0
    
    return normalized

# ===== Step 4: Dataset and DataLoader Classes =====
class SkinLesionDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None, preprocess=True):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.preprocess = preprocess
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        
        # Apply preprocessing as per IEEE paper
        if self.preprocess:
            image = preprocess_image(img_path)
            image = torch.from_numpy(image).float().permute(2, 0, 1)  # Convert to tensor and rearrange to [C, H, W]
        else:
            # Just load the image
            image = Image.open(img_path).convert('RGB')
        
        # Apply additional transformations (data augmentation)
        if self.transform:
            if self.preprocess:
                # If already preprocessed and tensor, handle differently
                image = self.transform(image)
            else:
                # If PIL Image
                image = self.transform(image)
            
        label = self.labels[idx]
        return image, label

def prepare_data():
    """
    Load the ISIC 2019 dataset, preprocess it, and create DataLoaders
    """
    print("Loading and preprocessing data...")
    
    # Load metadata and ground truth
    metadata = pd.read_csv(METADATA_PATH)
    ground_truth = pd.read_csv(GROUND_TRUTH_PATH)
    
    # Merge the dataframes
    full_data = pd.concat([metadata, ground_truth], axis=1)
    
    # Create image paths list and labels
    image_paths = [os.path.join(IMAGES_PATH, f"{img_id}.jpg") for img_id in full_data['image']]
    
    # Convert one-hot encoded labels to class indices
    class_columns = ['MEL', 'NV', 'BCC', 'AK', 'BKL', 'DF', 'VASC', 'SCC']
    labels = np.argmax(full_data[class_columns].values, axis=1)
    
    # Split data into train+val and test sets
    train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
        image_paths, labels, test_size=0.15, stratify=labels, random_state=42
    )
    
    # Further split train+val into train and val sets
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_val_paths, train_val_labels, test_size=0.15, stratify=train_val_labels, random_state=42
    )
    
    # Print dataset statistics
    print(f"Dataset split: Train={len(train_paths)}, Val={len(val_paths)}, Test={len(test_paths)}")
    
    # Define data augmentation for training
    train_transform = transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # No augmentation for validation and test sets
    val_test_transform = transforms.Compose([
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = SkinLesionDataset(train_paths, train_labels, transform=train_transform, preprocess=True)
    val_dataset = SkinLesionDataset(val_paths, val_labels, transform=val_test_transform, preprocess=True)
    test_dataset = SkinLesionDataset(test_paths, test_labels, transform=val_test_transform, preprocess=True)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    return train_loader, val_loader, test_loader, test_paths, test_labels

# ===== Step 5: Model Definition =====
def create_model():
    """
    Create a ResNet-18 model with transfer learning, as specified in the paper
    """
    print("Creating ResNet-18 model with transfer learning...")
    
    # Load pretrained ResNet-18 model
    model = models.resnet18(pretrained=True)
    
    # Freeze early layers but allow later layers to train
    # This is a modification from the notebook where all layers were frozen
    for name, param in model.named_parameters():
        if 'layer4' not in name and 'fc' not in name:  # Only train the last convolutional block and FC layer
            param.requires_grad = False
    
    # Modify the final fully connected layer for our classification task
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(0.5),  # Add dropout for regularization
        nn.Linear(num_ftrs, NUM_CLASSES)
    )
    
    model = model.to(device)
    
    return model

# ===== Step 6: Training Functions =====
def train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCHS):
    """
    Train the model and validate it on the validation set
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)
    
    # For tracking the best model
    best_val_acc = 0.0
    best_model_wts = None
    
    # Lists to store metrics for plotting
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    print("Beginning training...")
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
        # Calculate epoch training metrics
        epoch_train_loss = running_loss / len(train_loader.dataset)
        epoch_train_acc = correct / total
        train_losses.append(epoch_train_loss)
        train_accs.append(epoch_train_acc)
        
        print(f"Training Loss: {epoch_train_loss:.4f}, Training Accuracy: {epoch_train_acc:.4f}")
        
        # Validation phase
        model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                # Statistics
                running_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        # Calculate epoch validation metrics
        epoch_val_loss = running_loss / len(val_loader.dataset)
        epoch_val_acc = correct / total
        val_losses.append(epoch_val_loss)
        val_accs.append(epoch_val_acc)
        
        print(f"Validation Loss: {epoch_val_loss:.4f}, Validation Accuracy: {epoch_val_acc:.4f}")
        
        # Update learning rate scheduler
        scheduler.step(epoch_val_acc)
        
        # Save the best model
        if epoch_val_acc > best_val_acc:
            best_val_acc = epoch_val_acc
            best_model_wts = model.state_dict().copy()
            print(f"New best model saved with accuracy: {best_val_acc:.4f}")
            
            # Save the model checkpoint
            torch.save(model.state_dict(), 'best_model.pth')
        
        print("-" * 50)
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    
    # Plot training and validation metrics
    plot_training_metrics(train_losses, val_losses, train_accs, val_accs)
    
    return model

def plot_training_metrics(train_losses, val_losses, train_accs, val_accs):
    """
    Plot training and validation metrics
    """
    plt.figure(figsize=(12, 5))
    
    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    
    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_metrics.png')
    plt.close()

# ===== Step 7: Model Evaluation =====
def evaluate_model(model, test_loader):
    """
    Evaluate the model on the test set and print detailed metrics
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    print("Evaluating model on test set...")
    
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    # Create confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Print metrics
    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test F1 Score: {f1:.4f}")
    
    # Print detailed classification report
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES))
    
    # Plot confusion matrix
    plot_confusion_matrix(cm, CLASS_NAMES)
    
    return accuracy, precision, recall, f1

def plot_confusion_matrix(cm, class_names):
    """
    Plot the confusion matrix for test results
    """
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45, ha='right')
    plt.yticks(tick_marks, class_names)
    
    # Add text annotations
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()

# ===== Step 8: Explainable AI with LIME =====
def explain_predictions(model, test_paths, test_labels, num_samples=5):
    """
    Use LIME to explain model predictions for sample images
    """
    print("Generating LIME explanations for sample predictions...")
    
    # Prepare the LIME explainer
    explainer = lime_image.LimeImageExplainer()
    
    # Function to get preprocessed image as numpy array
    def get_preprocessed_image(img_path):
        preprocessed = preprocess_image(img_path)
        return preprocessed
    
    # Function for model to make predictions on batch
    def batch_predict(images):
        """
        Function that takes a batch of images and returns the predictions
        """
        model.eval()
        batch = torch.stack([torch.from_numpy(img).float().permute(2, 0, 1) for img in images])
        batch = batch.to(device)
        
        # Apply normalization
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        batch = torch.stack([normalize(img) for img in batch])
        
        with torch.no_grad():
            outputs = model(batch)
            probs = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
        return probs
    
    # Select random samples to explain
    num_test_samples = len(test_paths)
    indices = np.random.choice(range(num_test_samples), num_samples, replace=False)
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    
    for i, idx in enumerate(indices):
        # Load image
        img_path = test_paths[idx]
        true_label = test_labels[idx]
        true_class = CLASS_NAMES[true_label]
        
        # Get original image
        orig_img = np.array(Image.open(img_path).convert('RGB'))
        
        # Get preprocessed image for model
        preprocessed_img = get_preprocessed_image(img_path)
        
        # Get model prediction
        tensor_img = torch.from_numpy(preprocessed_img).float().permute(2, 0, 1).unsqueeze(0)  # Add batch dimension
        tensor_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor_img)
        tensor_img = tensor_img.to(device)
        
        with torch.no_grad():
            output = model(tensor_img)
            probabilities = torch.nn.functional.softmax(output, dim=1)[0]
            pred_idx = torch.argmax(probabilities).item()
            pred_class = CLASS_NAMES[pred_idx]
            pred_prob = probabilities[pred_idx].item()
        
        # Get LIME explanation
        explanation = explainer.explain_instance(
            preprocessed_img,
            batch_predict,
            top_labels=1,
            hide_color=0,
            num_samples=100
        )
        
        # Get the explanation for the top predicted class
        temp, mask = explanation.get_image_and_mask(
            explanation.top_labels[0],
            positive_only=True,
            num_features=5,
            hide_rest=False
        )
        
        # Create visualization
        if num_samples == 1:
            # Handle the case with only one sample differently
            axes[0].imshow(orig_img)
            axes[0].set_title(f"Original Image\nTrue: {true_class}")
            axes[0].axis('off')
            
            axes[1].imshow(preprocessed_img)
            axes[1].set_title(f"Model Prediction\n{pred_class} ({pred_prob:.2f})")
            axes[1].axis('off')
            
            axes[2].imshow(mark_boundaries(temp, mask))
            axes[2].set_title("LIME Explanation\n(Highlighted areas influenced the decision)")
            axes[2].axis('off')
        else:
            axes[i, 0].imshow(orig_img)
            axes[i, 0].set_title(f"Original Image\nTrue: {true_class}")
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(preprocessed_img)
            axes[i, 1].set_title(f"Model Prediction\n{pred_class} ({pred_prob:.2f})")
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(mark_boundaries(temp, mask))
            axes[i, 2].set_title("LIME Explanation\n(Highlighted areas influenced the decision)")
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig('lime_explanations.png')
    plt.close()
    
    print(f"LIME explanations generated for {num_samples} sample images.")

# ===== Step 9: Main Function =====
def main():
    """
    Main function to run the full process
    """
    print("Starting the skin lesion classification process...")
    
    # Step 1: Prepare data
    train_loader, val_loader, test_loader, test_paths, test_labels = prepare_data()
    
    # Step 2: Create model
    model = create_model()
    
    # Step 3: Check if a trained model exists, otherwise train the model
    if os.path.exists('best_model.pth'):
        print("Loading pre-trained model...")
        model.load_state_dict(torch.load('best_model.pth', map_location=device))
    else:
        print("Training model from scratch...")
        model = train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCHS)
    
    # Step 4: Evaluate model
    accuracy, precision, recall, f1 = evaluate_model(model, test_loader)
    
    # Step 5: Generate LIME explanations
    explain_predictions(model, test_paths, test_labels, num_samples=5)
    
    print("Process completed successfully!")
    print(f"Final metrics - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, "
          f"Recall: {recall:.4f}, F1 Score: {f1:.4f}")

if __name__ == "__main__":
    main()