In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.applications.densenet import preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D
from sklearn import svm
from sklearn.model_selection import train_test_split
from lime import lime_image
from skimage.segmentation import mark_boundaries
import os
from PIL import Image
from tensorflow.keras.preprocessing import image
from joblib import dump, load  # For saving and loading models

# Define directories and filenames
model_dir = "model_dir"
svm_model_file = os.path.join(model_dir, "svm_model.joblib")
train_data_file = os.path.join(model_dir, "train_data.npy")
test_data_file = os.path.join(model_dir, "test_data.npy")

# Create directory if it doesn't exist
os.makedirs(model_dir, exist_ok=True)

# Load DenseNet121 model
base_model = DenseNet121(weights='imagenet', include_top=False)
model = Model(inputs=base_model.input, outputs=GlobalAveragePooling2D()(base_model.output))

# Load and preprocess images
def load_and_preprocess_img(img_path, target_size=(224, 224)):
    img = image.load_img(img_path, target_size=target_size)
    img = image.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = preprocess_input(img)
    return img


# Extract features using DenseNet121 + Global Average Pooling
def extract_features(img_array):
    features = []
    for img in img_array:
        img = np.expand_dims(img, axis=0)  # Ensure correct shape
        img = preprocess_input(img)
        feature = model.predict(img)
        features.append(feature.flatten())
    return np.array(features)

# Function to save data
def save_data(X_train, X_test, y_train, y_test, img_train_paths, img_test_paths):
    np.save(train_data_file, {'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test, 'img_train_paths': img_train_paths, 'img_test_paths': img_test_paths})

# Function to load data
def load_data():
    data = np.load(train_data_file, allow_pickle=True).item()
    return data['X_train'], data['X_test'], data['y_train'], data['y_test'], data['img_train_paths'], data['img_test_paths']

# Load images
clean_dir = "mix/clean"
dirty_dir = "mix/dirty"

clean_images = [os.path.join(clean_dir, img) for img in os.listdir(clean_dir)]
dirty_images = [os.path.join(dirty_dir, img) for img in os.listdir(dirty_dir)]

# Create labels (0: Clean, 1: Dirty)
clean_labels = np.zeros(len(clean_images))
dirty_labels = np.ones(len(dirty_images))

# Combine images and labels
all_images = clean_images + dirty_images
all_labels = np.concatenate([clean_labels, dirty_labels])

# Check if model and data files exist
if os.path.exists(svm_model_file) and os.path.exists(train_data_file):
    print("Loading existing SVM model and data...")
    svm_classifier = load(svm_model_file)
    X_train, X_test, y_train, y_test, img_train_paths, img_test_paths = load_data()
else:
    print("Image encoding...")

    # Extract features for all images
    features = []
    for img_path in all_images:
        img = load_and_preprocess_img(img_path)
        features.append(model.predict(img).flatten())
    features = np.array(features)
    
    print("Training SVM model...")
    # Split dataset for training and testing
    X_train, X_test, y_train, y_test, img_train_paths, img_test_paths = train_test_split(
        features, all_labels, all_images, test_size=0.3, random_state=42)

    # Train SVM classifier
    svm_classifier = svm.SVC(kernel='linear', probability=True)
    svm_classifier.fit(X_train, y_train)

    # Save the model and data
    dump(svm_classifier, svm_model_file)
    save_data(X_train, X_test, y_train, y_test, img_train_paths, img_test_paths)

# Predict probabilities for the test set
probs = svm_classifier.predict_proba(X_test)

# Get probabilities for both "clean" and "dirty" classes
clean_class_probs = probs[:, 0]
dirty_class_probs = probs[:, 1]

# Select only clean and dirty class images from the test set
clean_test_images = [(img_path, prob) for img_path, prob, label in zip(img_test_paths, clean_class_probs, y_test) if label == 0]
dirty_test_images = [(img_path, prob) for img_path, prob, label in zip(img_test_paths, dirty_class_probs, y_test) if label == 1]

# Sort by the highest probability for clean and dirty classes
clean_test_images = sorted(clean_test_images, key=lambda x: x[1], reverse=True)[:30]
dirty_test_images = sorted(dirty_test_images, key=lambda x: x[1], reverse=True)[:30]

# Function to predict probabilities for LIME
def svm_predict_proba(img_arrays):
    # Extract features from images directly for LIME
    encoded_images = extract_features(img_arrays)
    return svm_classifier.predict_proba(encoded_images)

# Initialize LIME explainer
explainer = lime_image.LimeImageExplainer()


# Function to preprocess images for LIME and DenseNet
def preprocess_for_lime(img_path, target_size=(224, 224)):
    img = Image.open(img_path)
    
    # Ensure the image is in RGB format
    if img.mode != 'RGB':
        img = img.convert('RGB')
    
    img = img.resize(target_size)  # Resize to target size
    img = np.array(img)
    return img

# Function to get the starting index by checking existing LIME images
def get_start_index(save_file_prefix, folder="."):
    if not os.path.exists(folder):
        raise FileNotFoundError(f"The directory {folder} does not exist.")
    
    # List all files in the directory that match the prefix
    existing_files = [f for f in os.listdir(folder) if f.startswith(save_file_prefix) and f.endswith('.png')]
    
    if existing_files:
        # Get the index from the file names (e.g., lime_dirty_image_1.pdf -> 1)
        indices = [int(f.replace(save_file_prefix + "_", "").replace(".png", "")) for f in existing_files]
        max_index = max(indices)
        return max_index + 1  # Start from the next image
    else:
        return 1  # If no files exist, start from 1

# Function to display LIME explanation on the original image with proper alignment
def display_and_save_lime_images(test_images, class_name, save_file_prefix, output_folder="."):
    # Ensure output directory exists
    if not os.path.exists(output_folder):
        print(f"Creating directory: {output_folder}")
        os.makedirs(output_folder)

    # Get the starting index based on existing files
    start_index = get_start_index(save_file_prefix, output_folder)
    
    for i, (img_path, prob) in enumerate(test_images, start=1):
        if i < start_index:
            print(f"Skipping image {i} as it already exists...")
            continue  # Skip images that have already been processed
        
        img_lime = preprocess_for_lime(img_path)

        # Get LIME explanation
        explanation = explainer.explain_instance(img_lime, svm_predict_proba, top_labels=1, hide_color=0, num_samples=2000)

        # Get image and mask for the predicted class
        temp, mask = explanation.get_image_and_mask(label=1 if class_name == "dirty" else 0, positive_only=True, num_features=5, hide_rest=False)

        # Ensure the mask is overlaid on the exact same pixels as the original
        original_image = Image.open(img_path)
        
        # Convert to RGB if needed
        if original_image.mode != 'RGB':
            original_image = original_image.convert('RGB')

        original_image = original_image.resize((224, 224))  # Resize to match the input size for LIME
        
        # Overlay the LIME explanation with the original image
        overlay = mark_boundaries(np.array(original_image), mask, mode='thick', color=(0, 0, 1))  # Use 'thick' mode for better boundaries

        # Plot the original image and LIME explanation side-by-side
        fig, axes = plt.subplots(1, 2, figsize=(15, 7))
        
        # Plot the original image
        axes[0].imshow(original_image)
        axes[0].axis('off')
        axes[0].set_title(f"Original Image {i}")

        # Plot the image with the LIME overlay
        axes[1].imshow(overlay)
        axes[1].axis('off')
        axes[1].set_title(f"LIME Explanation {i}")

        # Save the figure in high-resolution PDF
        save_file_name = os.path.join(output_folder, f"{save_file_prefix}_{i}.png")
        plt.tight_layout()
        fig.savefig(save_file_name, dpi=300)
        plt.close(fig)

# Example usage for dirty images
display_and_save_lime_images(dirty_test_images, class_name="dirty", save_file_prefix="lime_dirty_image", output_folder="output_dir2")

# Example usage for clean images
display_and_save_lime_images(clean_test_images, class_name="clean", save_file_prefix="lime_clean_image", output_folder="output_dir2")




Loading existing SVM model and data...
Skipping image 1 as it already exists...
Skipping image 2 as it already exists...
Skipping image 3 as it already exists...


  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]



  0%|          | 0/2000 [00:00<?, ?it/s]

