# Simple CostGrow-Style Fluvial Downscaling

Lightweight notebook version of the main `costGrow` flow:
1. Resample coarse WSE to DEM grid.
2. Keep wet partials where `WSE > DEM`.
3. Grow into dry cells with terrain-penalized least-cost paths (`skimage.graph.MCP_Geometric`).
4. Remove isolated wet regions not connected to phase-2 anchors.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import rasterio
from rasterio.warp import reproject, Resampling

from skimage.graph import MCP_Geometric
from skimage.measure import label


In [None]:
# Inputs
input_hires_dem_fp = '/path/to/high_res_dem.tif'
input_lores_wse_fp = '/path/to/low_res_wse.tif'

# Output
output_wse_fp = str(Path(input_hires_dem_fp).with_name('wse_downscaled_costgrow_simple.tif'))

# Controls
max_grow_coarse_pixels = 10
terrain_penalty_scale = 1.0
decay_per_meter = 0.0


In [None]:
def read_raster(fp):
    with rasterio.open(fp) as src:
        arr = src.read(1, masked=True).astype('float64').filled(np.nan)
        profile = src.profile.copy()
    return arr, profile


def reproject_to_match(src_arr, src_profile, dst_profile, resampling):
    src_nodata = -9999.0
    dst_nodata = -9999.0

    src_clean = np.where(np.isfinite(src_arr), src_arr, src_nodata).astype('float64')
    dst = np.full((dst_profile['height'], dst_profile['width']), dst_nodata, dtype='float64')

    reproject(
        source=src_clean,
        destination=dst,
        src_transform=src_profile['transform'],
        src_crs=src_profile['crs'],
        src_nodata=src_nodata,
        dst_transform=dst_profile['transform'],
        dst_crs=dst_profile['crs'],
        dst_nodata=dst_nodata,
        resampling=resampling,
    )

    dst[dst == dst_nodata] = np.nan
    return dst


def reproject_valid_mask(src_arr, src_profile, dst_profile):
    src_mask = np.isfinite(src_arr).astype('uint8')
    dst_mask = np.zeros((dst_profile['height'], dst_profile['width']), dtype='uint8')

    reproject(
        source=src_mask,
        destination=dst_mask,
        src_transform=src_profile['transform'],
        src_crs=src_profile['crs'],
        src_nodata=0,
        dst_transform=dst_profile['transform'],
        dst_crs=dst_profile['crs'],
        dst_nodata=0,
        resampling=Resampling.nearest,
    )
    return dst_mask.astype(bool)


def mcp_distance(seed_mask, domain_mask):
    valid_seeds = seed_mask & domain_mask
    if not np.any(valid_seeds):
        raise ValueError('No valid seed cells available for MCP distance.')

    cost = np.ones(seed_mask.shape, dtype='float64')
    cost[~domain_mask] = np.inf

    mcp = MCP_Geometric(cost, fully_connected=True)
    cumulative_costs, _ = mcp.find_costs(starts=np.argwhere(valid_seeds))
    return cumulative_costs


def mcp_fill(seed_values, seed_mask, cost_surface, domain_mask, target_mask=None):
    valid_seeds = seed_mask & domain_mask
    if not np.any(valid_seeds):
        raise ValueError('No valid seed cells available for MCP fill.')

    cost = cost_surface.astype('float64').copy()
    cost[~domain_mask] = np.inf

    mcp = MCP_Geometric(cost, fully_connected=True)
    cumulative_costs, _ = mcp.find_costs(starts=np.argwhere(valid_seeds))

    if target_mask is None:
        target_mask = domain_mask & ~valid_seeds
    else:
        target_mask = target_mask & domain_mask & ~valid_seeds

    filled = seed_values.copy()

    for r, c in np.argwhere(target_mask):
        if not np.isfinite(cumulative_costs[r, c]):
            continue
        src_r, src_c = mcp.traceback((r, c))[0]
        filled[r, c] = seed_values[src_r, src_c]

    return filled, cumulative_costs


def keep_components_connected_to_anchor(wet_mask, anchor_mask):
    labels = label(wet_mask.astype('uint8'), connectivity=1)
    keep_labels = np.unique(labels[anchor_mask & wet_mask])
    keep_labels = keep_labels[keep_labels != 0]
    return np.isin(labels, keep_labels) & wet_mask


In [None]:
# Load rasters
hires_dem, dem_profile = read_raster(input_hires_dem_fp)
lores_wse, wse_profile = read_raster(input_lores_wse_fp)

if dem_profile['crs'] != wse_profile['crs']:
    raise ValueError('DEM and WSE CRS must match before downscaling.')

# 01: resample coarse WSE to fine DEM grid
wse_fine_resampled = reproject_to_match(lores_wse, wse_profile, dem_profile, resampling=Resampling.bilinear)
resampled_valid = reproject_valid_mask(lores_wse, wse_profile, dem_profile)
wse_fine_resampled[~resampled_valid] = np.nan

# domain
dem_valid = np.isfinite(hires_dem)
wse_fine_resampled[~dem_valid] = np.nan

# 02: wet partials (must be above ground)
wse_wet_partials = wse_fine_resampled.copy()
wse_wet_partials[wse_wet_partials <= hires_dem] = np.nan
anchor_mask = np.isfinite(wse_wet_partials)

if not np.any(anchor_mask):
    raise ValueError('No wet cells remain after applying WSE > DEM.')

# rough coarse->fine scale factor
fine_pixel_size = float(np.mean([abs(dem_profile['transform'].a), abs(dem_profile['transform'].e)]))
coarse_pixel_size = float(np.mean([abs(wse_profile['transform'].a), abs(wse_profile['transform'].e)]))
downscale = max(1, int(round(coarse_pixel_size / fine_pixel_size)))
max_grow_fine_pixels = max_grow_coarse_pixels * downscale

# fill resampled WSE neutrally for terrain-penalty construction
neutral_seed_mask = np.isfinite(wse_fine_resampled)
wse_neutral_seed_values = np.where(neutral_seed_mask, wse_fine_resampled, np.nan)
wse_neutral_filled, _ = mcp_fill(
    seed_values=wse_neutral_seed_values,
    seed_mask=neutral_seed_mask,
    cost_surface=np.ones_like(hires_dem, dtype='float64'),
    domain_mask=dem_valid,
)

# terrain-penalized cost surface
delta = wse_neutral_filled - hires_dem
cost_surface = np.where(delta > 0.0, 1.0, 1.0 + np.abs(delta) * terrain_penalty_scale)
cost_surface[~dem_valid] = np.inf

# growth distance threshold (in fine pixels)
distance_pixels = mcp_distance(anchor_mask, dem_valid)
grow_mask = np.isfinite(distance_pixels) & (distance_pixels <= max_grow_fine_pixels)

# 03: dry partial growth with terrain-penalty MCP
wse_seed_values = np.where(anchor_mask, wse_wet_partials, np.nan)
wse_grown, _ = mcp_fill(
    seed_values=wse_seed_values,
    seed_mask=anchor_mask,
    cost_surface=cost_surface,
    domain_mask=dem_valid,
    target_mask=grow_mask,
)

# optional linear decay
decay = distance_pixels * fine_pixel_size * decay_per_meter
wse_grown = wse_grown - np.where(np.isfinite(decay), decay, 0.0)

# merge growth into phase-2 wet partials
wse_post = wse_wet_partials.copy()
add_mask = (~anchor_mask) & grow_mask & np.isfinite(wse_grown) & (wse_grown > hires_dem)
wse_post[add_mask] = wse_grown[add_mask]

# 04: remove isolated grown blobs not connected to anchors
wet_mask_post = np.isfinite(wse_post)
keep_mask = keep_components_connected_to_anchor(wet_mask_post, anchor_mask)
wse_post[~keep_mask] = np.nan
wse_post[~dem_valid] = np.nan

print(f'anchors: {anchor_mask.sum():,}')
print(f'grown cells added: {add_mask.sum():,}')
print(f'final wet cells: {np.isfinite(wse_post).sum():,}')


In [None]:
# Pre/Post plots
pre_depth = wse_fine_resampled - hires_dem
post_depth = wse_post - hires_dem

wse_all = np.concatenate([
    wse_fine_resampled[np.isfinite(wse_fine_resampled)],
    wse_post[np.isfinite(wse_post)],
])

if wse_all.size == 0:
    raise ValueError('No valid WSE values to plot.')

vmin, vmax = np.percentile(wse_all, [2, 98])

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

im0 = axes[0, 0].imshow(wse_fine_resampled, cmap='Blues', vmin=vmin, vmax=vmax)
axes[0, 0].set_title('Pre: Resampled Coarse WSE')
axes[0, 0].set_axis_off()
fig.colorbar(im0, ax=axes[0, 0], fraction=0.046)

im1 = axes[0, 1].imshow(wse_post, cmap='Blues', vmin=vmin, vmax=vmax)
axes[0, 1].set_title('Post: Downscaled WSE')
axes[0, 1].set_axis_off()
fig.colorbar(im1, ax=axes[0, 1], fraction=0.046)

axes[1, 0].hist(pre_depth[np.isfinite(pre_depth)], bins=60, color='#5DA5DA', alpha=0.9)
axes[1, 0].axvline(0.0, color='black', linestyle='--', linewidth=1)
axes[1, 0].set_title('Pre Depth Histogram (WSE - DEM)')
axes[1, 0].set_xlabel('Depth')
axes[1, 0].set_ylabel('Count')

axes[1, 1].hist(post_depth[np.isfinite(post_depth)], bins=60, color='#60BD68', alpha=0.9)
axes[1, 1].axvline(0.0, color='black', linestyle='--', linewidth=1)
axes[1, 1].set_title('Post Depth Histogram (WSE - DEM)')
axes[1, 1].set_xlabel('Depth')
axes[1, 1].set_ylabel('Count')

plt.tight_layout()
plt.show()


In [None]:
# Write output GeoTIFF
out_profile = dem_profile.copy()
out_profile.update(
    dtype='float32',
    count=1,
    nodata=-9999.0,
    compress='lzw',
)

out_arr = np.where(np.isfinite(wse_post), wse_post, out_profile['nodata']).astype('float32')
Path(output_wse_fp).parent.mkdir(parents=True, exist_ok=True)

with rasterio.open(output_wse_fp, 'w', **out_profile) as dst:
    dst.write(out_arr, 1)

print(f'Wrote downscaled WSE: {output_wse_fp}')
