# DOWSER V0 - Feature Extraction from Google Earth Engine

This notebook extracts environmental features for each water point using Google Earth Engine.

**Input:** `waterpoints_multiclass.parquet` (35,273 points with labels)

**Output:** `waterpoints_with_features.parquet` (points + extracted features)

**Features extracted:**
- Topography: elevation, slope, aspect, TWI, curvature
- Vegetation: NDVI, NDWI
- Climate: annual precipitation, precipitation seasonality
- Soil: clay %, sand %, soil organic carbon
- Hydrology: distance to water, water occurrence
- Land cover: ESA WorldCover class

In [35]:
# ==============================================================
# CELL 1: Imports and Configuration
# ==============================================================

import ee
import pandas as pd
import geopandas as gpd
import numpy as np
from pathlib import Path
import time
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Project paths
PROJECT_ROOT = Path.cwd().parent
PROCESSED = PROJECT_ROOT / "data/processed"
PROCESSED.mkdir(parents=True, exist_ok=True)

print(f"Project root: {PROJECT_ROOT}")
print(f"Output dir: {PROCESSED}")

Project root: /Users/leonardovannoli/work/dowser/dowser-v0
Output dir: /Users/leonardovannoli/work/dowser/dowser-v0/data/processed


In [36]:
# ==============================================================
# CELL 2: Initialize Google Earth Engine
# ==============================================================
GEE_PROJECT_ID = "northern-cooler-426712-n6"
try:
    ee.Initialize(project=GEE_PROJECT_ID)
    print("✅ GEE already initialized")
except Exception as e:
    print("Authenticating GEE...")
    ee.Authenticate()
    ee.Initialize()
    print("✅ GEE authenticated and initialized")

✅ GEE already initialized


In [37]:
# ==============================================================
# CELL 3: Load Water Points Dataset
# ==============================================================

# Load the multiclass dataset
input_path = PROCESSED / "waterpoints_multiclass.parquet"
gdf = gpd.read_parquet(input_path)

print(f"Loaded {len(gdf)} water points")
print(f"\nClass distribution:")
class_names = {0: "NO_WATER", 1: "SURFACE (0-5m)", 2: "SHALLOW (5-20m)", 3: "DEEP (30-100m)"}
for cls, name in class_names.items():
    count = (gdf["class"] == cls).sum()
    print(f"  {name}: {count}")

print(f"\nCountry distribution:")
print(gdf["country"].value_counts())

Loaded 35273 water points

Class distribution:
  NO_WATER: 26556
  SURFACE (0-5m): 1138
  SHALLOW (5-20m): 4273
  DEEP (30-100m): 3306

Country distribution:
country
KEN    18315
TZA    16958
Name: count, dtype: int64


In [None]:
# ==============================================================
# CELL 4: Define Feature Extraction Functions - OLD VERSION
# ==============================================================

def get_image_stack():
    """
    Build a multi-band image stack with all environmental features.
    
    Returns:
        ee.Image: Multi-band image with all features
    """
    
    # -------------------------
    # 1. TOPOGRAPHY (Copernicus DEM 30m)
    # -------------------------
    dem = ee.ImageCollection("COPERNICUS/DEM/GLO30").select("DEM").mosaic()
    
    # Terrain derivatives
    slope = ee.Terrain.slope(dem).rename("slope")
    aspect = ee.Terrain.aspect(dem).rename("aspect")
    
    # Topographic Wetness Index (TWI) - proxy for water accumulation
    # TWI = ln(a / tan(b)) where a = flow accumulation, b = slope
    flow_acc = ee.Image("WWF/HydroSHEDS/15ACC")
    slope_rad = slope.multiply(np.pi / 180)  # Convert to radians
    twi = flow_acc.log().divide(slope_rad.tan().add(0.001)).rename("twi")
    
    # Curvature - profile curvature indicates convergence/divergence
    # Positive = concave (water accumulates), Negative = convex (water disperses)
    curvature = dem.convolve(ee.Kernel.laplacian8()).rename("curvature")
    
    # -------------------------
    # 2. VEGETATION INDICES (Sentinel-2, 2020-2023 composite)
    # -------------------------
    s2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") \
        .filterDate("2020-01-01", "2023-12-31") \
        .filter(ee.Filter.lt("CLOUDY_PIXEL_PERCENTAGE", 20)) \
        .median()
    
    # NDVI - Normalized Difference Vegetation Index
    # High NDVI in arid areas can indicate groundwater presence
    ndvi = s2.normalizedDifference(["B8", "B4"]).rename("ndvi")
    
    # NDWI - Normalized Difference Water Index
    # Detects surface water and moisture
    ndwi = s2.normalizedDifference(["B3", "B8"]).rename("ndwi")
    
    # -------------------------
    # 3. PRECIPITATION (CHIRPS 2014-2023)
    # -------------------------
    chirps = ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY") \
        .filterDate("2014-01-01", "2024-01-01")
    
    # Mean annual precipitation (mm/year)
    precip_annual = chirps.mean().multiply(365).rename("precip_annual_mm")
    
    # Precipitation seasonality (coefficient of variation of monthly totals)
    # High seasonality = more variable rainfall = less reliable recharge
    def monthly_total(month):
        return chirps.filter(ee.Filter.calendarRange(month, month, "month")).sum().set("month", month)
    
    monthly_precip = ee.ImageCollection([monthly_total(m) for m in range(1, 13)])
    precip_mean = monthly_precip.mean()
    precip_std = monthly_precip.reduce(ee.Reducer.stdDev())
    precip_cv = precip_std.divide(precip_mean.add(1)).rename("precip_seasonality")
    
    # -------------------------
    # 4. SOIL PROPERTIES (SoilGrids 250m)
    # -------------------------
    # Clay content (%) - high clay = low permeability
    clay = ee.Image("projects/soilgrids-isric/clay_mean") \
        .select("clay_0-5cm_mean").rename("clay_pct")
    
    # Sand content (%) - high sand = high permeability
    sand = ee.Image("projects/soilgrids-isric/sand_mean") \
        .select("sand_0-5cm_mean").rename("sand_pct")
    
    # Soil Organic Carbon - indicator of soil structure
    soc = ee.Image("projects/soilgrids-isric/soc_mean") \
        .select("soc_0-5cm_mean").rename("soc")
    
    # -------------------------
    # 5. SURFACE WATER (JRC Global Surface Water) - WITH UNMASK
    # -------------------------
    jrc = ee.Image("JRC/GSW1_4/GlobalSurfaceWater")
    
    # Water occurrence (% of time with water) - unmask to 0 where no data
    water_occurrence = jrc.select("occurrence").unmask(0).rename("water_occurrence")
    
    # Water seasonality (months per year with water)
    water_seasonality = jrc.select("seasonality").unmask(0).rename("water_seasonality")
    
    # -------------------------
    # 6. LAND COVER (ESA WorldCover 2021)
    # -------------------------
    landcover = ee.Image("ESA/WorldCover/v200/2021").select("Map").rename("landcover")
    
    # -------------------------
    # 7. LAND SURFACE TEMPERATURE (MODIS)
    # -------------------------
    lst = ee.ImageCollection("MODIS/061/MOD11A1") \
        .filterDate("2020-01-01", "2023-12-31") \
        .select("LST_Day_1km") \
        .mean() \
        .multiply(0.02).subtract(273.15) \
        .rename("lst_celsius")
    
    # -------------------------
    # BUILD FINAL STACK
    # -------------------------
    stack = dem.rename("elevation").unmask(-9999) \
        .addBands(slope.unmask(-9999)) \
        .addBands(aspect.unmask(-9999)) \
        .addBands(twi.unmask(-9999)) \
        .addBands(curvature.unmask(0)) \
        .addBands(ndvi.unmask(-9999)) \
        .addBands(ndwi.unmask(-9999)) \
        .addBands(precip_annual.unmask(0)) \
        .addBands(precip_cv.unmask(0)) \
        .addBands(clay.unmask(-9999)) \
        .addBands(sand.unmask(-9999)) \
        .addBands(soc.unmask(-9999)) \
        .addBands(water_occurrence) \
        .addBands(water_seasonality) \
        .addBands(landcover.unmask(0)) \
        .addBands(lst.unmask(-9999))
    
    return stack

# Test the stack
print("Building image stack...")
stack = get_image_stack()
band_names = stack.bandNames().getInfo()
print(f"\n✅ Image stack created with {len(band_names)} bands:")
for i, name in enumerate(band_names):
    print(f"  {i+1}. {name}")

# Quick test
test_lon = 34.3580455285
test_lat = -1.0370774269
point = ee.Geometry.Point([test_lon, test_lat])
result = stack.sample(point, scale=30).first().getInfo()
print(f"\n✅ Test sampling: {len(result['properties'])} values extracted")

Building image stack...

✅ Image stack created with 16 bands:
  1. elevation
  2. slope
  3. aspect
  4. twi
  5. curvature
  6. ndvi
  7. ndwi
  8. precip_annual_mm
  9. precip_seasonality
  10. clay_pct
  11. sand_pct
  12. soc
  13. water_occurrence
  14. water_seasonality
  15. landcover
  16. lst_celsius

✅ Test sampling: 16 values extracted


In [39]:
# ==============================================================
# CELL 4: SIMPLIFIED Image Stack (10 bands - no heavy computations)
# ==============================================================

# --- ELEVATION & TERRAIN (3 bands) ---
dem = ee.Image("USGS/SRTMGL1_003").select('elevation')
elevation = dem.rename('elevation')
slope = ee.Terrain.slope(dem).rename('slope')
aspect = ee.Terrain.aspect(dem).rename('aspect')

# --- VEGETATION (2 bands) ---
s2 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") \
    .filterDate('2020-01-01', '2023-12-31') \
    .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 20)) \
    .median()

ndvi = s2.normalizedDifference(['B8', 'B4']).rename('ndvi')
ndwi = s2.normalizedDifference(['B3', 'B8']).rename('ndwi')

# --- CLIMATE (1 band - simple annual mean) ---
precip = ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY") \
    .filterDate('2019-01-01', '2023-12-31') \
    .mean() \
    .multiply(365) \
    .rename('precip_annual_mm')

# --- SOIL (3 bands) ---
soil = ee.Image("OpenLandMap/SOL/SOL_TEXTURE-CLASS_USDA-TT_M/v02")
clay = ee.Image("projects/soilgrids-isric/clay_mean").select('clay_0-5cm_mean').rename('clay_pct')
sand = ee.Image("projects/soilgrids-isric/sand_mean").select('sand_0-5cm_mean').rename('sand_pct')
soc = ee.Image("projects/soilgrids-isric/soc_mean").select('soc_0-5cm_mean').rename('soc')

# --- LAND COVER (1 band) ---
landcover = ee.Image("ESA/WorldCover/v200/2021").select('Map').rename('landcover')

# === BUILD STACK (10 bands) ===
stack = ee.Image.cat([
    elevation.unmask(-9999),
    slope.unmask(-9999),
    aspect.unmask(-9999),
    ndvi.unmask(-9999),
    ndwi.unmask(-9999),
    precip.unmask(0),
    clay.unmask(-9999),
    sand.unmask(-9999),
    soc.unmask(-9999),
    landcover.unmask(0)
])

print("✅ Simplified stack: 10 bands")
print("Bands:", stack.bandNames().getInfo())

✅ Simplified stack: 10 bands
Bands: ['elevation', 'slope', 'aspect', 'ndvi', 'ndwi', 'precip_annual_mm', 'clay_pct', 'sand_pct', 'soc', 'landcover']


In [40]:
# ==============================================================
# CELL 5: Feature Extraction Function
# ==============================================================

def extract_features_batch(df, image_stack, batch_size=500):
    """
    Extract features for a batch of points using GEE.
    
    Args:
        df: DataFrame with #lon_deg, #lat_deg columns
        image_stack: ee.Image with all bands
        batch_size: Number of points per GEE request
    
    Returns:
        DataFrame with extracted features
    """
    all_results = []
    n_batches = (len(df) + batch_size - 1) // batch_size
    
    print(f"Extracting features for {len(df)} points in {n_batches} batches...")
    
    for batch_idx in tqdm(range(n_batches), desc="Processing batches"):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, len(df))
        batch_df = df.iloc[start_idx:end_idx]
        
        # Create FeatureCollection for this batch
        features = []
        for idx, row in batch_df.iterrows():
            point = ee.Geometry.Point([row["#lon_deg"], row["#lat_deg"]])
            feat = ee.Feature(point, {"original_idx": idx})
            features.append(feat)
        
        fc = ee.FeatureCollection(features)
        
        # Sample the image stack at point locations
        try:
            sampled = image_stack.sampleRegions(
                collection=fc,
                scale=30,  # 30m resolution
                geometries=False
            )
            
            # Download results
            batch_results = sampled.getInfo()
            
            for feat in batch_results["features"]:
                props = feat["properties"]
                all_results.append(props)
            
            # Rate limiting to avoid quota issues
            time.sleep(0.5)
            
        except Exception as e:
            print(f"\n⚠️ Error in batch {batch_idx}: {e}")
            # Add None values for failed batch
            for idx in batch_df.index:
                all_results.append({"original_idx": idx})
            continue
    
    return pd.DataFrame(all_results)

print("✅ Feature extraction function defined")

✅ Feature extraction function defined


In [30]:
# ==============================================================
# CELL 6: Test on Small Sample
# ==============================================================

# Test with 100 random points first
test_sample = gdf.sample(n=100, random_state=42).copy()
test_sample = test_sample.reset_index(drop=True)

print(f"Testing with {len(test_sample)} points...")

# Extract features
test_features = extract_features_batch(test_sample, stack, batch_size=100)

print(f"\n✅ Test complete!")
print(f"Extracted features shape: {test_features.shape}")
print(f"\nFeature columns:")
print(test_features.columns.tolist())
print(f"\nSample values:")
test_features.head()

Testing with 100 points...
Extracting features for 100 points in 1 batches...


Processing batches: 100%|██████████| 1/1 [09:40<00:00, 580.87s/it]



✅ Test complete!
Extracted features shape: (100, 17)

Feature columns:
['aspect', 'clay_pct', 'curvature', 'elevation', 'landcover', 'lst_celsius', 'ndvi', 'ndwi', 'original_idx', 'precip_annual_mm', 'precip_seasonality', 'sand_pct', 'slope', 'soc', 'twi', 'water_occurrence', 'water_seasonality']

Sample values:


Unnamed: 0,aspect,clay_pct,curvature,elevation,landcover,lst_celsius,ndvi,ndwi,original_idx,precip_annual_mm,precip_seasonality,sand_pct,slope,soc,twi,water_occurrence,water_seasonality
0,263.786499,340,3.020142,1275.702881,30,30.705205,0.549524,-0.563215,0,1265.683838,0.493165,439,0.326275,338,103.537679,0,0
1,354.577759,356,0.703979,1137.974731,40,37.196564,0.273618,-0.388735,1,650.671265,0.933702,547,0.030591,157,5444.012429,0,0
2,74.002762,349,-1.501465,1519.063721,20,28.30662,0.651097,-0.637993,2,1534.985229,0.621208,440,0.074039,363,0.0,0,0
3,26.899609,167,9.543213,393.460693,50,31.008779,0.313085,-0.400952,3,1065.606201,1.091736,790,0.237935,130,0.0,0,0
4,120.278488,207,0.191772,787.056396,40,27.634907,0.264884,-0.392227,4,1033.258423,0.949164,684,0.236481,437,135.185045,0,0


In [32]:
# Check the test results
print("Feature statistics:")
print(test_features.describe())

print("\nMissing values:")
print(test_features.isnull().sum())

print("\nNo-data values (-9999):")
for col in test_features.columns:
    if col != "original_idx":
        n_nodata = (test_features[col] == -9999).sum()
        if n_nodata > 0:
            print(f"  {col}: {n_nodata}")

Feature statistics:
            aspect     clay_pct   curvature    elevation   landcover  \
count   100.000000   100.000000  100.000000   100.000000  100.000000   
mean     15.573912   248.460000    1.335820  1312.416439   26.400000   
std    1440.575506  1039.802064    7.195326   352.292168   13.067038   
min   -9999.000000 -9999.000000  -19.807251    55.721268   10.000000   
25%     124.366226   263.750000   -2.811707  1162.700256   17.500000   
50%     263.786499   375.000000    0.817566  1316.806152   25.000000   
75%     272.386383   434.500000    5.407389  1527.738892   40.000000   
max     354.577759   533.000000   25.134277  2269.140381   50.000000   

       lst_celsius        ndvi        ndwi  original_idx  precip_annual_mm  \
count   100.000000  100.000000  100.000000    100.000000        100.000000   
mean     30.661728    0.513151   -0.541862     49.500000       1448.105358   
std       3.163654    0.187601    0.115074     29.011492        575.036768   
min      24.201843 

In [None]:
# ==============================================================
# CELL 8: Extract Features for ALL Points
# ==============================================================

# WARNING: This will take ~1-2 hours for 35k points
# Uncomment to run

print("="*60)
print("FULL FEATURE EXTRACTION")
print("="*60)
print(f"Total points: {len(gdf)}")
print("="*60)

# Reset index for proper merging
gdf_reset = gdf.reset_index(drop=True)

# Extract all features
all_features = extract_features_batch(gdf_reset, stack, batch_size=50)

print(f"\n✅ Extraction complete!")
print(f"Features shape: {all_features.shape}")

FULL FEATURE EXTRACTION
Total points: 35273
Estimated time: ~1-2 hours
Extracting features for 35273 points in 706 batches...


Processing batches:   1%|          | 8/706 [27:55<31:44:08, 163.68s/it]

In [None]:
# ==============================================================
# CELL 9: Merge Features with Original Data
# ==============================================================

# Merge extracted features with original dataframe
gdf_reset["original_idx"] = gdf_reset.index

# Merge on original_idx
merged = gdf_reset.merge(all_features, on="original_idx", how="left")

print(f"Merged dataset shape: {merged.shape}")
print(f"\nColumns:")
print(merged.columns.tolist())

# Check for missing features
feature_cols = ["elevation", "slope", "aspect", "twi", "curvature", 
                "ndvi", "ndwi", "precip_annual_mm", "precip_seasonality",
                "clay_pct", "sand_pct", "soc", "water_occurrence", 
                "water_seasonality", "dist_to_water_m", "landcover", "lst_celsius"]

print(f"\nMissing values in feature columns:")
for col in feature_cols:
    if col in merged.columns:
        n_missing = merged[col].isnull().sum()
        pct_missing = n_missing / len(merged) * 100
        if n_missing > 0:
            print(f"  {col}: {n_missing} ({pct_missing:.1f}%)")

In [None]:
# ==============================================================
# CELL 10: Save Final Dataset
# ==============================================================

# Select relevant columns for final dataset
cols_to_keep = [
    # Identifiers
    "#lon_deg", "#lat_deg", "country", "#clean_adm1", "#clean_adm2",
    # Labels
    "class", "source_type", "success", "#status_clean",
    # Features - Topography
    "elevation", "slope", "aspect", "twi", "curvature",
    # Features - Vegetation
    "ndvi", "ndwi",
    # Features - Climate
    "precip_annual_mm", "precip_seasonality",
    # Features - Soil
    "clay_pct", "sand_pct", "soc",
    # Features - Water
    "water_occurrence", "water_seasonality", "dist_to_water_m",
    # Features - Land
    "landcover", "lst_celsius",
    # Geometry
    "geometry"
]

# Filter to existing columns
cols_to_keep = [c for c in cols_to_keep if c in merged.columns]
final_df = merged[cols_to_keep].copy()

# Convert back to GeoDataFrame
final_gdf = gpd.GeoDataFrame(final_df, geometry="geometry", crs="EPSG:4326")

# Save
output_path = PROCESSED / "waterpoints_with_features.parquet"
final_gdf.to_parquet(output_path, index=False)

print(f"✅ Saved: {output_path}")
print(f"Shape: {final_gdf.shape}")

In [None]:
# ==============================================================
# CELL 11: Final Summary
# ==============================================================

print("="*60)
print("DATASET SUMMARY")
print("="*60)

print(f"\nTotal samples: {len(final_gdf)}")

print(f"\nClass distribution:")
for cls, name in class_names.items():
    count = (final_gdf["class"] == cls).sum()
    pct = count / len(final_gdf) * 100
    print(f"  {name}: {count} ({pct:.1f}%)")

print(f"\nFeature statistics:")
feature_cols = ["elevation", "slope", "twi", "ndvi", "precip_annual_mm", 
                "clay_pct", "sand_pct", "dist_to_water_m"]
feature_cols = [c for c in feature_cols if c in final_gdf.columns]
print(final_gdf[feature_cols].describe().round(2))

print(f"\n✅ Dataset ready for training!")
print(f"Next step: Run notebook 02_train_model.ipynb")

In [None]:
# ==============================================================
# CELL 12: Quick Visualization (Optional)
# ==============================================================

import matplotlib.pyplot as plt

# Plot feature distributions by class
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

features_to_plot = ["elevation", "slope", "ndvi", "precip_annual_mm", 
                    "clay_pct", "sand_pct", "twi", "dist_to_water_m"]
features_to_plot = [f for f in features_to_plot if f in final_gdf.columns]

colors = {0: "red", 1: "lightblue", 2: "blue", 3: "darkblue"}

for i, feat in enumerate(features_to_plot):
    ax = axes[i]
    for cls in [0, 1, 2, 3]:
        data = final_gdf[final_gdf["class"] == cls][feat].dropna()
        ax.hist(data, bins=30, alpha=0.5, label=class_names[cls], color=colors[cls])
    ax.set_xlabel(feat)
    ax.set_ylabel("Count")
    if i == 0:
        ax.legend(fontsize=8)

plt.tight_layout()
plt.savefig(PROCESSED.parent.parent / "outputs/figures/feature_distributions.png", dpi=150)
plt.show()

print("✅ Figure saved to outputs/figures/feature_distributions.png")