# Geospatial Embedding Datasets with TorchGeo

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/torchgeo_embeddings.ipynb)

## Overview

[TorchGeo v0.9.0](https://github.com/torchgeo/torchgeo/releases/tag/v0.9.0) introduces **Earth Embeddings** — pre-computed representations from geospatial foundation models that encode satellite imagery into compact vector representations. These embeddings enable rapid analysis without requiring GPU compute for running foundation models.

This notebook demonstrates how to use the `geoai` embeddings module to:

1. **Browse** available embedding datasets
2. **Load** patch-based embedding datasets (Clay Foundation Model)
3. **Visualize** high-dimensional embeddings using PCA
4. **Cluster** embeddings to discover spatial patterns
5. **Search** for similar locations using cosine similarity
6. **Classify** land use types using lightweight classifiers on embeddings

### Embedding Dataset Types

| Type | Format | Examples | Use Case |
| :--- | :----- | :------- | :------- |
| **Patch-based** | GeoParquet | Clay, Major TOM, Earth Index | Global-scale analysis, classification |
| **Pixel-based** | GeoTIFF | Google Satellite, Tessera, Presto | High-resolution mapping, change detection |

## Install Package

Uncomment the command below if needed.

In [None]:
# %pip install geoai-py scikit-learn

## Import Libraries

In [None]:
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from huggingface_hub import HfApi, hf_hub_download

import geoai

## 1. Browse Available Embedding Datasets

The `geoai` package provides a registry of all embedding datasets available in TorchGeo v0.9.0. Use `list_embedding_datasets()` to see what's available.

In [None]:
# List all embedding datasets
df = geoai.list_embedding_datasets(verbose=False)
df

In [None]:
# Filter by type: patch-based datasets
geoai.list_embedding_datasets(kind="patch", verbose=False)

In [None]:
# Filter by type: pixel-based datasets
geoai.list_embedding_datasets(kind="pixel", verbose=False)

In [None]:
# Get detailed info about a specific dataset
info = geoai.get_embedding_info("google_satellite")
for key, value in info.items():
    print(f"{key}: {value}")

## 2. Download Clay Embeddings (SF Bay Area)

We'll use Clay Foundation Model embeddings for the San Francisco Bay Area from [HuggingFace](https://huggingface.co/datasets/made-with-clay/classify-embeddings-sf-baseball-marinas). This dataset contains 768-dimensional embeddings computed from NAIP aerial imagery across 20 tiles, along with labeled locations for baseball fields (class 0) and marinas (class 1).

### Download all embedding tiles

In [None]:
repo_id = "made-with-clay/classify-embeddings-sf-baseball-marinas"

# List all embedding GeoParquet files
api = HfApi()
embedding_files = [
    f.path
    for f in api.list_repo_tree(repo_id, repo_type="dataset")
    if f.path.endswith(".gpq")
]
print(f"Found {len(embedding_files)} embedding tiles")

In [None]:
# Download all tiles and concatenate into a single GeoDataFrame
all_gdfs = []
for f in embedding_files:
    path = hf_hub_download(repo_id, f, repo_type="dataset")
    gdf = gpd.read_parquet(path)
    all_gdfs.append(gdf)

embeddings_gdf = pd.concat(all_gdfs, ignore_index=True)
embeddings_gdf = gpd.GeoDataFrame(
    embeddings_gdf, geometry="geometry", crs=all_gdfs[0].crs
)
print(f"Combined: {len(embeddings_gdf)} patches")
print(f"Bounds: {embeddings_gdf.total_bounds}")
print(f"Embedding dimension: {len(embeddings_gdf.iloc[0]['embeddings'])}")

In [None]:
# Download the labeled locations (baseball fields and marinas)
labels_file = hf_hub_download(repo_id, "baseball.geojson", repo_type="dataset")
labels_gdf = gpd.read_file(labels_file)
print(f"Labeled locations: {len(labels_gdf)}")
print(f"Class distribution:")
print(labels_gdf["class"].value_counts())

### Extract embedding vectors

Convert the embedding column to a NumPy array and extract coordinates for analysis.

In [None]:
# Extract embeddings, coordinates from the GeoParquet
embeddings = np.stack(embeddings_gdf["embeddings"].values)
centroids = embeddings_gdf.geometry.centroid
coords_x = centroids.x.values
coords_y = centroids.y.values

print(f"Embeddings shape: {embeddings.shape}")
print(f"X range: [{coords_x.min():.4f}, {coords_x.max():.4f}]")
print(f"Y range: [{coords_y.min():.4f}, {coords_y.max():.4f}]")

## 3. Visualize Embeddings

### Plot individual embedding vectors

In [None]:
# Plot a few embedding vectors to see their patterns
fig, axes = plt.subplots(1, 3, figsize=(15, 3))
for i, ax in enumerate(axes):
    idx = i * (len(embeddings) // 3)
    ax.plot(embeddings[idx], linewidth=0.5)
    ax.set_title(f"Patch {idx} ({coords_y[idx]:.3f}°N, {coords_x[idx]:.3f}°W)")
    ax.set_xlabel("Dimension")
    ax.set_ylabel("Value")
plt.tight_layout()
plt.show()

### PCA projection of all embeddings

In [None]:
# Visualize the embedding space using PCA
fig = geoai.visualize_embeddings(
    embeddings,
    method="pca",
    figsize=(8, 8),
    s=3,
    alpha=0.4,
    title="PCA of Clay Embeddings (SF Bay Area)",
)
plt.show()

## 4. Cluster Embeddings

Use unsupervised clustering to discover patterns in the embeddings without any labels.

In [None]:
# Cluster the embeddings into groups
result = geoai.cluster_embeddings(embeddings, n_clusters=8, method="kmeans")
cluster_labels = result["labels"]
print(f"Number of clusters: {result['n_clusters']}")
print(f"Cluster sizes: {np.bincount(cluster_labels)}")

In [None]:
# Visualize clusters in PCA space
fig = geoai.visualize_embeddings(
    embeddings,
    labels=cluster_labels,
    method="pca",
    figsize=(10, 8),
    s=5,
    alpha=0.5,
    title="K-Means Clusters of Clay Embeddings",
)
plt.show()

In [None]:
# Map clusters geographically
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(
    coords_x,
    coords_y,
    c=cluster_labels,
    cmap="tab10",
    s=3,
    alpha=0.6,
)
plt.colorbar(scatter, ax=ax, label="Cluster")
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.set_title("Geographic Distribution of Embedding Clusters")
ax.set_aspect("equal")
plt.tight_layout()
plt.show()

## 5. Similarity Search

Find the most similar locations to a query embedding using cosine similarity.

In [None]:
# Pick a query embedding (first patch)
query_idx = 0
query = embeddings[query_idx]
print(f"Query location: ({coords_y[query_idx]:.4f}°N, {coords_x[query_idx]:.4f}°W)")

# Find top-10 most similar locations
results = geoai.embedding_similarity(
    query=query, embeddings=embeddings, metric="cosine", top_k=10
)

print("\nTop 10 most similar locations:")
for rank, (idx, score) in enumerate(
    zip(results["indices"], results["scores"]), start=1
):
    print(
        f"  {rank}. Index {idx}: similarity={score:.4f}, "
        f"location=({coords_y[idx]:.4f}°N, {coords_x[idx]:.4f}°W)"
    )

In [None]:
# Visualize the query and its nearest neighbors on a map
fig, ax = plt.subplots(figsize=(10, 8))

# Background: all embeddings in gray
ax.scatter(coords_x, coords_y, c="lightgray", s=1, alpha=0.3)

# Highlight nearest neighbors
nn_indices = results["indices"]
ax.scatter(
    coords_x[nn_indices],
    coords_y[nn_indices],
    c="blue",
    s=50,
    marker="o",
    label="Nearest Neighbors",
    edgecolors="black",
    linewidths=0.5,
)

# Highlight the query point
ax.scatter(
    coords_x[query_idx],
    coords_y[query_idx],
    c="red",
    s=100,
    marker="*",
    label="Query",
    edgecolors="black",
    linewidths=0.5,
    zorder=5,
)

ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.set_title("Similarity Search: Query and Nearest Neighbors")
ax.legend()
ax.set_aspect("equal")
plt.tight_layout()
plt.show()

## 6. Classification with Embeddings

Train a lightweight k-NN classifier on the Clay embeddings using labeled data. The dataset includes labeled locations for baseball fields (class 0) and marinas (class 1) in the San Francisco Bay Area.

### Prepare training data

We match labeled points to their nearest embedding patches using a spatial join.

In [None]:
# Ensure both GeoDataFrames use the same CRS
if labels_gdf.crs != embeddings_gdf.crs:
    labels_gdf = labels_gdf.to_crs(embeddings_gdf.crs)

# Spatial join: find which embedding patch each labeled point falls within
joined = gpd.sjoin(labels_gdf, embeddings_gdf, how="inner", predicate="within")
print(f"Matched {len(joined)} labeled points to embedding patches")
print(f"Class distribution: {joined['class'].value_counts().to_dict()}")

In [None]:
# Extract embeddings and labels for matched points
labeled_embeddings = np.stack(
    [embeddings_gdf.iloc[idx]["embeddings"] for idx in joined["index_right"]]
)
class_labels = joined["class"].values

print(f"Labeled embeddings shape: {labeled_embeddings.shape}")
print(f"Labels shape: {class_labels.shape}")

In [None]:
# Split into train/validation sets
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(
    labeled_embeddings,
    class_labels,
    test_size=0.3,
    random_state=42,
    stratify=class_labels,
)
print(f"Train: {X_train.shape[0]} samples")
print(f"Val:   {X_val.shape[0]} samples")

### Train a k-NN classifier

In [None]:
label_names = ["Baseball Field", "Marina"]

# Train using geoai's convenience function
result = geoai.train_embedding_classifier(
    train_embeddings=X_train,
    train_labels=y_train,
    val_embeddings=X_val,
    val_labels=y_val,
    method="knn",
    n_neighbors=5,
    label_names=label_names,
)

print(f"\nTrain accuracy: {result['train_accuracy']:.2%}")
print(f"Val accuracy:   {result['val_accuracy']:.2%}")

### Compare different classifiers

In [None]:
# Try different classifiers
methods = ["knn", "random_forest", "logistic_regression"]
results_summary = []

for method in methods:
    res = geoai.train_embedding_classifier(
        train_embeddings=X_train,
        train_labels=y_train,
        val_embeddings=X_val,
        val_labels=y_val,
        method=method,
        label_names=label_names,
        verbose=False,
    )
    results_summary.append(
        {
            "Method": method,
            "Train Acc": f"{res['train_accuracy']:.2%}",
            "Val Acc": f"{res['val_accuracy']:.2%}",
        }
    )

pd.DataFrame(results_summary)

### Visualize classified embeddings

In [None]:
# Visualize labeled embeddings in PCA space
fig = geoai.visualize_embeddings(
    labeled_embeddings,
    labels=class_labels,
    label_names=label_names,
    method="pca",
    figsize=(8, 8),
    s=30,
    alpha=0.8,
    title="PCA of Labeled Embeddings (Baseball vs Marina)",
)
plt.show()

## 7. Comparing Embeddings for Change Detection

Embedding vectors from different time periods can be compared to detect change. The `compare_embeddings` function computes element-wise similarity between two sets of embeddings.

Here we demonstrate the concept by comparing embeddings from different spatial patches.

In [None]:
# Compare first half vs second half of patches to simulate temporal comparison
n = len(embeddings)
half = n // 2
emb_a = embeddings[:half]
emb_b = embeddings[half : half + half]  # same number of samples

similarity = geoai.compare_embeddings(emb_a, emb_b, metric="cosine")

fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(similarity, bins=50, edgecolor="black", alpha=0.7)
ax.axvline(
    similarity.mean(),
    color="red",
    linestyle="--",
    label=f"Mean: {similarity.mean():.3f}",
)
ax.set_xlabel("Cosine Similarity")
ax.set_ylabel("Count")
ax.set_title("Embedding Similarity Distribution")
ax.legend()
plt.tight_layout()
plt.show()

## 8. Using TorchGeo Dataset Classes Directly

For more advanced usage, you can use the TorchGeo dataset classes directly through `geoai.load_embedding_dataset()`. This gives you access to all the TorchGeo features like transforms, sampling, and plotting.

Note: The TorchGeo `ClayEmbeddings` class expects a `date` or `datetime` column in the GeoParquet file. Some community-contributed embedding files may not include this column. In such cases, load the data with geopandas directly (as shown above).

In [None]:
# Load using geoai's unified interface
single_file = hf_hub_download(repo_id, embedding_files[0], repo_type="dataset")
ds = geoai.load_embedding_dataset("clay", root=single_file)

print(f"Dataset length: {len(ds)}")
print(f"Dataset type: {type(ds).__name__}")

In [None]:
# Access a sample - may fail if the file lacks a 'datetime' column
try:
    sample = ds[0]
    print(f"Sample keys: {list(sample.keys())}")
    print(f"Embedding shape: {sample['embedding'].shape}")
    print(f"Location: ({sample['y'].item():.4f}°N, {sample['x'].item():.4f}°W)")

    fig = ds.plot(sample)
    plt.show()
except KeyError as e:
    print(
        f"Note: This parquet file is missing the '{e.args[0]}' column "
        f"expected by TorchGeo's ClayEmbeddings class."
    )
    print("For such files, use geopandas directly (as shown above).")
    print("The TorchGeo class works best with official Clay data products.")

## Summary

In this notebook, we demonstrated the `geoai` embeddings module which provides a unified interface to TorchGeo v0.9.0's embedding datasets. Key takeaways:

- **9 embedding datasets** are available, spanning patch-based and pixel-based formats
- **No GPU required** for analysis — embeddings are pre-computed
- **Lightweight classifiers** (k-NN, Random Forest) work well on embeddings
- **Unsupervised clustering** reveals spatial patterns without labels
- **Similarity search** enables content-based spatial retrieval
- **Change detection** is possible by comparing embeddings across time periods

For pixel-based datasets (Google Satellite Embedding, Tessera, etc.), download GeoTIFF files and use `geoai.load_embedding_dataset()` with the `paths` parameter.