In [5]:
import os
import re
import zipfile
import numpy as np
import pandas as pd
import rasterio
from rasterio.transform import from_origin
from whitebox.whitebox_tools import WhiteboxTools
import earthaccess
from tqdm import tqdm
import contextlib
import io
import geopandas as gpd

# ---------------------------
# Tile ID utilities
# ---------------------------
def tile_id_from_coords(lat, lon):
    """Convert coords to tile ID (e.g. N40W106)."""
    if pd.isna(lat) or pd.isna(lon):
        return None
    ns = "N" if lat >= 0 else "S"
    ew = "E" if lon >= 0 else "W"
    return f"{ns}{abs(int(lat)):02d}{ew}{abs(int(lon)):03d}"

# ---------------------------
# DEM Download
# ---------------------------
def download_dem_bbox(min_lon, min_lat, max_lon, max_lat, out_dir="dem_tiles", prefer="SRTMGL1"):
    os.makedirs(out_dir, exist_ok=True)
    earthaccess.login(strategy="environment", persist=True)

    dataset = ("SRTMGL1", "003") if prefer == "SRTMGL1" else ("COPDEM_GLO_30", "001")

    try:
        results = earthaccess.search_data(
            short_name=dataset[0],
            version=dataset[1],
            bounding_box=(min_lon, min_lat, max_lon, max_lat),
            count=10
        )
    except IndexError:
        return []

    if not results or len(results) == 0:
        return []

    buf = io.StringIO()
    with contextlib.redirect_stdout(buf), contextlib.redirect_stderr(buf):
        paths = earthaccess.download(results, out_dir)

    return paths

def download_dem_point(lat, lon, out_dir="dem_tiles", buffer=0.1):
    # Clamp to valid ranges
    min_lon = max(-180.0, lon - buffer)
    max_lon = min(180.0, lon + buffer)
    min_lat = max(-90.0, lat - buffer)
    max_lat = min(90.0, lat + buffer)

    paths = download_dem_bbox(min_lon, min_lat, max_lon, max_lat, out_dir=out_dir, prefer="SRTMGL1")
    if paths:
        return paths, "SRTM"
    paths = download_dem_bbox(min_lon, min_lat, max_lon, max_lat, out_dir=out_dir, prefer="COPDEM")
    if paths:
        return paths, "Copernicus"
    return [], "None"

# ---------------------------
# HGT → GeoTIFF
# ---------------------------
def parse_hgt_bounds(hgt_path):
    name = os.path.splitext(os.path.basename(hgt_path))[0]
    m = re.match(r'([NS])(\d{1,2})([EW])(\d{1,3})', name, re.IGNORECASE)
    if not m:
        raise ValueError(f"Cannot parse HGT name: {hgt_path}")
    lat_sign = 1 if m.group(1).upper() == 'N' else -1
    lon_sign = 1 if m.group(3).upper() == 'E' else -1
    lat0 = lat_sign * int(m.group(2))
    lon0 = lon_sign * int(m.group(4))
    west, south = float(lon0), float(lat0)
    east, north = west + 1.0, south + 1.0
    return west, south, east, north

def hgt_to_gtiff(hgt_path, tif_path):
    west, south, east, north = parse_hgt_bounds(hgt_path)
    nbytes = os.path.getsize(hgt_path)
    side = int(np.sqrt(nbytes // 2))
    if side not in (3601, 1201):
        raise ValueError(f"Unexpected HGT side length: {side}")
    data = np.fromfile(hgt_path, dtype=">i2").reshape((side, side))
    data = data[:-1, :-1]
    res = 1.0 / (side - 1)

    transform = from_origin(west, north, res, res)
    profile = {
        "driver": "GTiff",
        "height": data.shape[0],
        "width": data.shape[1],
        "count": 1,
        "dtype": "int16",
        "crs": "EPSG:4326",
        "transform": transform,
        "nodata": -32768,
        "tiled": True,
        "compress": "LZW"
    }

    with rasterio.open(tif_path, "w", **profile) as dst:
        dst.write(data, 1)

def prepare_tif(path):
    """Unpack zip/HGT and convert to GeoTIFF. Remove raw files after processing."""
    if path.lower().endswith(".tif"):
        return os.path.abspath(path)

    if path.lower().endswith(".zip"):
        tif_out, hgt_out = None, None
        with zipfile.ZipFile(path, "r") as z:
            tifs = [m for m in z.namelist() if m.lower().endswith(".tif")]
            if tifs:
                tif_out = os.path.join(os.path.dirname(path), os.path.basename(tifs[0]))
                if not os.path.exists(tif_out):
                    z.extract(tifs[0], os.path.dirname(path))
                tif_out = os.path.abspath(tif_out)
            else:
                hgts = [m for m in z.namelist() if m.lower().endswith(".hgt")]
                if hgts:
                    hgt_out = os.path.join(os.path.dirname(path), os.path.basename(hgts[0]))
                    if not os.path.exists(hgt_out):
                        z.extract(hgts[0], os.path.dirname(path))
                    tif_out = hgt_out.replace(".hgt", ".tif")
                    if not os.path.exists(tif_out):
                        hgt_to_gtiff(hgt_out, tif_out)
                    try:
                        os.remove(hgt_out)
                    except PermissionError:
                        pass
                    tif_out = os.path.abspath(tif_out)
        try:
            os.remove(path)
        except PermissionError:
            pass
        if tif_out:
            return tif_out
        else:
            raise FileNotFoundError(f"No .tif or .hgt in {path}")
    raise FileNotFoundError(f"Unsupported DEM format: {path}")

# ---------------------------
# Whitebox
# ---------------------------
wbt = WhiteboxTools()
wbt.verbose = False

def run_whitebox(tif_file):
    tif_file = os.path.abspath(tif_file).replace("\\", "/")
    slope_tif = tif_file.replace(".tif", "_slope.tif")
    aspect_tif = tif_file.replace(".tif", "_aspect.tif")
    geomorph_tif = tif_file.replace(".tif", "_geomorph.tif")

    if not os.path.exists(slope_tif):
        wbt.slope(dem=tif_file, output=slope_tif, zfactor=1.0, units="degrees")
    if not os.path.exists(aspect_tif):
        wbt.aspect(dem=tif_file, output=aspect_tif)
    if not os.path.exists(geomorph_tif):
        wbt.geomorphons(dem=tif_file, output=geomorph_tif, search=3, threshold=0.0, forms=True)

    return slope_tif, aspect_tif, geomorph_tif

# ---------------------------
# Extract raster value
# ---------------------------
def extract_value(raster, lat, lon):
    if raster is None or not os.path.exists(raster):
        return None
    with rasterio.open(raster) as src:
        for val in src.sample([(lon, lat)]):
            return float(val[0])

# ---------------------------
# Main pipeline for GeoJSON
# ---------------------------
def enrich_geojson(input_geojson, output_geojson, out_dir="dem_tiles"):
    os.makedirs(out_dir, exist_ok=True)

    gdf = gpd.read_file(input_geojson)

    # 🔑 Ensure correct CRS (must be WGS84 lat/lon)
    if gdf.crs is None:
        print("⚠️ No CRS found, assuming EPSG:4326")
        gdf.set_crs(epsg=4326, inplace=True)
    else:
        gdf = gdf.to_crs(epsg=4326)

    # Add expected cols
    for col in ["dem", "slope", "aspect", "geomorphon", "dem_source", "geomorphon_class"]:
        if col not in gdf.columns:
            gdf[col] = None

    # Collect centroids
    centroids = gdf.geometry.centroid
    coords = [(pt.y, pt.x) for pt in centroids]  # lat, lon
    print("🔍 Sample centroid coordinates:", coords[:5])

    # Step 1: collect needed tiles
    needed_tiles = {}
    for (lat, lon) in tqdm(coords, desc="Collecting tiles"):
        tid = tile_id_from_coords(lat, lon)
        if tid and tid not in needed_tiles:
            needed_tiles[tid] = (lat, lon)

    # Step 2: prepare tiles
    downloaded = {}
    for tid, (lat, lon) in tqdm(needed_tiles.items(), desc="Preparing tiles"):
        tif_path = os.path.join(out_dir, f"{tid}.tif")
        if os.path.exists(tif_path):
            downloaded[tid] = ([tif_path], "Local")
        else:
            zip_paths, source = download_dem_point(lat, lon, out_dir=out_dir)
            if zip_paths:
                tifs = [prepare_tif(zp) for zp in zip_paths]
                downloaded[tid] = (tifs, source)

    # Step 3: run Whitebox
    tile_results = {}
    for tid, (tifs, source) in tqdm(downloaded.items(), desc="Running Whitebox"):
        for tif in tifs:
            slope_tif, aspect_tif, geomorph_tif = run_whitebox(tif)
            tile_results[tid] = (tif, slope_tif, aspect_tif, geomorph_tif, source)

    # Step 4: extract values for each centroid
    geomorph_classes = {
        1: "flat", 2: "summit", 3: "ridge", 4: "shoulder", 5: "spur",
        6: "slope", 7: "hollow", 8: "footslope", 9: "valley", 10: "pit"
    }

    for idx, (lat, lon) in enumerate(tqdm(coords, desc="Extracting values")):
        tid = tile_id_from_coords(lat, lon)
        if tid is None or tid not in tile_results:
            continue
        tif, slope_tif, aspect_tif, geomorph_tif, source = tile_results[tid]
        gdf.at[idx, "dem"] = extract_value(tif, lat, lon)
        gdf.at[idx, "slope"] = extract_value(slope_tif, lat, lon)
        gdf.at[idx, "aspect"] = extract_value(aspect_tif, lat, lon)
        gdf.at[idx, "geomorphon"] = extract_value(geomorph_tif, lat, lon)
        gdf.at[idx, "geomorphon_class"] = geomorph_classes.get(gdf.at[idx, "geomorphon"], None)

    # Save enriched GeoJSON
    gdf.to_file(output_geojson, driver="GeoJSON")
    print(f"✅ Done! Saved {output_geojson}")

# ---------------------------
# Run
# ---------------------------
if __name__ == "__main__":
    enrich_geojson(
        "data/grid_tuscany_forest.geojson",
        "data/grid_tuscany_with_topography.geojson"
    )



  centroids = gdf.geometry.centroid


🔍 Sample centroid coordinates: [(44.386179452275954, 9.864970345902769), (44.38629011836419, 9.86976997901685), (44.35718832429243, 9.913952012201083), (44.35425125363808, 9.914213167113978), (44.35749819170338, 9.919000002801898)]


Collecting tiles: 100%|██████████| 133826/133826 [00:00<00:00, 1186612.62it/s]
Preparing tiles: 100%|██████████| 9/9 [01:12<00:00,  8.09s/it]
Running Whitebox: 100%|██████████| 9/9 [00:09<00:00,  1.02s/it]
Extracting values: 100%|██████████| 133826/133826 [31:10<00:00, 71.53it/s] 


✅ Done! Saved data/grid_tuscany_with_topography.geojson
