In [None]:

import os 
import numpy as np
import geopandas as gpd
import glob
from pysheds.grid import Grid
from pysheds.view import Raster, ViewFinder
import rasterio
from rasterio.features import rasterize
from rasterio.plot import show
from rasterio.features import shapes
import matplotlib.pyplot as plt
from matplotlib import colors
from shapely.geometry import shape


In [2]:
raster_path = os.path.join('..','source_data','DEM_clipped_25m.tif')
catchment_points=os.path.join('..','source_data','sample_locations_2.gpkg')
catch_lyr=r"pour_points"
streams=os.path.join('..','source_data','streams.gpkg')
stream_layer="stream_net"
raster_out=os.path.join('..', 'work','test_site_basins_2')


In [None]:
# Load points with geopandas
catchment = gpd.read_file(catchment_points, layer=catch_lyr)
print(f"Original CRS read: {catchment.crs}")

# Set CRS explicitly if it's wrong
if catchment.crs != 'EPSG:3005':
    catchment = catchment.set_crs('EPSG:3005', allow_override=True)
    print(f"CRS set to: {catchment.crs}")

In [None]:
print("catchment point network info:")
print(f"  Total features: {len(catchment)}")
print(f"  CRS: {catchment.crs}")
print(f"  Geometry types: {catchment.geometry.type.value_counts()}")

# Check for invalid geometries
invalid_geoms = catchment[~catchment.geometry.is_valid] 
invalid_geoms = catchment[~catchment.geometry.is_valid]
if len(invalid_geoms) > 0:
    print(f"  WARNING: {len(invalid_geoms)} invalid geometries found!")
    # Fix invalid geometries
    catchment = catchment[catchment.geometry.is_valid]

In [None]:
# Load raster
with rasterio.open(raster_path) as src:
    raster_data = src.read(1)
    raster_bounds = src.bounds
    raster_crs = src.crs
    raster_transform = src.transform
    nodata = src.nodata
    
# Create ViewFinder from rasterio data
viewfinder = ViewFinder(affine=raster_transform, 
                       shape=raster_data.shape, 
                       crs=raster_crs, 
                       nodata=nodata)

# Create Raster object
dem = Raster(raster_data, viewfinder=viewfinder)

# Create Grid from Raster
grid = Grid.from_raster(dem)


print(f"Rasterio - Raster CRS: {raster_crs}")
print(f"Rasterio - Raster bounds: {raster_bounds}")
print(f"GeoPandas - Catchment CRS: {catchment.crs}")
print(f"GeoPandas - Point coordinates: {catchment.geometry.x.iloc[0]}, {catchment.geometry.y.iloc[0]}")

# Plot using rasterio and geopandas
fig, ax = plt.subplots(figsize=(10, 8))

# Plot raster
with rasterio.open(raster_path) as src:
    show(src, ax=ax, cmap='terrain')

# Plot points
catchment.plot(ax=ax, color='red', markersize=100, zorder=2)

plt.title('DEM and Catchment Points (Rasterio + GeoPandas)')
plt.xlabel('Easting (m)')
plt.ylabel('Northing (m)')
plt.tight_layout()

In [None]:
stream_network = gpd.read_file(streams, lyr=stream_layer)

In [None]:
print("Stream network info:")
print(f"  Total features: {len(stream_network)}")
print(f"  CRS: {stream_network.crs}")
print(f"  Geometry types: {stream_network.geometry.type.value_counts()}")

# Check for invalid geometries
invalid_geoms = stream_network[~stream_network.geometry.is_valid]
if len(invalid_geoms) > 0:
    print(f"  WARNING: {len(invalid_geoms)} invalid geometries found!")
    # Fix invalid geometries
    stream_network = stream_network[stream_network.geometry.is_valid]

In [8]:
# Plot stream network, DEM, and catchment points together
fig, ax = plt.subplots(figsize=(12, 10))

# Plot DEM as background
with rasterio.open(raster_path) as src:
    show(src, ax=ax, cmap='terrain', alpha=0.7)

# Plot stream network
stream_network.plot(ax=ax, color='blue', linewidth=1.5, alpha=0.8, label='Stream Network')

# Plot catchment points
catchment.plot(ax=ax, color='red', markersize=100, zorder=3, label='Pour Points')

# Add labels for catchment points
for idx, row in catchment.iterrows():
    ax.annotate(row['Site'], 
                xy=(row.geometry.x, row.geometry.y),
                xytext=(5, 5), textcoords='offset points',
                fontsize=8, color='darkred',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.7))

plt.title('DEM, Stream Network, and Pour Points')
plt.xlabel('Easting (m)')
plt.ylabel('Northing (m)')
plt.legend()
plt.tight_layout()
plt.show()

# Also create a zoomed view around each pour point
for idx, row in catchment.iterrows():
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Define buffer around point (in map units)
    buffer_dist = 2000  # 2km buffer
    minx = row.geometry.x - buffer_dist
    maxx = row.geometry.x + buffer_dist
    miny = row.geometry.y - buffer_dist
    maxy = row.geometry.y + buffer_dist
    
    # Plot DEM
    with rasterio.open(raster_path) as src:
        show(src, ax=ax, cmap='terrain', alpha=0.7)
    
    # Plot streams in area
    streams_clipped = stream_network.cx[minx:maxx, miny:maxy]
    if not streams_clipped.empty:
        streams_clipped.plot(ax=ax, color='blue', linewidth=2, alpha=0.8)
    
    # Plot the specific point
    ax.scatter(row.geometry.x, row.geometry.y, color='red', s=200, zorder=3, 
               edgecolor='white', linewidth=2)
    
    # Set extent
    ax.set_xlim(minx, maxx)
    ax.set_ylim(miny, maxy)
    
    plt.title(f'Area around {row["Site"]}')
    plt.xlabel('Easting (m)')
    plt.ylabel('Northing (m)')
    plt.tight_layout()
    plt.show()

In [9]:
# Create a mask where streams exist
stream_mask = rasterize(
    [(geom, 1) for geom in stream_network.geometry],
    out_shape=raster_data.shape,
    transform=raster_transform,
    fill=0,
    dtype='uint8'
)

In [10]:
del stream_network

In [11]:
# Condition DEM
# ----------------------
# Fill pits in DEM
pit_filled_dem = grid.fill_pits(dem)

# Fill depressions in DEM
flooded_dem = grid.fill_depressions(pit_filled_dem)
    
# Resolve flats in DEM
inflated_dem = grid.resolve_flats(flooded_dem)

In [12]:
#plot catchment ontop filled dem
fig, ax = plt.subplots(figsize=(10, 8))
ax.imshow(pit_filled_dem, extent=grid.extent, cmap='terrain', zorder=1)
catchment.plot(ax=ax, color='red', markersize=100, zorder=2)
plt.title('Catchment Area on Filled DEM')
plt.xlabel('Easting (m)')
plt.ylabel('Northing (m)')
plt.tight_layout()

In [13]:
# Determine D8 flow directions from DEM
# ----------------------
# Specify directional mapping
dirmap = (64, 128, 1, 2, 4, 8, 16, 32)
    
# Compute flow directions
# -------------------------------------
fdir = grid.flowdir(inflated_dem, dirmap=dirmap)

In [None]:
# Calculate flow accumulation
# --------------------------
acc = grid.accumulation(fdir, dirmap=dirmap)

In [None]:
# Plot flow accumulation
fig, ax = plt.subplots(figsize=(8,6))
fig.patch.set_alpha(0)
plt.grid('on', zorder=0)
im = ax.imshow(acc, extent=grid.extent, zorder=2,
               cmap='cubehelix',
               norm=colors.LogNorm(1, acc.max()),
               interpolation='bilinear')
plt.colorbar(im, ax=ax, label='Upstream Cells')
plt.title('Flow Accumulation', size=14)
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.tight_layout()

In [16]:
# Burn streams into DEM before processing
dem_burned = raster_data.copy().astype(np.float32)

# Set burn depth (meters to lower stream elevations)
burn_depth = 5  # Adjust this value as needed

# Burn streams by lowering elevation along stream paths
dem_burned = np.where(stream_mask > 0, dem_burned - burn_depth, dem_burned)

# Create new ViewFinder and Raster with burned DEM
viewfinder_burned = ViewFinder(affine=raster_transform, 
                              shape=dem_burned.shape, 
                              crs=raster_crs, 
                              nodata=nodata)

# Create new Raster object with burned DEM
dem_burned_raster = Raster(dem_burned, viewfinder=viewfinder_burned)

# Create new Grid from burned Raster
grid_burned = Grid.from_raster(dem_burned_raster)

# Condition the burned DEM
pit_filled_dem = grid_burned.fill_pits(dem_burned_raster)
flooded_dem = grid_burned.fill_depressions(pit_filled_dem)
inflated_dem = grid_burned.resolve_flats(flooded_dem)

# Compute flow directions from burned DEM
fdir = grid_burned.flowdir(inflated_dem, dirmap=dirmap)

# Calculate flow accumulation
acc = grid_burned.accumulation(fdir, dirmap=dirmap)

In [None]:
# Check stream network coverage and quality
print("Stream mask statistics:")
print(f"  Total stream pixels: {np.sum(stream_mask)}")
print(f"  Stream coverage: {np.sum(stream_mask) / stream_mask.size * 100:.2f}%")
print(f"  DEM shape: {raster_data.shape}")

# Check if burned streams created proper channels
burned_difference = dem_burned - raster_data
print("Burned elevation difference stats:")
print(f"  Min: {np.min(burned_difference):.2f}m")
print(f"  Max: {np.max(burned_difference):.2f}m")
print(f"  Mean where burned: {np.mean(burned_difference[stream_mask > 0]):.2f}m")

# Visualize the burned streams
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original DEM
axes[0].imshow(raster_data, cmap='terrain')
axes[0].set_title('Original DEM')

# Stream mask
axes[1].imshow(stream_mask, cmap='Blues')
axes[1].set_title('Stream Mask')

# Burned DEM difference
im = axes[2].imshow(burned_difference, cmap='RdBu_r', vmin=-10, vmax=0)
axes[2].set_title('Elevation Change (Burned - Original)')
plt.colorbar(im, ax=axes[2])

plt.tight_layout()
plt.show()

In [None]:
# Save original flow direction and accumulation
original_fdir = fdir.copy()
original_acc = acc.copy()
original_dem = dem_burned_raster.copy()

for catchment_x, catchment_y, site in zip(catchment.geometry.x, catchment.geometry.y, catchment['Site']):
    site = site.replace(' ','_').replace('(','').replace(')','').replace('-','_')
    print(f"\nProcessing: {site}")
    print(f"  Original coordinates: {catchment_x:.2f}, {catchment_y:.2f}")
    
    # Create a fresh grid for each catchment
    grid_new = Grid.from_raster(original_dem)
    
    # First, check what the flow accumulation is at the original point
    try:
        orig_col, orig_row = grid_new.nearest_cell(catchment_x, catchment_y)
        orig_acc_value = original_acc[orig_row, orig_col]
        print(f"  Flow accumulation at original point: {orig_acc_value:.0f}")
    except:
        print(f"  Cannot determine accumulation at original point")
        orig_acc_value = 0
    
    # Try different accumulation thresholds with stricter distance limits
    thresholds = [100, 250, 500, 1000]
    max_distances = [200, 400, 600, 800]  # Corresponding max distances
    snap_found = False
    best_snap = None
    
    for threshold, max_dist in zip(thresholds, max_distances):
        try:
            x_snap, y_snap = grid_new.snap_to_mask(original_acc > threshold, (catchment_x, catchment_y))
            distance = ((x_snap - catchment_x)**2 + (y_snap - catchment_y)**2)**0.5
            
            # Get accumulation at snap point
            snap_col, snap_row = grid_new.nearest_cell(x_snap, y_snap)
            snap_acc_value = original_acc[snap_row, snap_col]
            
            print(f"  Threshold {threshold}: Snapped to {x_snap:.2f}, {y_snap:.2f}")
            print(f"    Distance: {distance:.1f}m, Accumulation: {snap_acc_value:.0f}")
            
            # Accept if within distance and has reasonable accumulation
            if distance < max_dist and snap_acc_value > threshold:
                snap_found = True
                best_snap = (x_snap, y_snap, distance, snap_acc_value)
                break
        except Exception as e:
            print(f"  Threshold {threshold}: No suitable snap point found ({e})")
            continue
    
    # If snapping failed or original point has good accumulation, use original
    if not snap_found or orig_acc_value > 50:
        x_snap, y_snap = catchment_x, catchment_y
        print(f"  Using original coordinates (snap_found={snap_found}, orig_acc={orig_acc_value:.0f})")
    else:
        x_snap, y_snap, distance, snap_acc = best_snap
        print(f"  Using snapped coordinates: {x_snap:.2f}, {y_snap:.2f} (dist: {distance:.1f}m)")

    # Delineate catchment
    try:
        catch = grid_new.catchment(x=x_snap, y=y_snap, fdir=original_fdir, dirmap=dirmap, 
                                 xytype='coordinate')
    except Exception as e:
        print(f"  ERROR: Failed to delineate catchment: {e}")
        continue

    # Check if catchment is reasonable size
    catchment_area = np.sum(catch) * 25 * 25 / 1000000  # Convert to km²
    catchment_pixels = np.sum(catch)
    print(f"  Catchment: {catchment_pixels} pixels, {catchment_area:.4f} km²")
    
    # Skip if catchment is too small (less than 10 pixels or 0.01 km²)
    if catchment_pixels < 10 or catchment_area < 0.01:
        print(f"  WARNING: Catchment too small ({catchment_pixels} pixels, {catchment_area:.4f} km²)")
        
        # Create diagnostic plot for small catchments
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Define area around point for visualization
        buffer = 200  # pixels
        
        # Convert coordinates to indices using the viewfinder
        orig_col, orig_row = grid_new.nearest_cell(catchment_x, catchment_y)
        snap_col, snap_row = grid_new.nearest_cell(x_snap, y_snap)
        
        # Convert to integers
        orig_row, orig_col = int(orig_row), int(orig_col)
        snap_row, snap_col = int(snap_row), int(snap_col)
        
        r_min = max(0, min(orig_row, snap_row) - buffer)
        r_max = min(original_acc.shape[0], max(orig_row, snap_row) + buffer)
        c_min = max(0, min(orig_col, snap_col) - buffer)
        c_max = min(original_acc.shape[1], max(orig_col, snap_col) + buffer)
        
        # Flow accumulation
        acc_subset = original_acc[r_min:r_max, c_min:c_max]
        im1 = axes[0,0].imshow(acc_subset, cmap='Blues', norm=colors.LogNorm(vmin=1))
        axes[0,0].scatter(orig_col-c_min, orig_row-r_min, color='red', s=100, marker='o', label='Original')
        axes[0,0].scatter(snap_col-c_min, snap_row-r_min, color='orange', s=100, marker='x', label='Snapped')
        axes[0,0].set_title('Flow Accumulation')
        axes[0,0].legend()
        
        # DEM
        dem_subset = raster_data[r_min:r_max, c_min:c_max]
        axes[0,1].imshow(dem_subset, cmap='terrain')
        axes[0,1].scatter(orig_col-c_min, orig_row-r_min, color='red', s=100, marker='o')
        axes[0,1].scatter(snap_col-c_min, snap_row-r_min, color='orange', s=100, marker='x')
        axes[0,1].set_title('DEM')
        
        # Stream mask
        stream_subset = stream_mask[r_min:r_max, c_min:c_max]
        axes[1,0].imshow(stream_subset, cmap='Blues')
        axes[1,0].scatter(orig_col-c_min, orig_row-r_min, color='red', s=100, marker='o')
        axes[1,0].scatter(snap_col-c_min, snap_row-r_min, color='orange', s=100, marker='x')
        axes[1,0].set_title('Stream Mask')
        
        # Catchment
        catch_subset = catch[r_min:r_max, c_min:c_max] if catch.shape == original_acc.shape else None
        if catch_subset is not None:
            axes[1,1].imshow(catch_subset, cmap='Reds')
            axes[1,1].set_title(f'Catchment ({catchment_pixels} pixels)')
        else:
            axes[1,1].text(0.5, 0.5, 'Catchment shape mismatch', ha='center', va='center', transform=axes[1,1].transAxes)
        
        plt.suptitle(f'Diagnostic: {site} - Small Catchment ({catchment_area:.4f} km²)')
        plt.tight_layout()
        plt.show()
        
        print(f"  Skipping {site} due to small catchment size...")
        continue

    # If we get here, catchment is acceptable - proceed with export
    grid_new.clip_to(catch)
    clipped_catch = grid_new.view(catch)
    
    output_path = os.path.join(raster_out, f"{site}catchment.tif")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    grid_new.to_raster(clipped_catch, 
                     output_path,
                     dtype='uint8',
                     compress='lzw',
                     nodata=0)

    print(f"  Catchment raster saved to: {output_path}")

In [None]:

#raster to polygon to geopackage, with each catchment as seperate layer 
output_gpkg = os.path.join('..', 'work', 'catchments.gpkg')
raster_files = glob.glob(os.path.join(raster_out, '*catchment*'))

for raster_file in raster_files:
    site_name = os.path.basename(raster_file).replace('catchment.tif', '')
    print(f"Processing {site_name}...")
    with rasterio.open(raster_file) as src:
        image = src.read(1)
        mask = image != src.nodata
        results = (
            {'properties': {'raster_val': v}, 'geometry': s}
            for i, (s, v)
            in enumerate(shapes(image, mask=mask, transform=src.transform))
        )
        geoms = list(results)
        catchment_shape = shape(geoms[0]['geometry'])

        # Create a new GeoDataFrame for this catchment
        catchment_gdf = gpd.GeoDataFrame(
            [{'site': site_name, 'geometry': catchment_shape}],
            crs='EPSG:3005'
        )
        catchment_gdf.to_file(output_gpkg, layer=f"{site_name}_catchment", driver="GPKG")
