# Manifold Study - UMAP Visualization with Interactive Hover

This notebook extracts embeddings from the YOLO model and visualizes them using UMAP.
You can hover over points to see image paths and click to inspect samples.

In [None]:
import numpy as np
from pathlib import Path
from collections import defaultdict
import random
from tqdm import tqdm
import umap
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from ultralytics.models.yolo import YOLO
import cv2

CLASS_NAMES = ['knife', 'gun', 'rifle', 'baseball_bat']
CLASS_ID_TO_NAME = {i: name for i, name in enumerate(CLASS_NAMES)}

In [None]:
def load_balanced_samples(dataset_path, samples_per_class=100, split='train', seed=42):
    random.seed(seed)
    
    images_dir = Path(dataset_path) / 'images' / split
    labels_dir = Path(dataset_path) / 'labels' / split
    
    class_samples = defaultdict(list)
    
    label_files = list(labels_dir.glob('*.txt'))
    print(f"Found {len(label_files)} label files in {split} split")
    
    for label_path in label_files:
        image_name = label_path.stem + '.jpg'
        image_path = images_dir / image_name
        
        if not image_path.exists():
            image_name = label_path.stem + '.png'
            image_path = images_dir / image_name
            if not image_path.exists():
                continue
        
        with open(label_path, 'r') as f:
            lines = f.readlines()
        
        for line in lines:
            parts = line.strip().split()
            if len(parts) >= 5:
                class_id = int(parts[0])
                class_samples[class_id].append({
                    'image_path': str(image_path),
                    'label_path': str(label_path),
                    'class_id': class_id,
                    'class_name': CLASS_ID_TO_NAME.get(class_id, f'class_{class_id}'),
                    'bbox': [float(x) for x in parts[1:5]]
                })
    
    print("\nClass distribution in dataset:")
    for class_id, samples in sorted(class_samples.items()):
        print(f"  {CLASS_ID_TO_NAME.get(class_id, f'class_{class_id}')}: {len(samples)} samples")
    
    balanced_samples = []
    for class_id, samples in class_samples.items():
        if samples_per_class is None:
            selected = samples
        elif len(samples) >= samples_per_class:
            selected = random.sample(samples, samples_per_class)
        else:
            selected = samples
            print(f"  Warning: {CLASS_ID_TO_NAME.get(class_id)} has only {len(samples)} samples")
        balanced_samples.extend(selected)
    
    random.shuffle(balanced_samples)
    
    print(f"\nTotal samples: {len(balanced_samples)}")
    return balanced_samples

In [None]:
def extract_embeddings(model, samples, layer_idx=10, imgsz=640):
    embeddings = []
    metadata = []
    
    feature_maps = {}
    
    def hook_fn(module, input, output):
        pooled = output.mean(dim=(2, 3))
        feature_maps['embedding'] = pooled.detach().cpu().numpy()
    
    target_layer = model.model.model[layer_idx]
    handle = target_layer.register_forward_hook(hook_fn)
    
    seen_images = set()
    
    for sample in tqdm(samples, desc=f"Extracting embeddings from layer {layer_idx}"):
        img_path = sample['image_path']
        
        if img_path in seen_images:
            embeddings.append(feature_maps['embedding'].squeeze())
            metadata.append(sample)
            continue
        
        seen_images.add(img_path)
        
        try:
            _ = model(img_path, imgsz=imgsz, verbose=False)
            embeddings.append(feature_maps['embedding'].squeeze())
            metadata.append(sample)
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            continue
    
    handle.remove()
    
    embeddings = np.vstack(embeddings)
    print(f"\nExtracted embeddings shape: {embeddings.shape}")
    print(f"Samples per class:")
    class_counts = defaultdict(int)
    for m in metadata:
        class_counts[m['class_name']] += 1
    for cls, count in sorted(class_counts.items()):
        print(f"  {cls}: {count}")
    
    return embeddings, metadata

In [None]:
def run_umap(embeddings, n_neighbors=15, min_dist=0.1, metric='cosine', n_components=3):
    print(f"Running UMAP (n_neighbors={n_neighbors}, min_dist={min_dist}, metric={metric}, n_components={n_components})...")
    reducer = umap.UMAP(
        n_neighbors=n_neighbors, 
        min_dist=min_dist, 
        n_components=n_components, 
        random_state=42, 
        metric=metric
    )
    coords = reducer.fit_transform(embeddings)
    print("UMAP complete!")
    return coords

In [None]:
def create_interactive_umap(coords, metadata, save_html=None):
    """
    Create interactive Plotly scatter plot with hover info.
    Hover shows: class, image filename, bbox info.
    """
    filenames = [Path(m['image_path']).name for m in metadata]
    class_names = [m['class_name'] for m in metadata]
    image_paths = [m['image_path'] for m in metadata]
    bboxes = [f"cx={m['bbox'][0]:.2f}, cy={m['bbox'][1]:.2f}, w={m['bbox'][2]:.2f}, h={m['bbox'][3]:.2f}" for m in metadata]
    
    colors = {
        'knife': '#e74c3c',
        'gun': '#3498db', 
        'rifle': '#2ecc71',
        'baseball_bat': '#9b59b6'
    }
    
    fig = go.Figure()
    
    for cls in sorted(set(class_names)):
        mask = [i for i, c in enumerate(class_names) if c == cls]
        
        fig.add_trace(go.Scatter(
            x=coords[mask, 0],
            y=coords[mask, 1],
            mode='markers',
            name=f'{cls} ({len(mask)})',
            marker=dict(
                size=8,
                color=colors.get(cls, '#95a5a6'),
                opacity=0.7,
                line=dict(width=1, color='white')
            ),
            text=[filenames[i] for i in mask],
            customdata=[[image_paths[i], bboxes[i]] for i in mask],
            hovertemplate=(
                '<b>%{text}</b><br>'
                'Class: ' + cls + '<br>'
                'BBox: %{customdata[1]}<br>'
                'Path: %{customdata[0]}<br>'
                '<extra></extra>'
            )
        ))
    
    fig.update_layout(
        title=dict(
            text='UMAP Visualization of YOLO Embeddings (Layer 10 - C2PSA)',
            font=dict(size=16)
        ),
        xaxis_title='UMAP Dimension 1',
        yaxis_title='UMAP Dimension 2',
        legend=dict(
            yanchor='top',
            y=0.99,
            xanchor='left',
            x=0.01,
            bgcolor='rgba(255,255,255,0.8)'
        ),
        width=1000,
        height=800,
        template='plotly_white'
    )
    
    if save_html:
        fig.write_html(save_html)
        print(f"Saved interactive plot to: {save_html}")
    
    fig.show()
    return fig

In [None]:
def create_interactive_umap_3d(coords, metadata, save_html=None):
    filenames = [Path(m['image_path']).name for m in metadata]
    class_names = [m['class_name'] for m in metadata]
    image_paths = [m['image_path'] for m in metadata]
    bboxes = [f"cx={m['bbox'][0]:.2f}, cy={m['bbox'][1]:.2f}, w={m['bbox'][2]:.2f}, h={m['bbox'][3]:.2f}" for m in metadata]
    
    colors = {
        'knife': '#e63946',      # red
        'gun': '#457b9d',        # muted blue
        'rifle': '#6c757d',      # muted gray-green (less dominant)
        'baseball_bat': '#f4a261' # orange
    }
    
    fig = go.Figure()
    
    class_counts = defaultdict(int)
    for c in class_names:
        class_counts[c] += 1
    sorted_classes = sorted(class_counts.keys(), key=lambda x: class_counts[x], reverse=True)
    
    for cls in sorted_classes:
        mask = [i for i, c in enumerate(class_names) if c == cls]
        
        fig.add_trace(go.Scatter3d(
            x=coords[mask, 0],
            y=coords[mask, 1],
            z=coords[mask, 2],
            mode='markers',
            name=f'{cls} ({len(mask)})',
            marker=dict(
                size=3,
                color=colors.get(cls, '#95a5a6'),
                opacity=0.6,
                line=dict(width=0)
            ),
            text=[filenames[i] for i in mask],
            customdata=[[image_paths[i], bboxes[i]] for i in mask],
            hovertemplate=(
                '<b>%{text}</b><br>'
                'Class: ' + cls + '<br>'
                'BBox: %{customdata[1]}<br>'
                'Path: %{customdata[0]}<br>'
                '<extra></extra>'
            )
        ))
    
    fig.update_layout(
        title=dict(
            text='3D UMAP Visualization of YOLO Embeddings',
            font=dict(size=16)
        ),
        scene=dict(
            xaxis_title='UMAP 1',
            yaxis_title='UMAP 2',
            zaxis_title='UMAP 3',
            camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
        ),
        legend=dict(
            yanchor='top', y=0.99,
            xanchor='left', x=0.01,
            bgcolor='rgba(255,255,255,0.8)'
        ),
        width=1000,
        height=800,
        template='plotly_white'
    )
    
    if save_html:
        fig.write_html(save_html)
        print(f"Saved interactive 3D plot to: {save_html}")
    
    fig.show()
    return fig

In [None]:
def visualize_sample(image_path, bbox=None):
    """
    Helper to visualize a specific sample (call this after finding interesting points).
    """
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    plt.figure(figsize=(10, 8))
    plt.imshow(img_rgb)
    
    if bbox:
        h, w = img.shape[:2]
        cx, cy, bw, bh = bbox
        x1 = int((cx - bw/2) * w)
        y1 = int((cy - bh/2) * h)
        x2 = int((cx + bw/2) * w)
        y2 = int((cy + bh/2) * h)
        
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                             fill=False, edgecolor='lime', linewidth=3)
        plt.gca().add_patch(rect)
    
    plt.title(Path(image_path).name, fontsize=12)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

## Load Model

In [None]:
model_path = "/Users/vaibhavnakrani/yolo_dangerous_weapons/models/yolo/5_jan_2026_yolo11m/weights/best.pt"
model = YOLO(model_path)
model.to("mps")
print("Model loaded!")

## Load Balanced Samples

In [None]:
DATASET_PATH = "/Users/vaibhavnakrani/Desktop/workspace/yolo_dataset_4_dec"
SAMPLES_PER_CLASS = 500  # Adjust: 100 for quick, 300+ for better viz

samples = load_balanced_samples(DATASET_PATH, samples_per_class=None, split='train')

## Extract Embeddings

In [None]:
embeddings, metadata = extract_embeddings(model, samples, layer_idx=10, imgsz=640)

In [None]:
# save embeddings and metadata
import pickle

# Save embeddings and metadata
with open('embeddings.pkl', 'wb') as f:
    pickle.dump((embeddings, metadata), f)

# Load embeddings and metadata
with open('embeddings.pkl', 'rb') as f:
    embeddings, metadata = pickle.load(f)


## Run UMAP

In [None]:
umap_coords = run_umap(embeddings, n_neighbors=15, min_dist=0.1, metric='cosine', n_components=3)

## Interactive Visualization

**Hover over points** to see:
- Image filename
- Class name  
- Bounding box coordinates
- Full image path

Copy the path and use `visualize_sample()` to see the actual image!

In [None]:
fig = create_interactive_umap(
    umap_coords, 
    metadata,
    save_html='/Users/vaibhavnakrani/yolo_dangerous_weapons/notebooks/umap_interactive.html'
)

In [None]:
fig = create_interactive_umap_3d(
    umap_coords, 
    metadata,
)

## Explore Clusters

Select points by defining a rectangular region (x_min, x_max, y_min, y_max) based on the UMAP coordinates you see in the plot above.

In [None]:
def get_samples_in_region(coords, metadata, x_min, x_max, y_min, y_max):
    """Get all samples within a rectangular region of UMAP space."""
    mask = (
        (coords[:, 0] >= x_min) & (coords[:, 0] <= x_max) &
        (coords[:, 1] >= y_min) & (coords[:, 1] <= y_max)
    )
    indices = np.where(mask)[0]
    selected = [metadata[i] for i in indices]
    
    print(f"Found {len(selected)} samples in region x=[{x_min}, {x_max}], y=[{y_min}, {y_max}]")
    
    class_counts = defaultdict(int)
    for s in selected:
        class_counts[s['class_name']] += 1
    print("Class distribution:")
    for cls, count in sorted(class_counts.items()):
        print(f"  {cls}: {count}")
    
    return selected, indices


def show_cluster_images(samples, max_images=20, cols=5):
    """Display images from a cluster in a grid."""
    n_images = min(len(samples), max_images)
    rows = (n_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))
    axes = axes.flatten() if n_images > 1 else [axes]
    
    for i, ax in enumerate(axes):
        if i < n_images:
            sample = samples[i]
            img = cv2.imread(sample['image_path'])
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            h, w = img.shape[:2]
            cx, cy, bw, bh = sample['bbox']
            x1 = int((cx - bw/2) * w)
            y1 = int((cy - bh/2) * h)
            x2 = int((cx + bw/2) * w)
            y2 = int((cy + bh/2) * h)
            
            cv2.rectangle(img_rgb, (x1, y1), (x2, y2), (0, 255, 0), 3)
            
            ax.imshow(img_rgb)
            ax.set_title(f"{sample['class_name']}\n{Path(sample['image_path']).stem[:20]}", fontsize=8)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# First, let's see the coordinate ranges to help you pick a region
print("UMAP coordinate ranges:")
print(f"  X: [{umap_coords[:, 0].min():.2f}, {umap_coords[:, 0].max():.2f}]")
print(f"  Y: [{umap_coords[:, 1].min():.2f}, {umap_coords[:, 1].max():.2f}]")

In [None]:
# ADJUST THESE VALUES based on the cluster you want to explore
# Look at the plot and pick approximate x/y ranges for your cluster of interest

x_min, x_max = 10, 17   # <-- Change these
y_min, y_max = -5, 10  # <-- Change these (top-right cluster looks interesting)

cluster_samples, cluster_indices = get_samples_in_region(umap_coords, metadata, x_min, x_max, y_min, y_max)

In [None]:
# Show images from the selected cluster
show_cluster_images(cluster_samples, max_images=20, cols=5)

In [None]:
# Highlight selected cluster on the UMAP plot
plt.figure(figsize=(12, 10))

# Plot all points in gray
plt.scatter(umap_coords[:, 0], umap_coords[:, 1], c='lightgray', alpha=0.3, s=40)

# Highlight selected cluster
plt.scatter(umap_coords[cluster_indices, 0], umap_coords[cluster_indices, 1], 
           c='red', s=80, edgecolors='black', linewidth=1, label=f'Selected ({len(cluster_indices)})')

# Draw selection rectangle
rect = plt.Rectangle((x_min, y_min), x_max-x_min, y_max-y_min, 
                     fill=False, edgecolor='red', linewidth=2, linestyle='--')
plt.gca().add_patch(rect)

plt.legend(fontsize=11)
plt.title('Selected Cluster Region', fontsize=14)
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
import os
import shutil

# Create the 'new-zealand' folder if it doesn't exist
output_folder = "new-zealand"
os.makedirs(output_folder, exist_ok=True)

for ent in cluster_samples:
    img_path = ent['image_path']
    if os.path.exists(img_path):
        # Copy image to the new folder, keep original filename
        shutil.copy(img_path, os.path.join(output_folder, os.path.basename(img_path)))
    else:
        print(f"Warning: {img_path} does not exist and was not copied.")
    

## Model Architecture - All Layers

In [None]:
# List all layers in the model
print(f"Total layers: {len(model.model.model)}\n")
print("="*70)
print(f"{'Index':<6} {'Layer Type':<20} {'Description'}")
print("="*70)

layer_descriptions = {
    0: "Stem Conv (3→64, stride 2)",
    1: "Stem Conv (64→128, stride 2)", 
    2: "C3k2 block",
    3: "Downsample Conv (stride 2)",
    4: "C3k2 block",
    5: "Downsample Conv (stride 2)",
    6: "C3k2 block",
    7: "Downsample Conv (stride 2)",
    8: "C3k2 block",
    9: "SPPF (Spatial Pyramid Pooling)",
    10: "C2PSA (Attention) ← WE USED THIS",
    11: "Upsample 2x",
    12: "Concat",
    13: "C3k2 block (Neck)",
    14: "Upsample 2x",
    15: "Concat",
    16: "C3k2 block (P3 output)",
    17: "Downsample Conv",
    18: "Concat",
    19: "C3k2 block (P4 output)",
    20: "Downsample Conv",
    21: "Concat",
    22: "C3k2 block (P5 output)",
    23: "Detect Head"
}

for i, layer in enumerate(model.model.model):
    layer_type = layer.__class__.__name__
    desc = layer_descriptions.get(i, "")
    marker = " ★" if i == 10 else ""
    print(f"{i:<6} {layer_type:<20} {desc}{marker}")