In [None]:
!pip install -r requirements.txt

## Sample GeoAI Wrapper for Clay

In [None]:
from geoai.clay import Clay
import torch

In [None]:
clay_model = Clay(sensor_name="sentinel-2-l2a")
clay_model.module

In [None]:
t1 = torch.rand((256, 256, 10))
embedding = clay_model.generate(t1)
print(embedding.shape)

## Loading Data

In [None]:
!pip install 'stackstac[viz]'

In [None]:
import math

import geopandas as gpd
import numpy as np
import pandas as pd
import pystac_client
import stackstac
from shapely import Point
from rasterio.enums import Resampling
import torch
import yaml
from box import Box
from matplotlib import pyplot as plt


from sklearn import decomposition, svm
from torchvision.transforms import v2

from claymodel.module import ClayMAEModule

In [None]:
# Point over Monchique Portugal
lat, lon = 37.30939, -8.57207

# Dates of a large forest fire
start = "2018-07-01"
end = "2018-09-01"

In [None]:
STAC_API = "https://earth-search.aws.element84.com/v1"
COLLECTION = "sentinel-2-l2a"

# Search the catalogue
catalog = pystac_client.Client.open(STAC_API)
search = catalog.search(
    collections=[COLLECTION],
    datetime=f"{start}/{end}",
    bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),
    max_items=100,
    query={"eo:cloud_cover": {"lt": 80}},
)

all_items = search.get_all_items()

# Reduce to one per date (there might be some duplicates
# based on the location)
items = []
dates = []
for item in all_items:
    if item.datetime.date() not in dates:
        items.append(item)
        dates.append(item.datetime.date())

print(f"Found {len(items)} items")

In [None]:
items[0].properties

In [None]:
# Extract coordinate system from first item
epsg_str = items[0].properties["proj:code"]
epsg = int(epsg_str.split(":")[-1])  # Convert 'EPSG:32629' to 32629

# Convert point of interest into the image projection
# (assumes all images are in the same projection)
poidf = gpd.GeoDataFrame(
    pd.DataFrame(),
    crs="EPSG:4326",
    geometry=[Point(lon, lat)],
).to_crs(epsg_str)

coords = poidf.iloc[0].geometry.coords[0]

# Create bounds in projection
size = 256
gsd = 10
bounds = (
    coords[0] - (size * gsd) // 2,
    coords[1] - (size * gsd) // 2,
    coords[0] + (size * gsd) // 2,
    coords[1] + (size * gsd) // 2,
)

In [None]:
# Retrieve the pixel values, for the bounding box in
# the target projection. In this example we use only
# the RGB and NIR bands.
stack = stackstac.stack(
    items,
    bounds=bounds,
    snap_bounds=False,
    epsg=epsg,
    resolution=gsd,
    dtype="float64",
    rescale=False,
    fill_value=np.nan,
    assets=["blue", "green", "red", "nir"],
    resampling=Resampling.nearest,
)

print(stack)

stack = stack.compute()

In [None]:
stack.sel(band=["red", "green", "blue"]).plot.imshow(
    row="time", rgb="band", vmin=0, vmax=2000, col_wrap=6
)

## Clay Embeddings

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
ckpt = "~/.cache/clay/clay-v1.5.ckpt"
torch.set_default_device(device)

model = ClayMAEModule.load_from_checkpoint(
    ckpt,
    model_size="large",
    metadata_path="geoai/config/clay_metadata.yaml",
    dolls=[16, 32, 64, 128, 256, 768, 1024],
    doll_weights=[1, 1, 1, 1, 1, 1, 1],
    mask_ratio=0.0,
    shuffle=False,
)
model.eval()

model = model.to(device)

In [None]:
# Extract mean, std, and wavelengths from metadata
platform = "sentinel-2-l2a"
metadata = Box(yaml.safe_load(open("geoai/config/clay_metadata.yaml")))
mean = []
std = []
waves = []
# Use the band names to get the correct values in the correct order.
for band in stack.band:
    mean.append(metadata[platform].bands.mean[str(band.values)])
    std.append(metadata[platform].bands.std[str(band.values)])
    waves.append(metadata[platform].bands.wavelength[str(band.values)])

# Prepare the normalization transform function using the mean and std values.
transform = v2.Compose(
    [
        v2.Normalize(mean=mean, std=std),
    ]
)

In [None]:
# Prep datetimes embedding using a normalization function from the model code.
def normalize_timestamp(date):
    week = date.isocalendar().week * 2 * np.pi / 52
    hour = date.hour * 2 * np.pi / 24

    return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))


datetimes = stack.time.values.astype("datetime64[s]").tolist()
times = [normalize_timestamp(dat) for dat in datetimes]
week_norm = [dat[0] for dat in times]
hour_norm = [dat[1] for dat in times]


# Prep lat/lon embedding using the
def normalize_latlon(lat, lon):
    lat = lat * np.pi / 180
    lon = lon * np.pi / 180

    return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))


latlons = [normalize_latlon(lat, lon)] * len(times)
lat_norm = [dat[0] for dat in latlons]
lon_norm = [dat[1] for dat in latlons]

# Normalize pixels
pixels = torch.from_numpy(stack.data.astype(np.float32))
pixels = transform(pixels)

In [None]:
# Prepare additional information
datacube = {
    "platform": platform,
    "time": torch.tensor(
        np.hstack((week_norm, hour_norm)),
        dtype=torch.float32,
        device=device,
    ),
    "latlon": torch.tensor(
        np.hstack((lat_norm, lon_norm)), dtype=torch.float32, device=device
    ),
    "pixels": pixels.to(device),
    "gsd": torch.tensor(stack.gsd.values, device=device),
    "waves": torch.tensor(waves, device=device),
}

In [None]:
with torch.no_grad():
    unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)

# The first embedding is the class token, which is the
# overall single embedding. We extract that for PCA below.
embeddings_clay = unmsk_patch[:, 0, :].cpu()

In [None]:
# Run PCA
pca = decomposition.PCA(n_components=1)
pca_result = pca.fit_transform(embeddings_clay)

plt.xticks(rotation=-45)

# Plot all points in blue first
plt.scatter(stack.time, pca_result, color="blue")

# Re-plot cloudy images in green
plt.scatter(stack.time[0], pca_result[0], color="green")
plt.scatter(stack.time[2], pca_result[2], color="green")

# Color all images after fire in red
plt.scatter(stack.time[-5:], pca_result[-5:], color="red")

In [None]:
# Label the images we downloaded
# 0 = Cloud
# 1 = Forest
# 2 = Fire
labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])

# Split into fit and test manually, ensuring we have all 3 classes in both sets
fit = [0, 1, 3, 4, 7, 8, 9]
test = [2, 5, 6, 10, 11]

# Train a Support Vector Machine model
clf = svm.SVC()
clf.fit(embeddings_clay[fit] + 100, labels[fit])

# Predict classes on test set
prediction = clf.predict(embeddings_clay[test] + 100)

# Perfect match for SVM
match = np.sum(labels[test] == prediction)
print(f"Matched {match} out of {len(test)} correctly")

In [None]:
embeddings_clay.shape

## GeoAI Wrapper Embeddings

In [None]:
from geoai.clay import Clay
import torch

In [None]:
# Initialize Clay model using geoai wrapper with custom metadata for 4 bands
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Create custom metadata for just the 4 bands we're using
custom_metadata = {
    "band_order": ["blue", "green", "red", "nir"],
    "rgb_indices": [2, 1, 0],
    "gsd": 10,
    "bands": {
        "mean": {"blue": 1105.0, "green": 1355.0, "red": 1552.0, "nir": 2743.0},
        "std": {"blue": 1809.0, "green": 1757.0, "red": 1888.0, "nir": 1742.0},
        "wavelength": {"blue": 0.493, "green": 0.56, "red": 0.665, "nir": 0.842},
    },
}

clay_model = Clay(custom_metadata=custom_metadata, device=str(device))

In [None]:
# Convert WGS84 bounds for Clay model
# First convert stack bounds back to WGS84
proj_bounds = (bounds[0], bounds[1], bounds[2], bounds[3])
bounds_gdf = gpd.GeoDataFrame(
    geometry=[Point(bounds[0], bounds[1]), Point(bounds[2], bounds[3])], crs=epsg_str
).to_crs("EPSG:4326")

wgs84_bounds = (
    bounds_gdf.iloc[0].geometry.x,  # min_lon
    bounds_gdf.iloc[0].geometry.y,  # min_lat
    bounds_gdf.iloc[1].geometry.x,  # max_lon
    bounds_gdf.iloc[1].geometry.y,  # max_lat
)

In [None]:
# Process each image through Clay model
embeddings_list = []
datetimes = stack.time.values.astype("datetime64[s]").tolist()

for i, datetime_obj in enumerate(datetimes):
    # Extract image for this time step [H, W, C]
    image = stack[i].values.transpose(1, 2, 0)  # Convert from [C, H, W] to [H, W, C]

    # Convert numpy datetime64 to Python datetime
    if hasattr(datetime_obj, "astype"):
        timestamp = datetime_obj.astype("datetime64[s]").astype("int")
        date = datetime.datetime.fromtimestamp(timestamp)
    else:
        date = datetime_obj

    # Generate embedding using geoai wrapper
    embed = clay_model.generate(
        image=image,
        bounds=wgs84_bounds,
        date=date,
        gsd=gsd,
        only_cls_token=True,  # Get only the class token (global embedding)
    )

    embeddings_list.append(embed.squeeze(0))
# Stack all embeddingsnp
embeddings_geoai = torch.stack(embeddings_list).cpu()

In [None]:
embeddings_geoai.shape

In [None]:
# Run PCA
pca = decomposition.PCA(n_components=1)
pca_result = pca.fit_transform(embeddings_geoai)

plt.xticks(rotation=-45)

# Plot all points in blue first
plt.scatter(stack.time, pca_result, color="blue")

# Re-plot cloudy images in green
plt.scatter(stack.time[0], pca_result[0], color="green")
plt.scatter(stack.time[2], pca_result[2], color="green")

# Color all images after fire in red
plt.scatter(stack.time[-5:], pca_result[-5:], color="red")

In [None]:
print(f"Generated embeddings shape: {embeddings_geoai.shape}")

# Label the images we downloaded
# 0 = Cloud
# 1 = Forest
# 2 = Fire
labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])

# Split into fit and test manually, ensuring we have all 3 classes in both sets
fit = [0, 1, 3, 4, 7, 8, 9]
test = [2, 5, 6, 10, 11]

# Train a Support Vector Machine model
clf = svm.SVC()
clf.fit(embeddings_geoai[fit] + 100, labels[fit])

# Predict classes on test set
prediction = clf.predict(embeddings_geoai[test] + 100)

# Perfect match for SVM
match = np.sum(labels[test] == prediction)
print(f"Matched {match} out of {len(test)} correctly")

## Comparing GeoAI and Clay

In [None]:
torch.allclose(embeddings_geoai, embeddings_clay)

In [None]:
((embeddings_geoai - embeddings_clay) / embeddings_clay).max()