In [None]:
import ee
import json
import math
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
from tqdm import tqdm
from umap import UMAP
from scipy.spatial.distance import cdist

# ---- USER INPUTS ----
N = 16  # number of random points to sample
EMBED_PIXELS = 64  # patch width (px)
EMBED_SCALE = 20  # m/px (AlphaEarth native)
YEAR = 2024
SEED = 42

print("Initializing Google Earth Engine...")
ee.Authenticate()
ee.Initialize(project="gsapp-map")
print("✔ Earth Engine initialized\n")

# ---- LOAD INPUT COORDS ----
INPUT_COORDS = f"random_coords_{N}.json"  # list of dicts with lat/lon
print(f"Loading coordinates from {INPUT_COORDS} ...")
with open(INPUT_COORDS, "r") as f:
    coords_list = json.load(f)

gdf = pd.DataFrame(coords_list)
if not {"lat", "lon"}.issubset(gdf.columns):
    gdf.columns = ["lon", "lat"]

print(f"Loaded {len(gdf)} coordinates.")
print(gdf.head(), "\n")

# ---- CREATE GEODATAFRAME + BUFFERS ----
print("Creating GeoDataFrame and buffer geometries ...")
gdf["geometry"] = gdf.apply(lambda r: Point(r["lon"], r["lat"]), axis=1)
gdf = gpd.GeoDataFrame(gdf, geometry="geometry", crs="EPSG:4326")

# ---- LOAD DATASETS ----
print("Loading image collections ...")
alphaearth = ee.ImageCollection("GOOGLE/SATELLITE_EMBEDDING/V1/ANNUAL")
print("✔ AlphaEarth loaded (64-band embeddings)\n")


def get_embedding_patch(lat, lon, patch_size=EMBED_PIXELS):
    print(f"→ Sampling patch at lat={lat:.5f}, lon={lon:.5f}")

    pt = ee.Geometry.Point([lon, lat])
    buffer_m = patch_size * EMBED_SCALE  # total buffer in meters
    region = pt.buffer(buffer_m).bounds()  # rectangle region

    # Get mosaic of AlphaEarth for the year
    img = (
        alphaearth.filterDate(f"{YEAR}-01-01", f"{YEAR + 1}-01-01")
        .filterBounds(region)
        .mosaic()
    )

    band_names = [f"A{str(i).zfill(2)}" for i in range(64)]

    # ---- Sample a grid of points within the rectangle ----
    n_points = patch_size**2
    coords_fc = ee.FeatureCollection.randomPoints(region, n_points, seed=SEED)
    print(f"  ✔ Created {n_points} random points for sampling")

    try:
        samples = (
            img.select(band_names)
            .sampleRegions(collection=coords_fc, scale=EMBED_SCALE, geometries=False)
            .getInfo()
        )
    except Exception as e:
        print(f"  ✖ GEE request failed: {e}")
        return None

    features = samples.get("features", [])
    if len(features) != n_points:
        print(f"Warning: Only {len(features)} points returned, expected {n_points}")

    # ---- Convert to numpy array ----
    all_values = []
    for f in features:
        props = f.get("properties", {})
        vals = [props[b] for b in band_names]
        all_values.append(vals)

    arr = np.array(all_values, dtype=np.float32)  # shape (n_points, 64)
    arr = arr.reshape((patch_size, patch_size, 64))  # reshape to (H,W,B)
    print(f"  ✔ Patch shape: {arr.shape}\n")

    return arr


# ---- COLLECT EMBEDDINGS ----
print("Collecting AlphaEarth embedding patches ...\n")
embeddings = []
valid_points = []

for i, row in tqdm(gdf.iterrows(), total=len(gdf)):
    lat, lon = row["geometry"].y, row["geometry"].x
    patch = get_embedding_patch(lat, lon)
    if patch is not None:
        embeddings.append(patch)
        valid_points.append((lat, lon))
    if patch.shape != (EMBED_PIXELS, EMBED_PIXELS, 64):
        print(f"Warning: Invalid patch at index {i}.\n")

print(f"✔ Collected {len(embeddings)} valid patches.\n")

if not embeddings:
    raise RuntimeError("No valid patches collected. Check region and parameters.")

# ---- VERIFY ARRAY SHAPES ----
embeddings_arr = np.array(embeddings, dtype=np.float32)
print("Embeddings array shape:", embeddings_arr.shape)
print("Embeddings dtype:", embeddings_arr.dtype, "\n")

# ---- COMPUTE PATCH MEANS (for similarity) ----
print("Computing patch-level mean embeddings for similarity ...")
mean_embs = embeddings_arr.reshape(len(embeddings), -1, 64).mean(axis=1)
print("Mean embedding shape:", mean_embs.shape, "\n")

# ---- COMPUTE DISTANCES & UMAP ----
print("Computing pairwise distances ...")
from scipy.spatial.distance import cdist

distances = cdist(mean_embs, mean_embs, metric="euclidean")
print("Distance matrix shape:", distances.shape)

print("Running UMAP dimensionality reduction ...")
um = UMAP(n_components=2, random_state=SEED, metric="precomputed")
umap_2d = um.fit_transform(distances)
print("UMAP output shape:", umap_2d.shape, "\n")

# ---- ARRANGE INTO GRID ----
print("Arranging patches into grid ...")
n = len(umap_2d)
grid_side = int(math.ceil(math.sqrt(n)))
grid_x = np.linspace(0, 1, grid_side)
grid_y = np.linspace(0, 1, grid_side)
grid_coords = np.array(np.meshgrid(grid_x, grid_y)).T.reshape(-1, 2)[:n]

from scipy.optimize import linear_sum_assignment

cost = cdist(umap_2d, grid_coords)
r_idx, c_idx = linear_sum_assignment(cost)
ordered_embeddings = [embeddings[i] for i in r_idx]
print("✔ Grid arrangement complete.\n")

In [None]:
# # ---- ADD UMAP COORDS TO ORIGINAL GDF ----
# print("Adding UMAP x, y coordinates to original GeoDataFrame ...")
# gdf.loc[[i for i in r_idx], "umap_x"] = umap_2d[:, 0]
# gdf.loc[[i for i in r_idx], "umap_y"] = umap_2d[:, 1]
# print(gdf.head())

# # ---- ADD GRID POSITIONS ----
# print("Adding grid positions to GeoDataFrame ...")
# # Map the assigned grid_coords to integer row/col positions (0..grid_side-1)
# assigned_coords = grid_coords[c_idx]  # shape (n_samples, 2), normalized 0..1
# grid_rows = (assigned_coords[:, 1] * (grid_side - 1)).round().astype(int)
# grid_cols = (assigned_coords[:, 0] * (grid_side - 1)).round().astype(int)

# # Add to GeoDataFrame
# gdf.loc[r_idx, "grid_x"] = grid_cols
# gdf.loc[r_idx, "grid_y"] = grid_rows

# print(gdf.head())

# # Save to GeoJSON and CSV
# geojson_path = f"output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_data.geojson"
# csv_path = f"output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_data.csv"
# gdf.to_file(geojson_path, driver="GeoJSON")
# gdf.drop(columns="geometry").to_csv(csv_path, index=False)
# print(f"✔ Saved to: {geojson_path} and {csv_path}")


# ---- BUILD FILTERED GDF OF VALID POINTS ----
print("Building filtered GeoDataFrame for valid points ...")

valid_gdf = gpd.GeoDataFrame(valid_points, columns=["lat", "lon"])
valid_gdf["geometry"] = valid_gdf.apply(lambda r: Point(r["lon"], r["lat"]), axis=1)
print(f"✔ {len(valid_gdf)} valid points retained.\n")

# ---- ADD UMAP COORDINATES ----
print("Adding UMAP x, y coordinates ...")
valid_gdf["umap_x"] = umap_2d[:, 0]
valid_gdf["umap_y"] = umap_2d[:, 1]

# ---- ADD GRID POSITIONS ----
print("Adding grid positions ...")
assigned_coords = grid_coords[c_idx]
grid_rows = (assigned_coords[:, 1] * (grid_side - 1)).round().astype(int)
grid_cols = (assigned_coords[:, 0] * (grid_side - 1)).round().astype(int)
valid_gdf["grid_x"] = grid_cols
valid_gdf["grid_y"] = grid_rows

print(valid_gdf.head(), "\n")

# ---- SAVE OUTPUT ----
geojson_path = f"output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_data.geojson"
csv_path = f"output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_data.csv"
valid_gdf.to_file(geojson_path, driver="GeoJSON")
valid_gdf.drop(columns="geometry").to_csv(csv_path, index=False)
print(f"✔ Saved filtered data to: {geojson_path} and {csv_path}")

In [None]:
print(gdf.head())

In [None]:
# ---- BUILD RGB IMAGE FROM EMBEDDINGS ----
print("Creating RGB visualization from first 3 channels (A00-A02) ...")
tile_px = EMBED_PIXELS
grid_img = np.zeros((grid_side * tile_px, grid_side * tile_px, 3), dtype=np.float32)

for i, (gx, gy) in enumerate(grid_coords[c_idx]):
    row = i // grid_side
    col = i % grid_side
    patch = ordered_embeddings[i][..., :3]  # RGB from first 3 bands
    grid_img[
        row * tile_px : (row + 1) * tile_px, col * tile_px : (col + 1) * tile_px, :
    ] = patch / np.max(patch)

from PIL import Image

img_out = (np.clip(grid_img, 0, 1) * 255).astype(np.uint8)
Image.fromarray(img_out).save(
    f"output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_embedding_grid_3bands.png"
)

print("✔ Saved AlphaEarth similarity grid image.")
print(
    f"Output: output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_embedding_grid_3bands.png\n"
)

# ---- BUILD RGB IMAGE FROM EMBEDDINGS USING GLOBAL UMAP 64->3 ----
print("Creating RGB visualization from 64-band embeddings via GLOBAL UMAP ...")
tile_px = EMBED_PIXELS
grid_img = np.zeros((grid_side * tile_px, grid_side * tile_px, 3), dtype=np.float32)

from umap import UMAP

# ---- Flatten all patches to H*W x 64 ----
print("Flattening all patches for global UMAP ...")
all_pixels = []
patch_shapes = []  # keep track of (H,W) for each patch
for patch in ordered_embeddings:
    H, W, B = patch.shape
    patch_shapes.append((H, W))
    all_pixels.append(patch.reshape(-1, B))
all_pixels = np.vstack(all_pixels)  # shape = (N*H*W, 64)
print("All pixels flattened shape:", all_pixels.shape)

# ---- Run UMAP once for all pixels ----
print("Running UMAP 64->3 for global color mapping ...")
umap_rgb = UMAP(n_components=3, random_state=SEED).fit_transform(all_pixels)
print("UMAP output shape:", umap_rgb.shape)

# ---- Split back into patches ----
print("Reassembling patches ...")
rgb_patches = []
start = 0
for H, W in patch_shapes:
    size = H * W
    patch_rgb = umap_rgb[start : start + size].reshape(H, W, 3)
    rgb_patches.append(patch_rgb)
    start += size

# ---- Normalize globally to 0..1 ----
print("Normalizing RGB values globally ...")
rgb_min = min(p.min() for p in rgb_patches)
rgb_max = max(p.max() for p in rgb_patches)
for i in range(len(rgb_patches)):
    rgb_patches[i] = (rgb_patches[i] - rgb_min) / (rgb_max - rgb_min + 1e-8)

# ---- Paste into final grid image ----
print("Assembling final grid image ...")
for i, (gx, gy) in enumerate(grid_coords[c_idx]):
    row = i // grid_side
    col = i % grid_side
    patch_rgb = rgb_patches[i]
    grid_img[
        row * tile_px : (row + 1) * tile_px, col * tile_px : (col + 1) * tile_px, :
    ] = patch_rgb

# ---- Save PNG ----
from PIL import Image

img_out = (np.clip(grid_img, 0, 1) * 255).astype(np.uint8)
output_file = f"output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_embedding_grid_umap.png"
Image.fromarray(img_out).save(output_file)
print("✔ Saved AlphaEarth similarity grid (global UMAP 64->3):")
print(f"  {output_file}")

In [None]:
import os
import requests
from io import BytesIO
from PIL import Image
import ee
import numpy as np

# ---- USER INPUTS ----
OUTPUT_DIR = "output"
os.makedirs(OUTPUT_DIR, exist_ok=True)
TILE_PX = EMBED_PIXELS
side_meters_patch = EMBED_PIXELS * EMBED_SCALE

print("Preparing satellite thumbnail grid ...")

# ---- Use same Sentinel composite as before ----
s2 = (
    ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
    .filter(ee.Filter.calendarRange(YEAR, YEAR, "year"))
    .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 20))
)
s2_median = s2.select(["B4", "B3", "B2"]).median()

# ---- Output grid image size ----
grid_side = int(np.ceil(np.sqrt(len(valid_points))))
img_size = (grid_side * TILE_PX, grid_side * TILE_PX)
sat_grid_img = Image.new("RGB", img_size, (0, 0, 0))


# ---- Helper to fetch thumbnail ----
def fetch_satellite_thumbnail(center_lon, center_lat, side_meters, out_px):
    geom = ee.Geometry.Point([center_lon, center_lat]).buffer(side_meters / 2).bounds()
    try:
        url = s2_median.getThumbURL(
            {
                "region": geom.getInfo(),
                "dimensions": f"{out_px}x{out_px}",
                "format": "png",
                "min": 0,
                "max": 3000,
                "bands": ["B4", "B3", "B2"],
            }
        )
        resp = requests.get(url, timeout=60)
        resp.raise_for_status()
        return Image.open(BytesIO(resp.content)).convert("RGB")
    except Exception as e:
        print(f"  ✖ Failed at ({center_lat:.5f},{center_lon:.5f}): {e}")
        return Image.new("RGB", (out_px, out_px), (0, 0, 0))


# ---- Assemble grid ----
print("Downloading thumbnails and pasting into grid ...")
for i, idx in enumerate(r_idx):
    lat = valid_gdf.iloc[idx]["lat"]
    lon = valid_gdf.iloc[idx]["lon"]
    gx, gy = grid_coords[c_idx][i]
    row = int(i // grid_side)
    col = int(i % grid_side)
    left = col * TILE_PX
    upper = row * TILE_PX
    thumb = fetch_satellite_thumbnail(lon, lat, side_meters_patch, TILE_PX)
    sat_grid_img.paste(thumb, (left, upper))
    print(f"✔ Fetched thumbnail {i+1}/{len(valid_points)} at row {row}, col {col}")

# ---- Save output ----
sat_out_path = os.path.join(
    OUTPUT_DIR, f"{N}_{TILE_PX}_{EMBED_SCALE}_{YEAR}_sentinel_grid.png"
)
sat_grid_img.save(sat_out_path)
print("✔ Saved satellite thumbnail grid to:", sat_out_path)

In [None]:
# ---- ADD GRID POSITIONS (correctly mapped to sample indices) ----
print("Adding grid positions to GeoDataFrame (correct mapping) ...")

# assigned_coords[k] is the normalized [x,y] for the k-th assignment pair (r_idx[k], c_idx[k])
assigned_coords = grid_coords[c_idx]  # shape (n_samples, 2)

# compute integer col/row for each assigned grid cell (in the same k order)
assigned_cols = (assigned_coords[:, 0] * (grid_side - 1)).round().astype(int)
assigned_rows = (assigned_coords[:, 1] * (grid_side - 1)).round().astype(int)

# create arrays to hold grid pos for every sample (indexed by sample_index)
grid_x_arr = np.full(len(valid_gdf), -1, dtype=int)
grid_y_arr = np.full(len(valid_gdf), -1, dtype=int)

# For each assignment k, sample_index = r_idx[k], put the assigned col/row there
for k, sample_index in enumerate(r_idx):
    grid_x_arr[sample_index] = int(assigned_cols[k])
    grid_y_arr[sample_index] = int(assigned_rows[k])

# Add to GeoDataFrame (valid_gdf rows correspond to embedding indices 0..n-1)
valid_gdf["grid_x"] = grid_x_arr
valid_gdf["grid_y"] = grid_y_arr

print("Sample valid_gdf rows with grid positions:")
print(valid_gdf[["lat", "lon", "grid_x", "grid_y"]].head())

In [None]:
# ---- Assemble grid (use valid_gdf grid_x/grid_y) ----
print("Downloading thumbnails and pasting into grid (correct positions) ...")

# compute image canvas size from grid_side and tile size
img_size = (grid_side * TILE_PX, grid_side * TILE_PX)
sat_grid_img = Image.new("RGB", img_size, (0, 0, 0))

for sample_index in range(len(valid_gdf)):
    lat = valid_gdf.iloc[sample_index]["lat"]
    lon = valid_gdf.iloc[sample_index]["lon"]
    gx = int(valid_gdf.iloc[sample_index]["grid_x"])
    gy = int(valid_gdf.iloc[sample_index]["grid_y"])

    if gx < 0 or gy < 0:
        print(f"  ⚠ sample {sample_index} has no grid assignment, skipping")
        continue

    # compute pixel offsets; flip y so row 0 is top if you want top-left origin
    left = gx * TILE_PX
    top = (
        grid_side - 1 - gy
    ) * TILE_PX  # use this if you want grid_y=0 at bottom; remove (grid_side-1-...) if you used top-left origin

    thumb = fetch_satellite_thumbnail(lon, lat, side_meters_patch, TILE_PX)
    sat_grid_img.paste(thumb, (left, top))
    print(
        f"✔ Pasted sample {sample_index} at grid ({gx},{gy}) -> pixels ({left},{top})"
    )

# ---- Save output ----
sat_out_path = os.path.join(
    OUTPUT_DIR, f"{N}_{TILE_PX}_{EMBED_SCALE}_{YEAR}_sentinel_grid.png"
)
sat_grid_img.save(sat_out_path)
print("✔ Saved satellite thumbnail grid to:", sat_out_path)

In [None]:
# ---- ADD GRID POSITION TO GDF (based on assignment) ----
print("Assigning grid positions ...")

# Derive the grid coordinates from the assignment result
assigned_grid = grid_coords[c_idx]

# Normalize grid_x and grid_y to integer grid indices
grid_x_idx = np.round(assigned_grid[:, 0] * (grid_side - 1)).astype(int)
grid_y_idx = np.round(assigned_grid[:, 1] * (grid_side - 1)).astype(int)

# Assign these to the GeoDataFrame in the same order as r_idx
valid_gdf.loc[r_idx, "grid_x"] = grid_x_idx
valid_gdf.loc[r_idx, "grid_y"] = grid_y_idx

print("✔ Grid positions added to GeoDataFrame.\n")

# Quick verification
print("Verification (first 10 samples):")
for i in range(min(10, len(valid_gdf))):
    lat = valid_gdf.iloc[i]["lat"]
    lon = valid_gdf.iloc[i]["lon"]
    gx = int(valid_gdf.iloc[i]["grid_x"])
    gy = int(valid_gdf.iloc[i]["grid_y"])
    left = gx * TILE_PX
    top = (grid_side - 1 - gy) * TILE_PX
    print(
        f"sample {i}: latlon={lat:.5f},{lon:.5f} -> grid ({gx},{gy}) -> pixel pos ({left},{top})"
    )

In [None]:
import math
import numpy as np
import pandas as pd
import geopandas as gpd
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
from umap import UMAP

# ==== INPUTS ====
EMBED_PIXELS = 128  # size of each patch in pixels
EMBED_SCALE = "10m"  # arbitrary scale name
YEAR = 2024
N = "example"
SEED = 42

# Example inputs you already have
# gdf: GeoDataFrame with your patch locations
# embeddings: list or array of patch embeddings (shape: [N, H, W, 64])
# embeddings_arr: numpy array of embeddings

# ---- 1. COMPUTE PATCH MEANS ----
print("Computing patch-level mean embeddings for similarity ...")
mean_embs = embeddings_arr.reshape(len(embeddings), -1, 64).mean(axis=1)
print("Mean embedding shape:", mean_embs.shape)

# ---- 2. COMPUTE DISTANCES & UMAP ----
print("Computing pairwise distances ...")
distances = cdist(mean_embs, mean_embs, metric="euclidean")
print("Distance matrix shape:", distances.shape)

print("Running UMAP dimensionality reduction ...")
um = UMAP(n_components=2, random_state=SEED, metric="precomputed")
umap_2d = um.fit_transform(distances)
print("UMAP output shape:", umap_2d.shape)

# ---- 3. ARRANGE INTO GRID ----
print("Arranging patches into grid ...")
n = len(umap_2d)
grid_side = int(math.ceil(math.sqrt(n)))
grid_x = np.linspace(0, 1, grid_side)
grid_y = np.linspace(0, 1, grid_side)
grid_coords = np.array(np.meshgrid(grid_x, grid_y)).T.reshape(-1, 2)[:n]

# Hungarian algorithm for minimal total assignment cost
cost = cdist(umap_2d, grid_coords)
r_idx, c_idx = linear_sum_assignment(cost)
print("✔ Grid arrangement complete.")

# ---- 4. ADD COORDINATES TO GDF ----
print("Adding UMAP and grid coordinates to GeoDataFrame ...")

# UMAP coordinates
gdf.loc[r_idx, "umap_x"] = umap_2d[:, 0]
gdf.loc[r_idx, "umap_y"] = umap_2d[:, 1]

# Normalized grid coordinates (0..1)
assigned_coords = grid_coords[c_idx]
gdf.loc[r_idx, "grid_x_norm"] = assigned_coords[:, 0]
gdf.loc[r_idx, "grid_y_norm"] = assigned_coords[:, 1]

# Integer grid positions
grid_cols = (assigned_coords[:, 0] * (grid_side - 1)).round().astype(int)
grid_rows = (assigned_coords[:, 1] * (grid_side - 1)).round().astype(int)
gdf.loc[r_idx, "grid_x"] = grid_cols
gdf.loc[r_idx, "grid_y"] = grid_rows

# ---- 5. ADD PIXEL POSITIONS (TOP-LEFT ORIGIN) ----
print("Computing pixel positions ...")
gdf["pixel_x"] = gdf["grid_x"] * EMBED_PIXELS
gdf["pixel_y"] = gdf["grid_y"] * EMBED_PIXELS

# ---- 6. SAVE ----
print("Saving GeoDataFrame to disk ...")
geojson_path = f"output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_data.geojson"
csv_path = f"output/{N}_{EMBED_PIXELS}_{EMBED_SCALE}_{YEAR}_data.csv"

gdf.to_file(geojson_path, driver="GeoJSON")
gdf.drop(columns="geometry").to_csv(csv_path, index=False)

print(f"✔ Saved to:\n  {geojson_path}\n  {csv_path}")

# ---- 7. QUICK VERIFICATION ----
print("\nVerification (first 10 samples):")
for i in range(min(10, len(gdf))):
    gx = int(gdf.iloc[i]["grid_x"])
    gy = int(gdf.iloc[i]["grid_y"])
    px = int(gdf.iloc[i]["pixel_x"])
    py = int(gdf.iloc[i]["pixel_y"])
    print(f"sample {i}: grid=({gx},{gy}) → pixel=({px},{py})")