# Visualizing Image Embeddings Using Hugging Face Models

This tutorial explores how to visualize the latent spaces of image embeddings using Vision Transformer models from the Hugging Face Transformers library. We'll cover what latent spaces are in the context of images, why they're important, and how to extract and visualize them from pre-trained models.

## 1. Introduction to Image Latent Spaces

A **latent space** (also known as a latent representation or embedding space) is a compressed representation of data where similar items are positioned closer together. These spaces are "latent" because they represent hidden factors that explain the observed data.

### Why are image latent spaces important?

- They enable dimensionality reduction, converting high-dimensional image data into more manageable lower-dimensional representations
- They capture visual similarities and semantic relationships between images
- They allow for meaningful manipulations (style transfer, image editing, etc.)
- They provide insights into how vision models internally represent visual information

In computer vision, visualizing latent spaces can help us understand:
- How models cluster visually similar images
- Which visual features the model considers important
- Potential biases in the model's representations
- How different vision architectures learn different representations

## 2. Setting Up the Environment

Let's install the necessary libraries for our tutorial:

In [None]:
# Install required packages
# See notebook on text embeddings for more info on how to install the packages
#
# %pip install transformers torch scikit-learn matplotlib pandas seaborn umap-learn Pillow datasets

In [None]:
# Import libraries
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.offsetbox as offsetbox  # Import offsetbox module for thumbnail visualization
import seaborn as sns
from PIL import Image
from datasets import load_dataset
from transformers import ViTFeatureExtractor, ViTModel
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

## 3. Image Embeddings from Vision Transformers

Let's explore the latent space of a vision model. We'll use a pre-trained Vision Transformer (ViT) to generate embeddings for a set of images, and then visualize how these embeddings are organized in latent space.

In [None]:
# Load a pre-trained Vision Transformer
model_name = "google/vit-base-patch16-224"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
vit_model = ViTModel.from_pretrained(model_name)

### Loading and preprocessing images

For this tutorial, we'll use the CIFAR-10 dataset which contains images from 10 different classes. If you have your own set of images, you could use those instead.

In [None]:
# Example code for processing local images
'''
# If you have local image files
image_paths = [
    "path/to/cat1.jpg", "path/to/cat2.jpg", "path/to/cat3.jpg",
    "path/to/dog1.jpg", "path/to/dog2.jpg", "path/to/dog3.jpg",
    # etc.
]
image_labels = ["cat", "cat", "cat", "dog", "dog", "dog", ...]

images = [Image.open(path).convert("RGB") for path in image_paths]
'''

In [None]:
# Load a small subset of images from the CIFAR-10 dataset
dataset = load_dataset("cifar10", split="train[:100]")  # Load 100 images for demonstration

# Display some sample images from the dataset
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
axes = axes.flatten()

for i in range(10):
    img = dataset[i]["img"]
    label = dataset.features["label"].int2str(dataset[i]["label"])
    axes[i].imshow(img)
    axes[i].set_title(f"Class: {label}")
    axes[i].axis("off")

plt.tight_layout()
plt.show()

# Extract images and labels
# Handle images properly - CIFAR-10 images are numpy arrays
images = []
image_labels = []
for img in dataset:
    # CIFAR-10 returns numpy arrays in the 'img' field
    if isinstance(img["img"], np.ndarray):
        images.append(Image.fromarray(img["img"]))
    else:
        # If it's already a PIL Image, use it directly
        images.append(img["img"])
    image_labels.append(dataset.features["label"].int2str(img["label"]))

### Generating Image Embeddings

Now, let's use our Vision Transformer model to generate embeddings for each image.

In [None]:
# Get image embeddings
def get_image_embeddings(images, model, feature_extractor):
    # Initialize an empty list to store embeddings
    embeddings = []
    
    # Set the model to evaluation mode
    model.eval()
    
    # Process each image
    with torch.no_grad():  # No need to calculate gradients
        for image in images:
            # Preprocess the image
            inputs = feature_extractor(images=image, return_tensors="pt")
            
            # Forward pass through the model
            outputs = model(**inputs)
            
            # Get the embedding (we'll use the [CLS] token embedding)
            embedding = outputs.last_hidden_state[:, 0, :].numpy()
            embeddings.append(embedding[0])
    
    # Convert list to numpy array
    return np.array(embeddings)

# Generate embeddings for our images
image_embeddings = get_image_embeddings(images, vit_model, feature_extractor)

print(f"Generated embeddings with shape: {image_embeddings.shape}")

### Understanding Vision Transformer Embeddings

The Vision Transformer (ViT) represents images by:

1. Splitting the image into fixed-size patches (e.g., 16x16 pixels)
2. Linearly embedding each patch
3. Adding position embeddings
4. Passing these embeddings through transformer encoder blocks

The output embedding we extracted (the [CLS] token) serves as a global representation of the entire image, capturing both local features and their relationships. These embeddings are typically high-dimensional (768 dimensions for the base model we're using).

## 4. Dimensionality Reduction for Visualization

Let's apply dimensionality reduction techniques to visualize our image embeddings in 2D space.

In [None]:
# Apply dimensionality reduction techniques
# PCA
pca = PCA(n_components=2, random_state=42)
image_embeddings_pca = pca.fit_transform(image_embeddings)

# t-SNE
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
image_embeddings_tsne = tsne.fit_transform(image_embeddings)

# UMAP
umap_reducer = UMAP(n_components=2, n_neighbors=15, min_dist=0.1, random_state=42)
image_embeddings_umap = umap_reducer.fit_transform(image_embeddings)

In [None]:
# Create a function to visualize image embeddings
def plot_image_embeddings(embeddings, labels, title):
    # Create a DataFrame for easier plotting
    df = pd.DataFrame({
        'x': embeddings[:, 0],
        'y': embeddings[:, 1],
        'label': labels
    })
    
    # Create the plot
    plt.figure(figsize=(12, 10))
    
    # Get unique labels
    unique_labels = sorted(df['label'].unique())
    
    # Create a color map
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
    color_map = {label: colors[i] for i, label in enumerate(unique_labels)}
    
    # Create a scatter plot
    for label in unique_labels:
        label_data = df[df['label'] == label]
        plt.scatter(label_data['x'], label_data['y'], 
                    c=[color_map[label]], label=label, alpha=0.7, s=100)
    
    # Add title and legend
    plt.title(title, fontsize=16)
    plt.legend(fontsize=12, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize the image embeddings with PCA
plot_image_embeddings(image_embeddings_pca, image_labels, 
                     "Image Embeddings Visualization using PCA")

In [None]:
# Visualize the image embeddings with t-SNE
plot_image_embeddings(image_embeddings_tsne, image_labels, 
                       "Image Embeddings Visualization using t-SNE")

In [None]:
# Visualize the image embeddings with UMAP
plot_image_embeddings(image_embeddings_umap, image_labels, 
                       "Image Embeddings Visualization using UMAP")

## 5. Visualizing Images in Latent Space

Let's create a more informative visualization by showing actual thumbnail images at their embedding positions. This helps us to better understand what kinds of images are grouped together.

In [None]:
def plot_images_in_latent_space(embeddings, images, labels, title, sample_size=50):
    # Sample a subset of images if we have too many
    if len(images) > sample_size:
        indices = np.random.choice(len(images), sample_size, replace=False)
        sampled_embeddings = embeddings[indices]
        sampled_images = [images[i] for i in indices]
        sampled_labels = [labels[i] for i in indices]
    else:
        sampled_embeddings = embeddings
        sampled_images = images
        sampled_labels = labels
    
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(16, 14))
    
    # Plot each image at its embedding position
    for i, (x, y, img, label) in enumerate(zip(sampled_embeddings[:, 0], 
                                              sampled_embeddings[:, 1], 
                                              sampled_images,
                                              sampled_labels)):
        # Convert PIL image to numpy array if it's not already
        if isinstance(img, Image.Image):
            img_array = np.array(img)
        else:
            img_array = img
            
        # Create an OffsetImage of the image
        img_box = offsetbox.OffsetImage(img_array, zoom=2)
        ab = offsetbox.AnnotationBbox(img_box, (x, y), frameon=True, 
                                     pad=0.2, bboxprops=dict(edgecolor=plt.cm.tab10(hash(label) % 10)))
        ax.add_artist(ab)
    
    # Add scatter points for legend (invisible, just for the legend)
    unique_labels = sorted(set(sampled_labels))
    for i, label in enumerate(unique_labels):
        mask = np.array(sampled_labels) == label
        ax.scatter(sampled_embeddings[mask, 0], sampled_embeddings[mask, 1], 
                   c=[plt.cm.tab10(i % 10)], label=label, alpha=0)
    
    # Set title and legend
    ax.set_title(title, fontsize=16)
    ax.legend(fontsize=12, bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Remove ticks since the actual values aren't meaningful
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Add a grid for visual reference
    ax.grid(True, linestyle='--', alpha=0.6)
    
    plt.tight_layout()
    plt.show()

In [None]:
# This assumes we have the original images as numpy arrays
# Convert all our images to numpy arrays if they're PIL images
image_arrays = [np.array(img) for img in images]

# Plot images in t-SNE space (since t-SNE often gives the best visual clustering)
plot_images_in_latent_space(image_embeddings_tsne, image_arrays, image_labels, 
                            "Images in Latent Space (t-SNE)", sample_size=50)

## 6. Analyzing Clusters and Relationships

Let's analyze what kinds of relationships the model has learned by examining which images are grouped together in the latent space.

In [None]:
# Calculate and visualize the average embedding for each class
def plot_class_centroids(embeddings, labels, title):
    # Create a DataFrame for the embedding data
    df = pd.DataFrame({
        'x': embeddings[:, 0],
        'y': embeddings[:, 1],
        'label': labels
    })
    
    # Calculate centroids for each class
    centroids = df.groupby('label').mean().reset_index()
    
    # Plot the embeddings with class centroids
    plt.figure(figsize=(14, 12))
    
    # Get unique labels
    unique_labels = sorted(df['label'].unique())
    
    # Create a color map
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
    color_map = {label: colors[i] for i, label in enumerate(unique_labels)}
    
    # Plot individual points
    for label in unique_labels:
        label_data = df[df['label'] == label]
        plt.scatter(label_data['x'], label_data['y'], 
                    c=[color_map[label]], label=label, alpha=0.3, s=50)
    
    # Plot centroids with labels
    for i, row in centroids.iterrows():
        plt.scatter(row['x'], row['y'], c=[color_map[row['label']]], 
                   s=300, edgecolors='black', linewidths=2, alpha=1.0)
        plt.annotate(row['label'], (row['x'], row['y']), 
                     fontsize=14, fontweight='bold', ha='center', va='center')
    
    # Add title and legend
    plt.title(title, fontsize=16)
    plt.legend(fontsize=12, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()
    
    return centroids

# Plot the class centroids
class_centroids = plot_class_centroids(image_embeddings_tsne, image_labels, 
                                     "Class Centroids in t-SNE Space")

In [None]:
# Calculate distances between centroids to understand class relationships
def calculate_centroid_distances(centroids):
    # Extract coordinates and labels
    labels = centroids['label'].values
    points = centroids[['x', 'y']].values
    
    # Calculate pairwise distances
    n = len(points)
    distances = np.zeros((n, n))
    
    for i in range(n):
        for j in range(n):
            distances[i, j] = np.sqrt(np.sum((points[i] - points[j])**2))
    
    # Create a DataFrame for the distance matrix
    distance_df = pd.DataFrame(distances, index=labels, columns=labels)
    
    return distance_df

# Calculate and display centroid distances
centroid_distances = calculate_centroid_distances(class_centroids)
print("Distances between class centroids:")
display(centroid_distances)

# Visualize the distance matrix as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(centroid_distances, annot=True, cmap='viridis', fmt='.2f')
plt.title('Distance Between Class Centroids in Latent Space', fontsize=16)
plt.tight_layout()
plt.show()

## 7. Dimensionality Reduction Techniques Compared

Let's discuss the different dimensionality reduction techniques we've used:

### 7.1 Principal Component Analysis (PCA)

**Advantages:**
- Linear and deterministic (same results on multiple runs)
- Preserves global structure and directions of maximum variance
- Computationally efficient, even for larger datasets
- Can explain how much variance each dimension captures

**Disadvantages:**
- Cannot capture non-linear relationships in the data
- May not preserve local structure well
- Performance degrades in very high dimensions with non-linear manifolds

**Best used when:**
- You want to understand global variance directions
- Data has approximately linear relationships
- You need deterministic, reproducible results
- You're working with very large datasets where computational efficiency matters

### 7.2 t-Distributed Stochastic Neighbor Embedding (t-SNE)

**Advantages:**
- Excellent at preserving local structure and finding clusters
- Can reveal patterns hidden in high-dimensional space
- Handles non-linear relationships well
- Works well for visualizing natural clusters in data

**Disadvantages:**
- Stochastic (different results on multiple runs)
- Does not preserve global structure well
- Can be computationally expensive
- Sensitive to hyperparameters (especially perplexity)
- Not good for downstream tasks beyond visualization

**Best used when:**
- You want to visualize clusters and local neighborhoods
- Global distances between separated clusters are less important
- You want to explore the data without strict reproducibility requirements

### 7.3 Uniform Manifold Approximation and Projection (UMAP)

**Advantages:**
- Preserves both local and global structure better than t-SNE
- Faster than t-SNE, especially for larger datasets
- More stable across multiple runs than t-SNE
- Can be used for dimensionality reduction as a preprocessing step, not just visualization
- Supports supervised dimension reduction

**Disadvantages:**
- Still somewhat stochastic (though more stable than t-SNE)
- Has multiple hyperparameters that need tuning
- Theoretical foundations more complex than PCA or t-SNE

**Best used when:**
- You want a balance between preserving local and global structure
- You need faster performance than t-SNE for larger datasets
- You plan to use the reduced dimensions for downstream tasks

## 8. Conclusion and Best Practices

### Key takeaways for visualizing image embeddings in latent space:

1. **Understand your goal:**
   - For cluster analysis and pattern discovery, t-SNE or UMAP often work best
   - For understanding variance directions, PCA is more appropriate
   - For balancing local and global structure, UMAP is a good choice

2. **Be aware of limitations:**
   - Dimensionality reduction always loses information
   - Different algorithms preserve different aspects of the data
   - Random initializations can affect results (especially for t-SNE)

3. **Tips for better visualizations:**
   - Try multiple dimensionality reduction techniques and compare results
   - Experiment with hyperparameters (perplexity for t-SNE, n_neighbors for UMAP)
   - Use appropriate color coding and labels
   - Consider using thumbnail images to directly see patterns
   - For large datasets, consider subsampling or using incremental techniques

4. **Interpreting image embedding visualizations:**
   - Proximity in the visualization generally means visual similarity
   - Clusters often represent visually similar objects or scenes
   - Directions in the latent space may correspond to visual attributes (colors, shapes, textures)
   - Outliers could be unusual images or potential errors in the dataset

5. **Beyond visualization:**
   - Image embeddings can be used for image retrieval, classification, and anomaly detection
   - Distance metrics in latent space can quantify visual similarity
   - Clustering in latent space can identify natural image groupings
   - Latent space manipulations can enable image editing and style transfer

## 9. Additional Exploration Ideas

Here are some additional ideas to explore image embeddings further:

1. **Compare different models:** Try different vision architectures (ViT vs. ResNet vs. EfficientNet) and see how their latent spaces differ

2. **Layer-wise analysis:** Extract embeddings from different layers of a model to see how representations evolve

3. **Fine-tuning effects:** Compare latent spaces before and after fine-tuning on a specific task

4. **Multimodal embeddings:** Explore joint text-image embedding spaces using models like CLIP

5. **Image interpolation:** Interpolate between two embeddings and decode back to images to see the transition

6. **Interactive visualization:** Use tools like Tensorboard Projector for interactive image embedding exploration

7. **Attention visualization:** Combine latent space visualization with attention maps for deeper insights

8. **Style transfer and image editing:** Use latent space manipulations to edit image features or transfer styles