# Image Classification using MobileNetV2 and XGBoost

This notebook implements an image classification pipeline using:
1. **Feature Extraction**: MobileNetV2 pre-trained on ImageNet for deep feature extraction
2. **Classification**: XGBoost for efficient multi-class classification
3. **Visualization**: UMAP/t-SNE for embedding space visualization

## Pipeline Overview
1. Load images from dataset directory structure
2. Extract deep features using MobileNetV2 (with caching for efficiency)
3. Train XGBoost classifier on extracted features
4. Evaluate model performance with confusion matrix
5. Visualize feature embeddings in 3D space (optional)

## Import Dependencies

In [None]:
# Standard library imports
import os
import pickle as pkl
from pathlib import Path
import random

# Scientific computing
import numpy as np
from PIL import Image

# Machine learning
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, 
    ConfusionMatrixDisplay,
    classification_report,
    precision_score,
    recall_score,
    f1_score
)
from sklearn.manifold import TSNE

# Deep learning
import torch
from torchvision import models
from torchvision.transforms import v2 as transforms

# XGBoost
import xgboost as xgb

# Utilities
from tqdm import tqdm
from matplotlib import pyplot as plt
import umap

## Helper Functions

In [None]:
def load_dataset(dataset_path, max_per_class=None):
    """
    Load images and labels from a directory structure where each subdirectory represents a class.
    
    Args:
        dataset_path (str): Path to dataset root directory containing class subdirectories
        max_per_class (int, optional): Maximum number of images to load per class. 
                                       If None, loads all images. Defaults to None.
    
    Returns:
        tuple: (images, labels) where images is a list of PIL Images and labels is a list of class names
    """
    images, labels = [], []
    
    for class_name in os.listdir(dataset_path):
        class_path = os.path.join(dataset_path, class_name)
        
        # Skip if not a directory
        if not Path(class_path).is_dir():
            continue
        
        fnames = list(os.listdir(class_path))
        
        # Shuffle filenames for random sampling (if limiting samples per class)
        if max_per_class and max_per_class < 100000:
            random.Random(0).shuffle(fnames)
            
        for idx, image_name in tqdm(enumerate(fnames, start=1), desc=f"Loading {class_name}"):
            # Skip hidden files
            if image_name.startswith("."):
                continue
                
            # Stop if reached max samples for this class
            if max_per_class is not None and idx > max_per_class:
                break
                    
            image_path = os.path.join(class_path, image_name)
            image = Image.open(image_path).convert("RGB")
            
            images.append(image)
            labels.append(class_name)

    return images, labels


def extract_features(image, is_batch=False):
    """
    Extract deep features from images using MobileNetV2 pre-trained on ImageNet.
    
    Args:
        image (torch.Tensor): Input image tensor(s)
        is_batch (bool): Whether input is a batch of images. Defaults to False.
    
    Returns:
        numpy.ndarray: Extracted feature vectors (1280-dimensional for MobileNetV2)
    """
    # Add batch dimension if single image
    if is_batch:
        batch_img_tensor = image
    else:
        batch_img_tensor = image.unsqueeze(0)

    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_img_tensor = batch_img_tensor.to(device)

    # Load pre-trained MobileNetV2 and set to evaluation mode
    model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.DEFAULT).to(device)
    model.eval()

    # Remove classification layer to extract features from penultimate layer
    model.classifier[1] = torch.nn.Identity()

    # Extract features without gradient computation
    with torch.no_grad():
        features = model(batch_img_tensor)

    # Flatten and convert to numpy
    if is_batch:
        features_flattened = torch.flatten(features, start_dim=1).detach().cpu().numpy()
    else:
        features_flattened = torch.flatten(features, start_dim=1).detach().cpu().numpy()[0]

    return features_flattened


def get_features_batched(images, adjust_size=True):
    """
    Extract features from a batch of images efficiently.
    
    Args:
        images (list): List of PIL Images
        adjust_size (bool): Whether to apply preprocessing transforms. Defaults to True.
    
    Returns:
        numpy.ndarray: Feature matrix of shape (n_images, 1280)
    """
    # Define standard ImageNet preprocessing
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Apply preprocessing if needed
    if adjust_size:
        images = [preprocess(img) for img in images]

    # Stack into batch and extract features
    batch = torch.stack(images)
    features = extract_features(batch, is_batch=True)

    return features


class XGBWrapper:
    """
    Wrapper class for XGBoost to handle multi-class classification with label encoding.
    """
    
    def __init__(self, params, epochs):
        """
        Initialize XGBoost wrapper.
        
        Args:
            params (dict): XGBoost parameters (max_depth, eta, objective, etc.)
            epochs (int): Number of boosting rounds
        """
        self.params = params
        self.epochs = epochs
        self.label_encoder = LabelEncoder()
        self.model = None

    def fit(self, X, y):
        """
        Train XGBoost model on feature matrix X with labels y.
        
        Args:
            X (numpy.ndarray): Feature matrix
            y (array-like): Class labels (strings)
        """
        # Encode string labels to integers
        labels = self.label_encoder.fit_transform(y)
        self.params["num_class"] = len(self.label_encoder.classes_)
        dtrain = xgb.DMatrix(X, label=labels)

        # Train the model
        self.model = xgb.train(self.params, dtrain, self.epochs)

    def predict(self, X):
        """
        Predict class labels for feature matrix X.
        
        Args:
            X (numpy.ndarray): Feature matrix
            
        Returns:
            numpy.ndarray: Predicted class labels (strings)
        """
        dtest = xgb.DMatrix(X)
        y_pred = self.model.predict(dtest)
        y_pred = y_pred.argmax(axis=1)
        y_pred = self.label_encoder.inverse_transform(y_pred)
        return y_pred

## Dataset Loading and Feature Extraction

This section loads the dataset and extracts features using MobileNetV2. Features are cached to disk to avoid recomputing them in subsequent runs.

In [None]:
dataset_path = "/Users/klavs/Desktop/code-refractored/3"  # Path to dataset root directory
max_per_class = 100000  # Maximum samples per class (higher = better if data quality is good)
batch_size = 500  # Batch size for feature extraction (adjust based on GPU memory)
dimred = True  # Whether to perform dimensionality reduction visualization

print("Loading dataset...")
images, labels = load_dataset(dataset_path, max_per_class=max_per_class)
print(f"Loaded {len(images)} images across {len(set(labels))} classes")

# Feature caching saves computation time on subsequent runs
feat_cache_fname = f"{dataset_path}/feat_cache_{max_per_class}_mobilenetv2.pkl"

if Path(feat_cache_fname).exists():
    # Load pre-computed features from cache
    print(f"Loading features from cache: {feat_cache_fname}")
    
    with open(feat_cache_fname, "rb") as fp:
        X = pkl.loads(fp.read())
    
    # Validate that cache matches current dataset
    if len(images) != len(X):
        print(f"WARNING: Cache size mismatch! Images: {len(images)}, Features: {len(X)}")
        print("Re-extracting features...")
        Path(feat_cache_fname).unlink()  # Delete invalid cache
        
        # Re-extract features (will be handled below)
        feats = []
        for batch_off in tqdm(range(0, len(images), batch_size), desc="Extracting features"):
            this_batch = images[batch_off : batch_off + batch_size]
            this_batch_feats = get_features_batched(this_batch)
            feats.append(this_batch_feats)
        X = np.concatenate(feats, axis=0)
        
        # Save to cache
        with open(feat_cache_fname, "wb") as fp:
            fp.write(pkl.dumps(X))
        print(f"Features saved to cache: {feat_cache_fname}")
    else:
        print(f"Successfully loaded {X.shape[0]} feature vectors of dimension {X.shape[1]}")

else:
    # Extract features in batches (first time or after cache deletion)
    print("No cache found. Extracting features...")
    feats = []
    for batch_off in tqdm(range(0, len(images), batch_size), desc="Extracting features"):
        this_batch = images[batch_off : batch_off + batch_size]
        this_batch_feats = get_features_batched(this_batch)
        feats.append(this_batch_feats)
    X = np.concatenate(feats, axis=0)
    
    # Save features to cache for future use
    with open(feat_cache_fname, "wb") as fp:
        fp.write(pkl.dumps(X))
    print(f"Features saved to cache: {feat_cache_fname}")
    print(f"Extracted {X.shape[0]} feature vectors of dimension {X.shape[1]}")

## Model Training and Evaluation

Train an XGBoost classifier on the extracted features and evaluate its performance.

In [None]:
print("Splitting dataset into train/test sets...")
X_train, X_test, y_train, y_test = train_test_split(
    X, labels, test_size=0.05, random_state=44
)
print(f"Training set: {len(X_train)} samples")
print(f"Test set: {len(X_test)} samples")
print("\nTraining XGBoost classifier...")

# XGBoost hyperparameters
params = {
    "max_depth": 3,          # Maximum tree depth (controls model complexity)
    "eta": 0.3,              # Learning rate (step size shrinkage)
    "objective": "multi:softprob"  # Multi-class classification with probability output
}

# Train model
clf = XGBWrapper(params, epochs=100)
clf.fit(X_train, y_train)

# Save trained model to disk
model_save_path = f"{dataset_path}/xgb_trained_mobilenetv2_max{max_per_class}.pkl"
with open(model_save_path, "wb") as fp:
    fp.write(pkl.dumps(clf))
print(f"Model saved to: {model_save_path}")
print("\nEvaluating model on test set...")
y_pred = clf.predict(X_test)

# Get unique classes for metric calculations
classes = sorted(list(set(y_test)))
n_classes = len(classes)

# Overall accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"\nAccuracy: {accuracy:.4f} ({int(accuracy * len(y_test))} / {len(y_test)} correct)")

# Detailed classification report
print(classification_report(y_test, y_pred, target_names=classes))

## Feature Space Visualization (Optional)

Visualize the high-dimensional feature embeddings in 3D using dimensionality reduction techniques (t-SNE and UMAP).

In [None]:
if dimred:
    print("Performing dimensionality reduction for visualization...")
    unique_categories = list(set(labels))
    
    # Define colors for each class (extend this list if you have more than 3 classes)
    colors = ["#aa0000", "#00aa00", "#0000aa", "#aaaa00", "#aa00aa", "#00aaaa"]
    color_map = dict(zip(unique_categories, colors[:len(unique_categories)]))
    point_colors = [color_map[cat] for cat in labels]
    
    print(f"Class colors: {color_map}")
    
    print("Running t-SNE...")
    tsne = TSNE(n_components=3, random_state=42, verbose=1)
    X_tsne = tsne.fit_transform(X)
    
    # Create 3D scatter plot
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection="3d")
    scatter = ax.scatter(X_tsne[:, 0], X_tsne[:, 1], X_tsne[:, 2], c=point_colors, alpha=0.6)
    ax.set_title("t-SNE Projection of Feature Space")
    ax.set_xlabel("t-SNE Component 1")
    ax.set_ylabel("t-SNE Component 2")
    ax.set_zlabel("t-SNE Component 3")
    
    # Create legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color_map[cat], label=cat) for cat in unique_categories]
    ax.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    plt.savefig("tsne_visualization.png", dpi=150)
    plt.show()
    print("t-SNE visualization saved as 'tsne_visualization.png'")
    print("Running UMAP...")
    um = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=3, random_state=42, verbose=True)
    X_umap = um.fit_transform(X)
    
    # Create 3D scatter plot
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection="3d")
    scatter = ax.scatter(X_umap[:, 0], X_umap[:, 1], X_umap[:, 2], c=point_colors, alpha=0.6)
    ax.set_title("UMAP Projection of Feature Space")
    ax.set_xlabel("UMAP Component 1")
    ax.set_ylabel("UMAP Component 2")
    ax.set_zlabel("UMAP Component 3")
    
    # Create legend
    ax.legend(handles=legend_elements, loc='upper right')
    
    plt.tight_layout()
    plt.savefig("umap_visualization.png", dpi=150)
    plt.show()
    print("UMAP visualization saved as 'umap_visualization.png'")
    
    print("\nVisualization complete!")
else:
    print("Dimensionality reduction visualization skipped (dimred=False)")