# Google Satellite Embedding with TorchGeo

[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/google_satellite_embedding.ipynb)

## Overview

The [AlphaEarth Foundations Satellite Embedding](https://developers.google.com/earth-engine/datasets/catalog/GOOGLE_SATELLITE_EMBEDDING_V1_ANNUAL) dataset, produced by Google and Google DeepMind, provides pre-computed 64-dimensional embedding vectors at 10-meter resolution. Each pixel encodes information from optical, radar, LiDAR, and other Earth observation sources into a unit-length vector.

Key characteristics:
- **Resolution**: 10 m per pixel
- **Dimensions**: 64 (bands A00-A63)
- **Temporal**: Annual composites from 2018-2024
- **Coverage**: Global
- **License**: CC-BY-4.0

This notebook demonstrates how to:
1. **Download** embedding data from [Source Cooperative](https://source.coop/tge-labs/aef)
2. **Load** with TorchGeo's `GoogleSatelliteEmbedding` via `geoai`
3. **Visualize** embeddings as PCA-based RGB images
4. **Cluster** embeddings to discover land cover patterns
5. **Search** for similar pixels using cosine similarity
6. **Detect change** by comparing embeddings across years

References:
- [Paper](https://arxiv.org/abs/2507.22291)
- [Blog post](https://medium.com/google-earth/ai-powered-pixels-introducing-googles-satellite-embedding-dataset-31744c1f4650)
- [Data source](https://source.coop/tge-labs/aef)

## Install Package

Uncomment the command below if needed.

In [None]:
# %pip install geoai-py scikit-learn

## Import Libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import rasterio

import geoai

## Dataset Info

View metadata about the Google Satellite Embedding dataset from the geoai registry.

In [None]:
info = geoai.get_embedding_info("google_satellite")
for key, value in info.items():
    print(f"{key}: {value}")

## Download Embedding Data

Download embeddings for a small region near Paradise, CA for two years (2018 and 2024). The 2018 Camp Fire destroyed much of this area, so comparing pre-fire (2018) and post-rebuilding (2024) embeddings should reveal significant change.

The download function fetches data from [Source Cooperative](https://source.coop/tge-labs/aef) using windowed reads from Cloud-Optimized GeoTIFFs, so only the requested region is transferred.

In [None]:
bbox = (-121.65, 39.73, -121.55, 39.80)
output_dir = "aef_data"

files = geoai.download_google_satellite_embedding(
    bbox=bbox,
    output_dir=output_dir,
    years=[2018, 2024],
    crs=None,
)
print(f"Downloaded {len(files)} file(s): {files}")

## Inspect Downloaded Data

Check the properties of the downloaded GeoTIFF files.

In [None]:
for f in files:
    with rasterio.open(f) as src:
        print(f"File: {f}")
        print(f"  Shape: {src.count} bands x {src.height}H x {src.width}W")
        print(f"  CRS: {src.crs}")
        print(f"  Bounds: {src.bounds}")
        print(f"  Resolution: {src.res}")
        print(f"  Dtype: {src.dtypes[0]}")
        print()

## Load with TorchGeo

Load the downloaded data using TorchGeo's `GoogleSatelliteEmbedding` dataset class via `geoai.load_embedding_dataset()`.

In [None]:
ds = geoai.load_embedding_dataset("google_satellite", paths=output_dir)
print(f"Dataset type: {type(ds).__name__}")
print(f"CRS: {ds.crs}")
print(f"Resolution: {ds.res}")
print(f"Bounds: {ds.bounds}")

## Extract Pixel Embeddings

Use `extract_pixel_embeddings()` to sample patches from the dataset and flatten the pixels into an `(N, 64)` array suitable for analysis.

In [None]:
data = geoai.extract_pixel_embeddings(ds, num_samples=20, size=256, flatten=True)
embeddings = data["embeddings"]
print(f"Embeddings shape: {embeddings.shape}")
print(f"Value range: [{embeddings.min():.4f}, {embeddings.max():.4f}]")

## Visualize Embeddings as RGB

Use PCA to project the 64-band embedding raster into 3 principal components for RGB visualization. Each sample patch is visualized individually.

In [None]:
# Get a single sample for visualization
from torchgeo.samplers import RandomGeoSampler

sampler = RandomGeoSampler(ds, size=512, length=1)
query = next(iter(sampler))
sample = ds[query]
print(f"Sample image shape: {sample['image'].shape}")

fig = geoai.plot_embedding_raster(
    sample["image"],
    title="Google Satellite Embedding (PCA RGB)",
)
plt.show()

## Interactive Map: PCA RGB

Save PCA-projected embeddings as a 3-band GeoTIFF and display on an interactive map with a satellite basemap.

In [None]:
import os
import leafmap
from sklearn.decomposition import PCA

# Save derived products to a separate directory outside of output_dir
# to avoid interfering with TorchGeo's recursive directory scanning
viz_dir = "aef_viz"
os.makedirs(viz_dir, exist_ok=True)

# Create PCA RGB GeoTIFFs for both years
pca = PCA(n_components=3)
pca_files = []

for f in files:
    with rasterio.open(f) as src:
        data = src.read()  # (64, H, W)
        h, w = data.shape[1], data.shape[2]
        pixels = data.reshape(64, -1).T  # (H*W, 64)
        mask = ~np.isnan(pixels).any(axis=1)

        rgb = np.zeros((pixels.shape[0], 3))
        if mask.any():
            rgb[mask] = pca.fit_transform(pixels[mask])
            rgb -= rgb[mask].min(axis=0)
            maxv = rgb[mask].max(axis=0)
            maxv[maxv == 0] = 1
            rgb /= maxv
        rgb = rgb.reshape(h, w, 3).clip(0, 1)

        # Save as 3-band uint8 GeoTIFF
        basename = os.path.basename(f).replace(".tif", "_pca_rgb.tif")
        pca_output = os.path.join(viz_dir, basename)
        with rasterio.open(
            pca_output,
            "w",
            driver="GTiff",
            height=h,
            width=w,
            count=3,
            dtype="uint8",
            crs=src.crs,
            transform=src.transform,
            compress="lzw",
        ) as dst:
            for b in range(3):
                dst.write((rgb[:, :, b] * 255).astype(np.uint8), b + 1)
        pca_files.append(pca_output)

print(f"PCA RGB files: {pca_files}")

In [None]:
# Display PCA RGB on an interactive map
m = leafmap.Map()
m.add_basemap("Esri.WorldImagery")
m.add_raster(pca_files[0], layer_name="Embeddings 2018 (PCA RGB)")
m.add_raster(pca_files[1], layer_name="Embeddings 2024 (PCA RGB)")
m

## Cluster Embeddings

Use K-Means clustering to discover spatial patterns in the embeddings without any labels. Each cluster may correspond to a distinct land cover type.

In [None]:
# Remove NaN pixels before clustering
valid_mask = ~np.isnan(embeddings).any(axis=1)
valid_embeddings = embeddings[valid_mask]
print(f"Valid pixels: {valid_embeddings.shape[0]} / {embeddings.shape[0]}")

result = geoai.cluster_embeddings(valid_embeddings, n_clusters=8, method="kmeans")
cluster_labels = result["labels"]
print(f"Number of clusters: {result['n_clusters']}")
print(f"Cluster sizes: {np.bincount(cluster_labels)}")

In [None]:
# Visualize clusters in PCA space
fig = geoai.visualize_embeddings(
    valid_embeddings,
    labels=cluster_labels,
    method="pca",
    figsize=(10, 8),
    s=1,
    alpha=0.3,
    title="K-Means Clusters of Satellite Embeddings (PCA)",
)
plt.show()

In [None]:
# Visualize clusters as a spatial map from a single patch
sampler = RandomGeoSampler(ds, size=512, length=1)
query = next(iter(sampler))
sample = ds[query]
image = sample["image"].numpy()  # (64, H, W)

c, h, w = image.shape
pixels = image.reshape(c, -1).T  # (H*W, 64)

# Remove NaN pixels
pixel_valid = ~np.isnan(pixels).any(axis=1)
valid_px = pixels[pixel_valid]

# Predict clusters using the fitted model
pred_labels = result["model"].predict(valid_px)

# Reconstruct spatial map
cluster_map = np.full(h * w, -1, dtype=int)
cluster_map[pixel_valid] = pred_labels
cluster_map = cluster_map.reshape(h, w)

fig, ax = plt.subplots(figsize=(8, 8))
im = ax.imshow(cluster_map, cmap="tab10", interpolation="nearest")
ax.set_title("Embedding Cluster Map")
ax.axis("off")
plt.colorbar(im, ax=ax, shrink=0.7, label="Cluster")
plt.tight_layout()
plt.show()

## Interactive Map: Cluster Map

Save the cluster map as a georeferenced GeoTIFF and display it on an interactive map.

In [None]:
# Save cluster map for the full 2024 embedding as a GeoTIFF
with rasterio.open(files[1]) as src:
    emb_data = src.read()  # (64, H, W)
    emb_h, emb_w = emb_data.shape[1], emb_data.shape[2]
    # Match dtype to the KMeans model's cluster centers
    target_dtype = result["model"].cluster_centers_.dtype
    emb_pixels = emb_data.reshape(64, -1).T.astype(target_dtype)

    emb_valid = ~np.isnan(emb_pixels).any(axis=1)
    emb_valid_px = emb_pixels[emb_valid]

    # Predict clusters
    full_labels = result["model"].predict(emb_valid_px)

    # Reconstruct spatial map (use 255 as nodata for uint8)
    full_cluster_map = np.full(emb_h * emb_w, 255, dtype=np.uint8)
    full_cluster_map[emb_valid] = full_labels.astype(np.uint8)
    full_cluster_map = full_cluster_map.reshape(emb_h, emb_w)

    cluster_output = os.path.join(viz_dir, "cluster_map_2024.tif")
    with rasterio.open(
        cluster_output,
        "w",
        driver="GTiff",
        height=emb_h,
        width=emb_w,
        count=1,
        dtype="uint8",
        crs=src.crs,
        transform=src.transform,
        compress="lzw",
        nodata=255,
    ) as dst:
        dst.write(full_cluster_map, 1)

print(f"Cluster map saved to {cluster_output}")

In [None]:
# Display cluster map on an interactive map
m = leafmap.Map()
m.add_basemap("Esri.WorldImagery")
m.add_raster(
    cluster_output, cmap="tab10", nodata=255, opacity=0.7, layer_name="Clusters"
)
m

## Similarity Search

Find pixels with the most similar embedding vectors to a query pixel using cosine similarity.

In [None]:
# Use the center pixel as a query
center_idx = len(valid_embeddings) // 2
query_embedding = valid_embeddings[center_idx]

results = geoai.embedding_similarity(
    query=query_embedding,
    embeddings=valid_embeddings,
    metric="cosine",
    top_k=10,
)

print("Top 10 most similar pixels:")
for rank, (idx, score) in enumerate(
    zip(results["indices"], results["scores"]), start=1
):
    print(f"  {rank}. Index {idx}: similarity={score:.4f}")

## Change Detection

Compare embeddings from two years to detect changes on the ground. We read a matching patch from each year and compute the cosine similarity between corresponding pixels. Low similarity values indicate change.

In [None]:
# Read the two downloaded files directly
with rasterio.open(files[0]) as src1, rasterio.open(files[1]) as src2:
    data1 = src1.read()  # (64, H, W)
    data2 = src2.read()

    # Use the smaller common extent
    min_h = min(data1.shape[1], data2.shape[1])
    min_w = min(data1.shape[2], data2.shape[2])
    data1 = data1[:, :min_h, :min_w]
    data2 = data2[:, :min_h, :min_w]

print(f"Year 1 shape: {data1.shape}")
print(f"Year 2 shape: {data2.shape}")

# Flatten to (N_pixels, 64)
emb1 = data1.reshape(64, -1).T
emb2 = data2.reshape(64, -1).T

# Remove pixels where either year has NaN
valid = ~(np.isnan(emb1).any(axis=1) | np.isnan(emb2).any(axis=1))
emb1_valid = emb1[valid]
emb2_valid = emb2[valid]
print(f"Valid pixel pairs: {emb1_valid.shape[0]}")

# Compute cosine similarity
similarity = geoai.compare_embeddings(emb1_valid, emb2_valid, metric="cosine")
print(f"Mean similarity: {similarity.mean():.4f}")
print(f"Min similarity: {similarity.min():.4f}")
print(f"Max similarity: {similarity.max():.4f}")

In [None]:
# Visualize the similarity distribution
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(similarity, bins=100, edgecolor="black", alpha=0.7, color="steelblue")
ax.axvline(
    similarity.mean(),
    color="red",
    linestyle="--",
    linewidth=2,
    label=f"Mean: {similarity.mean():.3f}",
)
ax.set_xlabel("Cosine Similarity")
ax.set_ylabel("Pixel Count")
ax.set_title("Embedding Similarity Between 2018 and 2024")
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Create a spatial change map
change_map = np.full(min_h * min_w, np.nan)
change_map[valid] = similarity
change_map = change_map.reshape(min_h, min_w)

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Year 1 PCA RGB
from sklearn.decomposition import PCA

pca = PCA(n_components=3)
for ax_idx, (data, year) in enumerate([(data1, "2018"), (data2, "2024")]):
    pixels = data.reshape(64, -1).T
    mask = ~np.isnan(pixels).any(axis=1)
    rgb = np.zeros((pixels.shape[0], 3))
    if mask.any():
        rgb[mask] = pca.fit_transform(pixels[mask])
        rgb -= rgb[mask].min(axis=0)
        maxv = rgb[mask].max(axis=0)
        maxv[maxv == 0] = 1
        rgb /= maxv
    rgb = rgb.reshape(min_h, min_w, 3).clip(0, 1)
    axes[ax_idx].imshow(rgb)
    axes[ax_idx].set_title(f"Embeddings {year} (PCA RGB)")
    axes[ax_idx].axis("off")

# Change map
im = axes[2].imshow(change_map, cmap="RdYlGn", vmin=0, vmax=1, interpolation="nearest")
axes[2].set_title("Cosine Similarity (Change Map)")
axes[2].axis("off")
plt.colorbar(im, ax=axes[2], shrink=0.7, label="Similarity")

plt.tight_layout()
plt.show()

## Save Embeddings as GeoTIFF

Export the change map or embedding data as a georeferenced GeoTIFF for use in GIS software.

In [None]:
# Save the change map as a single-band GeoTIFF
with rasterio.open(files[0]) as src:
    bounds = src.bounds
    transform = src.transform
    file_crs = src.crs

change_output = os.path.join(viz_dir, "change_map_2018_2024.tif")
with rasterio.open(
    change_output,
    "w",
    driver="GTiff",
    height=min_h,
    width=min_w,
    count=1,
    dtype="float64",
    crs=file_crs,
    transform=transform,
    compress="lzw",
    nodata=np.nan,
) as dst:
    dst.write(change_map, 1)
    dst.set_band_description(1, "cosine_similarity")

print(f"Change map saved to {change_output}")

## Interactive Map: Change Detection

Visualize the change detection results on interactive maps. Use a split map to compare PCA-projected embeddings from 2018 and 2024 side by side, and overlay the change map on satellite imagery.

In [None]:
# Split map comparing 2018 vs 2024 PCA RGB embeddings
m = leafmap.Map()
m.split_map(left_layer=pca_files[0], right_layer=pca_files[1])
m

In [None]:
# Display change map overlaid on satellite imagery
m = leafmap.Map()
m.add_basemap("Esri.WorldImagery")
m.add_raster(change_output, cmap="RdYlGn", layer_name="Change Map (Cosine Similarity)")
m

## Summary

This notebook demonstrated the end-to-end workflow for working with Google/AlphaEarth Satellite Embeddings using `geoai` and TorchGeo:

- **64-D embedding vectors** at 10 m resolution encode multi-source satellite observations
- **No GPU required** â€” embeddings are pre-computed and ready for analysis
- **Windowed download** fetches only the region of interest from Cloud-Optimized GeoTIFFs
- **Unsupervised clustering** reveals distinct land cover patterns
- **Cosine similarity** between years enables change detection
- **Interactive maps** via leafmap for exploring embeddings, clusters, and change maps
- Data is freely available from [Source Cooperative](https://source.coop/tge-labs/aef) under CC-BY-4.0