# OlmoEarth Embeddings using data from Google Earth Engine (GEE)

**Author**: Ivan Zvonkov (ivan.zvonkov@gmail.com)

**Last modified**: Nov 30, 2025

**Description**: One-stop shop for generating OlmoEarth embeddings using data from Google Earth Engine. The notebook is intended to be run on Google Colab with a GPU.

1. **Setup**: Specifies inference run configuration and shows amount of embeddings already generated.
2. **GEE data exports**: exports Earth observation data to a cloud bucket in tiles.
3. **OlmoEarth Setup**: Loads OlmoEarth model and creates function for converting Google Earth Engine data into OlmoEarth format.

4. **OlmoEarth Inference**: Runs OlmoEarth model on Earth observation data and upload embeddings to bucket.

OlmoEarth inference can be run inside the notebook (step 4) or in Google Cloud Run following the instructions in `scripts/tools/cloud_run/README.md`.


## 1. Setup



In [None]:
# Inference run configuration
#-------------------------------------------------------------------------------
NAME = "Togo_v2_nano"
START_DATE = '2019-03-01'
END_DATE = '2020-03-01'
RUN = f"{NAME}_{START_DATE}_{END_DATE}"

GCLOUD_PROJECT = "ai2-ivan"
IN_BUCKET = "ai2-ivan-helios-input-data" # Bucket for GEE input data
OUT_BUCKET = "ai2-ivan-helios-output-data" # Bucket for embedding outputs


# General setup
#-------------------------------------------------------------------------------
from google.colab import auth
from google.cloud import storage
from tqdm.notebook import tqdm
from pathlib import Path

auth.authenticate_user()
client = storage.Client(project=GCLOUD_PROJECT)
in_bucket = client.bucket(IN_BUCKET)
out_bucket = client.bucket(OUT_BUCKET)

BANDS = {
    "sentinel1":  ["VV", "VH"],
    "sentinel2":  ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B9", "B11", "B12"],
    "landsat":    ["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9", "B10", "B11"]
}

def remaining_tiles(mod=None, index=None):
  in_tifs = {b.name for b in in_bucket.list_blobs(prefix=RUN)}
  out_tifs = {b.name for b in out_bucket.list_blobs(prefix=RUN)}
  print(f"Embeddings generated:  {len(out_tifs)}/{len(in_tifs)}")
  remaining = list(in_tifs - out_tifs)
  if (mod is not None) and (index is not None):
    remaining = [t for t in remaining if ((int(Path(t).stem) % mod) == index)]
  return remaining

In [None]:
remaining_tiles();

## 2. GEE data exports (only run once)

In [None]:
import ee
import google

SCOPES = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/earthengine"]
CREDENTIALS, _ = google.auth.default(default_scopes=SCOPES)
ee.Initialize(CREDENTIALS, project=GCLOUD_PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')

roi = ee.FeatureCollection("FAO/GAUL/2015/level2").filter(ee.Filter.eq('ADM0_NAME', 'Togo')).geometry()
GEE_TILE_SIZE = 10*1000 # 10km2

start = ee.Date(START_DATE)
end = ee.Date(END_DATE)

# Sentinel-1 Data
#-------------------------------------------------------------------------------
S1_all = ee.ImageCollection('COPERNICUS/S1_GRD').filterBounds(roi).filterDate(start.advance(-31, 'days'), end.advance(31, 'days'))
S1 = S1_all.filter(ee.Filter.eq("orbitProperties_pass", S1_all.first().get("orbitProperties_pass"))).filter(ee.Filter.eq("instrumentMode", "IW"))
S1_VV = S1.filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VV"))
S1_VH = S1.filter(ee.Filter.listContains("transmitterReceiverPolarisation", "VH"))

def getCloseImages(middleDate, imageCollection):
  def setDate(img):
    dateDist = ee.Number(img.get("system:time_start")).subtract(middleDate.millis()).abs()
    return img.set("dateDist", dateDist)
  fromMiddleDate = imageCollection.map(setDate).sort("dateDist", True)
  fifteenDaysInMs = ee.Number(1296000000)
  maxDiff = ee.Number(fromMiddleDate.first().get("dateDist")).max(fifteenDaysInMs)
  return fromMiddleDate.filterMetadata("dateDist", "not_greater_than", maxDiff)

def get_S1_img(date1, date2):
  daysBetween = date2.difference(date1, 'days')
  middleDate = date1.advance(daysBetween.divide(2), 'days')
  kept_vv = getCloseImages(middleDate, S1_VV).select("VV")
  kept_vh = getCloseImages(middleDate, S1_VH).select("VH")
  S1_composite = ee.Image.cat([kept_vv.median(), kept_vh.median()])
  return S1_composite.select(BANDS["sentinel1"]).clip(roi).float() # S1 ranges from -50 to 1


# Sentinel-2 data
#-------------------------------------------------------------------------------
# In Togo CLOUD_SCORE_PLUS plus gives better mosaics than sorting by CLOUD_COVERAGE_ASSESSMENT
S2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED").filterBounds(roi).filterDate(start, end)
csPlus = ee.ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED').filterBounds(roi).filterDate(start, end)
QA_BAND = 'cs_cdf'; # Better than cs here
S2_cf = S2.linkCollection(csPlus, [QA_BAND])

def get_S2_img(date1, date2):
  return S2_cf.filterDate(date1, date2).qualityMosaic(QA_BAND).select(BANDS["sentinel2"]).clip(roi).float()


# Landsat 8 data
#-------------------------------------------------------------------------------
LANDSAT_SR = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2').merge(ee.ImageCollection('LANDSAT/LC08/C02/T2_L2')).filterBounds(roi).filterDate(start, end)
LANDSAT_TOA = ee.ImageCollection('LANDSAT/LC08/C02/T1_TOA').merge(ee.ImageCollection('LANDSAT/LC08/C02/T2_TOA')).filterBounds(roi).filterDate(start, end)
landsat = LANDSAT_SR.linkCollection(LANDSAT_TOA, ["B8", "B9", "B11"])

def get_landsat_img(date1, date2):
  landsat_img = landsat.filterDate(date1, date2).sort("CLOUD_COVER").mosaic().clip(roi).set("system:index", "")
  SR_BANDS = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7", "ST_B10"]
  landsat_SR = landsat_img.select(SR_BANDS).rename(["B1", "B2", "B3", "B4", "B5", "B6", "B7", "B10"])
  landsat_TOA = landsat_img.select(["B8", "B9", "B11"]).add(0.2).divide(0.0000275).toInt16()
  return landsat_SR.addBands(landsat_TOA).float()

proj = landsat.first().select("SR_B1").projection()
crs = proj.crs().getInfo()
print(crs)

In [None]:
# Create GEE tasks
#-------------------------------------------------------------------------------
S1_img_list = []
S2_img_list = []
landsat_img_list = []
numMonths = end.difference(start, 'month').toInt().getInfo()
for i in range(numMonths):
  d1 = start.advance(i, 'month')
  d2 = d1.advance(1, 'month')
  S1_img_list.append(get_S1_img(d1, d2))
  S2_img_list.append(get_S2_img(d1, d2))
  landsat_img_list.append(get_landsat_img(d1, d2))

def imageFromList(pre, imgList):
  seq = ee.ImageCollection.fromImages(imgList).toBands()
  newNames = seq.bandNames().map(lambda b: ee.String(pre+"_").cat(b))
  return seq.rename(newNames)

theBigInputImage = ee.Image.cat([
  ee.Image.pixelLonLat().clip(roi).select("latitude", "longitude").float(),
  imageFromList("sentinel2",  S2_img_list),
  imageFromList("sentinel1",  S1_img_list),
  imageFromList("landsat", landsat_img_list)
])

grid = roi.buffer(-100).coveringGrid(crs, GEE_TILE_SIZE) # Buffered to avoid almost empty tiles
grid_list = grid.toList(grid.size())

tasks = []
already_exist = 0
print("Preparing EarthEngine tasks...")
for i in tqdm(range(grid.size().getInfo())):
  if in_bucket.blob(f"{RUN}/{i}.tif").exists():
    already_exist += 1
    continue

  tile = ee.Feature(grid_list.get(i)).geometry()
  task = ee.batch.Export.image.toCloudStorage(
    image=theBigInputImage.clip(tile),
    description=f"{RUN}_{i}", bucket=IN_BUCKET, fileFormat='GeoTIFF',
    fileNamePrefix=f"{RUN}/{i}", scale=10, crs=crs, region=tile,
  );
  tasks.append(task)

if already_exist > 0:
  print(f"{already_exist} tiles already exist in Cloud Storage.")
if len(tasks) > 0:
  print(f"Run next cell to start exports.")

In [None]:
# Start gee tasks for tiles not yet exported
for task in tqdm(tasks):
  task.start()
print(f"Started {len(tasks)} export tasks, see: https://code.earthengine.google.com/tasks")

## 3. OlmoEarth Setup

In [None]:
# 2 minutes to setup
from datetime import datetime as dt
import math
import numpy as np
import pandas as pd
import time
import torch

if not Path("/content/olmoearth_pretrain").exists():
  !git clone https://github.com/allenai/olmoearth_pretrain.git

%cd /content/olmoearth_pretrain

!pip install -q .

from torch.utils.data import Dataset, DataLoader
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, Modality
from olmoearth_pretrain.data.normalize import Normalizer, Strategy
from olmoearth_pretrain.model_loader import ModelID, load_model_from_id

import rasterio as rio
from rasterio.windows import Window

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = load_model_from_id(ModelID.OLMOEARTH_V1_NANO)
model.eval()
model = model.encoder.to(device)
EMBEDDINGS_SIZE = model.project_and_aggregate.projection[0].out_features
print("Embedding size: ", EMBEDDINGS_SIZE)

In [None]:
# Derive timestamps
to_date_obj = lambda d: dt.strptime(d, "%Y-%m-%d").date()
timestamps_pd = pd.date_range(to_date_obj(START_DATE), to_date_obj(END_DATE), freq="MS")[:-1]
timestamps = [[t.year, t.month - 1, t.day] for t in timestamps_pd]

In [None]:
computed = Normalizer(Strategy.COMPUTED)
predefined = Normalizer(Strategy.PREDEFINED)

class OlmoEarthGEEDataset(Dataset):
  def __init__(self, tif_path):
      profile, input_dict = self._read_geotiff(tif_path)
      self.profile = profile
      self.tensors = {
        "timestamps": torch.from_numpy(input_dict["timestamps"]),
        "sentinel2":  torch.from_numpy(computed.normalize(Modality.SENTINEL2_L2A, input_dict["sentinel2"])).float(),
        "sentinel1":  torch.from_numpy(computed.normalize(Modality.SENTINEL1, input_dict["sentinel1"])).float(),
        "landsat":    torch.from_numpy(computed.normalize(Modality.LANDSAT, input_dict["landsat"])).float(),
        "latlon":     torch.from_numpy(predefined.normalize(Modality.LATLON, input_dict["latlon"])).float(),
      }
      for k in self.tensors.keys():
        self.tensors[k].share_memory_()

  @staticmethod
  def _read_geotiff(tif_path):
    with rio.open(tif_path) as src:
      profile = src.profile
      bands = src.descriptions
      height, width = src.height, src.width
      tile = src.read()

    height = tile.shape[1]
    width = tile.shape[2]
    num_pixels = height * width
    input_data = tile.reshape(len(bands), num_pixels)

    input_dict = {
      "timestamps": np.array([timestamps] * num_pixels),
      "latlon":     input_data[[bands.index("latitude"), bands.index("longitude")]].transpose(1, 0),
      "landsat":    np.zeros((num_pixels, 1, 1, len(timestamps), len(BANDS["landsat"]))),
      "sentinel1":  np.zeros((num_pixels, 1, 1, len(timestamps), len(BANDS["sentinel1"]))),
      "sentinel2":  np.zeros((num_pixels, 1, 1, len(timestamps), len(BANDS["sentinel2"]))),
    }

    for i, key in enumerate(bands):
      if key == "latitude" or key == "longitude":
        continue
      modality, timestep_str, band = key.split("_")
      band_index = BANDS[modality].index(band)
      input_dict[modality][:, 0, 0, int(timestep_str), band_index] = input_data[i]
    return profile, input_dict

  def __len__(self):
    return self.profile["width"] * self.profile["height"]

  def __getitem__(self, idx):
      sample = {k: v[idx] for k, v in self.tensors.items()}
      return sample


## 4. OlmoEarth Inference

In [None]:
BATCH_SIZE = 128*128

remaining = remaining_tiles()
while len(remaining) > 0:
  for tile in tqdm(remaining):
    print(tile)
    print(f"\n\tDownloading input data ...\t", end="")
    start = time.perf_counter()
    !gcloud storage cp gs://{IN_BUCKET}/{tile} in.tif 2>/dev/null
    duration = time.perf_counter() - start
    print(f"{duration:.2f}s\t ✓")

    print(f"\tReading & normalizing data ...\t", end="")
    start = time.perf_counter()
    dataset = OlmoEarthGEEDataset("in.tif")
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        persistent_workers=False
    )
    duration = time.perf_counter() - start
    print(f"{duration:.2f}s\t ✓")
    batches = math.ceil(len(dataset) / BATCH_SIZE)
    embeddings_list = []

    # Go through pixels in file in batches
    print(f"\tInference ...\t\t\t\t\t", end="")
    start = time.perf_counter()
    for batch in loader:
      gpu_batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

      # Create OlmoEarth sample
      masked_sample = MaskedOlmoEarthSample(
          timestamps=           gpu_batch["timestamps"],
          sentinel2_l2a=        gpu_batch["sentinel2"],
          sentinel1=            gpu_batch["sentinel1"],
          landsat=              gpu_batch["landsat"],
          latlon=               gpu_batch["latlon"],
          sentinel2_l2a_mask=   torch.zeros_like(gpu_batch["sentinel2"], device=device),
          sentinel1_mask=       torch.zeros_like(gpu_batch["sentinel1"], device=device),
          landsat_mask=         torch.zeros_like(gpu_batch["landsat"], device=device),
          latlon_mask=          torch.zeros_like(gpu_batch["latlon"], device=device)
      )

      # Make predictions
      with torch.no_grad():
          preds = model(masked_sample, patch_size=1, fast_pass=True)
          preds_projected = model.project_and_aggregate(preds["tokens_and_masks"])
          embeddings = preds_projected.cpu().numpy()
      embeddings_list.append(embeddings)
    duration = time.perf_counter() - start
    print(f"{duration:.2f}s\t ✓")

    print(f"\tWriting to file ...\t\t\t", end="")
    start = time.perf_counter()
    profile = dataset.profile
    profile.update(count=EMBEDDINGS_SIZE, dtype="float32", compress="deflate", bigtiff="YES")
    all_embeddings = np.concatenate(embeddings_list).transpose(1, 0)
    embeddings_reshaped = all_embeddings.reshape(EMBEDDINGS_SIZE, profile["height"], profile["width"])
    with rio.open("out.tif", "w", **profile) as dst:
      dst.write(embeddings_reshaped.astype("float32"))
    duration = time.perf_counter() - start
    print(f"{duration:.2f}s\t ✓")


    print(f"\tUploading embeddings ...\t\t", end="")
    start = time.perf_counter()
    !gcloud storage cp out.tif gs://{OUT_BUCKET}/{tile} 2>/dev/null
    duration = time.perf_counter() - start
    print(f"{duration:.2f}s\t ✓")

    !rm "in.tif" "out.tif"

    remaining = remaining_tiles()