# Image Embeddings with FiftyOne - MNIST Edition

This notebook explores image embeddings using FiftyOne, following the pattern from the [official tutorial](https://docs.voxel51.com/tutorials/image_embeddings.html) but adapted for MNIST.

**What we'll cover:**
1. Loading MNIST dataset into FiftyOne
2. Computing embeddings (raw pixels and neural network-based)
3. Visualizing embeddings with UMAP, t-SNE, and PCA
4. Interactive exploration and analysis
5. Finding outliers and potential mislabeled samples

In [None]:
import fiftyone as fo
import fiftyone.brain as fob
import fiftyone.zoo as foz
import numpy as np
from sklearn.decomposition import PCA
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

In [None]:
# Helper function to launch FiftyOne App in browser
def launch_in_browser(dataset_or_view):
    """Launch FiftyOne App and print URL for browser access"""
    session = fo.launch_app(dataset_or_view, auto=False)
    print(f"\nüåê Open FiftyOne App in browser:")
    print(f"   {session.url}")
    print(f"\nOr just visit: http://localhost:5151\n")
    return session

## 1. Load MNIST Dataset

FiftyOne has MNIST built-in through its dataset zoo. We'll load a subset for faster experimentation.

In [None]:
# Load MNIST test set (or specify split="train" for training set)
# Using max_samples to keep it manageable
dataset = foz.load_zoo_dataset(
    "mnist",
    split="test",
    max_samples=5000,  # Adjust as needed
    dataset_name="mnist_embeddings_tutorial",
)

print(f"Loaded {len(dataset)} samples")
print(f"Dataset info: {dataset}")

In [None]:
# Launch the FiftyOne App to explore the dataset
# session = fo.launch_app(dataset)

# To open in browser instead of embedded, use:
session = launch_in_browser(dataset)

## 2. Computing Embeddings - Method 1: Raw Pixels

Since MNIST images are small (28x28), we can use raw pixel values as embeddings. This is simple but can still reveal interesting patterns.

In [None]:
# Compute raw pixel embeddings
# For MNIST (28x28 grayscale), each image becomes a 784-dimensional vector

embeddings = []
for sample in dataset:
    # Load image and flatten to 1D array
    img = Image.open(sample.filepath).convert('L')  # Ensure grayscale
    img_array = np.array(img).flatten()
    embeddings.append(img_array)

embeddings = np.array(embeddings)
print(f"Raw pixel embeddings shape: {embeddings.shape}")

In [None]:
# Visualize in the App - embeddings panel should appear
# session = fo.launch_app(dataset)
# session = launch_in_browser(dataset)

# Select the embeddings visualization from the App sidebar
# You should see clusters corresponding to different digits!

## 3. Computing Embeddings - Method 2: Neural Network Features (MobileNet)

Now let's use a pre-trained neural network to compute more sophisticated embeddings. We'll use FiftyOne's Model Zoo.

In [None]:
# Load a pre-trained model from FiftyOne's zoo
# MobileNet is lightweight and good for getting started
model = foz.load_zoo_model("mobilenet-v2-imagenet-torch")

print(f"Loaded model: {model}")

In [None]:
# Compute embeddings using the model
# FiftyOne will extract features from the penultimate layer
embeddings_nn = dataset.compute_embeddings(model, batch_size=32)

print(f"Neural network embeddings shape: {embeddings_nn.shape}")

In [None]:
# Compute UMAP visualization for neural network embeddings
results_nn = fob.compute_visualization(
    dataset,
    embeddings=embeddings_nn,
    brain_key="mobilenet_umap",
    method="umap",
    verbose=True,
)

print("Neural network embeddings visualization ready!")

## 4. Comparing Dimensionality Reduction Methods

Let's compute visualizations using different methods: UMAP, t-SNE, and PCA.

In [None]:
# # t-SNE visualization
# results_tsne = fob.compute_visualization(
#     dataset,
#     embeddings=embeddings_nn,
#     brain_key="mobilenet_tsne",
#     method="tsne",
#     verbose=True,
# )

# print("t-SNE visualization computed!")

In [None]:
# # PCA visualization
# results_pca = fob.compute_visualization(
#     dataset,
#     embeddings=embeddings_nn,
#     brain_key="mobilenet_pca",
#     method="pca",
#     verbose=True,
# )

# print("PCA visualization computed!")

In [None]:
# # List all available visualizations
# print("Available embeddings visualizations:")
# for key in dataset.list_brain_runs():
#     print(f"  - {key}")

# # You can switch between them in the App's embeddings panel!

## 5. Interactive Exploration

**In the FiftyOne App:**
- Use the **Box Select** or **Lasso Select** tools to select clusters
- Selected samples will appear in the grid view
- Color by `ground_truth.label` to see how digits cluster
- Look for:
  - **Tight clusters**: Well-separated digit classes
  - **Outliers**: Unusual samples that don't fit their cluster
  - **Mixed clusters**: Potential mislabeling or ambiguous digits

In [None]:
# # Programmatic analysis: Find samples far from their cluster centers
# # This can help identify outliers

# from sklearn.neighbors import NearestNeighbors
# from collections import defaultdict

# # Group embeddings by label
# label_to_embeddings = defaultdict(list)
# label_to_ids = defaultdict(list)

# for sample, embedding in zip(dataset, embeddings_nn):
#     label = sample.ground_truth.label
#     label_to_embeddings[label].append(embedding)
#     label_to_ids[label].append(sample.id)

# # For each label, find outliers (samples far from their cluster)
# outlier_ids = []

# for label, embs in label_to_embeddings.items():
#     if len(embs) < 10:  # Skip if too few samples
#         continue
    
#     embs = np.array(embs)
#     ids = label_to_ids[label]
    
#     # Compute center of cluster
#     center = embs.mean(axis=0)
    
#     # Find samples furthest from center
#     distances = np.linalg.norm(embs - center, axis=1)
    
#     # Get top 5 outliers for this digit
#     outlier_indices = np.argsort(distances)[-5:]
#     outlier_ids.extend([ids[i] for i in outlier_indices])

# print(f"Found {len(outlier_ids)} potential outliers")

In [None]:
# # Create a view of outliers
# outliers_view = dataset.select(outlier_ids)

# print(f"Outliers view contains {len(outliers_view)} samples")
# print("\nLaunching App with outliers view...")

# # session = fo.launch_app(view=outliers_view)
# session = fo.launch_in_browser(view=outliers_view)

# # These are the samples that are most different from their digit class!
# # Look for potential mislabeling or unusual handwriting

## 6. Finding Similar Samples

Use embeddings to find visually similar samples - useful for data cleaning and understanding your dataset.

In [None]:
# # Pick a random sample and find its nearest neighbors
# query_sample = dataset.first()
# query_id = query_sample.id
# query_label = query_sample.ground_truth.label

# print(f"Query sample: ID={query_id}, Label={query_label}")

# # Get query embedding
# query_idx = dataset.match({"id": query_id}).first().id
# sample_ids = [s.id for s in dataset]
# query_idx = sample_ids.index(query_id)
# query_embedding = embeddings_nn[query_idx].reshape(1, -1)

# # Find 10 nearest neighbors
# nbrs = NearestNeighbors(n_neighbors=11, metric='cosine').fit(embeddings_nn)
# distances, indices = nbrs.kneighbors(query_embedding)

# # Get neighbor IDs (skip first one as it's the query itself)
# neighbor_ids = [sample_ids[i] for i in indices[0][1:]]

# print(f"\nFound {len(neighbor_ids)} similar samples")

In [None]:
# # View the query sample and its neighbors
# similar_view = dataset.select([query_id] + neighbor_ids)

# # session = fo.launch_app(view=similar_view)
# session = fo.launch_in_browser(view=similar_view)

# # Notice how visually similar these samples are!

## 7. Analysis: Label Quality Investigation

Find potential mislabeled samples by looking for samples whose nearest neighbors have different labels.

In [None]:
# # For each sample, check if its neighbors have the same label
# k = 10  # Number of neighbors to check
# nbrs = NearestNeighbors(n_neighbors=k+1, metric='cosine').fit(embeddings_nn)

# potential_errors = []

# for idx, (sample, embedding) in enumerate(zip(dataset, embeddings_nn)):
#     # Find neighbors
#     distances, indices = nbrs.kneighbors(embedding.reshape(1, -1))
    
#     # Get neighbor labels (skip self)
#     neighbor_labels = [dataset[sample_ids[i]].ground_truth.label 
#                       for i in indices[0][1:]]
    
#     # Check if majority of neighbors have different label
#     sample_label = sample.ground_truth.label
#     different_count = sum(1 for l in neighbor_labels if l != sample_label)
    
#     if different_count > k * 0.6:  # More than 60% different
#         potential_errors.append({
#             'id': sample.id,
#             'label': sample_label,
#             'common_neighbor_label': max(set(neighbor_labels), 
#                                         key=neighbor_labels.count),
#             'different_ratio': different_count / k
#         })

# print(f"Found {len(potential_errors)} potential labeling errors")

# # Show first few
# for error in potential_errors[:5]:
#     print(f"  Sample labeled '{error['label']}' but neighbors are mostly '{error['common_neighbor_label']}'")

In [None]:
# # View potential labeling errors
# if potential_errors:
#     error_ids = [e['id'] for e in potential_errors]
#     errors_view = dataset.select(error_ids)
    
#     # session = fo.launch_app(view=errors_view)
#     session = fo.launch_in_browser(view=errors_view)
    
#     print("Check these out - they might be mislabeled or ambiguous!")
# else:
#     print("No obvious labeling errors found - dataset looks clean!")

## 8. Computing Embeddings - Method 2: Neural Network Features (DINOv3)

In [None]:
# Load DINOv3 model from FiftyOne's zoo
# DINOv3 is a self-supervised vision transformer that produces high-quality embeddings
model_dino = foz.load_zoo_model("dinov2-vits14-torch")

print(f"Loaded DINOv3 model: {model_dino}")

# Compute DINOv3 embeddings for the entire dataset
embeddings_dino = dataset.compute_embeddings(model_dino, batch_size=32)

print(f"DINOv3 embeddings shape: {embeddings_dino.shape}")

In [None]:
# Compute UMAP visualization for DINOv3 embeddings
results_dino_umap = fob.compute_visualization(
    dataset,
    embeddings=embeddings_dino,
    brain_key="dinov3_umap",
    method="umap",
    verbose=True,
)

print("DINOv3 UMAP visualization ready!")

In [None]:
# # PCA visualization for DINOv3 embeddings
# results_dino_pca = fob.compute_visualization(
#     dataset,
#     embeddings=embeddings_dino,
#     brain_key="dinov3_pca",
#     method="pca",
#     verbose=True,
# )

# print("DINOv3 PCA visualization computed!")

In [None]:
# Launch FiftyOne App to explore DINOv3 embeddings
# You can switch between different embeddings (mobilenet_umap, dinov3_umap, dinov3_pca) 
# in the embeddings panel
session = launch_in_browser(dataset)

# In the App:
# - Go to the embeddings panel on the right
# - Select "dinov3_umap" or "dinov3_pca" from the dropdown
# - Color by ground_truth.label to see how DINOv3 clusters digits

In [None]:
# Compare embedding quality: DINOv3 vs MobileNet
# We can compare by looking at how well the embeddings separate different digit classes

from sklearn.metrics import silhouette_score

# Get labels for all samples
labels = [sample.ground_truth.label for sample in dataset]
label_to_int = {str(i): i for i in range(10)}
labels_int = [label_to_int[label[0]] for label in labels]

# Compute silhouette scores (higher is better, range: -1 to 1)
# Measures how similar samples are to their own cluster vs other clusters
score_mobilenet = silhouette_score(embeddings_nn, labels_int, metric='cosine')
score_dino = silhouette_score(embeddings_dino, labels_int, metric='cosine')

print("Embedding Quality Comparison (Silhouette Score):")
print(f"  MobileNet: {score_mobilenet:.4f}")
print(f"  DINOv3:    {score_dino:.4f}")
print(f"\nHigher scores indicate better cluster separation.")
print(f"DINOv3 is {'better' if score_dino > score_mobilenet else 'worse'} than MobileNet by {abs(score_dino - score_mobilenet):.4f}")

## 9. Training a Classifier Head on Embeddings

Now let's train a simple classifier on top of the frozen DINOv3 embeddings. This demonstrates how well the embeddings capture semantic information about the digits.

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [None]:
# embeddings_dino.shape, embeddings_nn.shape
# ((5000, 384), (5000, 1280))

In [None]:
# Extract embeddings and labels
# X = embeddings_dino  # DINOv3 embeddings
X = embeddings_nn  # MobileNet embeddings
y = np.array([sample.ground_truth.label for sample in dataset])

# Encode labels as integers
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

print(f"Data shape: {X.shape}")
print(f"Labels: {np.unique(y)}")
print(f"Number of classes: {len(label_encoder.classes_)}")

In [None]:
# Split data: 70% train, 15% validation, 15% test

train_indices, temp_indices, y_train, y_temp = train_test_split(
    np.arange(X.shape[0]), y_encoded, test_size=0.3, random_state=42, stratify=y_encoded
)
val_indices, test_indices, y_val, y_test = train_test_split(
    temp_indices, y_temp, test_size=0.5, random_state=42, stratify=y_temp
)

X_train, X_val, X_test = X[train_indices], X[val_indices], X[test_indices]
y_train, y_val, y_test = y_encoded[train_indices], y_encoded[val_indices], y_encoded[test_indices]

print(f"Train set: {X_train.shape[0]} samples")
print(f"Validation set: {X_val.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

# Convert to PyTorch tensors
X_train_tensor = torch.FloatTensor(X_train)
y_train_tensor = torch.LongTensor(y_train)
X_val_tensor = torch.FloatTensor(X_val)
y_val_tensor = torch.LongTensor(y_val)
X_test_tensor = torch.FloatTensor(X_test)
y_test_tensor = torch.LongTensor(y_test)

# Create data loaders
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# Define a simple MLP classifier head
class ClassifierHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(ClassifierHead, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )
    
    def forward(self, x):
        return self.network(x)

# Initialize model
embedding_dim = X_train.shape[1]
num_classes = len(label_encoder.classes_)
model_classifier = ClassifierHead(embedding_dim, hidden_dim=256, num_classes=num_classes)

print(f"Model architecture:")
print(model_classifier)
print(f"\nTotal parameters: {sum(p.numel() for p in model_classifier.parameters()):,}")

In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_classifier.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Training function
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for X_batch, y_batch in loader:
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

# Validation function
def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X_batch, y_batch in loader:
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

print("Training functions defined!")

In [None]:
# Train the classifier
num_epochs = 50
best_val_acc = 0
patience_counter = 0
early_stop_patience = 10

train_losses = []
val_losses = []
train_accs = []
val_accs = []

print("Starting training...")
print("-" * 60)

for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model_classifier, train_loader, criterion, optimizer)
    val_loss, val_acc = validate(model_classifier, val_loader, criterion)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Print progress every 5 epochs
    if (epoch + 1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    
    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        # Save best model
        best_model_state = model_classifier.state_dict().copy()
    else:
        patience_counter += 1
        if patience_counter >= early_stop_patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

# Load best model
model_classifier.load_state_dict(best_model_state)

print("-" * 60)
print(f"Training complete! Best validation accuracy: {best_val_acc:.2f}%")

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1.plot(train_losses, label='Train Loss', linewidth=2)
ax1.plot(val_losses, label='Validation Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(train_accs, label='Train Accuracy', linewidth=2)
ax2.plot(val_accs, label='Validation Accuracy', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final training accuracy: {train_accs[-1]:.2f}%")
print(f"Final validation accuracy: {val_accs[-1]:.2f}%")

In [None]:
# Get predictions on test set
model_classifier.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for X_batch, y_batch in test_loader:
        outputs = model_classifier(X_batch)
        _, predicted = torch.max(outputs.data, 1)
        all_preds.extend(predicted.numpy())
        all_labels.extend(y_batch.numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

In [None]:
# Evaluate on test set
test_loss, test_acc = validate(model_classifier, test_loader, criterion)

print("=" * 60)
print("FINAL TEST SET RESULTS")
print("=" * 60)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")
print("=" * 60)

Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=label_encoder.classes_,
            yticklabels=label_encoder.classes_)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix - DINOv3 Classifier on MNIST')
plt.tight_layout()
plt.show()

# Classification report
print("\nClassification Report:")
print("=" * 60)
print(classification_report(all_labels, all_preds, 
                          target_names=label_encoder.classes_,
                          digits=4))

View the misclassified training examples

In [None]:
mismatch_sample_mask = (all_labels != all_preds)
mismatch_sample_ids = np.argwhere(mismatch_sample_mask).flatten()

# map back to voxel dataset indixes
full_dataset_mismatch_test_sample_ids = test_indices[mismatch_sample_ids]

for idx, sample in enumerate(dataset):
    if idx in full_dataset_mismatch_test_sample_ids:
        sample["misclassified_test_sample"] = True
    else:
        sample["misclassified_test_sample"] = False
    sample.save()

In [None]:
session = launch_in_browser(dataset)

In [None]:
# Summary and next steps
print("=" * 60)
print("SUMMARY: NN Classifier Training")
print("=" * 60)
print(f"\nEmbedding dimension: {embedding_dim}")
print(f"Number of classes: {num_classes}")
print(f"Training samples: {len(X_train)}")
print(f"Validation samples: {len(X_val)}")
print(f"Test samples: {len(X_test)}")
print(f"\nBest validation accuracy: {best_val_acc:.2f}%")
print(f"Final test accuracy: {test_acc:.2f}%")
print(f"\nModel parameters: {sum(p.numel() for p in model_classifier.parameters()):,}")