# Define & Justify Covariate & Process Layers
Identify and acquire wall-to-wall raster data for:
- Balancing Covariates: (e.g., slope, relief, curvature). Justify selection based on potential variance reduction and geomorphic interpretability.
- Process-Based Metrics: (e.g., Topographic Wetness Index, Stream Power Index).

In [None]:
# === 1. Dynamic Resource Allocation & Dask Setup with Auto-Installation ===

# --- Auto-Install Function (define first) ---
def install_package(package_name, conda_name=None, channel="conda-forge"):
    """Install a package using micromamba."""
    try:
        conda_package = conda_name if conda_name else package_name
        print(f"Installing {conda_package} from {channel}...")
        subprocess.check_call([
            "micromamba", "install", "-n", "app",
            "-c", channel, conda_package, "-y"
        ], env=os.environ.copy())
        print(f"Successfully installed {conda_package}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"Failed to install {conda_package}: {e}")
        return False

# --- Core Libraries with Auto-Install ---
import os
import sys
import json
import warnings
import datetime as dt
from pathlib import Path
from typing import List, Dict, Optional
import time
import subprocess

# Standard scientific libraries (should already be installed)
import numpy as np
import psutil
import math

# Dask
try:
    import dask
    import dask.array as da
    from dask.diagnostics import ProgressBar
    from dask.distributed import Client, LocalCluster
    from dask import delayed
    print("Dask found.")
except ImportError:
    print("Dask not found. Installing...")
    if install_package("dask") and install_package("distributed"):
        import dask
        import dask.array as da
        from dask.diagnostics import ProgressBar
        from dask.distributed import Client, LocalCluster
        from dask import delayed
        print("Dask successfully installed and imported.")
    else:
        raise ImportError("Could not install Dask. Please install manually.")

# Rasterio
try:
    import rasterio
    from rasterio.windows import Window
    print("Rasterio found.")
except ImportError:
    print("Rasterio not found. Installing...")
    if install_package("rasterio"):
        import rasterio
        from rasterio.windows import Window
        print("Rasterio successfully installed and imported.")
    else:
        raise ImportError("Could not install Rasterio. Please install manually.")

# Rioxarray
try:
    import rioxarray as rxr
    print("Rioxarray found.")
except ImportError:
    print("Rioxarray not found. Installing...")
    if install_package("rioxarray"):
        import rioxarray as rxr
        print("Rioxarray successfully installed and imported.")
    else:
        raise ImportError("Could not install Rioxarray. Please install manually.")

# Xarray
try:
    import xarray as xr
    print("Xarray found.")
except ImportError:
    print("Xarray not found. Installing...")
    if install_package("xarray"):
        import xarray as xr
        print("Xarray successfully installed and imported.")
    else:
        raise ImportError("Could not install Xarray. Please install manually.")

# GeoPandas
try:
    import geopandas as gpd
    print("GeoPandas found.")
except ImportError:
    print("GeoPandas not found. Installing...")
    if install_package("geopandas"):
        import geopandas as gpd
        print("GeoPandas successfully installed and imported.")
    else:
        raise ImportError("Could not install GeoPandas. Please install manually.")

# SciPy
try:
    from scipy.ndimage import maximum_filter, minimum_filter, generic_filter
    print("SciPy found.")
except ImportError:
    print("SciPy not found. Installing...")
    if install_package("scipy"):
        from scipy.ndimage import maximum_filter, minimum_filter, generic_filter
        print("SciPy successfully installed and imported.")
    else:
        raise ImportError("Could not install SciPy. Please install manually.")

# Scikit-image
try:
    from skimage.transform import resize
    print("Scikit-image found.")
except ImportError:
    print("Scikit-image not found. Installing...")
    if install_package("scikit-image"):
        from skimage.transform import resize
        print("Scikit-image successfully installed and imported.")
    else:
        raise ImportError("Could not install Scikit-image. Please install manually.")

# --- GPU Libraries (Install Later via Environment) ---
try:
    import cupy as cp
    import cupyx.scipy.ndimage as ndimage_gpu
    from rmm.allocators.cupy import rmm_cupy_allocator
    cp.cuda.set_allocator(rmm_cupy_allocator)
    _HAS_CUDA = True
    print("CUDA libraries found. GPU acceleration is available.")
    # Get properties of the first GPU
    gpu = cp.cuda.Device(0)
    print(f"   GPU Name: {gpu.name}")
    print(f"   GPU Memory: {gpu.mem_info[1] / 1024**3:.2f} GB total, {gpu.mem_info[0] / 1024**3:.2f} GB free")
except ImportError as e:
    print(f"CUDA libraries not found: {e}")
    print("Add 'cupy' and 'rmm' to environment.yml for GPU acceleration.")
    _HAS_CUDA = False

# --- WhiteboxTools (Install Later via Environment) ---
try:
    from whitebox.whitebox_tools import WhiteboxTools
    _HAS_WBT = True
    print("WhiteboxTools found.")
except ImportError:
    print("WhiteboxTools not found. Add 'whitebox' to environment.yml if needed.")
    _HAS_WBT = False

# --- PyWavelets ---
try:
    import pywt
    _HAS_PYWT = True
    print("PyWavelets found.")
except ImportError:
    print("PyWavelets not found. Installing...")
    if install_package("pywavelets"):
        try:
            import pywt
            _HAS_PYWT = True
            print("PyWavelets successfully installed and imported.")
        except ImportError:
            print("PyWavelets installed. Please restart kernel to use it.")
            _HAS_PYWT = False
    else:
        print("Could not install PyWavelets.")
        _HAS_PYWT = False

# --- Scikit-TDA (Ripser + Persim) ---
try:
    import ripser
    import persim
    _HAS_SCIKIT_TDA = True
    print("Scikit-TDA (ripser, persim) found.")
except ImportError:
    print("Scikit-TDA libraries not found. Installing...")
    ripser_success = install_package("ripser")
    persim_success = install_package("persim")

    if ripser_success and persim_success:
        try:
            import ripser
            import persim
            _HAS_SCIKIT_TDA = True
            print("Scikit-TDA successfully installed and imported.")
        except ImportError:
            print("Scikit-TDA installed. Please restart kernel to use it.")
            _HAS_SCIKIT_TDA = False
    else:
        print("Could not install all Scikit-TDA components.")
        _HAS_SCIKIT_TDA = False

# --- Robust Path Resolution ---
def find_project_root(marker='README.md'):
    """Find the project root by searching upwards for a marker file."""
    path = Path.cwd().resolve()
    while path.parent != path:
        if (path / marker).exists():
            return path
        path = path.parent
    raise FileNotFoundError(f"Project root with marker '{marker}' not found.")

PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / "data"
PROC_DIR = DATA_DIR / "processed"

# --- Analysis Grid & CRS ---
TARGET_CRS = "EPSG:5070"  # NAD83 / Conus Albers
TARGET_RES_M = 30.0

# --- Covariate Parameters (Centralized Configuration) ---
RELIEF_WINDOWS_M: List[int] = [90, 450, 1500]
WAVELET: str = "db2"
WAVELET_LEVELS: int = 3
GAUSSIAN_SIGMA_M: float = 60.0
MIN_SLOPE_DEG: float = 0.1
TDA_N_SAMPLES: int = 5000  # Number of pixels to sample for TDA calculation

# --- Inputs (from previous notebooks) ---
DEM_TILES_DIR = PROC_DIR / "dem_1arcsec_tiles"
AOI_GPKG_PATH = PROC_DIR / "study_area_provinces.gpkg"

AOI_LAYER = "provinces"
DEM_TILES = sorted(DEM_TILES_DIR.glob("*.tif"))

# --- Outputs ---
COVARIATE_OUT_DIR = PROC_DIR / "covariates"
COVARIATE_OUT_DIR.mkdir(parents=True, exist_ok=True)
ALIGNED_DEM_PATH = COVARIATE_OUT_DIR / "dem_aligned.tif"
VRT_PATH = COVARIATE_OUT_DIR / "source_dem.vrt"
SLOPE_PATH = COVARIATE_OUT_DIR / "slope_deg.tif"
RELIEF_PATHS = {w: COVARIATE_OUT_DIR / f"relief_{w}m.tif" for w in RELIEF_WINDOWS_M}
FILLED_PATH = COVARIATE_OUT_DIR / "dem_filled.tif"
SMOOTH_PATH = COVARIATE_OUT_DIR / "dem_smooth_for_hydro.tif"
TWI_PATH = COVARIATE_OUT_DIR / "twi.tif"
SPI_PATH = COVARIATE_OUT_DIR / "spi.tif"
ROUGH_PATHS = {
    lvl: COVARIATE_OUT_DIR / f"rough_{WAVELET}_L{lvl}.tif"
    for lvl in range(1, WAVELET_LEVELS + 1)
}
TDA_PIT_PATH = COVARIATE_OUT_DIR / "tda_pit_persistence.tif"
TDA_RIDGE_PATH = COVARIATE_OUT_DIR / "tda_ridge_persistence.tif"

# --- Centralized Tiling Configuration for GPU ---
TILE_SIZE = 4096
HALO = 256

# --- Dask & Parallel Processing Configuration ---
PHYSICAL_CORES = psutil.cpu_count(logical=False)
AVAILABLE_MEMORY_BYTES = psutil.virtual_memory().available
MEMORY_FOR_DASK_BYTES = AVAILABLE_MEMORY_BYTES * 0.80
MEMORY_LIMIT_PER_WORKER = MEMORY_FOR_DASK_BYTES / PHYSICAL_CORES
N_WORKERS = PHYSICAL_CORES
SAFETY_FACTOR = 12
BYTES_PER_PIXEL = np.dtype('float32').itemsize
target_chunk_bytes = MEMORY_LIMIT_PER_WORKER / SAFETY_FACTOR
pixels_per_chunk = target_chunk_bytes / BYTES_PER_PIXEL

cluster = LocalCluster(
    n_workers=N_WORKERS,
    threads_per_worker=1,
    memory_limit=MEMORY_LIMIT_PER_WORKER
)
client = Client(cluster)

# --- Print Setup Summary ---
print("\n--- Configuration Summary ---")
print(f"Project Root:        {PROJECT_ROOT}")
print(f"Found {len(DEM_TILES)} DEM tiles.")
print(f"Output Covariate Dir:  {COVARIATE_OUT_DIR}")
print("-" * 29)
print(f"Dask Dashboard:      {client.dashboard_link}")
print(f"Dask Cluster Size:   {N_WORKERS} workers @ {psutil._common.bytes2human(MEMORY_LIMIT_PER_WORKER)} each")
print(f"Optional deps -> WhiteboxTools: {_HAS_WBT} | PyWavelets: {_HAS_PYWT} | Scikit-TDA: {_HAS_SCIKIT_TDA}")
print("-" * 29)
print(f"GPU Available:       {_HAS_CUDA}")
print(f"GPU Tile Size:       {TILE_SIZE}x{TILE_SIZE} with {HALO}px halo")
print("-----------------------------")

if not all([_HAS_WBT, _HAS_PYWT, _HAS_SCIKIT_TDA, _HAS_CUDA]):
    print("\nSome packages missing. Consider adding to environment.yml and rebuilding.")

print("\nSetup complete! Ready for geospatial analysis.")

In [None]:
# === 1. Dynamic Resource Allocation & Dask Setup ===

# --- Core Libraries ---
import os
import sys
import json
import warnings
import datetime as dt
from pathlib import Path
from typing import List, Dict, Optional
import time

import dask
import dask.array as da
from dask.diagnostics import ProgressBar
from dask.distributed import Client, LocalCluster
from dask import delayed

import numpy as np
import rasterio
from rasterio.windows import Window
import rioxarray as rxr
import xarray as xr
import geopandas as gpd
import psutil
import math
import subprocess
from scipy.ndimage import maximum_filter, minimum_filter, generic_filter
from skimage.transform import resize

try:
    import cupy as cp
    import cupyx.scipy.ndimage as ndimage_gpu
    from rmm.allocators.cupy import rmm_cupy_allocator
    cp.cuda.set_allocator(rmm_cupy_allocator)
    _HAS_CUDA = True
    print("✅ CUDA libraries found. GPU acceleration is available.")
    # Get properties of the first GPU
    gpu = cp.cuda.Device(0)
    print(f"   GPU Name: {gpu.name}")
    print(f"   GPU Memory: {gpu.mem_info[1] / 1024**3:.2f} GB total, {gpu.mem_info[0] / 1024**3:.2f} GB free")
except ImportError:
    _HAS_CUDA = False
    print("❌ CUDA libraries (cupy, rmm) not found. GPU acceleration is disabled.")

# --- Optional Heavy Dependencies (Checked at Runtime) ---
try:
    from whitebox.whitebox_tools import WhiteboxTools
    _HAS_WBT = True
except ImportError:
    _HAS_WBT = False
try:
    import pywt
    _HAS_PYWT = True
except ImportError:
    _HAS_PYWT = False
# *** FIX: ADD THE SCIKIT-TDA CHECK HERE ***
try:
    import ripser
    import persim
    _HAS_SCIKIT_TDA = True
except ImportError:
    _HAS_SCIKIT_TDA = False

# --- Robust Path Resolution ---
def find_project_root(marker='README.md'):
    """Find the project root by searching upwards for a marker file."""
    path = Path.cwd().resolve()
    while path.parent != path:
        if (path / marker).exists():
            return path
        path = path.parent
    raise FileNotFoundError(f"Project root with marker '{marker}' not found.")

PROJECT_ROOT = find_project_root()
DATA_DIR = PROJECT_ROOT / "data"
PROC_DIR = DATA_DIR / "processed"

# --- Analysis Grid & CRS ---
TARGET_CRS = "EPSG:5070"  # NAD83 / Conus Albers
TARGET_RES_M = 30.0

# --- Covariate Parameters (Centralized Configuration) ---
RELIEF_WINDOWS_M: List[int] = [90, 450, 1500]
WAVELET: str = "db2"
WAVELET_LEVELS: int = 3
GAUSSIAN_SIGMA_M: float = 60.0
MIN_SLOPE_DEG: float = 0.1
# *** NEW: ADD TDA PARAMETER HERE ***
TDA_N_SAMPLES: int = 5000 # Number of pixels to sample for TDA calculation

# --- Inputs (from previous notebooks) ---
DEM_TILES_DIR = PROC_DIR / "dem_1arcsec_tiles"
AOI_GPKG_PATH = PROC_DIR / "study_area_provinces.gpkg"

AOI_LAYER = "provinces"
DEM_TILES = sorted(DEM_TILES_DIR.glob("*.tif"))

# --- Outputs (All defined here to avoid NameError later) ---
COVARIATE_OUT_DIR = PROC_DIR / "covariates"
COVARIATE_OUT_DIR.mkdir(parents=True, exist_ok=True)
ALIGNED_DEM_PATH = COVARIATE_OUT_DIR / "dem_aligned.tif"
VRT_PATH = COVARIATE_OUT_DIR / "source_dem.vrt"
SLOPE_PATH = COVARIATE_OUT_DIR / "slope_deg.tif"
RELIEF_PATHS = {w: COVARIATE_OUT_DIR / f"relief_{w}m.tif" for w in RELIEF_WINDOWS_M}
FILLED_PATH = COVARIATE_OUT_DIR / "dem_filled.tif"
SMOOTH_PATH = COVARIATE_OUT_DIR / "dem_smooth_for_hydro.tif"
TWI_PATH = COVARIATE_OUT_DIR / "twi.tif"
SPI_PATH = COVARIATE_OUT_DIR / "spi.tif"
ROUGH_PATHS = {
    lvl: COVARIATE_OUT_DIR / f"rough_{WAVELET}_L{lvl}.tif"
    for lvl in range(1, WAVELET_LEVELS + 1)
}
TDA_PIT_PATH = COVARIATE_OUT_DIR / "tda_pit_persistence.tif"
TDA_RIDGE_PATH = COVARIATE_OUT_DIR / "tda_ridge_persistence.tif"

# --- Centralized Tiling Configuration for GPU ---
# Define a tile size that will comfortably fit in GPU memory with overhead
# 4096x4096 is a good starting point for GPUs with >12GB VRAM. Adjust if needed.
TILE_SIZE = 4096
# Define overlap/halo for focal operations to avoid edge artifacts
HALO = 256

# --- Dask & Parallel Processing Configuration ---
# MEMORY OPTIMIZATION: Conservative settings for 64GB usable memory
PHYSICAL_CORES = psutil.cpu_count(logical=False)
TOTAL_USABLE_MEMORY_GB = 64  # Set to your safe memory limit
TOTAL_USABLE_MEMORY_BYTES = TOTAL_USABLE_MEMORY_GB * 1024**3

# Use only 50% of cores to reduce memory pressure
N_WORKERS = max(PHYSICAL_CORES // 2, 2)

# Reserve memory for OS and other processes
MEMORY_FOR_DASK_BYTES = TOTAL_USABLE_MEMORY_BYTES * 0.70  # 70% of 64GB = ~45GB
MEMORY_LIMIT_PER_WORKER = MEMORY_FOR_DASK_BYTES / N_WORKERS

# More conservative safety factor for large rasters
SAFETY_FACTOR = 20  # Increased from 12 to reduce chunk sizes
BYTES_PER_PIXEL = np.dtype('float32').itemsize
target_chunk_bytes = MEMORY_LIMIT_PER_WORKER / SAFETY_FACTOR
pixels_per_chunk = target_chunk_bytes / BYTES_PER_PIXEL

cluster = LocalCluster(
    n_workers=N_WORKERS,
    threads_per_worker=1,
    memory_limit=MEMORY_LIMIT_PER_WORKER
)
client = Client(cluster)

# --- Print Setup Summary ---
print("--- Configuration Summary ---")
print(f"Project Root:        {PROJECT_ROOT}")
print(f"Found {len(DEM_TILES)} DEM tiles.")
print(f"Output Covariate Dir:  {COVARIATE_OUT_DIR}")
print("-" * 29)
print(f"Dask Dashboard:      {client.dashboard_link}")
print(f"Dask Cluster Size:   {N_WORKERS} workers @ {psutil._common.bytes2human(MEMORY_LIMIT_PER_WORKER)} each")
print(f"Total Memory Budget: {TOTAL_USABLE_MEMORY_GB}GB (usable), {MEMORY_FOR_DASK_BYTES / 1024**3:.1f}GB (for Dask)")
print(f"Optional deps -> WhiteboxTools: {_HAS_WBT} | PyWavelets: {_HAS_PYWT} | Scikit-TDA: {_HAS_SCIKIT_TDA}")
print("-" * 29)
print(f"GPU Available:       {_HAS_CUDA}")
print(f"GPU Tile Size:       {TILE_SIZE}x{TILE_SIZE} with {HALO}px halo")
print("-----------------------------\n")

In [2]:
# CV-00b (optional): quick scanners to help locate DEMs
def scan_for_geotiffs(root=".", max_print=25):
    root = Path(root)
    if not root.exists():
        print(f"scan_for_geotiffs: '{root}' does not exist.")
        return []
    found = [p for p in root.rglob("*") if p.suffix.lower() in {".tif",".tiff",".vrt"}]
    print(f"Found {len(found)} candidate DEM file(s) under '{root}'.")
    for p in found[:max_print]:
        print(" -", p)
    if len(found) > max_print:
        print(" ... (truncated)")
    return found

# Examples:
# scan_for_geotiffs("data")
# scan_for_geotiffs("D:/")


In [None]:
# === 2. Build Aligned Analysis-Ready DEM ===
# This cell creates a single, analysis-ready DEM by:
#  1. Building a VRT from the downloaded tiles for efficient access.
#  2. Warping (reprojecting, clipping, resampling) the VRT to the final analysis grid.

if not ALIGNED_DEM_PATH.exists():
    if not DEM_TILES:
        raise FileNotFoundError(f"No DEM tiles found in {DEM_TILES_DIR}. Please run the previous notebook.")

    # 1. Build VRT from all source DEM tiles
    print(f"Building VRT from {len(DEM_TILES)} tiles -> {VRT_PATH}")
    input_list_path = COVARIATE_OUT_DIR / "gdal_input_file_list.txt"
    with open(input_list_path, 'w') as f:
        for tile_path in DEM_TILES:
            f.write(f"{str(tile_path)}\n")

    gdalbuildvrt_command = [
        'gdalbuildvrt', '-input_file_list', str(input_list_path), str(VRT_PATH)
    ]

    try:
        subprocess.run(gdalbuildvrt_command, check=True, capture_output=True, text=True)
        print("...VRT built successfully.")
    except subprocess.CalledProcessError as e:
        print("--- GDAL VRT ERROR ---"); print(e.stderr); raise
    finally:
        if input_list_path.exists(): input_list_path.unlink()

    # --- FIX FOR REPROJECTION MEMORY ERROR ---
    # 2. Use gdalwarp for memory-efficient reprojection, clipping, and resampling.
    print(f"Warping VRT to target grid using gdalwarp: {TARGET_CRS} @ {TARGET_RES_M}m resolution.")

    # Load the Area of Interest to get the target bounds
    aoi_gdf = gpd.read_file(AOI_GPKG_PATH, layer=AOI_LAYER).to_crs(TARGET_CRS)
    bounds = aoi_gdf.total_bounds

    # Build the gdalwarp command with memory-efficient settings
    gdalwarp_command = [
        'gdalwarp',
        '-t_srs', TARGET_CRS,       # Target CRS
        '-tr', str(TARGET_RES_M), str(TARGET_RES_M), # Target resolution
        '-te', str(bounds[0]), str(bounds[1]), str(bounds[2]), str(bounds[3]), # Target extent
        '-r', 'bilinear',           # Resampling method
        '-co', 'COMPRESS=LZW',      # Output compression
        '-co', 'TILED=YES',         # Output tiling
        '-co', 'BLOCKXSIZE=512',    # Smaller block size for memory efficiency
        '-co', 'BLOCKYSIZE=512',    # Smaller block size for memory efficiency
        '-co', 'BIGTIFF=IF_SAFER',
        '-wm', '2048',              # Limit warp memory to 2GB
        '-multi',                   # Use multithreading
        '-wo', 'NUM_THREADS=ALL_CPUS',
        '-of', 'GTiff',             # Output format
        '-overwrite',               # Overwrite if exists
        str(VRT_PATH),              # Input VRT
        str(ALIGNED_DEM_PATH)       # Output file
    ]

    try:
        # This is a long-running process, so we don't capture output unless there's an error.
        # This allows GDAL's progress bar to print to the console.
        print("Writing final aligned DEM (this may take several minutes)...")
        subprocess.run(gdalwarp_command, check=True)
        print("...Warping complete.")
    except subprocess.CalledProcessError as e:
        print("--- GDAL WARP ERROR ---"); print(e.stderr); raise
    # --- END FIX ---
else:
    print(f"Using existing analysis-ready DEM: {ALIGNED_DEM_PATH}")

# --- OPTIMIZED CHUNK CALCULATION ---
print("Calculating optimal chunk size for Dask (memory-efficient)...")
with rasterio.open(ALIGNED_DEM_PATH) as src:
    # Get native block size from the file
    block_y, block_x = src.block_shapes[0]
    print(f"Native block size: {block_y} x {block_x}")
    
    # Calculate chunk size based on conservative memory limits
    # Target: each chunk should be ~200MB to allow for intermediate operations
    TARGET_CHUNK_MB = 200
    target_chunk_bytes = TARGET_CHUNK_MB * 1024 * 1024
    pixels_per_chunk = target_chunk_bytes / BYTES_PER_PIXEL
    chunk_dim_raw = int(math.sqrt(pixels_per_chunk))
    
    # Align to block size for efficient I/O
    ALIGNMENT = 512  # Standard GeoTIFF block size
    chunk_dim_aligned = int(round(chunk_dim_raw / ALIGNMENT) * ALIGNMENT)
    chunk_dim_aligned = max(chunk_dim_aligned, ALIGNMENT)
    chunk_dim_aligned = min(chunk_dim_aligned, 4096)  # Cap at 4096 for safety
    
    OPTIMAL_CHUNKS = {'y': chunk_dim_aligned, 'x': chunk_dim_aligned}
    chunk_memory_mb = (chunk_dim_aligned ** 2 * BYTES_PER_PIXEL) / (1024 * 1024)
    print(f"Calculated Optimal Chunk Size: {OPTIMAL_CHUNKS}")
    print(f"Per-chunk memory: {chunk_memory_mb:.1f} MB")
    print(f"Max concurrent chunks with {N_WORKERS} workers: ~{int(MEMORY_FOR_DASK_BYTES / (chunk_memory_mb * 1024 * 1024))}")
# --- END OPTIMIZED CHUNK CALCULATION ---

# Load the final, aligned DEM as a Dask-backed DataArray for all subsequent steps
# Use smaller chunks to prevent memory overflow
dem_aligned = rxr.open_rasterio(ALIGNED_DEM_PATH, chunks=OPTIMAL_CHUNKS).squeeze("band", drop=True)
print("\nFinal Analysis-Ready DEM:")
print(dem_aligned)
print(f"\nEstimated memory if fully loaded: {dem_aligned.nbytes / (1024**3):.2f} GB")
print(f"Number of chunks: {len(dem_aligned.data.chunks[0]) * len(dem_aligned.data.chunks[1])}")

In [4]:
# === REVISED CELL 1: FINALIZED TDA Helper Functions ===

from ripser import ripser
from persim.landscapes import PersLandscapeExact
import numpy as np
import warnings

def calculate_tda_pit_persistence(dem_chunk: np.ndarray, n_samples: int) -> np.ndarray:
    """
    Calculates a local TDA pit persistence metric.

    **CRITICAL**: This function MUST return a 2D NumPy array of shape (1, 1)
    to be compatible with the `chunks` argument in `dask.array.map_blocks`.
    """
    # ... (all the sampling and ripser logic is the same as before) ...
    # ... we just change what is returned ...

    if np.all(np.isnan(dem_chunk)) or dem_chunk.shape[0] < 10 or dem_chunk.shape[1] < 10:
        return np.array([[np.nan]], dtype=np.float32)

    height_values = dem_chunk[~np.isnan(dem_chunk)].reshape(-1, 1)
    if height_values.shape[0] < 10:
        return np.array([[np.nan]], dtype=np.float32)

    if height_values.shape[0] > n_samples:
        step = int(np.ceil(height_values.shape[0] / n_samples))
        points_sampled = height_values[::step, :]
    else:
        points_sampled = height_values

    try:
        diagrams = ripser(points_sampled, maxdim=0)['dgms']
        if len(diagrams) == 0 or diagrams[0].shape[0] == 0:
            return np.array([[0.0]], dtype=np.float32)

        h0_diagram = diagrams[0]
        finite_h0 = h0_diagram[np.isfinite(h0_diagram[:, 1])]

        if finite_h0.shape[0] == 0:
            return np.array([[0.0]], dtype=np.float32)

        ple = PersLandscapeExact(dgms=[finite_h0], hom_deg=0)
        persistence_summary = ple.p_norm(p=1)

    except (ValueError, RuntimeError) as e:
        warnings.warn(f"TDA calculation failed on a chunk and will be replaced with NaN. Error: {e}")
        persistence_summary = np.nan

    # **THE FIX**: Return a (1, 1) NumPy array to match the output chunks.
    return np.array([[persistence_summary]], dtype=np.float32)


def calculate_tda_ridge_persistence(dem_chunk: np.ndarray, n_samples: int) -> np.ndarray:
    """Calculates TDA ridge persistence by inverting the DEM."""
    inverted_chunk = dem_chunk * -1.0
    return calculate_tda_pit_persistence(inverted_chunk, n_samples=n_samples)

In [5]:
# === CELL 5: STAGE 1 - GENERATE INTERMEDIATE TDA RASTERS ===

print("--- STAGE 1: Generating low-resolution TDA summary rasters ---")
tda_intermediate_tasks = []

if _HAS_SCIKIT_TDA:
    TDA_PIT_LOW_RES_PATH = COVARIATE_OUT_DIR / "tda_pit_persistence_low_res.tif"
    TDA_RIDGE_LOW_RES_PATH = COVARIATE_OUT_DIR / "tda_ridge_persistence_low_res.tif"

    # Define the explicit output chunking scheme
    num_y_blocks = len(dem_aligned.chunksizes['y'])
    num_x_blocks = len(dem_aligned.chunksizes['x'])
    output_chunks = ((1,) * num_y_blocks, (1,) * num_x_blocks)

    # Calculate coords once to reuse
    chunk_sizes_y = dem_aligned.chunksizes['y']
    chunk_sizes_x = dem_aligned.chunksizes['x']
    y_mid_indices = np.cumsum(chunk_sizes_y) - np.array(chunk_sizes_y) / 2
    x_mid_indices = np.cumsum(chunk_sizes_x) - np.array(chunk_sizes_x) / 2
    y_coords = dem_aligned.y.isel(y=y_mid_indices.astype(int)).data
    x_coords = dem_aligned.x.isel(x=x_mid_indices.astype(int)).data

    # --- Task for TDA Pit Persistence ---
    if not TDA_PIT_LOW_RES_PATH.exists():
        print(f"GRAPHING: Low-res TDA Pit Persistence -> {TDA_PIT_LOW_RES_PATH.name}")
        tda_low_res_dask = da.map_blocks(
            calculate_tda_pit_persistence, dem_aligned.data,
            n_samples=TDA_N_SAMPLES, dtype=np.float32, chunks=output_chunks
        )
        tda_low_res_da = xr.DataArray(
            tda_low_res_dask, dims=('y', 'x'), coords={'y': y_coords, 'x': x_coords}
        )
        tda_low_res_da.rio.write_crs(dem_aligned.rio.crs, inplace=True)

        # Get the lazy write task object by passing compute=False
        pit_write_task = tda_low_res_da.rio.to_raster(
            TDA_PIT_LOW_RES_PATH, tiled=True, lock=True, compress='LZW', windowed=True, compute=False
        )
        tda_intermediate_tasks.append(pit_write_task)
    else:
        print(f"EXISTS: Low-res TDA Pit Persistence")

    # --- Task for TDA Ridge Persistence ---
    if not TDA_RIDGE_LOW_RES_PATH.exists():
        print(f"GRAPHING: Low-res TDA Ridge Persistence -> {TDA_RIDGE_LOW_RES_PATH.name}")
        tda_low_res_ridge_dask = da.map_blocks(
            calculate_tda_ridge_persistence, dem_aligned.data,
            n_samples=TDA_N_SAMPLES, dtype=np.float32, chunks=output_chunks
        )
        tda_low_res_ridge_da = xr.DataArray(
            tda_low_res_ridge_dask, dims=('y', 'x'), coords={'y': y_coords, 'x': x_coords}
        )
        tda_low_res_ridge_da.rio.write_crs(dem_aligned.rio.crs, inplace=True)

        ridge_write_task = tda_low_res_ridge_da.rio.to_raster(
            TDA_RIDGE_LOW_RES_PATH, tiled=True, lock=True, compress='LZW', windowed=True, compute=False
        )
        tda_intermediate_tasks.append(ridge_write_task)
    else:
        print(f"EXISTS: Low-res TDA Ridge Persistence")

# --- Execute only the TDA intermediate tasks ---
if tda_intermediate_tasks:
    print(f"\nExecuting {len(tda_intermediate_tasks)} TDA intermediate tasks...")
    with ProgressBar():
        dask.compute(*tda_intermediate_tasks)
    print("...Stage 1 complete.")
else:
    print("\nAll TDA intermediate files already exist.")

--- STAGE 1: Generating low-resolution TDA summary rasters ---
GRAPHING: Low-res TDA Pit Persistence -> tda_pit_persistence_low_res.tif
GRAPHING: Low-res TDA Ridge Persistence -> tda_ridge_persistence_low_res.tif

Executing 2 TDA intermediate tasks...
...Stage 1 complete.


In [None]:
# === CELL 6: STAGE 2 - GENERATE ALL FINAL COVARIATES ===
import numpy as np
import pywt
from scipy.ndimage import generic_filter
from skimage.transform import resize

print("\n--- STAGE 2: Generating all final, full-resolution covariates ---")
final_write_tasks = []
wbt_tasks = [] # Keep wbt tasks separate as they are handled differently

# --- Helper function for wavelet roughness (already defined in your notebook) ---
# NOTE: This is the optimized version from our previous discussion.
def calculate_wavelet_roughness(dem_chunk: np.ndarray, wavelet_name: str, level: int) -> np.ndarray:
    """
    Calculates wavelet-based roughness. This version is optimized for memory and
    uses bilinear resizing for a smoother output.
    """
    if np.all(np.isnan(dem_chunk)):
        return dem_chunk.astype(np.float32)

    # Pad to reduce edge artifacts during wavelet transform
    pad_width = 16
    padded_chunk = np.pad(dem_chunk, pad_width=pad_width, mode='reflect')

    # Decompose the signal into different frequency bands
    coeffs = pywt.wavedec2(padded_chunk, wavelet=wavelet_name, level=level, mode='symmetric')
    cH, cV, cD = coeffs[-level] # Get detail coefficients at the desired scale

    # Calculate energy (magnitude) of the detail coefficients
    energy_at_scale = np.sqrt(cH**2 + cV**2 + cD**2)

    # Calculate local standard deviation of energy as the roughness metric
    roughness_at_scale = generic_filter(energy_at_scale, np.std, size=3, mode='reflect')

    # Use skimage.resize for efficient, smooth upsampling (bilinear interpolation)
    roughness_resized = resize(
        roughness_at_scale,
        padded_chunk.shape,
        order=1,
        preserve_range=True,
        anti_aliasing=True
    ).astype(np.float32)

    # Un-pad the result to match the original chunk shape
    return roughness_resized[pad_width:-pad_width, pad_width:-pad_width]

print("--- Defining Final Covariate Generation Tasks ---")

# --- Tier 1: Foundational Covariates ---
if not SLOPE_PATH.exists():
    print(f"GRAPHING: Slope -> {SLOPE_PATH.name}")
    slope_deg = xrs_slope(dem_aligned, name='slope_deg')
    task = slope_deg.rio.to_raster(SLOPE_PATH, tiled=True, lock=True, compress='LZW', windowed=True, compute=False)
    final_write_tasks.append(task)
else: print(f"EXISTS: Slope")

for w in RELIEF_WINDOWS_M:
    out_path = RELIEF_PATHS[w]
    if not out_path.exists():
        print(f"GRAPHING: Relief {w}m -> {out_path.name}")
        win_pix = max(3, round(w / TARGET_RES_M)); win_pix += (win_pix % 2 == 0)
        depth = win_pix // 2
        def relief_on_chunk(chunk, window_size):
            return maximum_filter(chunk, size=window_size, mode='reflect') - minimum_filter(chunk, size=window_size, mode='reflect')
        relief_dask_array = da.map_overlap(relief_on_chunk, dem_aligned.data, depth=depth, boundary='reflect', dtype=np.float32, window_size=win_pix)
        relief_da = xr.DataArray(relief_dask_array, coords=dem_aligned.coords, name=f"relief_{w}m")
        task = relief_da.rio.to_raster(out_path, tiled=True, lock=True, compress='LZW', windowed=True, compute=False)
        final_write_tasks.append(task)
    else: print(f"EXISTS: Relief {w}m")

if _HAS_PYWT:
    for lvl in range(1, WAVELET_LEVELS + 1):
        out_path = ROUGH_PATHS[lvl]
        if not out_path.exists():
            print(f"GRAPHING: Wavelet Roughness (L{lvl}) -> {out_path.name}")
            depth = 32 # Heuristic overlap for padding + filter
            roughness_dask_array = da.map_overlap(calculate_wavelet_roughness, dem_aligned.data, depth=depth, boundary='reflect', dtype=np.float32, wavelet_name=WAVELET, level=lvl)
            rough_xra = xr.DataArray(roughness_dask_array, coords=dem_aligned.coords, name=f"rough_L{lvl}")
            task = rough_xra.rio.to_raster(out_path, tiled=True, lock=True, compress='LZW', windowed=True, compute=False)
            final_write_tasks.append(task)
        else: print(f"EXISTS: Wavelet Roughness L{lvl}")
else: print("SKIPPING: Wavelet Roughness (pywt not found)")

# --- Tier 2: Process-Based Metrics (Executed Synchronously) ---
print("\n--- Executing Synchronous WhiteboxTools Processes ---")
if _HAS_WBT:
    wbt = WhiteboxTools(); wbt.verbose = False
    # The DEM needs to be filled and smoothed before TWI/SPI can be calculated.
    # These steps are run sequentially as they depend on each other.
    if not FILLED_PATH.exists():
        print("WBT: Breaching depressions...")
        wbt.breach_depressions_least_cost(dem=str(ALIGNED_DEM_PATH), output=str(FILLED_PATH), dist=5, fill=True)
    if not SMOOTH_PATH.exists():
        print("WBT: Smoothing DEM for hydro analysis...")
        wbt.gaussian_filter(i=str(FILLED_PATH), output=str(SMOOTH_PATH), sigma=(GAUSSIAN_SIGMA_M / TARGET_RES_M))
    if not TWI_PATH.exists():
        print("WBT: Calculating TWI...")
        wbt.topographic_wetness_index(i=str(SMOOTH_PATH), output=str(TWI_PATH))
    if not SPI_PATH.exists():
        print("WBT: Calculating SPI...")
        wbt.stream_power_index(i=str(SMOOTH_PATH), output=str(SPI_PATH))
    print("...WhiteboxTools processing complete.")
else:
    print("SKIPPING: Tier 2 Metrics (whitebox-tools not found)")

# --- Tier 3: Upsampling TDA Summaries ---
if _HAS_SCIKIT_TDA:
    @dask.delayed
    def run_gdal_upsample(in_path, out_path, template_path):
        """Upsamples a low-res raster to match the template's grid using gdalwarp."""
        if Path(out_path).exists(): return True # Skip if already done
        print(f"GDAL_WARP TASK: Upsampling {in_path.name} -> {out_path.name}")

        with rasterio.open(template_path) as t:
            bounds = t.bounds
            res_x, res_y = t.res

        gdalwarp_cmd = [
            'gdalwarp', '-r', 'bilinear', '-multi',
            '-tr', str(res_x), str(abs(res_y)),
            '-te', str(bounds.left), str(bounds.bottom), str(bounds.right), str(bounds.top),
            '-co', 'COMPRESS=LZW', '-co', 'TILED=YES', '-co', 'BIGTIFF=IF_SAFER',
            '-overwrite', str(in_path), str(out_path)
        ]
        try:
            subprocess.run(gdalwarp_cmd, check=True, capture_output=True, text=True)
        except subprocess.CalledProcessError as e:
            print(f"--- GDAL WARP (Upsample) ERROR for {out_path.name} ---\n{e.stderr}")
            raise
        return True

    # Define and add the upsampling tasks to the Dask graph
    print("GRAPHING: TDA Upsampling tasks using gdalwarp...")
    task_pit_upsample = run_gdal_upsample(TDA_PIT_LOW_RES_PATH, TDA_PIT_PATH, ALIGNED_DEM_PATH)
    task_ridge_upsample = run_gdal_upsample(TDA_RIDGE_LOW_RES_PATH, TDA_RIDGE_PATH, ALIGNED_DEM_PATH)
    final_write_tasks.extend([task_pit_upsample, task_ridge_upsample])
else:
    print("SKIPPING: TDA Upsampling (scikit-tda not found)")

# --- Execute All Final Tasks ---
all_final_tasks = final_write_tasks
if all_final_tasks:
    print(f"\nExecuting a total of {len(all_final_tasks)} final tasks...")
    with ProgressBar():
        dask.compute(*all_final_tasks)
    print("...Stage 2 complete. All covariates generated.")
else:
    print("\nAll final covariate layers already exist. Skipping generation.")


--- STAGE 2: Generating all final, full-resolution covariates ---
--- Defining Final Covariate Generation Tasks ---
GRAPHING: Slope -> slope_deg.tif
GRAPHING: Relief 90m -> relief_90m.tif
GRAPHING: Relief 450m -> relief_450m.tif
GRAPHING: Relief 1500m -> relief_1500m.tif
GRAPHING: Wavelet Roughness (L1) -> rough_db2_L1.tif
GRAPHING: Wavelet Roughness (L2) -> rough_db2_L2.tif
GRAPHING: Wavelet Roughness (L3) -> rough_db2_L3.tif

--- Executing Synchronous WhiteboxTools Processes ---
WBT: Breaching depressions...


In [None]:
# CV-08: quick sanity plots (optional); safe to skip in headless runs
import matplotlib.pyplot as plt

def quickplot(path: Path, title: str):
    da = rxr.open_rasterio(path).squeeze("band", drop=True)
    fig, ax = plt.subplots(figsize=(6,5))
    im = ax.imshow(da.data, origin="upper")
    ax.set_title(title)
    fig.colorbar(im, ax=ax)
    plt.show()

# Examples:
# quickplot(SLOPE_PATH, "Slope (deg)")
# quickplot(list(RELIEF_PATHS.values())[0], "Local relief (fine)")
# if _HAS_PYWT: quickplot(ROUGH_PATHS[0], "Wavelet roughness (L1)")
# if _HAS_WBT: quickplot(TWI_PATH, "Topographic Wetness Index (TWI)")
# if _HAS_WBT: quickplot(SPI_PATH, "Stream Power Index (SPI)")


In [None]:
# === 4. Final Manifest ===
# This cell creates a final manifest file to document the run.

# Collect all output file paths for the manifest
output_files = {
    "dem_aligned": ALIGNED_DEM_PATH,
    "dem_vrt": VRT_PATH,
    "slope_deg": SLOPE_PATH,
}
output_files.update({f"relief_{w}m": p for w, p in RELIEF_PATHS.items()})

# Placeholder for future outputs
if _HAS_PYWT:
    output_files.update({f"roughness_{WAVELET}_L{l}": ROUGH_PATHS[l] for l in range(1, WAVELET_LEVELS + 1)})
if _HAS_WBT:
    output_files.update({
        "dem_filled": FILLED_PATH,
        "dem_smooth_for_hydro": SMOOTH_PATH,
        "twi": TWI_PATH,
        "spi": SPI_PATH,
    })

manifest = {
    "created_utc": dt.datetime.utcnow().isoformat() + "Z",
    "notebook": "1.4_covariate_process_layers.ipynb",
    "config": {
        "target_crs": TARGET_CRS,
        "target_resolution_m": TARGET_RES_M,
        "relief_windows_m": RELIEF_WINDOWS_M,
        "wavelet": WAVELET,
        "wavelet_levels": WAVELET_LEVELS,
        "gaussian_sigma_m": GAUSSIAN_SIGMA_M,
        "min_slope_deg": MIN_SLOPE_DEG,
        "dask_cluster": {
            "n_workers": N_WORKERS,
            "memory_limit_per_worker": str(psutil._common.bytes2human(MEMORY_LIMIT_PER_WORKER)),
        },
    },
    "inputs": {
        "dem_tiles_dir": str(DEM_TILES_DIR),
        "num_dem_tiles": len(DEM_TILES),
        "aoi_gpkg": str(AOI_GPKG_PATH),
        "aoi_layer": AOI_LAYER,
    },
    "outputs": {k: str(v) for k, v in output_files.items()},
    "software": {
        "python_version": sys.version,
        "rasterio": rasterio.__version__,
        "xarray": xr.__version__,
        "rioxarray": rxr.__version__,
        "geopandas": gpd.__version__,
        "numpy": np.__version__,
        "pywavelets": (pywt.__version__ if _HAS_PYWT else "not found"),
        "whitebox": ("installed" if _HAS_WBT else "not found")
    }
}

manifest_path = COVARIATE_OUT_DIR / "manifest_1_4_covariates.json"
with open(manifest_path, "w") as f:
    json.dump(manifest, f, indent=2)

print(f"Wrote manifest: {manifest_path}")

In [None]:
# === 5. (Optional) Final Output Validation ===
# This cell verifies that all expected files were created.

print("--- Verifying Output Files ---")
all_output_paths = {**RELIEF_PATHS, **ROUGH_PATHS, "slope": SLOPE_PATH}
if _HAS_WBT:
    all_output_paths.update({"twi": TWI_PATH, "spi": SPI_PATH})
if _HAS_SCIKIT_TDA:
     all_output_paths.update({"tda_pit_persistence": TDA_PIT_PATH})

success_count = 0
fail_count = 0

for name, path in all_output_paths.items():
    if path.exists() and path.stat().st_size > 0:
        print(f"  [ OK ] {name:<25} -> {path.name}")
        success_count += 1
    else:
        print(f"  [FAIL] {name:<25} -> FILE NOT FOUND OR EMPTY")
        fail_count += 1

print("-" * 30)
if fail_count == 0:
    print(f"✅ Success! All {success_count} expected files were generated.")
else:
    print(f"❌ Warning: {fail_count} file(s) are missing or empty. Please review logs.")