<a href="https://colab.research.google.com/github/j-cutrone/649-Group-8-Project/blob/main/649_group8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CNN vs ViT Comparison for the Morphological Mosquito Classification Model

### Description:
This code file implements a comparative deep learning pipeline to classify mosquito species using five distinct architectures: ResNet50, ConvNeXt Tiny, MobileNetV3, and Vision Transformers (ViT Base and Large). The workflow utilizes the timm library for transfer learning, applying white-balance correction and normalization to images before fine-tuning the models on a split of 85% training and 15% validation data. Following training, the code performs a comprehensive evaluation using confusion matrices, multi-class ROC curves, and classification reports to assess accuracy across 12 classes. Furthermore, it includes explainability modules that generate occlusion sensitivity maps to visualize model focus and t-SNE plots to analyze how different architectures cluster image features.

The dataset is aggregated from multiple sources stored in Google Drive, labeled as "VectorCam_images," covering specific species such as Anopheles (gambiae, funestus, stephensi, darlingi, nuneztovari, albimanus, coustani), Aedes (aegypti, albopictus), Culex, Mansonia, and non-mosquitoes. Specific data subsets identified include a general "all_labels.csv" collection, Colombian datasets for An. darlingi, nuneztovari, and albimanus, An. stephensi images from the University of Notre Dame, and An. coustani samples from Uganda. The script consolidates these varying CSVs and directory structures into a master dataframe, filtering out unused labels and organizing the files into a local directory structure for efficient loading.

In [None]:
# --- 1. Load in Packages ---
!pip install -q timm

import os
import shutil
import glob

import pandas as pd
from sklearn.model_selection import train_test_split
from torchvision import transforms, datasets
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, Subset
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torchvision import models
from tqdm.notebook import tqdm
from google.colab import drive
import copy

import numpy as np
import time
from pathlib import Path
import timm

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_curve, auc
import matplotlib.pyplot as plt
import itertools

import torch.nn.functional as F
import math
from scipy.ndimage import zoom

import ipywidgets as widgets
from IPython.display import display, clear_output
import io
import cv2
import gc

import seaborn as sns
from sklearn.manifold import TSNE
from torchvision import models, transforms, datasets

In [None]:
# --- 2. Mount Google Drive ---
BASE_DRIVE_PATH = '/content/gdrive/MyDrive/649_group8/FinalProject/VectorCam_images'

if not os.path.exists('/content/gdrive'):
    drive.mount('/content/gdrive')
    print("Drive mounted successfully.")
else:
    print("Drive already mounted.")
print(f"Base Drive Path set to: {BASE_DRIVE_PATH}")
if not os.path.exists(BASE_DRIVE_PATH):
    print(f"WARNING: BASE_DRIVE_PATH not found at {BASE_DRIVE_PATH}. Please ensure your Google Drive structure is correct.")
    if os.path.exists('/content/gdrive/MyDrive/649_group8/FinalProject/VectorCam_images'):
        print(f"Contents of /content/gdrive/MyDrive/649_group8/FinalProject/VectorCam_images: {os.listdir('/content/gdrive/MyDrive/649_group8/FinalProject/VectorCam_images')}")
    else:
        print(f"/content/gdrive/MyDrive/649_group8/FinalProject/VectorCam_images not found.")

In [None]:
# --- 3. Image Preprocessing Function: White Balance Function ---
def white_balance(img):
    img_array = np.asarray(img).astype(float)

    mean_r = np.mean(img_array[:, :, 0])
    mean_g = np.mean(img_array[:, :, 1])
    mean_b = np.mean(img_array[:, :, 2])

    avg_gray = (mean_r + mean_g + mean_b) / 3

    scale_r = avg_gray / mean_r if mean_r > 0 else 1
    scale_g = avg_gray / mean_g if mean_g > 0 else 1
    scale_b = avg_gray / mean_b if mean_b > 0 else 1

    img_array[:, :, 0] *= scale_r
    img_array[:, :, 2] *= scale_b
    img_array[:, :, 1] *= scale_g

    img_array = np.clip(img_array, 0, 255).astype(np.uint8)
    return Image.fromarray(img_array)

In [None]:
# --- 4. Construct Master DataFrame ---
RAW_IMAGES_PATH = os.path.join(BASE_DRIVE_PATH, 'raw_images')
all_data = []
master_df = pd.DataFrame()

if not os.path.exists(RAW_IMAGES_PATH):
    print(f"CRITICAL ERROR: RAW_IMAGES_PATH not found at {RAW_IMAGES_PATH}. Please verify your Google Drive setup.")
    print("Proceeding with an empty master_df, which will likely cause subsequent steps to fail or use default values.")
else:

    # A. Original Data
    csv_orig = os.path.join(RAW_IMAGES_PATH, 'all_labels.csv')
    if os.path.exists(csv_orig):
        df = pd.read_csv(csv_orig)
        df['species'] = pd.to_numeric(df['species'], errors='coerce')
        df = df[~df['species'].isin([2, 4])].copy()
        mapping = {0: 0, 1: 1, 3: 9, 5: 10, 6: 11}
        df['species_label'] = df['species'].map(mapping)
        df = df.dropna(subset=['species_label'])
        for _, row in df.iterrows():
            path = os.path.join(RAW_IMAGES_PATH, 'images_cropped_padded_original', str(row['Image_name']))
            all_data.append({'path': path, 'label': int(row['species_label']), 'wb': False, 'src': 'orig'})
    else:
        print(f"Warning: Original CSV not found at {csv_orig}")

    # B. Colombia CSV
    csv_col = os.path.join(RAW_IMAGES_PATH, 'Colombia', 'An_Darlingi_An_Nuneztovari.csv')
    dir_col = os.path.join(RAW_IMAGES_PATH, 'Colombia', 'Anopheles_Darlingi_Anopheles_Nuneztovari')
    if os.path.exists(csv_col):
        df = pd.read_csv(csv_col)
        for _, row in df.iterrows():
            path = os.path.join(dir_col, str(row['Image_name']))
            sp = str(row['species_name']).lower()
            if 'darlingi' in sp: all_data.append({'path': path, 'label': 3, 'wb': True, 'src': 'col_csv'})
            elif 'nuneztovari' in sp: all_data.append({'path': path, 'label': 4, 'wb': True, 'src': 'col_csv'})
    else:
        print(f"Warning: Colombia CSV not found at {csv_col}")

    # C. Colombia Folder (Albimanus)
    dir_alb = os.path.join(RAW_IMAGES_PATH, 'Colombia', 'Anopheles_Albimanus')
    if os.path.exists(dir_alb):
        imgs = glob.glob(os.path.join(dir_alb, '*.*'))
        imgs = [f for f in imgs if f.lower().endswith(('.jpg','.png','.jpeg'))]
        for f in imgs:
            all_data.append({'path': f, 'label': 2, 'wb': True, 'src': 'col_alb'})
    else:
        print(f"Warning: Albimanus directory not found at {dir_alb}")

    # D. Stephensi CSV
    csv_step = os.path.join(RAW_IMAGES_PATH, 'Stephensi', 'NotreDame_stephensi_VectorCAM_ids.csv')
    dir_step = os.path.join(RAW_IMAGES_PATH, 'Stephensi', 'VectorCam_UND_cropped_padded')
    if os.path.exists(csv_step):
        df = pd.read_csv(csv_step)
        img_col_name = next((c for c in df.columns if 'image' in c.lower()), None)
        sp_col_name = next((c for c in df.columns if 'species' in c.lower() or 'morph' in c.lower()), None)

        if img_col_name and sp_col_name and img_col_name in df.columns and sp_col_name in df.columns:
            for _, row in df.iterrows():
                path = os.path.join(dir_step, str(row[img_col_name]))
                sp = str(row[sp_col_name]).lower()
                if 'stephensi' in sp: all_data.append({'path': path, 'label': 5, 'wb': False, 'src': 'step'})
                elif 'gambiae' in sp: all_data.append({'path': path, 'label': 1, 'wb': False, 'src': 'step'})
        else:
            print(f"Warning: Stephensi CSV ({csv_step}) missing expected image ({img_col_name}) or species ({sp_col_name}) column. Skipping.")
    else:
        print(f"Warning: Stephensi CSV not found at {csv_step}")

    # E. Aedes Folders
    for sub, lbl in [('colombia_2025_crop_pad_Aedes Aegypti', 7), ('colombia_2025_crop_pad_Aedes Albopictus', 8)]:
        d = os.path.join(RAW_IMAGES_PATH, 'Aedes', sub)
        if os.path.exists(d):
            imgs = glob.glob(os.path.join(d, '*.*'))
            imgs = [f for f in imgs if f.lower().endswith(('.jpg','.png','.jpeg'))]
            for f in imgs:
                all_data.append({'path': f, 'label': lbl, 'wb': True, 'src': 'aedes'})
        else:
            print(f"Warning: Aedes directory not found at {d}")

    # F. Uganda Folder
    dir_ug = os.path.join(RAW_IMAGES_PATH, 'Uganda', 'Coustani_images')
    if os.path.exists(dir_ug):
        imgs = glob.glob(os.path.join(dir_ug, '*.*'))
        imgs = [f for f in imgs if f.lower().endswith(('.jpg','.png','.jpeg'))]
        for f in imgs:
            all_data.append({'path': f, 'label': 6, 'wb': False, 'src': 'ug'})
    else:
        print(f"Warning: Uganda directory not found at {dir_ug}")

    master_df = pd.DataFrame(all_data)

print("\nMaster DataFrame Counts:")
if not master_df.empty:
    print(master_df['label'].value_counts().sort_index())
    print(f"Total images in master_df: {len(master_df)}")
else:
    print("Master DataFrame is empty. Check data paths and loading logic.")

In [None]:
# --- 5. Process and Organize Images ---
LOCAL_DIR = 'organized_images_local'
if os.path.exists(LOCAL_DIR): shutil.rmtree(LOCAL_DIR)
os.makedirs(LOCAL_DIR, exist_ok=True)

print("Processing images to local directory...")
processed_count = 0
errors = []

if not master_df.empty:
    for _, row in tqdm(master_df.iterrows(), total=len(master_df)):
        try:
            if not os.path.exists(row['path']):
                errors.append(f"Source file not found: {row['path']}")
                continue

            dest_dir = os.path.join(LOCAL_DIR, str(row['label']))
            os.makedirs(dest_dir, exist_ok=True)
            fname = os.path.basename(row['path'])
            dest_path = os.path.join(dest_dir, fname)

            if row['wb']:
                with Image.open(row['path']) as img:
                    img = img.convert('RGB')
                    img_wb = white_balance(img)
                    img_wb.save(dest_path)
            else:
                shutil.copy2(row['path'], dest_path)
            processed_count += 1
        except Exception as e:
            errors.append(f"Error processing {row['path']}: {e}")

print(f"Processed {processed_count} images. Errors: {len(errors)}")
if errors:
    print("Some errors occurred during image processing (first 5):")
    for err in errors[:5]:
        print(f"- {err}")

In [None]:
# --- 6. Data Splitting ---
TARGET_ROOT = 'split_dataset'
if os.path.exists(TARGET_ROOT): shutil.rmtree(TARGET_ROOT)
os.makedirs(TARGET_ROOT, exist_ok=True)

train_df = pd.DataFrame()
val_df = pd.DataFrame()

if not master_df.empty:
    # New split: 85% train, 15% validation
    train_df, val_df = train_test_split(master_df, test_size=0.15, stratify=master_df['label'], random_state=42)
    print(f"Split sizes: Train={len(train_df)}, Val={len(val_df)}")
else:
    print("Skipping data splitting: master_df is empty.")


def copy_split(df_to_copy, split_name):
    if df_to_copy.empty:
        print(f"Skipping copying for {split_name} set: DataFrame is empty.")
        return

    split_dir = os.path.join(TARGET_ROOT, split_name)
    os.makedirs(split_dir, exist_ok=True)
    for _, row in tqdm(df_to_copy.iterrows(), desc=f"Populating {split_name}"):
        lbl = str(row['label'])
        fname = os.path.basename(row['path'])
        src = os.path.join(LOCAL_DIR, lbl, fname)
        dst = os.path.join(split_dir, lbl, fname)
        os.makedirs(os.path.dirname(dst), exist_ok=True)
        if os.path.exists(src):
            shutil.copy2(src, dst)
        else:
            print(f"Warning: Source file not found for copying to {split_name}: {src}")

copy_split(train_df, 'train')
copy_split(val_df, 'val')
print("Data splitting complete.")

In [None]:
# --- 6. Calculate Normalization Statistics ---
TRAIN_DIR = 'split_dataset/train'

calc_stats_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()])

if os.path.exists(TRAIN_DIR):
    train_dataset_stats = datasets.ImageFolder(root=TRAIN_DIR, transform=calc_stats_transform)
    loader_stats = DataLoader(
        train_dataset_stats,
        batch_size=64,
        shuffle=False,
        num_workers=2)

    print(f"Calculating mean and std for {len(train_dataset_stats)} images...")

    channels_sum, channels_squared_sum, num_batches = 0, 0, 0

    for data, _ in tqdm(loader_stats, desc="Calculating Stats"):
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
        num_batches += 1

    # Final Mean: (Sum of Means) / (Number of Batches)
    mean = channels_sum / num_batches

    # Final Standard Deviation: Calculated using the formula: sqrt( Mean(x^2) - (Mean(x))^2 )
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    MEAN_NORM_FACTOR = [round(m.item(), 4) for m in mean]
    STD_NORM_FACTOR = [round(s.item(), 4) for s in std]

    print(f"Calculated Mean: {MEAN_NORM_FACTOR}")
    print(f"Calculated Std:  {STD_NORM_FACTOR}")

else:
    print(f"WARNING: Training directory not found at {TRAIN_DIR}. Cannot calculate normalization factors. Setting to ImageNet defaults.")
    MEAN_NORM_FACTOR = [0.485, 0.456, 0.406] # ImageNet defaults
    STD_NORM_FACTOR = [0.229, 0.224, 0.225] # ImageNet defaults

In [None]:
# --- 7. Configuration and Hyperparameters ---
DRIVE_ROOT = "/content/gdrive/MyDrive/"
PROJECT_DIR = os.path.join(DRIVE_ROOT, "649_group8")
CHECKPOINT_DIR = os.path.join(PROJECT_DIR, "checkpoints")

if 'NUM_CLASSES' not in locals() or NUM_CLASSES == 0:
    print("WARNING: NUM_CLASSES not set from DataLoader or is 0. Setting to default 12 for model config.")
    NUM_CLASSES = 12
elif NUM_CLASSES != 12:
    print(f"WARNING: NUM_CLASSES is {NUM_CLASSES}, but expected 12. Proceeding with {NUM_CLASSES}.")

LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4
NUM_EPOCHS = 100
PATIENCE = 7

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)

print(f"Using device: {DEVICE}")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")
print(f"Confirmed NUM_CLASSES for model config: {NUM_CLASSES}")
print("-" * 50)

In [None]:
# --- 8. Model Definition and Training Functions ---

# --- A. Model Initialization Function ---
def create_model(model_name, num_classes):
    model = timm.create_model(model_name, pretrained=True, num_classes=0)

    for param in model.parameters():
        param.requires_grad = False

    if 'resnet' in model_name:
        print(f"Unfreezing layer4 for {model_name}...")
        # Unfreeze the last residual block (layer4)
        if hasattr(model, 'layer4'):
            for param in model.layer4.parameters():
                param.requires_grad = True
    elif 'convnext' in model_name:
         print(f"Unfreezing last stage for {model_name}...")
         if hasattr(model, 'stages'):
             for param in model.stages[3].parameters():
                 param.requires_grad = True
    elif 'mobilenet' in model_name:
        print(f"MobilenetV3: no specific layers unfrozen apart from head.")

    n_features = model.num_features

    if model_name == 'mobilenetv3_large_100':
        print(f"  (Original model.num_features for {model_name}: {n_features}. Forcing to 1280 for classifier compatibility.)")
        n_features = 1280

    classifier_name = model.default_cfg['classifier']

    new_head = nn.Linear(n_features, num_classes)
    for param in new_head.parameters():
        param.requires_grad = True
    if '.' in classifier_name:
        module_path, attr_name = classifier_name.rsplit('.', 1)
        parent_module = model.get_submodule(module_path)
        setattr(parent_module, attr_name, new_head)
    else:
        setattr(model, classifier_name, new_head)

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    return model, trainable_params

# --- B. Training Loop Function ---
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct_preds = 0

    for inputs, labels in tqdm(loader, desc="Training", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct_preds += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct_preds.double() / len(loader.dataset)
    return epoch_loss, epoch_acc.item()

# --- C. Validation Loop Function ---
def validate_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_preds = 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Validation", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct_preds += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct_preds.double() / len(loader.dataset)
    return epoch_loss, epoch_acc.item()

# --- D. Experiment Runner Function ---
def run_experiment(model_name, config, train_loader, val_loader, checkpoint_dir):
    print(f"\n--- Starting Experiment: {model_name} ---")

    device = config['device']
    num_classes = config['num_classes']

    # Initialize model
    model, trainable_params = create_model(model_name, num_classes)
    model = model.to(device)

    # Log parameter count
    total_params = sum(p.numel() for p in model.parameters())
    trainable_count = sum(p.numel() for p in trainable_params)
    print(f"Total Parameters: {total_params:,} | Trainable: {trainable_count:,}")

    # Setup Optimization
    criterion = config['criterion']()
    optimizer = config['optimizer'](trainable_params, lr=config['lr'], weight_decay=config['weight_decay'])
    scheduler = config['scheduler'](optimizer, step_size=config['scheduler_step'], gamma=config['scheduler_gamma'])

    # Tracking
    best_val_loss = float('inf')
    best_val_acc = 0.0 # Track best validation accuracy
    epochs_no_improve = 0
    best_model_state = None
    checkpoint_path = os.path.join(checkpoint_dir, f"{model_name}_best_checkpoint.pth")

    start_time = time.time()

    for epoch in range(config['num_epochs']):
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")

        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device)

        scheduler.step()

        print(f"  Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}")
        print(f"  Valid Loss: {val_loss:.4f} | Acc: {val_acc:.4f}")

        # Early Stopping & Checkpointing
        if val_loss < best_val_loss:
            print(f"  Validation loss improved ({best_val_loss:.4f} -> {val_loss:.4f}). Saving model...")
            best_val_loss = val_loss
            best_val_acc = val_acc # Update best accuracy
            epochs_no_improve = 0
            torch.save(model.state_dict(), checkpoint_path)
            best_model_state = copy.deepcopy(model.state_dict())
        else:
            epochs_no_improve += 1
            print(f"  No improvement. Patience: {epochs_no_improve}/{config['patience']}")

        if epochs_no_improve >= config['patience']:
            print("Early stopping triggered.")
            break

    total_time = time.time() - start_time
    print(f"\n--- Experiment Finished: {model_name} ---")
    print(f"Total Time: {total_time:.2f}s | Best Val Loss: {best_val_loss:.4f} | Best Val Acc: {best_val_acc:.4f}")

    if best_model_state:
        model.load_state_dict(best_model_state)

    return model, best_val_loss, best_val_acc

In [None]:
# --- 9. Execute Training Experiments ---

# --- A. Setup Checkpoint Directory ---
CHECKPOINT_DIR = '/content/gdrive/MyDrive/649_group8/checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# --- B. Define Models and Configuration ---
MODELS_TO_RUN = [
    'resnet50',
    'convnext_tiny',
    'vit_base_patch16_224',
    'vit_large_patch16_224',
    'mobilenetv3_large_100']

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

experiment_config = {
    'num_classes': NUM_CLASSES, # Sourced from DataLoader configuration
    'device': DEVICE,
    'criterion': nn.CrossEntropyLoss,
    'optimizer': optim.AdamW,
    'lr': 1e-4,
    'weight_decay': 1e-4,
    'scheduler': StepLR,
    'scheduler_step': 7,
    'scheduler_gamma': 0.1,
    'num_epochs': 100,
    'patience': 7,}

print(f"Starting training on device: {DEVICE}")
print(f"Checkpoints will be saved to: {CHECKPOINT_DIR}")

# --- C. Run Experiments ---
results = {}

# Ensure train_loader and val_loader are not None before running experiments
if train_loader is None or val_loader is None:
    print("CRITICAL ERROR: train_loader or val_loader is None. Cannot run experiments. Please check previous data loading steps.")

else:
    for model_name in MODELS_TO_RUN:
        model, best_loss, best_acc = run_experiment(
            model_name=model_name,
            config=experiment_config,
            train_loader=train_loader,
            val_loader=val_loader,
            checkpoint_dir=CHECKPOINT_DIR
        )
        results[model_name] = {'best_val_loss': best_loss, 'best_val_acc': best_acc}

    # --- D. Final Summary ---
    print("\n\n--- All Experiments Complete ---")
    print("Final Results (Best Validation Loss and Accuracy):")
    for model_name, metrics in results.items():
        print(f"  - {model_name}: Loss={metrics['best_val_loss']:.4f}, Acc={metrics['best_val_acc']:.4f}")

In [None]:
# --- 10. Final Class Definition ---
NUM_CLASSES = 12

CLASS_LABELS = {
    0: "Anopheles funestus",      # Label 0: 821 counts
    1: "Anopheles gambiae",       # Label 1: 1071 counts
    2: "Anopheles albimanus",     # Label 2: 257 counts
    3: "Anopheles darlingi",      # Label 3: 393 counts
    4: "Anopheles nuneztovari",   # Label 4: 648 counts
    5: "Anopheles stephensi",     # Label 5: 418 counts
    6: "Anopheles coustani",      # Label 6: 111 counts
    7: "Aedes aegypti",           # Label 7: 310 counts
    8: "Aedes albopictus",        # Label 8: 158 counts
    9: "Culex",                   # Label 9: 576 counts
    10: "Mansonia",               # Label 10: 583 counts
    11: "Non-mosquito"            # Label 11: 341 counts
}
SPECIES_NAMES = list(CLASS_LABELS.values())

print(f"Analysis set up for {NUM_CLASSES} classes with defined species names.")

In [None]:
# --- 11. Define F(x): ROC Curves, Precision Curves, Confusion Matrices ---

def plot_confusion_matrix(cm, classes, model_name, normalize=False, title='Confusion Matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        title = 'Normalized Confusion Matrix'

    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(f'{title}\n({model_name})')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

def plot_multiclass_roc(eval_results, class_names):
    for model_name, data in eval_results.items():
        true_labels = data['labels']
        pred_probs = data['probs']

        # Binarize the true labels for multi-class ROC calculation
        from sklearn.preprocessing import label_binarize
        y_test_binarized = label_binarize(true_labels, classes=np.arange(len(class_names)))

        plt.figure(figsize=(10, 8))

        # Calculate Micro-Average ROC
        fpr_micro, tpr_micro, _ = roc_curve(y_test_binarized.ravel(), pred_probs.ravel())
        roc_auc_micro = auc(fpr_micro, tpr_micro)
        plt.plot(fpr_micro, tpr_micro,
                 label=f'Micro-Avg ROC (AUC = {roc_auc_micro:0.2f})',
                 color='deeppink', linestyle=':', linewidth=4)

        # Calculate ROC for each class
        for i in range(len(class_names)):
            fpr, tpr, _ = roc_curve(y_test_binarized[:, i], pred_probs[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr,
                     label=f'Class {class_names[i]} (AUC = {roc_auc:0.2f})')

        plt.plot([0, 1], [0, 1], 'k--', linewidth=2)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curve Analysis: {model_name}')
        plt.legend(loc="lower right")
        plt.show()

In [None]:
# --- 12. Performance Reporting ---

print("\n--- Generating Analysis and Visualizations ---")

CLASS_NAMES = SPECIES_NAMES
EVAL_RESULTS = {}

print("Evaluating models on the validation set...")
for model_name in MODELS_TO_RUN:
    print(f"\nLoading and evaluating {model_name}...")

    model, _ = create_model(model_name, NUM_CLASSES)
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{model_name}_best_checkpoint.pth")

    if not os.path.exists(checkpoint_path):
        print(f"  -> Checkpoint not found for {model_name}. Skipping evaluation.")
        continue

    model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()

    all_labels = []
    all_probs = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc=f"Evaluating {model_name}", leave=False):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    EVAL_RESULTS[model_name] = {
        'labels': np.array(all_labels),
        'probs': np.array(all_probs)}

    del model
    torch.cuda.empty_cache()
    gc.collect()

print("Evaluation complete. Generating reports...")

for model_name, data in EVAL_RESULTS.items():
    true_labels = data['labels']
    pred_probs = data['probs']
    pred_labels = np.argmax(pred_probs, axis=1)

    print(f"\n###########################################")
    print(f"## {model_name.upper()} CLASSIFICATION REPORT ##")
    print(f"###########################################")

    report = classification_report(true_labels, pred_labels, target_names=CLASS_NAMES, zero_division=0)
    print(report)

    accuracy = accuracy_score(true_labels, pred_labels)
    print(f"Overall Accuracy: {accuracy:.4f}\n")

    cm = confusion_matrix(true_labels, pred_labels)
    plot_confusion_matrix(cm, classes=CLASS_NAMES, model_name=model_name, title='Confusion Matrix (Counts)')

    plot_confusion_matrix(cm, classes=CLASS_NAMES, model_name=model_name, normalize=True, title='Normalized Confusion Matrix')

plot_multiclass_roc(EVAL_RESULTS, CLASS_NAMES)

print("\n--- All evaluation plots generated. ---")

In [None]:
# --- 13. Model Explainability via Occlusion Mapping ---

# --- I. Define Global Constants ---

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 12
CHECKPOINT_DIR = '/content/gdrive/MyDrive/649_group8/checkpoints'

SPECIES_NAMES = [
    'Anopheles funestus', 'Anopheles gambiae', 'Anopheles albimanus',
    'Anopheles darlingi', 'Anopheles nuneztovari', 'Anopheles stephensi',
    'Anopheles coustani', 'Aedes aegypti', 'Aedes albopictus',
    'Culex', 'Mansonia', 'Non-mosquito']

# --- Define Single Input Image ---
ANALYSIS_IMAGE = {
    'path': '/content/organized_images_local/1/01_UNY40_20240702090942.jpg',
    'true_idx': 1,  # Anopheles gambiae
    'label': 'Anopheles gambiae'}
TRUE_LABEL = ANALYSIS_IMAGE['label']
TRUE_IDX = ANALYSIS_IMAGE['true_idx']
TARGET_IMAGE_PATH = ANALYSIS_IMAGE['path']


# Models to run
MODELS_TO_RUN = [
    'resnet50',
    'convnext_tiny',
    'vit_large_patch16_224',
    'vit_base_patch16_224',
    'mobilenetv3_large_100']

# --- II. Define Transforms and Load Image ---

VAL_TEST_TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

plot_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()])

# Load and transform the single image (once)
try:
    img_pil = Image.open(TARGET_IMAGE_PATH).convert('RGB')
    img_plot = plot_transform(img_pil).permute(1, 2, 0).cpu().numpy()
    img_tensor = VAL_TEST_TRANSFORM(img_pil).unsqueeze(0).to(DEVICE)
    print(f"Input image '{os.path.basename(TARGET_IMAGE_PATH)}' (True: {TRUE_LABEL}) loaded.")

except FileNotFoundError:
    print(f"CRITICAL ERROR: Input Image not found at {TARGET_IMAGE_PATH}. Check your path.")
    exit()

# --- III. Prediction Function ---

def get_model_prediction(model, image_tensor, true_idx):
    model.eval()
    with torch.no_grad():
        output = model(image_tensor)
        probabilities = F.softmax(output, dim=1)

        predicted_idx = torch.argmax(output, dim=1).item()
        predicted_name = SPECIES_NAMES[predicted_idx]

        predicted_confidence = probabilities[0, predicted_idx].item()

        is_correct = (predicted_idx == true_idx)

    return predicted_idx, predicted_name, predicted_confidence, is_correct

# --- IV. Occlusion Mapping Function (Monitors a specific analysis_class_idx) ---

def generate_occlusion_map(model, original_image_tensor, analysis_class_idx,
                          mask_size=16, stride=8, device=DEVICE):
    model.eval()
    B, C, H, W = original_image_tensor.shape

    with torch.no_grad():
        original_output = model(original_image_tensor)
        overall_predicted_idx = torch.argmax(original_output, dim=1).item()
        baseline_score = original_output[0, analysis_class_idx].item()

    rows = int(math.ceil((H - mask_size) / stride) + 1)
    cols = int(math.ceil((W - mask_size) / stride) + 1)

    heatmap = np.zeros((rows, cols))

    for i in range(rows):
        for j in range(cols):
            x, y = i * stride, j * stride

            occluded_image = original_image_tensor.clone().to(device)
            occlusion_patch_val = 0.4485
            occlusion_patch = torch.ones(C, mask_size, mask_size).to(device) * occlusion_patch_val

            x_end, y_end = min(x + mask_size, H), min(y + mask_size, W)
            occluded_image[0, :, x:x_end, y:y_end] = occlusion_patch[:, :x_end-x, :y_end-y]

            with torch.no_grad():
                occluded_output = model(occluded_image)
                occluded_score = occluded_output[0, analysis_class_idx].item()

            score_drop = baseline_score - occluded_score
            heatmap[i, j] = score_drop

    heatmap = np.maximum(0, heatmap)

    heatmap_resized = zoom(heatmap, (224/heatmap.shape[0], 224/heatmap.shape[1]), order=1)

    if heatmap_resized.max() > 0:
        heatmap_resized /= heatmap_resized.max()

    return heatmap_resized

# --- V. Execution Loop (1x2 Plot per Model) ---

print("\nStarting Occlusion Mapping and Classification for each model...")

for model_name in tqdm(MODELS_TO_RUN, desc="Processing Models"):

    # A. Load the model
    try:
        model, _ = create_model(model_name, NUM_CLASSES)
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{model_name}_best_checkpoint.pth")

        model.load_state_dict(torch.load(checkpoint_path, map_location=DEVICE))
        model.to(DEVICE).eval()

    except Exception as e:
        print(f"Error loading model {model_name}: {e}. Skipping.")
        continue

    # B. Get the model's prediction
    predicted_idx, predicted_name, predicted_confidence, is_correct = get_model_prediction(
        model, img_tensor, TRUE_IDX)

    accuracy_color = 'green' if is_correct else 'red'
    accuracy_status = 'CORRECT' if is_correct else 'INCORRECT'

    # C. Generate the Occlusion Map for the model's PREDICTED class
    try:
        heatmap = generate_occlusion_map(
            model=model,
            original_image_tensor=img_tensor,
            analysis_class_idx=predicted_idx # Monitor the model's actual prediction)

    except Exception as e:
        print(f"Skipping visualization for {model_name} due to error during mapping: {e}")
        continue

    # D. Create the 1x2 Plot
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    fig.suptitle(
        f"Model: {model_name.upper()} | Prediction: {predicted_name} "
        f"| Status: {accuracy_status}",
        fontsize=14,
        color=accuracy_color)

    # --- Plot 1: Input Image ---
    axes[0].imshow(img_plot)
    axes[0].set_title(
        f"Input Image (True: {TRUE_LABEL})",
        fontsize=10)
    axes[0].axis('off')

    # --- Plot 2: Occlusion Map ---
    axes[1].imshow(img_plot)
    axes[1].imshow(heatmap, cmap='jet', alpha=0.6)
    axes[1].set_title(
        f"Saliency Map for '{predicted_name}'\nConf: {predicted_confidence:.4f}",
        fontsize=10)
    axes[1].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.9])
    plt.show()

In [None]:
# --- 14. Interactive Explainability Tool Methods ---

class SuppressOutput:
    def __enter__(self):
        self._original_stdout = sys.stdout
        self._original_stderr = sys.stderr
        sys.stdout = open(os.devnull, 'w')
        sys.stderr = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout
        sys.stderr = self._original_stderr

# --- I. Model Configuration ---
MODELS_CONFIG = [
    {'name': 'resnet50'},
    {'name': 'convnext_tiny'},
    {'name': 'mobilenetv3_large_100'},
    {'name': 'vit_base_patch16_224'},
    {'name': 'vit_large_patch16_224'}]

MODEL_ARCHITECTURES = {
    'resnet50': {
        'type': 'Deep CNN',
        'description': "Loading ResNet50..."
    },
    'convnext_tiny': {
        'type': 'Modern CNN',
        'description': "Loading ConvNeXt Tiny..."
    },
    'mobilenetv3_large_100': {
        'type': 'Lightweight CNN',
        'description': "Loading MobileNetV3..."
    },
    'vit_base_patch16_224': {
        'type': 'Vision Transformer (ViT)',
        'description': "Loading ViT Base..."
    },
    'vit_large_patch16_224': {
        'type': 'Vision Transformer (ViT)',
        'description': "Loading ViT Large..."
    }}

# --- II. Parameter Counts ---
MODEL_PARAMS = {
    'resnet50': 25.6,       # M
    'convnext_tiny': 28.6,  # M
    'mobilenetv3_large_100': 5.5, # M
    'vit_base_patch16_224': 86.8, # M
    'vit_large_patch16_224': 307.3 # M}

# --- III. Transforms ---
VAL_TEST_TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# --- IV. Widget Components ---
uploader = widgets.FileUpload(accept='image/*', multiple=False)
process_btn = widgets.Button(description="Run Comparative Analysis", button_style='success', icon='check')
out_display = widgets.Output()

# --- V. Event Handler ---
def on_process_click(b):
    out_display.clear_output()

    if not uploader.value:
        with out_display:
            print("Please upload an image first!")
        return

    vals = uploader.value
    fname = list(vals.keys())[0]
    content = vals[fname]['content']

    try:
        pil_img = Image.open(io.BytesIO(content)).convert('RGB')
        pil_img = pil_img.resize((512, 512))
    except Exception as e:
        with out_display:
            print(f"Error reading or resizing image: {e}")
        return

    input_tensor = VAL_TEST_TRANSFORM(pil_img).unsqueeze(0).to(DEVICE)
    plot_img_pil = pil_img.resize((224, 224))
    orig_rgb = np.array(plot_img_pil)
    orig_bgr = cv2.cvtColor(orig_rgb, cv2.COLOR_RGB2BGR)

    results = []

    with out_display:
        print(f"Analyzing: {fname}")
        print("---")

    model_count = len(MODELS_CONFIG)
    for i, config in enumerate(tqdm(MODELS_CONFIG, desc="Overall Analysis Progress")):
        m_name = config['name']
        ckpt_path = os.path.join(CHECKPOINT_DIR, f"{m_name}_best_checkpoint.pth")

        params = MODEL_PARAMS.get(m_name, 'N/A')

        with out_display:
            arch_info = MODEL_ARCHITECTURES.get(m_name, {'description': f"Loading {m_name}. No detailed architecture description available."})
            print(f"\n[STEP {i+1}/{model_count}] {arch_info['description']} (Parameters: {params}M)")

        total_start_time = time.time()

        try:
            steps = 3
            with tqdm(total=steps, desc=f"  {m_name}", leave=False) as pbar:

                with SuppressOutput():
                    model, _ = create_model(m_name, NUM_CLASSES)

                load_start = time.time()
                model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
                model.to(DEVICE).eval()
                load_time = time.time() - load_start
                pbar.update(1) # Step 1: Model Loaded

                # 1. Inference
                inference_start = time.time()
                with torch.no_grad():
                    outputs = model(input_tensor)
                    probs = torch.softmax(outputs, dim=1)
                    top_prob, top_idx = torch.max(probs, 1)
                inference_time = time.time() - inference_start

                pred_idx = top_idx.item()
                pred_name = SPECIES_NAMES[pred_idx] if 'SPECIES_NAMES' in globals() and 0 <= pred_idx < len(SPECIES_NAMES) else f"Idx {pred_idx}"
                pbar.update(1)

                # 2. Generate Occlusion Map
                cam_start = time.time()
                cam_map, _ = generate_occlusion_map_single(model, input_tensor.clone(), device=DEVICE)
                cam_time = time.time() - cam_start

                total_run_time = time.time() - total_start_time

                results.append({
                    'name': m_name,
                    'pred': pred_name,
                    'conf': top_prob.item(),
                    'map': cam_map,
                    'runtime': total_run_time,
                    'load_time': load_time,
                    'inference_time': inference_time,
                    'cam_time': cam_time
                })
                pbar.update(1)

                del model
                torch.cuda.empty_cache()
                gc.collect()

        except Exception as e:
             total_run_time = time.time() - total_start_time
             with out_display:
                if os.path.exists(ckpt_path):
                     print(f"  -> Critical Error during inference/loading: {e}")
                else:
                     print(f"  -> Checkpoint not found. Skipping.")

        # --- Print Runtime Analysis ---
        if 'runtime' in locals():
            with out_display:
                print(f"  -> Runtime Analysis (Total: {total_run_time:.3f}s)")
                print(f"    - Model Loading: {load_time:.3f}s")
                print(f"    - Inference:     {inference_time:.3f}s")
                print(f"    - Occlusion Map: {cam_time:.3f}s")
    # --- End Model Processing ---


    # --- VI. Plotting (Single 1x6 Row) ---
    if results:
        num_models = len(results)
        num_plots = 1 + num_models

        fig, axes = plt.subplots(1, num_plots, figsize=(4 * num_plots, 5))

        primary_pred = results[0]['pred']

        fig.suptitle(
            f"Comparative Occlusion Maps (Prediction: {primary_pred})",
            fontsize=16
        )

        with out_display:

            axes[0].imshow(orig_rgb)
            axes[0].set_title(f"Input Image (224x224)", fontsize=10)
            axes[0].axis('off')

            for col_idx, result in enumerate(results):

                title_color = 'black'

                heatmap = result['map']
                heatmap_cv = cv2.resize(heatmap, (224, 224))
                heatmap_cv = np.uint8(255 * heatmap_cv)
                heatmap_cv = cv2.applyColorMap(heatmap_cv, cv2.COLORMAP_JET)

                overlay = cv2.addWeighted(orig_bgr, 0.6, heatmap_cv, 0.4, 0)
                overlay_rgb = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)

                # 2. Plot
                ax = axes[col_idx + 1]
                ax.imshow(overlay_rgb)

                ax.set_title(
                    f"{result['name']}\n"
                    f"Pred: {result['pred']} (Conf: {result['conf']:.2%})\n"
                    f"Run: {result['runtime']:.2f}s | {MODEL_PARAMS.get(result['name'], 'N/A')}M Params",
                    fontsize=9,
                    color=title_color)
                ax.axis('off')

            plt.tight_layout(rect=[0, 0.03, 1, 0.9])
            display(fig)
            plt.close(fig)

    print("=" * 60)

process_btn.on_click(on_process_click)

# --- VII. Display UI ---
display(widgets.VBox([
    widgets.HTML("<h2 style='font-size: 24pt;'>Mosquito Species Classifier</h2>"),
    widgets.HBox([
        widgets.Label("Upload Image:", style={'description_width': 'initial', 'font_size': '14pt'}),
        uploader
    ]),
    process_btn,
    out_display]))

In [None]:
# --- 15. TSNE plots ---
print("\nStarting t-SNE Analysis (Combined + Individual Mode)...")

CHECKPOINT_DIR = '/content/gdrive/MyDrive/649_group8/checkpoints'
OUTPUT_DIR = '/content/gdrive/MyDrive/649_group8/FinalProject/tsne_plots_new'
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Saving all plots to: {OUTPUT_DIR}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_configs = {
    'MobileNetV3':   {'file': 'mobilenetv3_large_100_best_checkpoint.pth', 'timm_name': 'mobilenetv3_large_100'},
    'ViT-Large':     {'file': 'vit_large_patch16_224_best_checkpoint.pth', 'timm_name': 'vit_large_patch16_224'},
    'ResNet-50':     {'file': 'resnet50_best_checkpoint.pth',             'timm_name': 'resnet50'},
    'ViT-Base':      {'file': 'vit_base_patch16_224_best_checkpoint.pth',  'timm_name': 'vit_base_patch16_224'},
    'ConvNeXt-Tiny': {'file': 'convnext_tiny_best_checkpoint.pth',        'timm_name': 'convnext_tiny'}}

def prepare_timm_model(config, num_classes, checkpoint_dir, device):
    """Creates model, loads weights, removes head."""
    try:
        model = timm.create_model(config['timm_name'], pretrained=False, num_classes=num_classes)
        model = model.to(device)

        ckpt_path = os.path.join(checkpoint_dir, config['file'])
        if not os.path.exists(ckpt_path):
            print(f"  [Error] Checkpoint not found: {ckpt_path}")
            return None

        checkpoint = torch.load(ckpt_path, map_location=device)
        state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint

        new_state_dict = {}
        for k, v in state_dict.items():
            name = k[7:] if k.startswith('module.') else k
            new_state_dict[name] = v

        model.load_state_dict(new_state_dict, strict=True)
        model.reset_classifier(0)
        model.eval()
        return model
    except Exception as e:
        print(f"  [Error] Failed to load {config['timm_name']}: {e}")
        return None

def extract_features(model, loader, device):
    """Extracts features from validation set."""
    features_list, labels_list = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(loader, desc="Extracting", leave=False):
            imgs = imgs.to(device)
            output = model(imgs)
            if output.ndim > 2:
                output = output.mean(dim=[-2, -1])
            features_list.append(output.cpu().numpy())
            labels_list.append(lbls.numpy())
    return np.vstack(features_list), np.concatenate(labels_list)

if 'val_loader' not in locals():
    print("CRITICAL ERROR: val_loader missing. Please run the Data Preparation code first.")
else:
    TRAIN_NUM_CLASSES = 12
    print(f"Assuming models trained with {TRAIN_NUM_CLASSES} classes.")

    plot_data_storage = []

    # A. Loop through models to generate data and Individual Plots
    for i, (display_name, config) in enumerate(model_configs.items()):
        print(f"\n--- Processing {display_name} ---")

        # A. Load Model
        model = prepare_timm_model(config, TRAIN_NUM_CLASSES, CHECKPOINT_DIR, device)

        if model:
            # B. Extract Features
            X, y = extract_features(model, val_loader, device)

            # C. Output Point Count
            num_points = len(y)
            print(f"  > Number of data points (images): {num_points}")

            # D. Compute t-SNE
            print(f"  > Computing t-SNE...")
            curr_perp = min(30, num_points - 1) if num_points > 1 else 1
            tsne = TSNE(n_components=2, random_state=42, perplexity=curr_perp, init='pca', learning_rate='auto')
            X_embedded = tsne.fit_transform(X)

            # Store data for combined plot
            plot_data_storage.append({
                'name': display_name,
                'coords': X_embedded,
                'labels': y
            })

            # E. Plot & Save INDIVIDUAL Figure
            plt.figure(figsize=(10, 8))
            scatter = sns.scatterplot(
                x=X_embedded[:, 0],
                y=X_embedded[:, 1],
                hue=y,
                palette='tab10',
                legend='full',
                s=80, alpha=0.7
            )
            plt.title(f"{display_name} (n={num_points})", fontsize=16, fontweight='bold')
            plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., title="Class ID")
            plt.tight_layout()

            indiv_path = os.path.join(OUTPUT_DIR, f"tsne_{display_name.replace(' ', '_')}.png")
            plt.savefig(indiv_path, dpi=300, bbox_inches='tight')
            print(f"  > Individual plot saved to: {indiv_path}")
            plt.close() # Close to save memory

            # Cleanup model
            del model
            torch.cuda.empty_cache()

        else:
            print(f"  > Skipping {display_name} due to load error.")

    # B. Generate COMBINED Plot
    if plot_data_storage:
        print(f"\n--- Generating Combined Plot for {len(plot_data_storage)} models ---")

        fig, axes = plt.subplots(1, 5, figsize=(25, 6))
        axes = axes.flatten()

        for i, data in enumerate(plot_data_storage):
            ax = axes[i]
            X_emb = data['coords']
            y_lbl = data['labels']
            name = data['name']

            sns.scatterplot(
                x=X_emb[:, 0],
                y=X_emb[:, 1],
                hue=y_lbl,
                palette='tab10',
                legend='full' if i == 4 else False, # Legend only on last subplot
                s=60, alpha=0.7,
                ax=ax
            )
            ax.set_title(f"{name}\n(n={len(y_lbl)})", fontsize=14, fontweight='bold')
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlabel('')
            ax.set_ylabel('')

        # Fix Legend on Combined Plot
        handles, labels = axes[4].get_legend_handles_labels()
        if handles:
            axes[4].legend_.remove()
            fig.legend(handles, labels, loc='center right', title='Class ID', bbox_to_anchor=(1.02, 0.5))

        plt.suptitle("t-SNE Clustering Comparison", fontsize=20, y=1.05)
        plt.tight_layout()

        combined_path = os.path.join(OUTPUT_DIR, "tsne_combined_all_models.png")
        plt.savefig(combined_path, bbox_inches='tight', dpi=300)
        print(f"  > Combined plot saved to: {combined_path}")
        plt.show()
    else:
        print("No models were successfully processed to create a combined plot.")