# Inlet-Integrated Watershed Characterization

This notebook implements an inlet-integrated watershed characterization workflow for urban hydrology and stormwater analysis.

## Credits
- **Author:** Lapone Techapinyawat, Ph.D. (Texas A&M University–Corpus Christi), 2025  
- **Author:** Hua Zhang, Ph.D. (Texas A&M University–Corpus Christi), 2025

## Affiliation
Geospatial Computer Science Program, Texas A&M University–Corpus Christi (TAMUCC)

## Notes
- This repository shares the workflow and documentation for research and educational use.
- Please cite/credit appropriately if you reuse or adapt this notebook.


# Urban Fill & Spill (Jupyter Notebook)

This notebook implements a multi-step workflow for **depression-aware urban runoff modeling** using high‑resolution
topography and stormwater infrastructure. The main components are:

1. **Iterative depression merging & downspout allocation** (Algorithm 1)
2. **Inlet watershed delineation via flow tracing** (Algorithm 2)
3. **Water depth mapping via depression filling** (Algorithm 3)
4. **Drainage-tree construction and event simulation** (Algorithms 4–6)

## What you need
- Python geospatial stack (GeoPandas, Rasterio, Shapely, NetworkX, NumPy/Pandas, Matplotlib).
- Input layers: depression basins, lowest points, inlets, study area (and any supporting rasters).

## How to run
Run the notebook **top to bottom**. Each algorithm section writes intermediate outputs that are used by the next section.
Update the **User inputs** blocks (file paths, iteration numbers, and calibration factors) before running.

## Notes
- Comments are written to be GitHub-friendly (clear assumptions, inputs/outputs, and where files come from).
- If you use or adapt this workflow in a publication, please cite the associated paper(s) and acknowledge the data sources.


In [None]:
import os
import geopandas as gpd
import pandas as pd
import numpy as np
import rasterio
from shapely.geometry import (
    Point,
    LineString,
    MultiLineString,
    Polygon,
    MultiPolygon,
    GeometryCollection,
    MultiPoint
)
from shapely.ops import (
    linemerge,
    unary_union,
    nearest_points,
    split,
    voronoi_diagram
)
from shapely.validation import make_valid
from scipy.spatial import Voronoi
from rtree import index
import warnings
import pickle
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

# Input and Output File Paths
BASIN_SHAPEFILE_INITIAL = "BS_filtered.shp"
DEM_PATH = "DEM.tif"
INLETS_SHAPEFILE = "inlet.shp"
SINKS_SHAPEFILE = "SNK.shp"
DOWNSPOUTS_SHAPEFILE = "DS.shp"
STUDY_AREA_SHAPEFILE = "studyarea.shp"

BASIN_HOLES_OUTPUT_PREFIX = "basin_holes"
BASIN_OUTPUT_PREFIX = "basin"
BOUNDARY_SEGMENTS_OUTPUT_PREFIX = "boundary_segments"
BOUNDARY_POINTS_OUTPUT_PREFIX = "boundary_points"
MID_SEGMENT_OUTPUT_PREFIX = "mid_segment"
LOWEST_SEGMENT_POINTS_OUTPUT_PREFIX = "lowest_segment_points"
BASIN_SEG_PNTS_OUTPUT_PREFIX = "basin_seg_pnts"
TEMP_SNAPPED_DOWNSPOUTS_OUTPUT = "temp_snapped_downspouts.shp"
LOWEST_BASIN_POINTS_OUTPUT_PREFIX = "lowest_basin_points"
SINKS_PROCESSING_OUTPUT = "sinks_processing.shp"
UPDATED_SINKS_OUTPUT = "updated_sinks.shp"
FINAL_BASINS_OUTPUT_PREFIX = "final_basins"
FINAL_SNAPPED_DOWNSPOUTS_OUTPUT_PREFIX = "final_snapped_downspouts"


def polygon_to_line(polygon):
    return polygon.boundary

def create_points_along_line(line, interval=1.00):
    n_points = int(line.length // interval)
    if n_points < 1:
        return [line.interpolate(0.5, normalized=True)]
    return [line.interpolate(i * interval, normalized=False) for i in range(1, n_points + 1)]

def ensure_numeric_basin_id(gdf, old_col="Id", new_col="basin_id"):
    if old_col in gdf.columns:
        gdf = gdf.rename(columns={old_col: new_col})
    if new_col in gdf.columns:
        gdf[new_col] = pd.to_numeric(gdf[new_col], errors='coerce')
    else:
        print(f"Warning: Column '{new_col}' not found in GeoDataFrame.")
    return gdf

def get_elevation_at_point(point, dem_array, dem_transform):
    if point is None or point.is_empty:
        return np.nan
    row, col = rasterio.transform.rowcol(dem_transform, point.x, point.y)
    if 0 <= row < dem_array.shape[0] and 0 <= col < dem_array.shape[1]:
        elevation = float(dem_array[row, col])
        if elevation == dem_array.max() or elevation == 0:
            return np.nan
        return elevation
    return np.nan

def assign_basin_ids_to_segments(boundary_segments, basins):
    print("Creating spatial index for basins...")
    basin_sindex = index.Index()
    for idx, basin in basins.iterrows():
        basin_sindex.insert(idx, basin.geometry.bounds)

    basin_ids_list = []
    print("Assigning basin IDs to segments...")
    for segment in boundary_segments.geometry:
        potential_basins_idx = list(basin_sindex.intersection(segment.bounds))
        if potential_basins_idx:
            matching_basins = []
            for idx in potential_basins_idx:
                basin = basins.iloc[idx]
                if basin.geometry.boundary.contains(segment):
                    matching_basins.append(str(basin['basin_id']))
                if len(matching_basins) >= 2:
                    break
            basin_ids_list.append(','.join(sorted(matching_basins)[:2]))
        else:
            basin_ids_list.append('')
    boundary_segments['bas_ids'] = basin_ids_list
    print(f"Assigned IDs to {len(boundary_segments)} segments")
    return boundary_segments

def create_midpoint_buffer(segment, buffer_distance=0.01):
    midpoint = segment.interpolate(0.5, normalized=True)
    return midpoint.buffer(buffer_distance)

def identify_and_export_hole_basins(basins, output_hole_shapefile):
    all_polygons = unary_union(basins['geometry'])
    holes = []
    if isinstance(all_polygons, Polygon):
        holes.extend([Polygon(hole) for hole in all_polygons.interiors])
    elif isinstance(all_polygons, MultiPolygon):
        for polygon in all_polygons.geoms:
            holes.extend([Polygon(hole) for hole in polygon.interiors])
    print(f"  - Found {len(holes)} holes")
    holes_gdf = gpd.GeoDataFrame(geometry=holes, crs=basins.crs)
    holes_gdf.to_file(output_hole_shapefile)
    print(f"  - Exported {len(holes)} holes to {output_hole_shapefile}")
    return holes_gdf

def process_segment(row, dem_array, dem_transform, interval=0.30):
    segment = row.geometry
    segment_id = row.seg_id
    basin_ids_str = row.bas_ids
    pts = []

    points_on_segment = create_points_along_line(segment, interval)
    if not points_on_segment:
        return ([], {'seg_id': segment_id, 'bas_ids': basin_ids_str, 'geometry': create_midpoint_buffer(segment)})

    xs = [pt.x for pt in points_on_segment]
    ys = [pt.y for pt in points_on_segment]
    rows, cols = rasterio.transform.rowcol(dem_transform, xs, ys)
    rows = np.clip(rows, 0, dem_array.shape[0] - 1)
    cols = np.clip(cols, 0, dem_array.shape[1] - 1)
    elevations = dem_array[rows, cols]

    for i, pt in enumerate(points_on_segment):
        elev = float(elevations[i])
        if elev == dem_array.max() or elev == 0:
            elev = np.nan
        if not np.isnan(elev):
            pts.append({
                'seg_id': segment_id,
                'bas_ids': basin_ids_str,
                'elev': elev,
                'geometry': pt
            })
    buffer_dict = {
        'seg_id': segment_id,
        'bas_ids': basin_ids_str,
        'geometry': create_midpoint_buffer(segment)
    }
    return pts, buffer_dict

def process_basins_and_segments(basins, dem_path, iteration, holes_gdf=None, save_outputs=False, interval=0.30):
    print(f"Iteration {iteration}: Processing basin data")
    basins = basins.copy()
    basins['geometry'] = basins['geometry'].apply(make_valid)
    if save_outputs:
        basins[['basin_id', 'geometry']].to_file(f"{BASIN_OUTPUT_PREFIX}_{iteration}.shp")
        print(f"  - Basins saved to '{BASIN_OUTPUT_PREFIX}_{iteration}.shp'")
    print(f"  - Processing {len(basins)} basins")

    if iteration == 0:
        print("Identifying and exporting hole basins")
        all_polygons = unary_union(basins['geometry'])
        holes = []
        if isinstance(all_polygons, Polygon):
            holes.extend([Polygon(hole) for hole in all_polygons.interiors])
        elif isinstance(all_polygons, MultiPolygon):
            for polygon in all_polygons.geoms:
                holes.extend([Polygon(hole) for hole in polygon.interiors])
        holes_gdf = gpd.GeoDataFrame(geometry=holes, crs=basins.crs)
        holes_gdf.to_file(f"{BASIN_HOLES_OUTPUT_PREFIX}_{iteration}.shp")
        print(f"  - Found and exported {len(holes)} holes")
    else:
        print("Using existing hole basins")

    print("Creating basin lines")
    basin_lines = basins.geometry.boundary
    merged_line = linemerge(unary_union(basin_lines))
    if isinstance(merged_line, LineString):
        boundary_segments = gpd.GeoDataFrame(geometry=[merged_line], crs=basins.crs)
    elif isinstance(merged_line, MultiLineString):
        boundary_segments = gpd.GeoDataFrame(geometry=list(merged_line.geoms), crs=basins.crs)
    else:
        raise ValueError("Unexpected geometry type for merged_line")

    boundary_segments['seg_id'] = [f"segment_{i}" for i in range(len(boundary_segments))]
    print(f"  - Created {len(boundary_segments)} boundary segments")

    print("Assigning basin IDs to segments")
    boundary_segments = assign_basin_ids_to_segments(boundary_segments, basins)

    print("Creating points along segments and calculating elevations")
    with rasterio.open(dem_path) as src:
        dem_array = src.read(1)
        dem_transform = src.transform

    all_points = []
    midpoint_buffers = []

    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(process_segment, row, dem_array, dem_transform, interval)
            for idx, row in boundary_segments.iterrows()
        ]
        for future in as_completed(futures):
            pts, buffer_dict = future.result()
            all_points.extend(pts)
            midpoint_buffers.append(buffer_dict)

    for i, pt in enumerate(all_points):
        pt['pnt_id'] = f"point_{i}"

    all_points_gdf = gpd.GeoDataFrame(all_points, crs=basins.crs, geometry='geometry')
    midpoint_buffers_gdf = gpd.GeoDataFrame(midpoint_buffers, crs=basins.crs, geometry='geometry')

    print("Filtering out points on hole basin boundaries")
    if holes_gdf is not None and not holes_gdf.empty:
        hole_boundaries_gdf = gpd.GeoDataFrame(geometry=holes_gdf.boundary, crs=holes_gdf.crs)
        points_on_holes = gpd.sjoin(all_points_gdf, hole_boundaries_gdf, how="inner", predicate="intersects")
        filtered_points_gdf = all_points_gdf.drop(points_on_holes.index)
        print(f"  - Filtered out {len(points_on_holes)} points")
    else:
        filtered_points_gdf = all_points_gdf
        print("  - No holes found; no points filtered.")

    if save_outputs:
        print("Saving output files")
        boundary_segments.to_file(f"{BOUNDARY_SEGMENTS_OUTPUT_PREFIX}_{iteration}.shp")
        filtered_points_gdf.to_file(f"{BOUNDARY_POINTS_OUTPUT_PREFIX}_{iteration}.shp")
        midpoint_buffers_gdf.to_file(f"{MID_SEGMENT_OUTPUT_PREFIX}_{iteration}.shp")
        print("  - All output files saved successfully")
    else:
        print(f"Skipping output file saving for iteration {iteration}")

    return boundary_segments, filtered_points_gdf, midpoint_buffers_gdf, holes_gdf, basins

def filter_lowest_elevation_points(points_gdf):
    print("Filtering lowest elevation points per segment")
    points_gdf['elev'] = pd.to_numeric(points_gdf['elev'], errors='coerce')
    lowest_elevation_points = points_gdf.sort_values('elev', ascending=True).drop_duplicates(subset='seg_id', keep='first')
    print("  - Filtered lowest elevation points")
    return lowest_elevation_points

def merge_basins_with_redundant_first_points(basin_info_df):
    first_point_to_basins = {}
    for idx, row in basin_info_df.iterrows():
        seg_pnts = row['seg_pnts']
        basin_id = row['basin_id']
        if pd.notna(seg_pnts) and isinstance(seg_pnts, str) and seg_pnts.strip():
            first_point = seg_pnts.split(',')[0]
            first_point_to_basins.setdefault(first_point, []).append(basin_id)
    basins_to_merge = []
    for first_point, basin_ids in first_point_to_basins.items():
        if len(basin_ids) > 1:
            basins_to_merge.append({'first_point': first_point, 'basin_ids': basin_ids})
    if basins_to_merge:
        print(f"  - Found {len(basins_to_merge)} groups of basins to merge")
        return basins_to_merge
    else:
        print("  - No basins to merge")
        return None

def merge_basins(basins_no_inlets, basins_to_merge, basins_with_inlets_ids, original_basins):
    merged_basins = basins_no_inlets.copy()
    for group in basins_to_merge:
        basin_ids = group['basin_ids']
        new_basin_id = min(basin_ids)
        basins_to_merge_geom = basins_no_inlets[basins_no_inlets['basin_id'].isin(basin_ids)]['geometry']
        merged_geometry = unary_union(basins_to_merge_geom)
        merged_basins = merged_basins[~merged_basins['basin_id'].isin(basin_ids)]
        new_basin = gpd.GeoDataFrame([{'basin_id': new_basin_id, 'geometry': merged_geometry}], crs=basins_no_inlets.crs)
        merged_basins = pd.concat([merged_basins, new_basin], ignore_index=True)
    basins_with_inlets = original_basins[original_basins['basin_id'].isin(basins_with_inlets_ids)]
    final_basins = pd.concat([merged_basins, basins_with_inlets], ignore_index=True)
    return final_basins

def iterative_merge_basins(basin_info_df_func, basins, basins_with_inlets_ids, max_iter=10):
    iteration = 0
    while iteration < max_iter:
        basin_info_df = basin_info_df_func(basins)
        groups = merge_basins_with_redundant_first_points(basin_info_df)
        prev_count = len(basins)
        if groups:
            basins_no_inlets = basins[~basins['basin_id'].isin(basins_with_inlets_ids)]
            basins = merge_basins(basins_no_inlets, groups, basins_with_inlets_ids, basins)
            new_count = len(basins)
            print(f"Merge iteration {iteration}: {prev_count} -> {new_count} basins")
            if new_count == prev_count:
                break
        else:
            break
        iteration += 1
    return basins

def get_basin_info_df(basins):
    return pd.DataFrame({
        'basin_id': basins['basin_id'],
        'seg_pnts': basins.apply(lambda r: r.get('seg_pnts', ''), axis=1)
    })

def update_sinks_after_merging(sinks_gdf, basins_updated_gdf, dem_data, dem_transform,
                                output_sinks_shapefile, output_sinks_processing_shapefile,
                                lowest_basin_points_gdf):
    print("Updating sinks after merging basins")
    basins_updated_gdf = ensure_numeric_basin_id(basins_updated_gdf, old_col='basin_id', new_col='basin_id')
    sinks_gdf = sinks_gdf.to_crs(basins_updated_gdf.crs)
    sinks_with_basins = gpd.sjoin(sinks_gdf[['geometry']], basins_updated_gdf[['basin_id', 'geometry']], how='left', predicate='within')
    sinks_with_basins['basin_id'] = sinks_with_basins['basin_id'].fillna(-1).astype(int)
    sinks_with_basins['elev'] = sinks_with_basins['geometry'].apply(lambda point: get_elevation_at_point(point, dem_data, dem_transform))
    sinks_with_basins['basin_id'] = pd.to_numeric(sinks_with_basins['basin_id'], errors='coerce')
    sinks_with_basins.to_file(output_sinks_processing_shapefile)
    print(f"  - Intermediate sinks exported to {output_sinks_processing_shapefile}")
    valid_sinks = sinks_with_basins[sinks_with_basins['basin_id'] != -1]
    lowest_elevations = lowest_basin_points_gdf.set_index('basin_id')['elev'].to_dict()
    grouped = valid_sinks.groupby('basin_id')
    final_sinks = []
    for basin_id, basin_geom in zip(basins_updated_gdf['basin_id'], basins_updated_gdf['geometry']):
        if basin_id in grouped.groups:
            basin_sinks = grouped.get_group(basin_id)
            lowest_sink = basin_sinks.loc[basin_sinks['elev'].idxmin()]
            elev = lowest_sink['elev']
            geometry = lowest_sink['geometry']
        else:
            centroid = basin_geom.centroid
            elev = get_elevation_at_point(centroid, dem_data, dem_transform)
            geometry = centroid
            print(f"  - Created synthetic sink for basin {basin_id}")
        deep_elev = elev
        if elev == 0:
            print(f"Warning: Basin {basin_id} has a sink with elevation 0. Setting 'deep_elev' to NaN.")
            deep_elev = np.nan
        elev = np.round(elev, 5) if not np.isnan(elev) else np.nan
        deep_elev = np.round(deep_elev, 5) if not np.isnan(deep_elev) else np.nan
        if basin_id in lowest_elevations:
            lowest_elev = lowest_elevations[basin_id]
        else:
            lowest_elev = np.nan
            print(f"Warning: Basin {basin_id} does not have a lowest basin point elevation.")
        if not np.isnan(lowest_elev) and not np.isnan(elev):
            dZ = np.round(lowest_elev - elev, 5)
        else:
            dZ = np.nan
        final_sinks.append({
            'basin_id': basin_id,
            'elev': elev,
            'deep_elev': deep_elev,
            'dZ': dZ,
            'geometry': geometry
        })
    final_sinks_gdf = gpd.GeoDataFrame(final_sinks, crs=basins_updated_gdf.crs)
    for col in ['elev', 'deep_elev', 'dZ']:
        final_sinks_gdf[col] = final_sinks_gdf[col].astype('float64')
    final_sinks_gdf.to_file(output_sinks_shapefile)
    print(f"  - Final sinks with 'elev', 'deep_elev', and 'dZ' exported to {output_sinks_shapefile}")
    print(f"  - Total sinks: {len(final_sinks_gdf)}, Synthetic sinks: {len(final_sinks_gdf) - len(valid_sinks)}\n")

def process_downspouts(downspouts_shapefile, all_basins, hole_basins, dem_data, dem_transform, output_snapped_downspouts_shapefile):
    print("Processing and snapping downspouts to hole basin boundaries.")
    downspouts = gpd.read_file(downspouts_shapefile)
    hole_boundaries = hole_basins.geometry.boundary.unary_union
    snapped_downspouts = []
    for _, downspout in downspouts.iterrows():
        nearest_point = nearest_points(downspout.geometry, hole_boundaries)[1]
        elevation = get_elevation_at_point(nearest_point, dem_data, dem_transform)
        snapped_downspouts.append({
            'geometry': Point(nearest_point.x, nearest_point.y),
            'elevation': elevation,
            'is_downsp': True,
            'basin_id': -1
        })
    snapped_downspouts_gdf = gpd.GeoDataFrame(snapped_downspouts, crs=all_basins.crs)
    snapped_downspouts_gdf.to_file(output_snapped_downspouts_shapefile)
    print(f"Snapped downspouts saved to '{output_snapped_downspouts_shapefile}'")
    return snapped_downspouts_gdf

def create_mid_segment_points(basins_gdf):
    mid_segments = []
    for idx, row in basins_gdf.iterrows():
        boundary = row.geometry.boundary
        if isinstance(boundary, MultiLineString):
            for line in boundary.geoms:
                midpoint = line.interpolate(0.5, normalized=True)
                mid_segments.append({
                    'geometry': midpoint,
                    'basin_ids': str(row['basin_id'])
                })
        else:
            midpoint = boundary.interpolate(0.5, normalized=True)
            mid_segments.append({
                'geometry': midpoint,
                'basin_ids': str(row['basin_id'])
            })
    return gpd.GeoDataFrame(mid_segments, crs=basins_gdf.crs)

def split_hole_basins_and_update_ids(hole_basins_gdf, snapped_downspouts_gdf, mid_segment_gdf, max_basin_id):
    print("Splitting hole basins...")
    split_basins = []
    for idx, hole_basin in hole_basins_gdf.iterrows():
        try:
            basin_geometry = hole_basin.geometry
            original_basin_id = hole_basin['basin_id']
            basin_downspouts = snapped_downspouts_gdf[snapped_downspouts_gdf.intersects(basin_geometry)].copy()
            n_points = len(basin_downspouts)
            if n_points == 0:
                print(f"Basin {idx}: No downspouts found")
                split_basins.append({
                    'basin_id': original_basin_id,
                    'geometry': basin_geometry,
                    'is_hole': True,
                    'is_split': False,
                    'shp_length': float(basin_geometry.length),
                    'shp_area': float(basin_geometry.area)
                })
            elif n_points == 1:
                print(f"Basin {idx}: Single downspout - assigning whole basin")
                split_basins.append({
                    'basin_id': original_basin_id,
                    'geometry': basin_geometry,
                    'is_hole': True,
                    'is_split': False,
                    'shp_length': float(basin_geometry.length),
                    'shp_area': float(basin_geometry.area)
                })
                ds_idx = basin_downspouts.index[0]
                snapped_downspouts_gdf.at[ds_idx, 'basin_id'] = original_basin_id
            else:
                try:
                    points = MultiPoint([point for point in basin_downspouts.geometry])
                    vor_polys = voronoi_diagram(points, envelope=basin_geometry)
                    for i, voronoi_poly in enumerate(voronoi_diagram(points, envelope=basin_geometry).geoms):
                        split_geom = basin_geometry.intersection(voronoi_poly)
                        if not split_geom.is_empty and split_geom.area > 0:
                            max_basin_id += 1
                            new_basin_id = max_basin_id
                            split_basins.append({
                                'basin_id': new_basin_id,
                                'geometry': split_geom,
                                'is_hole': True,
                                'is_split': True,
                                'shp_length': float(split_geom.length),
                                'shp_area': float(split_geom.area)
                            })
                            for ds_idx, downspout in basin_downspouts.iterrows():
                                if split_geom.contains(downspout.geometry):
                                    snapped_downspouts_gdf.at[ds_idx, 'basin_id'] = new_basin_id
                    print(f"Basin {idx}: Successfully split into {len(voronoi_diagram(points, envelope=basin_geometry).geoms)} parts")
                except Exception as e:
                    print(f"Warning: Voronoi failed for basin {idx}, using alternative splitting")
                    buffer_distance = np.sqrt(basin_geometry.area) / (2 * len(basin_downspouts))
                    for i, (ds_idx, downspout) in enumerate(basin_downspouts.iterrows()):
                        max_basin_id += 1
                        new_basin_id = max_basin_id
                        buffer = downspout.geometry.buffer(buffer_distance)
                        split_geom = basin_geometry.intersection(buffer)
                        if not split_geom.is_empty and split_geom.area > 0:
                            split_basins.append({
                                'basin_id': new_basin_id,
                                'geometry': split_geom,
                                'is_hole': True,
                                'is_split': True,
                                'shp_length': float(split_geom.length),
                                'shp_area': float(split_geom.area)
                            })
                            snapped_downspouts_gdf.at[ds_idx, 'basin_id'] = new_basin_id
        except Exception as e:
            print(f"Error processing basin {idx}: {str(e)}")
            split_basins.append({
                'basin_id': original_basin_id,
                'geometry': basin_geometry,
                'is_hole': True,
                'is_split': False,
                'shp_length': float(basin_geometry.length),
                'shp_area': float(basin_geometry.area)
            })
    split_basins_gdf = gpd.GeoDataFrame(split_basins, crs=hole_basins_gdf.crs)
    print(f"Created {len(split_basins_gdf)} split basins")
    return split_basins_gdf, snapped_downspouts_gdf, max_basin_id

def assign_basin_ids_to_downspouts(snapped_downspouts, all_basins, buffer_distance=0.01):
    for idx, downspout in snapped_downspouts.iterrows():
        buffer = downspout.geometry.buffer(buffer_distance)
        intersecting_basins = all_basins[all_basins.intersects(buffer)]
        if not intersecting_basins.empty:
            split_hole_basins = intersecting_basins[intersecting_basins['is_hole'] == True]
            if not split_hole_basins.empty:
                snapped_downspouts.at[idx, 'basin_id'] = split_hole_basins.iloc[0]['basin_id']
            else:
                snapped_downspouts.at[idx, 'basin_id'] = intersecting_basins.iloc[0]['basin_id']
        else:
            print(f"Warning: No basin found for downspout at index {idx}")
    return snapped_downspouts

def process_basin_info(basins, lowest_segment_points, snapped_downspouts, basins_with_inlets_ids, iteration, save_outputs):
    basin_points_dict = {}
    if not lowest_segment_points.empty:
        lsp = lowest_segment_points.copy()
        lsp['bas_ids_list'] = lsp['bas_ids'].str.split(',')
        lsp = lsp.explode('bas_ids_list')
        lsp['bas_ids_list'] = lsp['bas_ids_list'].str.strip()
        lsp['bas_ids_list'] = pd.to_numeric(lsp['bas_ids_list'], errors='ignore')
        grouped = lsp.groupby('bas_ids_list')
        for basin_id, group in grouped:
            basin_points_dict[basin_id] = group

    snapped_downspouts_dict = {}
    if snapped_downspouts is not None and not snapped_downspouts.empty:
        sds_group = snapped_downspouts.groupby('basin_id')
        for basin_id, group in sds_group:
            snapped_downspouts_dict[basin_id] = group

    basin_info_list = []
    for idx, basin in basins.iterrows():
        basin_id = basin['basin_id']
        if basin_id in basins_with_inlets_ids:
            basin_info_list.append({
                'basin_id': basin_id,
                'seg_pnts': '',
                'pnt1_ele': np.nan
            })
            continue

        seg_pnts = ''
        pnt1_ele = np.nan

        if snapped_downspouts is not None and basin_id in snapped_downspouts_dict:
            basin_downspouts = snapped_downspouts_dict[basin_id]
            point_ids = ['D' + str(i) for i in range(len(basin_downspouts))]
            seg_pnts = ','.join(point_ids)
            pnt1_ele = basin_downspouts.iloc[0]['elevation']
        elif basin_id in basin_points_dict:
            matching_points = basin_points_dict[basin_id]
            matching_points_sorted = matching_points.sort_values('elev', ascending=True)
            point_ids = matching_points_sorted['pnt_id'].tolist()
            seg_pnts = ','.join(point_ids)
            pnt1_ele = matching_points_sorted.iloc[0]['elev']

        basin_info_list.append({
            'basin_id': basin_id,
            'seg_pnts': seg_pnts,
            'pnt1_ele': pnt1_ele
        })

    basin_info_df = pd.DataFrame(basin_info_list)

    if save_outputs:
        basin_info_df.to_csv(f"{BASIN_SEG_PNTS_OUTPUT_PREFIX}_{iteration}.csv", index=False)

    return basin_info_df

def create_final_lowest_points(lowest_segment_points, snapped_downspouts, basins, basins_with_inlets_ids, iteration):
    print("Creating final lowest points")
    lowest_points_list = []
    print("Processing segment points...")
    for idx, basin in basins.iterrows():
        basin_id = basin['basin_id']
        if basin_id in basins_with_inlets_ids:
            continue
        basin_points = lowest_segment_points[
            lowest_segment_points['bas_ids'].apply(lambda x: str(basin_id) in str(x).split(','))
        ]
        if not basin_points.empty:
            lowest_point = basin_points.loc[basin_points['elev'].idxmin()]
            point_buffer = lowest_point.geometry.buffer(0.01)
            intersecting_basins = basins[basins.intersects(point_buffer)]
            other_basins = intersecting_basins[intersecting_basins['basin_id'] != basin_id]
            to_basin = other_basins.iloc[0]['basin_id'] if not other_basins.empty else -1
            lowest_points_list.append({
                'basin_id': basin_id,
                'geometry': lowest_point.geometry,
                'elevation': lowest_point['elev'],
                'type': 'segment',
                'to_basin': to_basin
            })
    if snapped_downspouts is not None and not snapped_downspouts.empty:
        print("Processing downspout points...")
        for _, downspout in snapped_downspouts.iterrows():
            buffer_distance = 0.01
            downspout_buffer = downspout.geometry.buffer(buffer_distance)
            containing_basins = basins[
                (basins.intersects(downspout_buffer)) &
                (basins['is_hole'] == True)
            ]
            if containing_basins.empty:
                print(f"Warning: No split hole basin found for downspout at {downspout.geometry}")
                continue
            from_basin_id = containing_basins.iloc[0]['basin_id']
            draining_basins = basins[
                (basins.intersects(downspout_buffer)) &
                (~basins['is_hole'].fillna(False))
            ]
            to_basin = draining_basins.iloc[0]['basin_id'] if not draining_basins.empty else -1
            lowest_points_list.append({
                'basin_id': from_basin_id,
                'geometry': downspout.geometry,
                'elevation': downspout['elevation'],
                'type': 'downspout',
                'to_basin': to_basin
            })
    lowest_points_gdf = gpd.GeoDataFrame(lowest_points_list, crs=basins.crs)
    print("\nVerifying basin assignments...")
    print(f"Total points: {len(lowest_points_gdf)}")
    print(f"Points with valid basin_id: {len(lowest_points_gdf[lowest_points_gdf['basin_id'] != -1])}")
    print(f"Points with valid to_basin: {len(lowest_points_gdf[lowest_points_gdf['to_basin'] != -1])}")
    output_file = f"{LOWEST_BASIN_POINTS_OUTPUT_PREFIX}_{iteration}.shp"
    lowest_points_gdf.to_file(output_file)
    print(f"\nResults saved to: {output_file}")
    print(f"  - Segments: {len(lowest_points_gdf[lowest_points_gdf['type'] == 'segment'])}")
    print(f"  - Downspouts: {len(lowest_points_gdf[lowest_points_gdf['type'] == 'downspout'])}")
    downspouts = lowest_points_gdf[lowest_points_gdf['type'] == 'downspout']
    print("\nDownspout Statistics:")
    print(f"  - Total downspouts: {len(downspouts)}")
    print(f"  - Downspouts with valid basin_id: {len(downspouts[downspouts['basin_id'] != -1])}")
    print(f"  - Downspouts with valid to_basin: {len(downspouts[downspouts['to_basin'] != -1])}")
    return lowest_points_gdf

def process_final_sinks(sinks_shapefile, basins, dem_data, dem_transform, lowest_points_gdf, snapped_downspouts, study_area_shapefile):
    print("Processing final sinks")
    sinks = gpd.read_file(sinks_shapefile)
    study_area = gpd.read_file(study_area_shapefile)
    sinks = sinks.to_crs(basins.crs)
    study_area = study_area.to_crs(basins.crs)
    basins = ensure_numeric_basin_id(basins, old_col='basin_id', new_col='basin_id')
    basins['basin_area'] = basins.geometry.area
    basins['basin_perimeter'] = basins.geometry.length
    sinks_with_basins = gpd.sjoin(sinks[['geometry']], basins[['basin_id', 'geometry']], how='left', predicate='within')
    sinks_with_basins['basin_id'] = sinks_with_basins['basin_id'].fillna(-1).astype(int)
    sinks_with_basins['elev'] = sinks_with_basins['geometry'].apply(lambda point: get_elevation_at_point(point, dem_data, dem_transform))
    sinks_with_basins['type'] = 'existing'
    sinks_with_basins.to_file(SINKS_PROCESSING_OUTPUT)
    print(f"  - Intermediate sinks exported to {SINKS_PROCESSING_OUTPUT}")
    valid_sinks = sinks_with_basins[sinks_with_basins['basin_id'] != -1]
    lowest_elevations = lowest_points_gdf.set_index('basin_id')['elevation'].to_dict()
    grouped = valid_sinks.groupby('basin_id')
    final_sinks = []

    if 'is_hole' not in basins.columns:
        basins['is_hole'] = False
    if 'is_split' not in basins.columns:
        basins['is_split'] = False

    for _, basin_row in basins.iterrows():
        basin_id = basin_row['basin_id']
        basin_geom = basin_row['geometry']
        is_hole = basin_row['is_hole']
        is_split = basin_row['is_split']

        elev, geometry, sink_type = np.nan, None, 'unassigned'

        if is_hole and is_split:
            geometry = basin_geom.centroid
            elev = get_elevation_at_point(geometry, dem_data, dem_transform)
            sink_type = 'synthetic'
        elif basin_id in grouped.groups:
            basin_sinks = grouped.get_group(basin_id)
            lowest_sink = basin_sinks.loc[basin_sinks['elev'].idxmin()]
            elev = lowest_sink['elev']
            geometry = lowest_sink['geometry']
            sink_type = 'existing'
        else:
            basin_downspouts = snapped_downspouts[snapped_downspouts['basin_id'] == basin_id]
            if not basin_downspouts.empty:
                downspout = basin_downspouts.iloc[0]
                geometry = downspout.geometry
                elev = downspout['elevation']
                sink_type = 'downspout'
            else:
                geometry = basin_geom.centroid
                elev = get_elevation_at_point(geometry, dem_data, dem_transform)
                sink_type = 'synthetic'

        if geometry is None or geometry.is_empty:
            print(f"Warning: Could not create a valid sink geometry for basin {basin_id}. Skipping.")
            continue

        deep_elev = elev if not np.isnan(elev) and elev != 0 else np.nan
        lowest_elev = lowest_elevations.get(basin_id, np.nan)
        dZ = np.round(lowest_elev - elev, 5) if not np.isnan(lowest_elev) and not np.isnan(elev) else np.nan
        final_sinks.append({
            'basin_id': basin_id,
            'elev': np.round(elev, 5) if not np.isnan(elev) else np.nan,
            'deep_elev': np.round(deep_elev, 5),
            'dZ': dZ,
            'type': sink_type,
            'basin_area': basin_geom.area,
            'basin_perimeter': basin_geom.length,
            'geometry': geometry
        })
    final_sinks_gdf = gpd.GeoDataFrame(final_sinks, crs=basins.crs)
    numeric_cols = ['elev', 'deep_elev', 'dZ', 'basin_area', 'basin_perimeter']
    for col in numeric_cols:
        final_sinks_gdf[col] = final_sinks_gdf[col].astype('float64')
    if not study_area.empty:
        final_sinks_gdf = gpd.clip(final_sinks_gdf, study_area)
    final_sinks_gdf.to_file(UPDATED_SINKS_OUTPUT)
    print(f"  - Final sinks saved to '{UPDATED_SINKS_OUTPUT}'")
    print(f"  - Total sinks: {len(final_sinks_gdf)}")
    print(f"    - Existing: {len(final_sinks_gdf[final_sinks_gdf['type'] == 'existing'])}")
    print(f"    - Downspout: {len(final_sinks_gdf[final_sinks_gdf['type'] == 'downspout'])}")
    print(f"    - Synthetic: {len(final_sinks_gdf[final_sinks_gdf['type'] == 'synthetic'])}")
    return final_sinks_gdf

def validate_geometries(gdf, min_area=0.0001):
    print("Validating geometries...")
    initial_count = len(gdf)
    valid_gdf = gdf.copy()
    valid_geometries = []
    valid_indices = []
    for idx, row in valid_gdf.iterrows():
        try:
            geom = row.geometry
            if geom is None or geom.is_empty:
                print(f"Skipping empty geometry at index {idx}")
                continue
            valid_geom = make_valid(geom)
            if isinstance(valid_geom, (Polygon, MultiPolygon)):
                if valid_geom.area >= min_area:
                    valid_geometries.append(valid_geom)
                    valid_indices.append(idx)
                else:
                    print(f"Skipping too small geometry at index {idx}")
            else:
                print(f"Skipping non-polygon geometry at index {idx}")
        except Exception as e:
            print(f"Error processing geometry at index {idx}: {e}")
            continue
    if valid_geometries:
        valid_gdf = gdf.loc[valid_indices].copy()
        valid_gdf['geometry'] = valid_geometries
        valid_gdf['geometry'] = valid_gdf['geometry'].apply(make_valid)
        print(f"Validated geometries: {len(valid_gdf)} out of {initial_count}")
        return valid_gdf
    else:
        print("No valid geometries found!")
        return gpd.GeoDataFrame(geometry=[], crs=gdf.crs)

class CheckpointManager:
    def __init__(self, checkpoint_dir='checkpoints'):
        self.checkpoint_dir = checkpoint_dir
        os.makedirs(checkpoint_dir, exist_ok=True)

    def save_checkpoint(self, iteration, data):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_{iteration}_{timestamp}.pkl')
        self._cleanup_old_checkpoints(iteration)
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(data, f)
        print(f"Saved checkpoint for iteration {iteration}")

    def load_latest_checkpoint(self):
        checkpoints = [f for f in os.listdir(self.checkpoint_dir) if f.startswith('checkpoint_')]
        if not checkpoints:
            return None, -1
        latest = max(checkpoints, key=lambda x: os.path.getctime(os.path.join(self.checkpoint_dir, x)))
        iteration = int(latest.split('_')[1])
        checkpoint_path = os.path.join(self.checkpoint_dir, latest)
        with open(checkpoint_path, 'rb') as f:
            data = pickle.load(f)
        print(f"Loaded checkpoint from iteration {iteration}")
        return data, iteration

    def _cleanup_old_checkpoints(self, current_iteration, keep_last=5):
        checkpoints = [f for f in os.listdir(self.checkpoint_dir)
                        if f.startswith(f'checkpoint_{current_iteration}_')]
        if len(checkpoints) > keep_last:
            checkpoints.sort(key=lambda x: os.path.getctime(os.path.join(self.checkpoint_dir, x)))
            for checkpoint in checkpoints[:-keep_last]:
                os.remove(os.path.join(self.checkpoint_dir, checkpoint))

def main(max_iterations, force_restart):
    checkpoint_manager = CheckpointManager()
    if not force_restart:
        checkpoint_data, start_iteration = checkpoint_manager.load_latest_checkpoint()
    else:
        checkpoint_data = None
        start_iteration = -1
        print("Forced restart - ignoring existing checkpoints")

    try:
        if checkpoint_data is not None:
            print(f"Resuming from iteration {start_iteration}")
            basins = checkpoint_data['basins']
            holes_gdf = checkpoint_data['holes_gdf']
            basins_with_inlets_ids = checkpoint_data['basins_with_inlets_ids']
            iteration = start_iteration + 1
            merging_needed = True
            last_iteration = start_iteration
        else:
            print("Starting fresh run")
            iteration = 0
            merging_needed = True
            holes_gdf = None
            try:
                inlets_gdf = gpd.read_file(INLETS_SHAPEFILE)
                print("\n==== Inlets GeoDataFrame Info ====")
                print("Columns:", list(inlets_gdf.columns))
                print("CRS:", inlets_gdf.crs)
                print("Number of features:", len(inlets_gdf))
                if not inlets_gdf.empty:
                    inlets_gdf = ensure_numeric_basin_id(inlets_gdf)
            except Exception as e:
                print(f"Warning: Error loading inlets: {e}")
                inlets_gdf = gpd.GeoDataFrame(geometry=[], crs=None)
            basins = gpd.read_file(BASIN_SHAPEFILE_INITIAL)
            print("\n==== Basins GeoDataFrame Info ====")
            print("Columns:", list(basins.columns))
            print("CRS:", basins.crs)
            print("Number of features:", len(basins))
            if basins.empty:
                raise ValueError("No data found in basins file")
            basins = ensure_numeric_basin_id(basins)
            basins_with_inlets_ids = []
            if not inlets_gdf.empty:
                if inlets_gdf.crs != basins.crs:
                    inlets_gdf = inlets_gdf.to_crs(basins.crs)
                for predicate in ['contains', 'intersects', 'within']:
                    basins_with_inlets = gpd.sjoin(basins, inlets_gdf, how='inner', predicate=predicate)
                    if not basins_with_inlets.empty:
                        basins_with_inlets_ids = basins_with_inlets['basin_id'].unique().tolist()
                        print(f"\nFound {len(basins_with_inlets_ids)} basins containing inlets")
                        break

        while merging_needed:
            print(f"\n=== Iteration {iteration} ===")
            if max_iterations >= 0 and iteration >= max_iterations:
                print(f"Reached maximum iterations ({max_iterations})")
                merging_needed = False
                last_iteration = iteration
                break
            save_outputs = (iteration == 0)
            boundary_segments, filtered_points_gdf, midpoint_buffers_gdf, holes_gdf, basins = process_basins_and_segments(
                basins, DEM_PATH, iteration, holes_gdf, save_outputs
            )
            checkpoint_data = {
                'basins': basins,
                'holes_gdf': holes_gdf,
                'basins_with_inlets_ids': basins_with_inlets_ids,
                'boundary_segments': boundary_segments,
                'filtered_points_gdf': filtered_points_gdf,
                'iteration': iteration
            }
            checkpoint_manager.save_checkpoint(iteration, checkpoint_data)
            lowest_segment_points = filter_lowest_elevation_points(filtered_points_gdf)
            if save_outputs:
                lowest_segment_points.to_file(f"{LOWEST_SEGMENT_POINTS_OUTPUT_PREFIX}_{iteration}.shp")
            basin_info_df = process_basin_info(
                basins, lowest_segment_points, None,
                basins_with_inlets_ids, iteration, save_outputs
            )
            basin_info_df_no_inlets = basin_info_df[~basin_info_df['basin_id'].isin(basins_with_inlets_ids)]
            basins_no_inlets = basins[~basins['basin_id'].isin(basins_with_inlets_ids)]
            basins_to_merge = merge_basins_with_redundant_first_points(basin_info_df_no_inlets)
            if basins_to_merge:
                merging_needed = True
                basins = merge_basins(basins_no_inlets, basins_to_merge, basins_with_inlets_ids, basins)
                iteration += 1
            else:
                merging_needed = False
                last_iteration = iteration

        print("\n=== Final Iteration (Snapping and Splitting) ===")
        all_polygons = unary_union(basins['geometry'])
        holes = []
        if isinstance(all_polygons, Polygon):
            holes.extend([Polygon(hole) for hole in all_polygons.interiors])
        elif isinstance(all_polygons, MultiPolygon):
            for polygon in all_polygons.geoms:
                holes.extend([Polygon(hole) for hole in polygon.interiors])
        max_basin_id = basins['basin_id'].max()
        hole_basins_data = []
        for hole in holes:
            max_basin_id += 1
            hole_basins_data.append({
                'basin_id': max_basin_id,
                'geometry': hole,
                'is_hole': True,
                'shp_length': hole.length,
                'shp_area': hole.area
            })
        with rasterio.open(DEM_PATH) as dem:
            dem_data = dem.read(1)
            dem_transform = dem.transform
        hole_basins_gdf = gpd.GeoDataFrame(hole_basins_data, crs=basins.crs)
        all_basins = pd.concat([basins, hole_basins_gdf], ignore_index=True)
        all_basins['is_hole'] = all_basins['is_hole'].fillna(False)
        snapped_downspouts = process_downspouts(
            DOWNSPOUTS_SHAPEFILE, all_basins, hole_basins_gdf,
            dem_data, dem_transform, TEMP_SNAPPED_DOWNSPOUTS_OUTPUT
        )
        mid_segment_gdf = create_mid_segment_points(all_basins)
        split_hole_basins_gdf, snapped_downspouts, max_basin_id = split_hole_basins_and_update_ids(
            hole_basins_gdf, snapped_downspouts, mid_segment_gdf, max_basin_id
        )
        final_basins = all_basins[all_basins['is_hole'] == False]
        final_basins = pd.concat([final_basins, split_hole_basins_gdf], ignore_index=True)
        final_basins = ensure_numeric_basin_id(final_basins, 'basin_id')
        lowest_basin_points_gdf = create_final_lowest_points(
            lowest_segment_points,
            snapped_downspouts,
            final_basins,
            basins_with_inlets_ids,
            last_iteration
        )
        final_sinks_gdf = process_final_sinks(
            SINKS_SHAPEFILE,
            final_basins,
            dem_data,
            dem_transform,
            lowest_basin_points_gdf,
            snapped_downspouts,
            STUDY_AREA_SHAPEFILE
        )
        print("\nSaving final outputs...")
        print("Processing final basins...")
        final_basins = validate_geometries(final_basins)
        if len(final_basins) > 0:
            final_basins.to_file(f"{FINAL_BASINS_OUTPUT_PREFIX}_{last_iteration}.shp")
            print(f"Saved final basins to {FINAL_BASINS_OUTPUT_PREFIX}_{last_iteration}.shp")
        print("\nSaving snapped downspouts...")
        snapped_downspouts.to_file(f"{FINAL_SNAPPED_DOWNSPOUTS_OUTPUT_PREFIX}_{last_iteration}.shp")
        print(f"Saved snapped downspouts to {FINAL_SNAPPED_DOWNSPOUTS_OUTPUT_PREFIX}_{last_iteration}.shp")
        print("\nFinal Statistics:")
        print(f"Total Basins: {len(final_basins)}")
        print(f"Split Hole Basins: {len(split_hole_basins_gdf)}")
        print(f"Snapped Downspouts: {len(snapped_downspouts)}")
        print(f"Final Sinks: {len(final_sinks_gdf)}")
        print(f"Total Iterations: {last_iteration + 1}")

    except Exception as e:
        print(f"Critical error in main execution: {e}")
        print(f"You can resume from iteration {iteration} using the checkpoint")
        import traceback
        traceback.print_exc()
        raise

if __name__ == "__main__":
    max_iterations = -1
    force_restart = True
    main(max_iterations, force_restart)

# Algorithm 2: Inlet Watershed Delineation via Flow Tracing

This section assigns each basin/depression to a **terminal inlet** by tracing surface connectivity. The output is a set of
inlet-associated basins (and optional flow edges for visualization).

**Inputs:** outputs from Algorithm 1 (final iteration), inlet layer, study area.  
**Outputs:** merged basin layer per inlet (and optional flow-edge layers).

In [None]:
import logging
import os
import geopandas as gpd
import pandas as pd
from shapely.geometry import LineString
from shapely.ops import linemerge, unary_union

# Important: Set the final iteration number from the previous script here
LAST_ITERATION_NUM = 24

# Input files
BASINS_SHP = f"final_basins_{LAST_ITERATION_NUM}.shp"
LOWEST_POINTS_SHP = f"lowest_basin_points_{LAST_ITERATION_NUM}.shp"
SINKS_SHP = "updated_sinks.shp"
INLETS_SHP = "inlet.shp"

# Output files
MERGED_BASINS_BY_INLET_SHP = "basins_merged_by_inlet.shp"
ALL_EDGES_SHP = "flow_paths_all.shp"
MERGED_EDGES_SHP = "flow_paths_merged.shp"

# Set up basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


def build_flow_dict(lowest_points_gdf):
    flow_dict = {}
    for _, row in lowest_points_gdf.iterrows():
        from_basin = row['basin_id']
        to_basin = row.get('to_basin', -1)

        if pd.isna(to_basin) or to_basin == -1:
            flow_dict[from_basin] = -1
        else:
            flow_dict[from_basin] = int(to_basin)

    return flow_dict


def find_final_inlet_basin(basin_id, flow_dict, basins_with_inlets, memo, visited=None):
    if basin_id in memo:
        return memo[basin_id]

    if visited is None:
        visited = set()

    if basin_id in visited:
        memo[basin_id] = -1
        return -1

    visited.add(basin_id)

    if basin_id in basins_with_inlets:
        memo[basin_id] = basin_id
        return basin_id

    next_basin = flow_dict.get(basin_id, -1)
    if next_basin == -1:
        memo[basin_id] = -1
        return -1

    final_inlet = find_final_inlet_basin(next_basin, flow_dict, basins_with_inlets, memo, visited)
    memo[basin_id] = final_inlet
    return final_inlet


def merge_basins_by_final_inlet(basins_gdf, lowest_points_gdf, inlets_gdf):
    logging.info("Starting basin merging process based on final inlet.")

    try:
        basins_with_inlets_gdf = gpd.sjoin(basins_gdf, inlets_gdf, how="inner", predicate="contains")
        basins_with_inlets = set(basins_with_inlets_gdf['basin_id'].unique())
        logging.info(f"Found {len(basins_with_inlets)} basins that directly contain an inlet.")
    except Exception as e:
        logging.error(f"Could not perform spatial join for inlets. Ensure CRS match. Error: {e}")
        return gpd.GeoDataFrame()

    flow_dict = build_flow_dict(lowest_points_gdf)

    logging.info("Tracing flow paths for all basins...")
    memo = {}
    basins_gdf['final_inlet_basin'] = basins_gdf['basin_id'].apply(
        lambda bid: find_final_inlet_basin(bid, flow_dict, basins_with_inlets, memo)
    )
    logging.info("Flow tracing complete.")

    logging.info("Merging basin geometries based on final inlet.")
    grouped = basins_gdf.groupby('final_inlet_basin')

    merged_basins = []
    for inlet_id, group in grouped:
        if inlet_id == -1:
            for _, basin_row in group.iterrows():
                merged_basins.append({
                    'final_inlet_basin': -1,
                    'original_id': basin_row['basin_id'],
                    'geometry': basin_row['geometry']
                })
        else:
            merged_geometry = unary_union(group['geometry'])
            merged_basins.append({
                'final_inlet_basin': inlet_id,
                'original_id': None,
                'geometry': merged_geometry
            })

    merged_basins_gdf = gpd.GeoDataFrame(merged_basins, crs=basins_gdf.crs)
    logging.info(f"Basin merging complete. Created {len(merged_basins_gdf)} final catchment areas.")

    return merged_basins_gdf


def create_flow_edges(sinks_gdf, lowest_points_gdf, inlets_gdf, basins_gdf):
    logging.info("Generating visual flow path edges.")
    edges = []
    sinks_dict = {row['basin_id']: row.geometry for _, row in sinks_gdf.iterrows()}

    for _, point in lowest_points_gdf.iterrows():
        from_basin = point['basin_id']
        to_basin = point.get('to_basin', -1)

        start_geom = sinks_dict.get(from_basin)
        if start_geom is None:
            continue

        edges.append({
            'geometry': LineString([start_geom, point.geometry]),
            'from_basin': from_basin,
            'to_basin': from_basin,
            'type': 'sink_to_spill'
        })

        if to_basin != -1 and pd.notna(to_basin):
            to_basin = int(to_basin)
            end_geom = sinks_dict.get(to_basin)
            if end_geom:
                edges.append({
                    'geometry': LineString([point.geometry, end_geom]),
                    'from_basin': from_basin,
                    'to_basin': to_basin,
                    'type': 'spill_to_sink'
                })

    edges_gdf = gpd.GeoDataFrame(edges, crs=sinks_gdf.crs)
    logging.info(f"Generated {len(edges_gdf)} flow path segments.")
    return edges_gdf


def main():
    # File Validation
    input_files = [BASINS_SHP, LOWEST_POINTS_SHP, SINKS_SHP, INLETS_SHP]
    for f in input_files:
        if not os.path.exists(f):
            logging.error(f"Input file not found: {f}. Please check the path and iteration number.")
            return

    # 1. Load Data
    logging.info("Loading input shapefiles...")
    basins_gdf = gpd.read_file(BASINS_SHP)
    lowest_points_gdf = gpd.read_file(LOWEST_POINTS_SHP)
    sinks_gdf = gpd.read_file(SINKS_SHP)
    inlets_gdf = gpd.read_file(INLETS_SHP)

    # 2. Perform Primary Task: Merge Basins by Final Inlet
    merged_basins_gdf = merge_basins_by_final_inlet(basins_gdf, lowest_points_gdf, inlets_gdf)

    if not merged_basins_gdf.empty:
        merged_basins_gdf.to_file(MERGED_BASINS_BY_INLET_SHP)
        logging.info(f"Successfully saved merged basins to {MERGED_BASINS_BY_INLET_SHP}")

    # 3. (Optional) Generate and Save Visual Flow Edges
    flow_edges_gdf = create_flow_edges(sinks_gdf, lowest_points_gdf, inlets_gdf, basins_gdf)

    if not flow_edges_gdf.empty:
        flow_edges_gdf.to_file(ALL_EDGES_SHP)
        logging.info(f"Saved all flow path segments to {ALL_EDGES_SHP}")

        merged_lines = linemerge(unary_union(flow_edges_gdf.geometry))
        merged_edges_gdf = gpd.GeoDataFrame(geometry=[merged_lines], crs=flow_edges_gdf.crs)
        merged_edges_gdf.to_file(MERGED_EDGES_SHP)
        logging.info(f"Saved merged flow paths to {MERGED_EDGES_SHP}")

    logging.info("Processing complete.")


if __name__ == "__main__":
    main()

# Watershed characterization plot

This section creates summary plots to inspect watershed characteristics (e.g., area/imperviousness distributions and
basic quality checks). The plots are intended for **quick diagnostics** before running depth mapping and the dynamic
simulation.

In [None]:
import os
import logging
import pandas as pd
import geopandas as gpd
import networkx as nx
from typing import Dict, Any, Optional, Set, List
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.ticker as mticker
import numpy as np
from scipy.spatial import Delaunay
from matplotlib.colors import ListedColormap, BoundaryNorm

# Logging Configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(module)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Configuration: Input/Output File Paths and Parameters
# --- Input Files (Update iteration number as needed) ---
ITERATION_NUMBER = 24 # IMPORTANT: Change this to match your input files.

BASINS_SHAPEFILE: str = f"final_basins_{ITERATION_NUMBER}.shp"
LOWEST_POINTS_SHAPEFILE: str = f"lowest_basin_points_{ITERATION_NUMBER}.shp"
SINKS_SHAPEFILE: str = "updated_sinks.shp"

# --- Output File ---
OUTPUT_CSV_FILE: str = "watershed_geomorphology_stats.csv"
OUTPUT_3D_PLOT_FILE: str = "watershed_3d_stats.png"

# Analysis Functions

def build_flow_graph(lowest_points_gdf: gpd.GeoDataFrame) -> nx.DiGraph:
    """Constructs a directed graph representing basin-to-basin flow."""
    logging.info("Building flow network graph...")
    G = nx.DiGraph()

    # Add all unique basins as nodes first
    all_basins = pd.concat([
        lowest_points_gdf['basin_id'],
        lowest_points_gdf['to_basin']
    ]).dropna().unique()
    G.add_nodes_from(all_basins)

    # Add directed edges for flow
    for _, row in lowest_points_gdf.iterrows():
        from_basin = row['basin_id']
        to_basin = row['to_basin']
        # Only add edges for valid, connected basins (-1 signifies an outlet)
        if pd.notna(from_basin) and pd.notna(to_basin) and to_basin != -1:
            G.add_edge(from_basin, to_basin)

    logging.info(f"Graph built with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")
    return G

def get_strahler_order(graph: nx.DiGraph) -> Dict[Any, int]:
    """
    Calculates the Strahler stream order for each node in the graph.
    Note: In network terms, we are looking at predecessors (upstream nodes).
    """
    orders = {}

    # Find terminal nodes (headwaters), which have an in-degree of 0
    terminals = [node for node, in_degree in graph.in_degree() if in_degree == 0]
    for node in terminals:
        orders[node] = 1

    # Iteratively process nodes
    nodes_to_process = list(graph.nodes())

    while len(orders) < len(nodes_to_process):
        num_ordered_before = len(orders)

        for node in nodes_to_process:
            if node in orders:
                continue

            predecessors = list(graph.predecessors(node))
            if not predecessors: # Should be caught by terminal check, but as a safeguard
                orders[node] = 1
                continue

            # Check if all predecessors have an order assigned
            if all(p in orders for p in predecessors):
                pred_orders = [orders[p] for p in predecessors]
                max_order = max(pred_orders)

                # If there are two or more predecessors with the max order, increment.
                if pred_orders.count(max_order) >= 2:
                    orders[node] = max_order + 1
                else:
                    orders[node] = max_order

        # If no new nodes were ordered in a full pass, there might be a cycle or
        # disconnected component
        if len(orders) == num_ordered_before:
            logging.warning("Stalled ordering process. Remaining nodes may be in cycles or disconnected.")
            # Assign a default order of -1 to remaining nodes
            for node in nodes_to_process:
                if node not in orders:
                    orders[node] = -1
            break

    return orders

def calculate_tree_width(graph: nx.DiGraph, root: Any) -> int:
    """
    Calculates the width of a tree (max nodes at any depth level) using BFS.
    The graph should be reversed (pointing from downstream to upstream).
    """
    if root not in graph:
        return 0

    level_counts = {}
    queue = [(root, 0)] # (node, level)
    visited = {root}

    while queue:
        current_node, level = queue.pop(0)
        level_counts[level] = level_counts.get(level, 0) + 1

        for neighbor in graph.neighbors(current_node):
            if neighbor not in visited:
                visited.add(neighbor)
                queue.append((neighbor, level + 1))

    return max(level_counts.values()) if level_counts else 0


def analyze_watersheds(basins_gdf: gpd.GeoDataFrame, lowest_points_gdf: gpd.GeoDataFrame, sinks_gdf: gpd.GeoDataFrame) -> pd.DataFrame:
    """
    Performs a full geomorphological analysis of all watersheds.
    """
    logging.info("Starting full watershed analysis...")

    # Build the main flow graph
    flow_graph = build_flow_graph(lowest_points_gdf)

    # Identify all distinct watershed trees (weakly connected components)
    watershed_subgraphs = [
        flow_graph.subgraph(c) for c in nx.weakly_connected_components(flow_graph)
    ]
    logging.info(f"Identified {len(watershed_subgraphs)} distinct watershed networks.")

    # Calculate Strahler order for the entire graph at once
    strahler_orders = get_strahler_order(flow_graph)

    # Prepare data for flow path length calculation
    sinks_geom_dict = {row['basin_id']: row.geometry for _, row in sinks_gdf.iterrows()}
    lowest_points_dict = {row['basin_id']: row.geometry for _, row in lowest_points_gdf.iterrows()}

    analysis_results = []

    for i, sub_graph in enumerate(watershed_subgraphs):
        if sub_graph.number_of_nodes() == 0:
            continue

        # Find the outlet node for this watershed (node with out-degree of 0 in the
        # subgraph)
        outlet_node = None
        for node in sub_graph.nodes():
            if sub_graph.out_degree(node) == 0:
                outlet_node = node
                break

        if outlet_node is None:
            logging.warning(f"Could not determine a single outlet for watershed {i}. Skipping.")
            continue

        # --- 1. Basic Metrics ---
        node_count = sub_graph.number_of_nodes()
        basin_ids_in_watershed = list(sub_graph.nodes())
        total_area = basins_gdf[basins_gdf['basin_id'].isin(basin_ids_in_watershed)]['area'].sum()

        # --- 2. Topological Metrics ---
        # Reverse the graph to trace from outlet upstream
        reversed_sub_graph = sub_graph.reverse(copy=True)
        tree_depth = nx.dag_longest_path_length(reversed_sub_graph)
        tree_width = calculate_tree_width(reversed_sub_graph, outlet_node)

        # Get the Strahler order of the outlet node
        watershed_strahler_order = strahler_orders.get(outlet_node, -1)

        # --- 3. Geometric / Hydraulic Metrics ---
        # This calculates the cumulative length of the straight-line edges connecting
        # the sink and outlet points of the sub-basins.
        cumulative_flow_path_length = 0.0
        for u, v in sub_graph.edges():
            from_sink = sinks_geom_dict.get(u)
            from_lp = lowest_points_dict.get(u)
            to_sink = sinks_geom_dict.get(v)

            if from_sink and from_lp:
                cumulative_flow_path_length += from_sink.distance(from_lp)
            if from_lp and to_sink:
                cumulative_flow_path_length += from_lp.distance(to_sink)

        drainage_density = (cumulative_flow_path_length / total_area) if total_area > 0 else 0

        analysis_results.append({
            'outlet_id': outlet_node,
            'node_count': node_count,
            'total_area_sqm': total_area,
            'tree_depth': tree_depth,
            'tree_width': tree_width,
            'strahler_order': watershed_strahler_order,
            'flow_path_length_m': cumulative_flow_path_length,
            'drainage_density': drainage_density
        })

    return pd.DataFrame(analysis_results)

def plot_3d_watershed_stats(df: pd.DataFrame, coeffs: np.ndarray):
    """Creates and saves a 3D scatter plot with a best-fit surface and drop lines."""
    logging.info("Generating 3D watershed statistics plot...")
    if df.empty or len(df) < 6: # Need enough points for a 2nd degree polynomial
        logging.warning("DataFrame has fewer than 6 points, skipping 3D plot.")
        return

    plt.rcParams['font.family'] = 'Arial'

    fig = plt.figure(figsize=(18, 15))
    ax = fig.add_subplot(111, projection='3d')

    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 1.0))

    ax.xaxis.line.set_color("black")
    ax.yaxis.line.set_color("black")
    ax.zaxis.line.set_color("black")
    ax.xaxis.line.set_linewidth(3)
    ax.yaxis.line.set_linewidth(3)
    ax.zaxis.line.set_linewidth(3)

    ax.grid(True, color='black', linestyle='-', linewidth=0.5)

    x_data = df['node_count']
    y_data = df['tree_depth']
    z_data = df['tree_width']

    # --- Discrete Color Bar ---
    order_values = sorted(df['strahler_order'].unique())
    cmap = plt.get_cmap('viridis', len(order_values))
    norm = BoundaryNorm(np.arange(min(order_values)-0.5, max(order_values)+1.5), cmap.N)

    sc = ax.scatter(
        x_data, y_data, z_data,
        c=df['strahler_order'],
        cmap=cmap, norm=norm,
        s=80, alpha=0.9, edgecolor='black', linewidth=0.7
    )

    # --- Set Axis Limits ---
    ax.set_xlim(160, 0) # Reversed as requested
    ax.set_ylim(0, 30)
    ax.set_zlim(0, 25)

    # --- Drop Lines and Wall Projections ---
    xlims = ax.get_xlim()
    ylims = ax.get_ylim()
    zlims = ax.get_zlim()

    # Generate colors for each point directly from the colormap and norm
    point_colors = cmap(norm(df['strahler_order']))

    for i in range(len(df)):
        # Vertical drop line to the floor
        ax.plot([x_data.iloc[i], x_data.iloc[i]], [y_data.iloc[i], y_data.iloc[i]], [z_data.iloc[i], zlims[0]],
                 c='black', linestyle='--', linewidth=0.8, alpha=0.6)

        # Projection dots on the walls with corresponding color
        ax.scatter(xlims[0], y_data.iloc[i], z_data.iloc[i], c=[point_colors[i]], s=20, alpha=0.7) # YZ plane
        ax.scatter(x_data.iloc[i], ylims[1], z_data.iloc[i], c=[point_colors[i]], s=20, alpha=0.7) # XZ plane
        ax.scatter(x_data.iloc[i], y_data.iloc[i], zlims[0], c=[point_colors[i]], s=20, alpha=0.7) # XY plane (floor)


    # --- Polynomial Best-Fit Surface that conforms to the data's shape ---
    x_surf = np.linspace(0, 160, 50)
    y_surf = np.linspace(0, 30, 50)
    X_surf, Y_surf = np.meshgrid(x_surf, y_surf)

    Z_surf = (coeffs[0] * X_surf**2 + coeffs[1] * Y_surf**2 + coeffs[2] * X_surf * Y_surf +
              coeffs[3] * X_surf + coeffs[4] * Y_surf + coeffs[5])

    points = np.c_[x_data, y_data]
    tri = Delaunay(points)
    p = np.c_[X_surf.ravel(), Y_surf.ravel()]
    mask = tri.find_simplex(p) < 0
    Z_surf.ravel()[mask] = np.nan

    ax.plot_surface(X_surf, Y_surf, Z_surf, cmap='coolwarm', alpha=0.3, edgecolor='none')
    ax.plot_wireframe(X_surf, Y_surf, Z_surf, color='black', linewidth=0.7, alpha=0.6)

    # --- Axes Labels and Title ---
    ax.set_xlabel('Number of Nodes', fontsize=28, labelpad=40)
    ax.set_ylabel('Tree Depth (Longest Path)', fontsize=28, labelpad=40)
    ax.set_zlabel('Tree Width (Max Nodes at Level)', fontsize=28, labelpad=40)
    ax.set_title('Watershed Network Characteristics', fontsize=32, pad=20)

    # --- Tick Font Size and Interval---
    ax.tick_params(axis='x', labelsize=22)
    ax.tick_params(axis='y', labelsize=22)
    ax.tick_params(axis='z', labelsize=22)
    ax.zaxis.set_major_locator(mticker.MultipleLocator(5))

    # --- Color Bar ---
    cbar = plt.colorbar(sc, ticks=order_values, pad=0.1)
    cbar.set_label('Strahler Order', size=24, labelpad=20)
    cbar.ax.tick_params(labelsize=20)

    # --- Draw remaining box edges ---
    ax.plot([xlims[0], xlims[1]], [ylims[0], ylims[0]], [zlims[0], zlims[0]], c='k', linewidth=3)
    ax.plot([xlims[0], xlims[1]], [ylims[1], ylims[1]], [zlims[0], zlims[0]], c='k', linewidth=3)
    ax.plot([xlims[0], xlims[0]], [ylims[0], ylims[1]], [zlims[0], zlims[0]], c='k', linewidth=3)
    ax.plot([xlims[1], xlims[1]], [ylims[0], ylims[1]], [zlims[0], zlims[0]], c='k', linewidth=3)
    ax.plot([xlims[0], xlims[0]], [ylims[0], ylims[0]], [zlims[0], zlims[1]], c='k', linewidth=3)
    ax.plot([xlims[0], xlims[1]], [ylims[1], ylims[1]], [zlims[1], zlims[1]], c='k', linewidth=3)
    ax.plot([xlims[0], xlims[0]], [ylims[0], ylims[1]], [zlims[1], zlims[1]], c='k', linewidth=3)


    try:
        plt.savefig(OUTPUT_3D_PLOT_FILE, dpi=450, bbox_inches='tight')
        logging.info(f"3D plot saved to {OUTPUT_3D_PLOT_FILE}")
    except Exception as e:
        logging.error(f"Failed to save 3D plot: {e}")
    plt.show()

# Main Execution Block
def main():
    """
    Main function to load data, run the analysis, and print/save results.
    """
    logging.info("--- Starting Watershed Geomorphology Analysis ---")

    # --- File Existence Check ---
    required_files = {
        "Basins": BASINS_SHAPEFILE,
        "Lowest Points": LOWEST_POINTS_SHAPEFILE,
        "Sinks": SINKS_SHAPEFILE
    }

    all_files_exist = True
    for name, path in required_files.items():
        if not os.path.exists(path):
            logging.critical(f"FATAL: Input file for {name} not found at '{path}'. Exiting.")
            all_files_exist = False
    if not all_files_exist:
        return

    # --- Load Data ---
    try:
        basins_gdf = gpd.read_file(BASINS_SHAPEFILE)
        lowest_points_gdf = gpd.read_file(LOWEST_POINTS_SHAPEFILE)
        sinks_gdf = gpd.read_file(SINKS_SHAPEFILE)
        logging.info("All input shapefiles loaded successfully.")
    except Exception as e:
        logging.critical(f"Fatal error loading input shapefiles: {e}. Exiting.")
        return

    # --- Run Analysis ---
    results_df = analyze_watersheds(basins_gdf, lowest_points_gdf, sinks_gdf)

    # --- Display and Save Results ---
    if not results_df.empty:
        logging.info("Analysis complete. Displaying results...")
        # Set pandas display options for better console output
        pd.set_option('display.max_rows', 500)
        pd.set_option('display.width', 1000)

        # Rename the outlet ID column for clarity
        results_df = results_df.rename(columns={'watershed_id': 'outlet_id'})

        print("\n--- Raw Watershed Geomorphology Statistics ---")
        print(results_df.to_string())

        # Calculate and print summary statistics, excluding the ID column
        print("\n\n--- Summary Statistics of All Watersheds ---")
        stats_df = results_df.drop(columns=['outlet_id'])
        summary_stats = stats_df.describe().loc[['mean', '50%', 'std', 'min', 'max']]
        summary_stats = summary_stats.rename(index={'50%': 'median'})
        print(summary_stats.to_string())

        # --- Fit model and print R-squared ---
        coeffs = None
        if len(results_df) >= 6:
            x_data = results_df['node_count']
            y_data = results_df['tree_depth']
            z_data = results_df['tree_width']
            # Create the design matrix for a 2nd degree polynomial: z = c0*x^2 + c1*y^2
            # + c2*xy + c3*x + c4*y + c5
            A = np.c_[x_data**2, y_data**2, x_data*y_data, x_data, y_data, np.ones(len(results_df))]
            coeffs, _, _, _ = np.linalg.lstsq(A, z_data, rcond=None)

            z_pred = A @ coeffs
            ss_res = np.sum((z_data - z_pred)**2)
            ss_tot = np.sum((z_data - np.mean(z_data))**2)
            r2 = 1 - (ss_res / ss_tot)
            print(f"\n--- Goodness of Fit for 3D Surface ---\nR-squared (R²): {r2:.4f}")

            # Print the plane equation
            print("\n--- Best-Fit Plane Equation ---")
            print(f"Tree Width = ({coeffs[0]:.4f} * NodeCount²) + ({coeffs[1]:.4f} * TreeDepth²) + ({coeffs[2]:.4f} * NodeCount * TreeDepth) + "
                  f"({coeffs[3]:.4f} * NodeCount) + ({coeffs[4]:.4f} * TreeDepth) + ({coeffs[5]:.4f})")

        else:
            logging.warning("Not enough data points to calculate a reliable best-fit surface.")

        # Export to CSV
        try:
            results_df.to_csv(OUTPUT_CSV_FILE, index=False)
            logging.info(f"Results successfully exported to {OUTPUT_CSV_FILE}")
        except Exception as e:
            logging.error(f"Failed to export results to CSV: {e}")

        # Generate the 3D plot if a model was successfully fitted
        if coeffs is not None:
            plot_3d_watershed_stats(results_df, coeffs)

    else:
        logging.warning("Analysis did not produce any results.")

    logging.info("--- Script finished. ---")


if __name__ == "__main__":
    main()

# Algorithm 3: Water Depth Mapping via Depression Filling

This section computes spatial water-depth patterns by filling depressions to specified stages/levels.

**Inputs:** depression polygons and elevation attributes/rasters.  
**Outputs:** depth maps (and optional figures) that can be used for interpretation and validation.

In [None]:
import logging
import os
from typing import Dict, List, Optional, Tuple
import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
from rasterio.features import rasterize

# Configuration Block
# All user-configurable settings are centralized here.

# Set up basic logging to provide progress and error information.
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(module)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# File Paths and Parameters
# Important: These paths should point to the outputs of the previous scripts.

# The final iteration number from the iterative merging script.
LAST_ITERATION_NUM: int = 24

CONFIG = {
    "inputs": {
        "dem": "DEM.tif",
        # Basin file from the previous iterative merging script.
        "basins": f"final_basins_{LAST_ITERATION_NUM}.shp",
        # Sinks file from the previous script, containing 'dZ' values.
        "sinks": "updated_sinks.shp"
    },
    "outputs": {
        "output_dir": "water_depth_analysis",
        "depth_raster": "water_depth_to_spill_point.tif"
    },
    "parameters": {
        # The column in the basins shapefile that contains the unique basin ID.
        "basin_id_column": "basin_id",
        # Use `all_touched=True` to ensure that pixels touched by a polygon's
        # edge are included in the rasterized mask.
        "rasterize_all_touched": True
    }
}


# Core Hydrological Filling Functions

def fill_single_basin_to_spill_elevation(
    dem_array: np.ndarray,
    basin_mask: np.ndarray,
    spill_elevation_value: float
) -> np.ndarray:
    """
    Calculates water depth within a single basin up to a given spill elevation.

    Args:
        dem_array: The full Digital Elevation Model as a NumPy array.
        basin_mask: A boolean NumPy array where `True` indicates the pixels
                    belonging to the current basin.
        spill_elevation_value: The elevation to which the basin will be "filled".

    Returns:
        A NumPy array of the same shape as the DEM, containing the calculated
        water depths for the specified basin.
    """
    if not basin_mask.any():
        return np.zeros_like(dem_array, dtype=np.float32)

    # Create an empty array to hold the results for this basin.
    basin_specific_water_depth = np.zeros_like(dem_array, dtype=np.float32)

    # Get the DEM elevations only for the pixels within the current basin.
    basin_dem_elevations = dem_array[basin_mask]

    # Calculate water depth: spill elevation - ground elevation.
    # np.maximum ensures that depth is never negative.
    calculated_depths = np.maximum(0, spill_elevation_value - basin_dem_elevations)

    # Place the calculated depths back into the full-sized array at the correct
    # location.
    basin_specific_water_depth[basin_mask] = calculated_depths
    return basin_specific_water_depth


def calculate_basin_water_depth_at_spill_points(config: Dict) -> None:
    """
    Calculates and saves a raster representing water depth in basins.

    This function orchestrates the entire process, from loading data to
    calculating depths and saving the final output raster.
    """
    dem_filepath = config['inputs']['dem']
    basins_filepath = config['inputs']['basins']
    sinks_filepath = config['inputs']['sinks']
    output_dir = config['outputs']['output_dir']

    logging.info(f"Starting basin filling. DEM: '{dem_filepath}', Basins: '{basins_filepath}'")

    # --- 1. Load and Prepare Data ---
    try:
        with rasterio.open(dem_filepath) as dem_src:
            dem_array = dem_src.read(1)
            dem_profile = dem_src.profile
            dem_transform = dem_src.transform
            dem_crs = dem_src.crs
    except Exception as e:
        logging.critical(f"Failed to read DEM: '{dem_filepath}'. Error: {e}")
        return

    try:
        basins_gdf = gpd.read_file(basins_filepath)
        if basins_gdf.crs != dem_crs:
            basins_gdf = basins_gdf.to_crs(dem_crs)
        sinks_gdf = gpd.read_file(sinks_filepath)
        if sinks_gdf.crs != dem_crs:
            sinks_gdf = sinks_gdf.to_crs(dem_crs)
    except Exception as e:
        logging.critical(f"Failed to load/reproject shapefiles. Error: {e}")
        return

    # --- 2. Validate Input Data ---
    basin_id_col = config['parameters']['basin_id_column']
    if basin_id_col not in basins_gdf.columns:
        logging.error(f"Critical: Column '{basin_id_col}' not found in basins file '{basins_filepath}'.")
        return
    required_sink_cols = ['basin_id', 'dZ', 'elev', 'geometry']
    if not all(col in sinks_gdf.columns for col in required_sink_cols):
        logging.error(f"Critical: Sinks file must contain {required_sink_cols}.")
        return

    # --- 3. Rasterize Basins ---
    logging.info(f"Rasterizing basins using ID column: '{basin_id_col}'...")
    try:
        # Ensure IDs are integers for consistent matching and rasterization values.
        basins_gdf[basin_id_col] = pd.to_numeric(basins_gdf[basin_id_col], errors='coerce').dropna().astype(int)
        sinks_gdf['basin_id'] = pd.to_numeric(sinks_gdf['basin_id'], errors='coerce').dropna().astype(int)

        shapes_for_rasterize = [
            (geom, int(b_id)) for geom, b_id in zip(basins_gdf.geometry, basins_gdf[basin_id_col])
            if geom and pd.notna(b_id)
        ]
        if not shapes_for_rasterize:
            logging.error("No valid basin geometries/IDs to rasterize.")
            return

        basin_id_raster = rasterize(
            shapes=shapes_for_rasterize,
            out_shape=dem_array.shape,
            transform=dem_transform,
            dtype=rasterio.int32,
            all_touched=config['parameters']['rasterize_all_touched'],
            fill=0
        )
    except Exception as e:
        logging.critical(f"Failed during data type conversion or rasterization: {e}")
        return
    logging.info("Basin rasterization complete.")

    # --- 4. Prepare Spill Elevation Map ---
    logging.info("Calculating spill elevations for each basin...")
    # The spill elevation is the sink's elevation plus the spill depth (dZ).
    sinks_gdf['spill_elev'] = sinks_gdf['elev'] + sinks_gdf['dZ']

    # Create a primary map from basin_id to its calculated spill elevation.
    basin_spill_elevation_map = sinks_gdf.dropna(
        subset=['spill_elev', 'basin_id']
    ).set_index('basin_id')['spill_elev'].to_dict()

    # Create a fallback map using just the sink elevation, in case dZ was NaN.
    sink_elevations_direct_map = sinks_gdf.dropna(
        subset=['elev', 'basin_id']
    ).set_index('basin_id')['elev'].to_dict()

    logging.info(f"Spill elevation map created for {len(basin_spill_elevation_map)} basins.")

    # --- 5. Fill Basins and Calculate Depth ---
    logging.info("Iterating through basins to calculate water depth...")
    final_cumulative_water_depth = np.zeros_like(dem_array, dtype=np.float32)
    unique_raster_ids = np.unique(basin_id_raster[basin_id_raster != 0])

    for basin_id in unique_raster_ids:
        current_mask = (basin_id_raster == basin_id)

        # Determine the spill elevation for the current basin.
        spill_elev = basin_spill_elevation_map.get(basin_id)

        # If primary spill elevation is missing (e.g., dZ was NaN), use fallback.
        if spill_elev is None or pd.isna(spill_elev):
            spill_elev = sink_elevations_direct_map.get(basin_id)
            if spill_elev is not None:
                logging.info(f"Basin {basin_id}: Using fallback sink elevation {spill_elev:.2f} as spill level.")
            else:
                logging.warning(f"No spill elevation or fallback found for basin {basin_id}. Skipping.")
                continue

        if pd.isna(spill_elev):
            logging.warning(f"Spill elevation for basin {basin_id} is NaN. Skipping.")
            continue

        # Calculate depth for this single basin.
        depth_this_basin = fill_single_basin_to_spill_elevation(dem_array, current_mask, spill_elev)

        # Add the results to the cumulative depth raster.
        final_cumulative_water_depth = np.maximum(final_cumulative_water_depth, depth_this_basin)

    # --- 6. Save Final Output Raster ---
    logging.info("Saving final water depth raster...")
    os.makedirs(output_dir, exist_ok=True)
    output_raster_path = os.path.join(output_dir, config['outputs']['depth_raster'])

    output_profile = dem_profile.copy()
    output_profile.update(dtype=rasterio.float32, compress='lzw', nodata=0.0)

    try:
        with rasterio.open(output_raster_path, 'w', **output_profile) as dst:
            dst.write(final_cumulative_water_depth.astype(rasterio.float32), 1)
        logging.info(f"Final water depth raster saved to: {output_raster_path}")
    except Exception as e:
        logging.error(f"Failed to save output raster: {e}")

    logging.info("Basin filling process completed.")


# Script Execution
if __name__ == "__main__":
    logging.info("Script started directly for basin water depth calculation.")

    # --- File Validation ---
    all_files_ok = True
    for key, path in CONFIG['inputs'].items():
        if not os.path.exists(path):
            logging.critical(f"Input file for '{key}' not found at: '{path}'")
            all_files_ok = False

    if all_files_ok:
        logging.info("All input files found. Starting process...")
        calculate_basin_water_depth_at_spill_points(CONFIG)
    else:
        logging.error("One or more input files are missing. Processing aborted.")

In [None]:
import os
import glob
import logging
import re
from typing import Tuple, Optional, List, Any
import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
from rasterio.mask import mask

# Logging Configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(module)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Configuration: Input/Output File Paths and Parameters

# The final iteration number.
LAST_ITERATION_NUM: int = 24

CONFIG = {
    "inputs": {
        "dem": "DEM.tif",
        "basins": f"final_basins_{LAST_ITERATION_NUM}.shp",
        "sinks": "updated_sinks.shp"
    },
    "outputs": {
        "output_dir": "water_depth_analysis",
        "depth_raster": "water_depth_to_spill_point.tif"
    },
    "parameters": {
        "basin_id_column": "basin_id",
        "rasterize_all_touched": True
    }
}

# Input File Paths
BASIN_SHAPEFILE_DIRECTORY: str = ""
BASIN_SHAPEFILE_PREFIX: str = "final_basins_"
BASIN_ID_COLUMN: str = "basin_id"
WATER_DEPTH_RASTER_PATH: str = os.path.join(CONFIG['outputs']['output_dir'], CONFIG['outputs']['depth_raster'])

# Output File Names
OUTPUT_BASINS_WITH_STATS_SHP: str = "basins_with_water_stats.shp"
OUTPUT_BASINS_STATS_CSV: str = "basins_water_stats.csv"

# Processing Parameters
TARGET_CRS: Optional[str] = None


# Helper Functions

def find_latest_iteration_shp_file(
    directory: str,
    prefix: str,
    extension: str = ".shp"
) -> Optional[str]:
    """
    Finds the shapefile with the highest iteration number in its name.

    Args:
        directory: The directory to search within.
        prefix: The common prefix of the filenames (e.g., "final_basins_").
        extension: The file extension (e.g., ".shp").

    Returns:
        The full path to the file with the highest iteration number, or None.
    """
    pattern = os.path.join(directory, f"{prefix}*{extension}")
    files = glob.glob(pattern)

    if not files:
        logging.warning(f"No files found matching pattern: {pattern}")
        return None

    latest_file_path = None
    max_iteration = -1

    filename_pattern_regex = re.compile(rf"{re.escape(prefix)}(\d+){re.escape(extension)}")

    for f_path in files:
        filename_only = os.path.basename(f_path)
        match = filename_pattern_regex.match(filename_only)
        if match:
            try:
                iteration_num = int(match.group(1))
                if iteration_num > max_iteration:
                    max_iteration = iteration_num
                    latest_file_path = f_path
            except ValueError:
                logging.warning(f"Could not parse iteration number from: {filename_only}")
                continue

    if latest_file_path:
        logging.info(f"Found latest iteration file: {latest_file_path}")
    else:
        logging.warning(f"Could not determine latest file for prefix '{prefix}'.")
    return latest_file_path


# Core Water Statistics Calculation Functions

def calculate_water_stats_for_single_basin(
    basin_geometry: Any,
    water_depth_raster_dataset: rasterio.DatasetReader
) -> Tuple[float, float]:
    """
    Calculates the total water volume and surface area for a single basin polygon.

    Args:
        basin_geometry: The Shapely geometry for the basin.
        water_depth_raster_dataset: An opened Rasterio dataset for the water depth raster.

    Returns:
        A tuple of (total_water_volume, total_water_surface_area).
    """
    if basin_geometry is None or basin_geometry.is_empty or not basin_geometry.is_valid:
        logging.warning("Input basin geometry is invalid. Returning zero for stats.")
        return 0.0, 0.0

    try:
        masked_image_array, _ = mask(
            dataset=water_depth_raster_dataset,
            shapes=[basin_geometry],
            crop=True,
            nodata=0.0,
            filled=True
        )

        water_depths_in_basin = masked_image_array[0]

        pixels_with_water_mask = water_depths_in_basin > 1e-6

        pixel_area = abs(water_depth_raster_dataset.transform.a * water_depth_raster_dataset.transform.e)

        total_surface_area = np.sum(pixels_with_water_mask) * pixel_area

        total_volume = np.sum(water_depths_in_basin[pixels_with_water_mask]) * pixel_area

        return float(total_volume), float(total_surface_area)

    except Exception as e:
        logging.error(f"Error processing basin geometry for stats: {e}", exc_info=True)
        return 0.0, 0.0


def process_all_basins_for_water_stats(
    basins_shapefile_path: str,
    water_depth_raster_path: str,
    basin_id_col_name: str
) -> Optional[gpd.GeoDataFrame]:
    """
    Processes each basin from a shapefile to calculate water statistics.

    Args:
        basins_shapefile_path: Path to the input basin shapefile.
        water_depth_raster_path: Path to the water depth raster.
        basin_id_col_name: The name of the column that uniquely identifies each basin.

    Returns:
        A GeoDataFrame with 'water_volume' and 'water_area' columns, or None on failure.
    """
    logging.info("Starting water depth statistics calculation.")
    logging.info(f"Using basins: '{basins_shapefile_path}' (ID column: '{basin_id_col_name}')")
    logging.info(f"Using depth raster: '{water_depth_raster_path}'")

    try:
        basins_gdf = gpd.read_file(basins_shapefile_path)
    except Exception as e:
        logging.critical(f"Failed to load basins shapefile: {e}")
        return None

    if basin_id_col_name not in basins_gdf.columns:
        logging.critical(f"ID column '{basin_id_col_name}' not found in basins file.")
        return None

    try:
        with rasterio.open(water_depth_raster_path) as depth_raster_src:
            target_crs = TARGET_CRS if TARGET_CRS else depth_raster_src.crs

            if basins_gdf.crs != target_crs:
                logging.info(f"Reprojecting basins to target CRS: '{target_crs}'.")
                basins_gdf = basins_gdf.to_crs(target_crs)

            volumes, areas = [], []

            logging.info("Processing each basin...")
            for _, basin_row in basins_gdf.iterrows():
                volume, area = calculate_water_stats_for_single_basin(basin_row.geometry, depth_raster_src)
                volumes.append(volume)
                areas.append(area)

            basins_gdf['water_volume'] = volumes
            basins_gdf['water_area'] = areas
            logging.info("Statistics calculation complete.")

            # --- Export Results ---
            output_dir = os.path.dirname(basins_shapefile_path)
            output_shp_path = os.path.join(output_dir, OUTPUT_BASINS_WITH_STATS_SHP)
            output_csv_path = os.path.join(output_dir, OUTPUT_BASINS_STATS_CSV)

            logging.info(f"Exporting results to Shapefile: '{output_shp_path}'")
            basins_gdf.to_file(output_shp_path)

            csv_summary = basins_gdf[[basin_id_col_name, 'water_volume', 'water_area']]
            csv_summary.to_csv(output_csv_path, index=False)
            logging.info(f"Exporting summary to CSV: '{output_csv_path}'")

            # --- Print Summary ---
            wet_basins = csv_summary[csv_summary['water_volume'] > 1e-6]
            logging.info("\n--- Summary Water Statistics ---")
            logging.info(f"Total Basins Processed: {len(basins_gdf)}")
            logging.info(f"Basins with Water: {len(wet_basins)}")
            logging.info(f"Total Water Volume: {csv_summary['water_volume'].sum():,.2f} cubic units")
            logging.info(f"Total Water Surface Area: {csv_summary['water_area'].sum():,.2f} square units")
            if not wet_basins.empty:
                logging.info(f"Mean Volume (wet basins): {wet_basins['water_volume'].mean():,.2f} cubic units")
            logging.info("---------------------------------")

            return basins_gdf

    except FileNotFoundError:
        logging.critical(f"Water depth raster not found at '{water_depth_raster_path}'.")
        return None
    except Exception as e:
        logging.critical(f"An error occurred during processing: {e}", exc_info=True)
        return None


# Script Execution
if __name__ == "__main__":
    logging.info("Script started directly.")

    input_basins_path = find_latest_iteration_shp_file(
        BASIN_SHAPEFILE_DIRECTORY,
        BASIN_SHAPEFILE_PREFIX
    )
    basin_id_column = BASIN_ID_COLUMN

    if not os.path.exists(WATER_DEPTH_RASTER_PATH):
        logging.critical(f"Water depth raster not found: '{WATER_DEPTH_RASTER_PATH}'")
        exit(1)
    if not input_basins_path or not os.path.exists(input_basins_path):
        logging.critical(f"Basin shapefile could not be found: '{input_basins_path}'")
        exit(1)

    process_all_basins_for_water_stats(
        basins_shapefile_path=input_basins_path,
        water_depth_raster_path=WATER_DEPTH_RASTER_PATH,
        basin_id_col_name=basin_id_column
    )

    logging.info("Script finished.")

In [None]:
import logging
import os
import time
from collections import deque
from typing import Dict, List, Optional, Tuple, Any

import geopandas as gpd
import networkx as nx
import pandas as pd
from shapely.geometry import Point, LineString, MultiLineString

# Logging Configuration
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(module)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# Configuration: Input/Output File Paths and Parameters

# Input File Paths
DEFAULT_EDGES_FILE: str = "flow_paths_all.shp"
DEFAULT_BASINS_FILE: str = "basins_with_water_stats.shp"

# Basin ID Column Configuration
BASIN_ID_COLUMN_NAME: str = "basin_id"

# Output File Name
OUTPUT_FLOW_ACCUMULATION_SHP: str = "basins_with_flow_accumulation.shp"

# Processing Parameters
TARGET_CRS: Optional[str] = None


# Data Loading and Graph Building Functions

def load_and_prepare_data(
    edges_filepath: str,
    basins_filepath: str,
    basin_id_col: str
) -> Tuple[Optional[gpd.GeoDataFrame], Optional[gpd.GeoDataFrame]]:
    """
    Loads and prepares the edges and basins shapefiles for analysis.
    """
    logging.info(f"Loading data: Edges='{edges_filepath}', Basins='{basins_filepath}'")
    try:
        edges_gdf = gpd.read_file(edges_filepath)
        basins_gdf = gpd.read_file(basins_filepath)
    except Exception as e:
        logging.critical(f"Failed to load input shapefiles: {e}")
        return None, None

    if basin_id_col not in basins_gdf.columns:
        logging.error(f"Specified basin ID column '{basin_id_col}' not found in '{basins_filepath}'.")
        return None, None

    # Standardize basin data
    basins_gdf['geo_area'] = basins_gdf.geometry.area
    for col in ['water_volume', 'water_area']:
        if col in basins_gdf.columns:
            basins_gdf[col] = pd.to_numeric(basins_gdf[col], errors='coerce').fillna(0.0)
        else:
            logging.warning(f"Column '{col}' not found in basins. Initializing to 0.0.")
            basins_gdf[col] = 0.0

    basins_gdf['area_accum'] = 0.0
    basins_gdf['stor_accum'] = 0.0

    # Keep only necessary columns to avoid clutter
    required_cols = [basin_id_col, 'geo_area', 'water_volume', 'water_area', 'area_accum', 'stor_accum', 'geometry']
    basins_gdf = basins_gdf[required_cols]

    logging.info("Data loading and preparation complete.")
    return edges_gdf, basins_gdf


def find_basin_containing_point(
    point: Point,
    basins_sindex: Any,
    basins_geometries: gpd.GeoSeries
) -> Optional[int]:
    """Finds the index of the basin containing a given point."""
    if point is None or point.is_empty:
        return None

    possible_matches_indices = list(basins_sindex.intersection(point.bounds))
    if not possible_matches_indices:
        return None

    candidate_basins = basins_geometries.iloc[possible_matches_indices]
    containing_mask = candidate_basins.contains(point)

    if containing_mask.any():
        return candidate_basins[containing_mask].index[0]
    return None


def build_flow_graph(
    edges_gdf: gpd.GeoDataFrame,
    basins_gdf: gpd.GeoDataFrame,
    basin_id_col: str
) -> nx.DiGraph:
    """Builds a directed graph representing flow paths between basins."""
    logging.info("Building flow direction graph from edges...")
    flow_graph = nx.DiGraph()

    # Add all basins as nodes, using their DataFrame index as the node ID
    for idx, row in basins_gdf.iterrows():
        flow_graph.add_node(idx, id_val=row[basin_id_col])

    basins_sindex = basins_gdf.sindex

    for _, edge_row in edges_gdf.iterrows():
        line = edge_row.geometry
        if not isinstance(line, LineString) or line.is_empty:
            continue

        start_point = Point(line.coords[0])
        end_point = Point(line.coords[-1])

        start_basin_idx = find_basin_containing_point(start_point, basins_sindex, basins_gdf.geometry)
        end_basin_idx = find_basin_containing_point(end_point, basins_sindex, basins_gdf.geometry)

        if start_basin_idx is not None and end_basin_idx is not None and start_basin_idx != end_basin_idx:
            flow_graph.add_edge(start_basin_idx, end_basin_idx)

    logging.info(f"Flow graph built with {flow_graph.number_of_nodes()} nodes and {flow_graph.number_of_edges()} edges.")
    return flow_graph


# Flow Accumulation Calculation

def calculate_flow_accumulation(
    flow_graph: nx.DiGraph,
    basins_gdf: gpd.GeoDataFrame
) -> gpd.GeoDataFrame:
    """
    Calculates cumulative area and storage for each basin using the flow graph.
    """
    logging.info("Calculating flow accumulation.")

    # Initialize accumulators with each basin's local values
    area_accumulator = basins_gdf['geo_area'].to_dict()
    storage_accumulator = basins_gdf['water_volume'].to_dict()

    if not nx.is_directed_acyclic_graph(flow_graph):
        logging.warning("Graph contains cycles! Accumulation results may be incorrect for basins in cycles.")
        # Attempt to proceed by breaking cycles, though results are not guaranteed
        cycles = list(nx.simple_cycles(flow_graph))
        for cycle in cycles:
            # Break each cycle by removing one edge
            edge_to_remove = (cycle[0], cycle[1])
            if flow_graph.has_edge(*edge_to_remove):
                flow_graph.remove_edge(*edge_to_remove)
                logging.info(f"Removed edge {edge_to_remove} to break a cycle.")

    try:
        # Process nodes in topological order (from upstream to downstream)
        for node_idx in nx.topological_sort(flow_graph):
            # For each downstream neighbor of the current node...
            for successor_idx in flow_graph.successors(node_idx):
                # ...add the current node's total accumulated value to it.
                area_accumulator[successor_idx] += area_accumulator[node_idx]
                storage_accumulator[successor_idx] += storage_accumulator[node_idx]
    except nx.NetworkXUnfeasible:
        logging.error("Topological sort failed even after attempting to break cycles. Accumulation aborted.")
        return basins_gdf # Return with only local values

    # Update the GeoDataFrame with the final accumulated values
    basins_gdf['area_accum'] = basins_gdf.index.map(area_accumulator)
    basins_gdf['stor_accum'] = basins_gdf.index.map(storage_accumulator)

    logging.info("Flow accumulation calculation complete.")
    return basins_gdf


# Main Orchestration Function

def run_flow_accumulation_analysis(
    edges_input_file: str,
    basins_input_file: str,
    output_shapefile: str,
    basin_id_column: str
):
    """Main function to orchestrate the flow accumulation analysis."""
    start_time = time.time()
    logging.info("--- Starting Flow Accumulation Analysis ---")

    edges_gdf, basins_gdf = load_and_prepare_data(edges_input_file, basins_input_file, basin_id_col=basin_id_column)
    if basins_gdf is None:
        logging.error("Failed to load or prepare data. Aborting.")
        return

    flow_network = build_flow_graph(edges_gdf, basins_gdf, basin_id_col=basin_id_column)

    if flow_network.number_of_edges() == 0:
        logging.warning("Flow graph has no edges. Accumulation will only reflect local values.")

    results_gdf = calculate_flow_accumulation(flow_network, basins_gdf)

    logging.info(f"Exporting results to: {output_shapefile}")
    results_gdf.to_file(output_shapefile)

    logging.info("\n--- Flow Accumulation Summary ---")
    logging.info(f"Total basins processed: {len(results_gdf)}")
    logging.info(f"Max accumulated area: {results_gdf['area_accum'].max():,.2f} sq. units")
    logging.info(f"Max accumulated storage: {results_gdf['stor_accum'].max():,.2f} cubic units")

    elapsed = time.time() - start_time
    logging.info(f"--- Analysis finished in {elapsed:.2f} seconds ---")


# Script Execution
if __name__ == "__main__":
    logging.info("Script started directly for flow accumulation.")

    basins_file = DEFAULT_BASINS_FILE
    id_column = BASIN_ID_COLUMN_NAME

    edges_file = DEFAULT_EDGES_FILE

    if not os.path.exists(edges_file):
        logging.critical(f"Edges file not found: '{edges_file}'")
    elif not basins_file or not os.path.exists(basins_file):
        logging.critical(f"Basins file not found or determined: '{basins_file}'")
    else:
        run_flow_accumulation_analysis(
            edges_input_file=edges_file,
            basins_input_file=basins_file,
            output_shapefile=OUTPUT_FLOW_ACCUMULATION_SHP,
            basin_id_column=id_column
        )

    logging.info("Script finished.")

# Algorithms 4–6: Drainage Tree, Initial Hydrograph, and Dynamic Fill–Spill–Merge Simulation

- **Algorithm 4** builds a drainage-tree network from depression connectivity (nodes/edges).
- **Algorithm 5** generates an initial hydrograph per depression for a rainfall event.
- **Algorithm 6** performs a dynamic simulation where depressions **fill, spill, and merge** over time.

These algorithms are typically run after Algorithms 1–3 have produced consistent basin/inlet assignments and
diagnostic checks look reasonable.

In [None]:
import math
import geopandas as gpd
import networkx as nx
import pandas as pd
import numpy as np
import rasterio
from rasterio.mask import mask
import matplotlib.pyplot as plt
from matplotlib.patches import Wedge, Circle, FancyArrowPatch
from matplotlib.colors import Normalize
import matplotlib.colors as colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Set, Tuple, Union
import logging
import matplotlib.patches as mpatches
from shapely.geometry import Point, Polygon, MultiPoint, LineString, MultiPolygon
from shapely.ops import nearest_points
import matplotlib.cm as cm
import traceback
import os
import sys
import time
from typing import List, Dict, Tuple, Optional, Set
import matplotlib.ticker as ticker
import matplotlib.cm as cm
from scipy.spatial import ConvexHull
import matplotlib.patches as patches

# User inputs (edit as needed)
# File Paths
lowest_points_file = r"lowest_basin_points_iter_24.shp"
basins_file = r"basins_flow_accumulation.shp"
inlets_file = r"inlet.shp"
study_area_file = r"studyarea.shp"
impervious_tif = r"TAMUCC_Imperviousness.tif"
dem_file = r"DEM.tif"
sinks_file = r"updated_sinks.shp"
rainfall_csv = r"E:\GIS_code\weatherstation_plot\export\rain_2025-06-16_02-20-00_2025-06-16_12-30-00.csv"

# Script Parameters
selected_inlets = ['7','444','100','472','486','493','484','2170','622','561','2461']
simulation_time_min = 240
timestep_min = 0.1
rain_unit = 'in/hr'
P2_USER_VALUE = 4.53 # P2 value used for Tc/Lag calculations

# Model Coefficients (Base Values)
N_IMPERV_BASE: float = 0.011
N_PERV_BASE: float = 0.15
INFILTRATION_RATE_CM_S_BASE: float = 0.000964667


# Calibration Adjustments (Multiplicative Factors or Percentage Changes)
# These parameters allow for sensitivity analysis or calibration by adjusting base
# values.
# A factor of 1.0 or percentage of 0.0 means no change from base values.

# 1. Slope Adjustment (for Tc calculation):
#    Multiplicative factor for the calculated land slope (S_tc) used in Tc calculations.
#    e.g., 1.02 means increase slope by 2% (S_tc_adjusted = S_tc_original * 1.02).
SLOPE_ADJUSTMENT_FACTOR: float = 1.00 # Example: 1.0 (no change)

# 2. Depression Storage Adjustment:
#    Multiplicative factor applied to the SUM of base depression storage, texture
# storage, and interception.
#    e.g., 1.1 means increase total initial storage by 10%.
DEPRESSION_STORAGE_ADJUSTMENT_FACTOR: float = 1.00 # Example: 1.0 (no change)

# 3. Infiltration Rate Adjustment:
#    Multiplicative factor for the base infiltration rate (INFILTRATION_RATE_CM_S_BASE).
#    e.g., 0.9 means decrease infiltration rate by 10% (infil_rate_adj = infil_rate_base
# * 0.9).
INFILTRATION_RATE_ADJUSTMENT_FACTOR: float = 1.00 # Example: 1.0 (no change)

# 4. Manning's N Impervious Adjustment:
#    Multiplicative factor for the base Manning's n for impervious surfaces
# (N_IMPERV_BASE).
#    e.g., 1.1 means N_imperv_adjusted = N_imperv_base * 1.1.
N_IMPERV_ADJUSTMENT_FACTOR: float = 1.00 # Example: 1.0 (no change)

# 5. Manning's N Pervious Adjustment:
#    Multiplicative factor for the base Manning's n for pervious surfaces (N_PERV_BASE).
#    e.g., 1.1 means N_perv_adjusted = N_perv_base * 1.1.
N_PERV_ADJUSTMENT_FACTOR: float = 1.00    # Example: 1.0 (no change)

# 6. Storage Depth Adjustments (Absolute Values in Meters)
#    These parameters define a uniform depth of initial storage across surface types.
#    This volume is added to the base depression storage from the shapefile's
# 'water_volu' field.

#    a) Texture Storage on Impervious Surfaces:
#    Represents water held in the texture of impervious surfaces (e.g., pavement).
TEXTURE_STORAGE_DEPTH_IMPERVIOUS_M: float = 0.003 # meters (e.g., 2.5 mm)

#    b) Interception on Pervious Surfaces:
#    Represents water intercepted by vegetation on pervious surfaces.
INTERCEPTION_STORAGE_DEPTH_PERVIOUS_M: float = 0.002 # meters (e.g., 1.0 mm)


# Derived Parameters (incorporating adjustments for use in simulation)
# These parameters are calculated from base values and adjustment factors.
N_IMPERV: float = N_IMPERV_BASE * N_IMPERV_ADJUSTMENT_FACTOR
N_PERV: float = N_PERV_BASE * N_PERV_ADJUSTMENT_FACTOR
INFILTRATION_RATE_CM_S: float = INFILTRATION_RATE_CM_S_BASE * INFILTRATION_RATE_ADJUSTMENT_FACTOR

# Conversion Factors
METERS_TO_FEET = 3.28084

# Plotting/Layout Constants
VERTICAL_SPACING = 1.5
HORIZONTAL_SPACING = 1.5

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

try:
    matplotlib.rcParams.update({'font.sans-serif': 'Arial', 'font.family': 'sans-serif'})
    logging.info("Attempting to set plot font to Arial.")
except Exception as e:
    logging.warning(f"Could not set font to Arial, using default: {e}")
    
# Geometric Helper Functions

def prep(d, tkey, qkey):
    # Prepares timeline and flow data, converting flow to m³/min
    if d is None: return None, None
    t = d.get(tkey); q = d.get(qkey)
    if t is None or q is None: return None, None
    if not isinstance(t, (list, np.ndarray)) or not isinstance(q, (list, np.ndarray)): return None, None
    try:
        ta = np.array(t, dtype=float)
        qa = np.array(q, dtype=float)
        if np.any(np.isnan(ta)) or np.any(np.isinf(ta)):
            logging.warning(f"prep: Found NaN/Inf in time array for key {tkey}. Returning None.")
            return None, None
        if np.any(np.isnan(qa)) or np.any(np.isinf(qa)):
            logging.warning(f"prep: Found NaN/Inf in quantity array for key {qkey}. Returning None.")
            return None, None
    except (ValueError, TypeError) as e:
        logging.warning(f"prep: Error converting data for keys {tkey}/{qkey}: {e}. Returning None.")
        return None, None
    n = min(ta.size, qa.size);
    if n == 0: return None, None
    return ta[:n], qa[:n] * 60.0

def safe_interp(x_new, x_old, y_old, default_val=0.0, label=""):
    # Interpolates y_old onto x_new timeline safely
    try:
        x_new = np.asarray(x_new, dtype=float)
        x_old = np.asarray(x_old, dtype=float)
        y_old = np.asarray(y_old, dtype=float)
        if x_new.size == 0: return np.full_like(x_new, default_val, dtype=float)
        if x_old.size == 0 or y_old.size == 0: return np.full(x_new.shape, default_val, dtype=float)
        if x_old.size != y_old.size:
            logging.warning(f"safe_interp ({label}): x_old ({x_old.size}) and y_old ({y_old.size}) size mismatch. Returning default.")
            return np.full(x_new.shape, default_val, dtype=float)
        if np.any(np.isnan(x_old)) or np.any(np.isinf(x_old)):
            logging.warning(f"safe_interp ({label}): NaN/Inf found in x_old. Returning default.")
            return np.full(x_new.shape, default_val, dtype=float)
        if np.any(np.isnan(y_old)) or np.any(np.isinf(y_old)):
            logging.warning(f"safe_interp ({label}): NaN/Inf found in y_old. Replacing with {default_val} before interp.")
            y_old = np.nan_to_num(y_old, nan=default_val, posinf=default_val, neginf=default_val)
        if not np.all(np.diff(x_old) >= 0):
            sort_indices = np.argsort(x_old); x_old = x_old[sort_indices]; y_old = y_old[sort_indices]
        unique_indices = np.unique(x_old, return_index=True)[1]
        if len(unique_indices) < len(x_old):
            x_old = x_old[unique_indices]; y_old = y_old[unique_indices]
        if x_old.size < 2:
            logging.warning(f"safe_interp ({label}): Too few points ({x_old.size}) left after cleaning x_old. Cannot interp, returning default.")
            return np.full(x_new.shape, default_val, dtype=float)
        interp_result = np.interp(x_new, x_old, y_old, left=default_val, right=default_val)
        interp_result = np.nan_to_num(interp_result.astype(float), nan=default_val, posinf=default_val, neginf=default_val)
        return interp_result
    except Exception as e:
        logging.error(f"safe_interp ({label}) failed unexpectedly: {e}")
        try: return np.full(np.asarray(x_new).shape, default_val, dtype=float)
        except: return np.array([default_val], dtype=float)

def find_farthest_vertex(geometry: Union[Polygon, MultiPolygon], point: Point) -> Optional[Point]:
    """
    Finds the vertex on the boundary of a Polygon or MultiPolygon that is farthest from the given point.
    This version correctly handles both single and multi-part geometries.
    """
    max_dist = -1.0
    farthest_v = None
    if not geometry or not geometry.is_valid or not point or not isinstance(point, Point):
        logging.warning("Invalid input geometry for find_farthest_vertex.")
        return None

    # Start Of Correction
    # Create a list of polygons to process. This handles both single and multi-part
    # inputs.
    polygons_to_process = []
    if isinstance(geometry, Polygon):
        polygons_to_process.append(geometry)
    elif isinstance(geometry, MultiPolygon):
        polygons_to_process.extend(list(geometry.geoms))
    else:
        logging.warning(f"Unsupported geometry type for find_farthest_vertex: {type(geometry)}")
        return None

    all_coords = []
    for poly in polygons_to_process:
        if poly.exterior and len(poly.exterior.coords) > 0:
            all_coords.extend(poly.exterior.coords)
        for interior in poly.interiors:
            if len(interior.coords) > 0:
                all_coords.extend(interior.coords)
    # End Of Correction
    
    if not all_coords:
        logging.warning(f"No coordinates extracted from geometry. Cannot find farthest vertex.")
        return None
        
    try:
        # Using a set to get unique vertex coordinates before creating points
        unique_coords = set(all_coords)
        vertex_points = [Point(xy) for xy in unique_coords]
        if not vertex_points:
            raise ValueError("No valid Point objects created from unique coordinates.")
    except Exception as geom_err:
        logging.error(f"Error creating Point objects from polygon vertices: {geom_err}")
        return None

    # Iterate through the created vertex points to find the farthest one
    for vertex_geom in vertex_points:
        try:
            dist = point.distance(vertex_geom)
            if dist > max_dist:
                max_dist = dist
                farthest_v = vertex_geom
        except Exception as dist_err:
            logging.warning(f"Error calculating distance for vertex {vertex_geom}: {dist_err}")
            continue
            
    if farthest_v is None:
        logging.warning("Could not determine a farthest vertex after checking all points.")
        
    return farthest_v

def calculate_tc_parameters_farthest_path(basin_poly: Union[Polygon, MultiPolygon],
                                           sink_pt: Point,
                                           dem_file: str) -> Tuple[Optional[float], Optional[float]]:
    if not isinstance(basin_poly, (Polygon, MultiPolygon)) or not isinstance(sink_pt, Point):
        logging.warning(f"Invalid input types for Tc parameters: poly={type(basin_poly)}, sink={type(sink_pt)}")
        return None, None
    if not basin_poly.is_valid or not sink_pt.is_valid:
        logging.warning(f"Invalid geometry provided for Tc parameter calculation.")
        return None, None
    try:
        farthest_vertex = find_farthest_vertex(basin_poly, sink_pt)
        if not farthest_vertex:
            logging.warning("Could not determine farthest vertex for Tc path.")
            return None, None
        L_tc_m = farthest_vertex.distance(sink_pt)
        if L_tc_m < 0.1:
            logging.debug(f"Farthest vertex distance to sink < 0.1m ({L_tc_m:.3f}m). Cannot reliably calculate slope. Returning L={L_tc_m}, slope=None.")
            return L_tc_m, None
        with rasterio.open(dem_file) as src:
            coords_to_sample = [(farthest_vertex.x, farthest_vertex.y), (sink_pt.x, sink_pt.y)]
            try:
                elevations_raw = list(src.sample(coords_to_sample))
                elevations = [val[0] for val in elevations_raw]
            except Exception as sample_err:
                logging.error(f"DEM sampling failed for Tc path calculation: {sample_err}")
                return L_tc_m, None
            if len(elevations) < 2:
                logging.error("DEM sampling returned fewer than 2 elevations.")
                return L_tc_m, None
            elev_farthest = elevations[0]
            elev_sink = elevations[1]
            nodata_val = src.nodata
            valid_farthest = np.isfinite(elev_farthest) and (nodata_val is None or elev_farthest != nodata_val)
            valid_sink = np.isfinite(elev_sink) and (nodata_val is None or elev_sink != nodata_val)
            if not valid_farthest or not valid_sink:
                logging.warning(f"Invalid elevation data. Farthest: {elev_farthest} (type {type(elev_farthest)}), Sink: {elev_sink} (type {type(elev_sink)}), DEM NoData value: {nodata_val}. Cannot calculate slope.")
                return L_tc_m, None
            elev_farthest = float(elev_farthest)
            elev_sink = float(elev_sink)
            delta_elev = abs(elev_farthest - elev_sink)
            S_tc_dim = delta_elev / L_tc_m
            S_tc_dim = max(S_tc_dim, 1e-5)
            logging.debug(f"Tc Params: L={L_tc_m:.2f}m, Elevs=[{elev_farthest:.3f}, {elev_sink:.3f}], S={S_tc_dim:.5f}")
            return L_tc_m, S_tc_dim
    except Exception as e:
        logging.error(f"Error calculating Tc parameters from farthest path: {e}")
        traceback.print_exc()
        return None, None

def compute_tc_sheet_flow(n_dimensionless: float, L_ft: float, P2_in: float, S_ft_ft: float) -> Optional[float]:
    tc_min = np.nan
    try:
        if not (isinstance(n_dimensionless, (int, float)) and n_dimensionless > 0 and
                isinstance(L_ft, (int, float)) and L_ft > 0 and
                isinstance(P2_in, (int, float)) and P2_in > 0 and
                isinstance(S_ft_ft, (int, float)) and S_ft_ft > 0):
            logging.debug(f"[SheetFlow] Invalid input (n={n_dimensionless}, L={L_ft}, P2={P2_in}, S={S_ft_ft}).")
            return None
        tc_hr = (0.007 * (n_dimensionless * L_ft)**0.8) / (P2_in**0.5 * S_ft_ft**0.4)
        tc_min = tc_hr * 60.0
        return tc_min
    except (ValueError, OverflowError, RuntimeWarning) as e:
        logging.error(f"[SheetFlow] Math error: {e}", exc_info=False)
        return None

def compute_mannings_n(imperv_area, total_area):
    if total_area <= 1e-6:
        return N_PERV
    imperv_area = float(imperv_area); total_area = float(total_area)
    imperv_ratio = max(0.0, min(1.0, imperv_area / total_area))
    n_total = imperv_ratio * N_IMPERV + (1 - imperv_ratio) * N_PERV
    return n_total

def get_rainfall_start_time(rainfall_csv_path: str) -> Optional[pd.Timestamp]:
    """
    Reads a rainfall CSV file and returns the earliest timestamp found.
    This serves as the absolute start time for the simulation.
    """
    try:
        # We only need the 'time' column to find the minimum (start) time.
        df = pd.read_csv(rainfall_csv_path, usecols=['time'])
        if df.empty or 'time' not in df.columns:
            logging.warning(f"Cannot find 'time' column or data in {rainfall_csv_path} to determine simulation start time.")
            return None
        # Convert to datetime objects and return the earliest one.
        return pd.to_datetime(df['time']).min()
    except Exception as e:
        logging.error(f"Could not determine start time from rainfall CSV '{rainfall_csv_path}': {e}")
        return None
# 1) Data Classes for Drainage Network
@dataclass
class DrainageNode:
    basin_id: str
    area: float
    impervious_area: float
    pervious_area: float
    water_volume: float
    effective_depth: float
    storage_capacity: float
    children: List['DrainageNode']
    parent: Optional['DrainageNode'] = None
    is_inlet: bool = False
    runoff_data: Optional[Dict] = None

    def __hash__(self):
        return hash(self.basin_id)

    def __eq__(self, other):
        if not isinstance(other, DrainageNode):
            return False
        return self.basin_id == other.basin_id

    def add_child(self, child: 'DrainageNode'):
        self.children.append(child)
        child.parent = self

    def get_subtree_stats(self) -> tuple[float, float, float]:
        total_area = self.area
        total_storage_capacity = self.storage_capacity
        for child in self.children:
            c_area, c_cap, _ = child.get_subtree_stats()
            total_area += c_area
            total_storage_capacity += c_cap
        avg_eff_depth = total_storage_capacity / total_area if total_area > 1e-9 else 0
        return total_area, total_storage_capacity, avg_eff_depth

@dataclass
class BasinState:
    basin_id: str
    current_area: float
    impervious_area: float
    pervious_area: float
    max_volume: float
    current_volume: float
    effective_depth: float
    alpha: float
    infiltration_rate: float = 0.0
    infiltrated_volume: float = 0.0
    parent_id: Optional[str] = None
    merged_from: List[str] = field(default_factory=list)
    is_merged: bool = False
    spilled_volume: float = 0.0
    runoff_state: Dict = field(default_factory=dict)
    lag_time: float = 0.0
    response_start_time: float = 0.0
    initial_transit_water: float = 0.0

    def __post_init__(self):
        if not isinstance(self.merged_from, list) or not self.merged_from:
            self.merged_from = [self.basin_id]
        if self.runoff_state is None:
            self.runoff_state = {}

# 2) DrainageForest: Building the Drainage Network
class DrainageForest:
    def __init__(self):
        self.inlet_roots: Dict[str, DrainageNode] = {}
        self.all_nodes: Dict[str, DrainageNode] = {}

    def calculate_effective_areas(self, basins_gdf: gpd.GeoDataFrame, impervious_tif: str) -> Dict[str, tuple]:
        """
        Calculates impervious and pervious areas for each basin.
        MODIFIED: Applies a 98% factor to the calculated impervious area, reclassifying
        the remaining 2% of the original impervious area as pervious.
        """
        effective_areas = {}
        try:
            with rasterio.open(impervious_tif) as src:
                if basins_gdf.crs != src.crs:
                    logging.info(f"Reprojecting basins GDF from {basins_gdf.crs} to raster CRS {src.crs}")
                    basins_gdf_proj = basins_gdf.to_crs(src.crs)
                else:
                    basins_gdf_proj = basins_gdf

                logging.info(f"Calculating effective areas using CRS: {src.crs}")
                for idx, basin_polygon_feature in basins_gdf_proj.iterrows():
                    original_basin_data = basins_gdf.loc[idx]
                    basin_id_str = str(original_basin_data['basin_id'])
                    tot_area_m2 = float(original_basin_data['area'])

                    imp_area_raster_calc_m2 = 0.0
                    try:
                        geom = [basin_polygon_feature.geometry]
                        out_image, out_transform = mask(src, geom, crop=True, nodata=src.nodata)

                        if out_image.size > 0:
                            if out_image.ndim == 3 and out_image.shape[0] == 1:
                                out_image = out_image.squeeze(axis=0)

                            valid_mask = (out_image != src.nodata) & (~np.isnan(out_image))
                            imperv_pix = np.sum(out_image[valid_mask] == 1)
                            perv_pix = np.sum(out_image[valid_mask] == 2)
                            total_valid_pix = imperv_pix + perv_pix

                            if total_valid_pix > 0:
                                imperv_ratio_from_raster = imperv_pix / total_valid_pix
                                imp_area_raster_calc_m2 = tot_area_m2 * imperv_ratio_from_raster

                        # NEW LOGIC: Apply 98% factor to impervious area
                        # The original pervious area is the total area minus the raster-
                        # calculated impervious area.
                        original_perv_area_m2 = tot_area_m2 - imp_area_raster_calc_m2

                        # The amount of the original impervious area to be reclassified
                        # as pervious (2%).
                        reclassified_area_m2 = imp_area_raster_calc_m2 * 0.02

                        # The new final impervious area is 98% of the original
                        # calculation.
                        final_impervious_area_m2 = imp_area_raster_calc_m2 * 0.98

                        # The new final pervious area is the original pervious area plus
                        # the reclassified amount.
                        final_pervious_area_m2 = original_perv_area_m2 + reclassified_area_m2
                        # End Of New Logic

                        # Store the final calculated values.
                        effective_areas[basin_id_str] = (max(0.0, final_impervious_area_m2), max(0.0, final_pervious_area_m2))

                    except (ValueError, IndexError) as ve:
                        logging.error(f"Geometry or mask error for basin {basin_id_str}: {ve}")
                        effective_areas[basin_id_str] = (0.0, tot_area_m2)
                    except Exception as e_mask:
                        logging.error(f"Error processing basin {basin_id_str} for effective area (masking part): {e_mask}")
                        traceback.print_exc()
                        effective_areas[basin_id_str] = (0.0, tot_area_m2)

        except rasterio.RasterioIOError as e_rio:
            logging.error(f"Error opening impervious raster {impervious_tif}: {e_rio}")
            raise
        except Exception as e_outer:
            logging.error(f"An unexpected error occurred during effective area calculation: {e_outer}")
            traceback.print_exc()
            raise
        return effective_areas

    def build_forest(self,
                     lowest_points_gdf: gpd.GeoDataFrame,
                     basins_gdf: gpd.GeoDataFrame,
                     inlets_gdf: gpd.GeoDataFrame,
                     study_area_gdf: gpd.GeoDataFrame,
                     impervious_tif: str
                     ) -> None:
        logging.info("Building drainage forest...")
        target_crs = basins_gdf.crs
        if target_crs is None:
            logging.warning("Basins GDF has no CRS defined. Assuming compatibility or check inputs.")
            
        if target_crs:
            logging.info(f"Using target CRS for forest building: {target_crs}")
            try:
                if lowest_points_gdf.crs != target_crs: lowest_points_gdf = lowest_points_gdf.to_crs(target_crs)
                if inlets_gdf.crs != target_crs: inlets_gdf = inlets_gdf.to_crs(target_crs)
                if study_area_gdf.crs != target_crs: study_area_gdf = study_area_gdf.to_crs(target_crs)
            except Exception as e:
                logging.error(f"Error reprojecting input GDFs to {target_crs}: {e}")
                return
        else:
            logging.warning("Cannot reproject layers as target CRS is unknown. Assuming CRS match.")

        eff_areas = self.calculate_effective_areas(basins_gdf, impervious_tif)

        try:
            study_area_geom = study_area_gdf.geometry.unary_union
        except Exception as e:
            logging.error(f"Error creating study area union: {e}")
            return

        if not inlets_gdf.has_sindex: inlets_gdf.sindex
        possible_inlets_idx = list(inlets_gdf.sindex.intersection(study_area_geom.bounds))
        possible_inlets = inlets_gdf.iloc[possible_inlets_idx]
        inlets_within = possible_inlets[possible_inlets.intersects(study_area_geom)]
        logging.info(f"Found {len(inlets_within)} inlets within study area geometry.")

        inlet_basins = set()
        if not basins_gdf.has_sindex: basins_gdf.sindex
        for idx, inlet in inlets_within.iterrows():
            possible_matches_idx = list(basins_gdf.sindex.intersection(inlet.geometry.bounds))
            if not possible_matches_idx: continue
            possible_matches = basins_gdf.iloc[possible_matches_idx]
            containing_basins = possible_matches[possible_matches.contains(inlet.geometry)]
            if not containing_basins.empty:
                basin_id = str(containing_basins.iloc[0]['basin_id'])
                inlet_basins.add(basin_id)
            else:
                logging.warning(f"Inlet {idx} at {inlet.geometry.coords[0]} not strictly contained within any basin polygon.")
        logging.info(f"Identified {len(inlet_basins)} unique inlet basins: {sorted(list(inlet_basins))}")

        G = nx.DiGraph()
        basins_gdf['basin_id_str'] = basins_gdf['basin_id'].astype(str)
        lowest_points_gdf['basin_id_str'] = lowest_points_gdf['basin_id'].astype(str)
        lowest_points_gdf['to_basin_str'] = lowest_points_gdf['to_basin'].astype(str)
        
        valid_basin_ids = set(basins_gdf['basin_id_str'])
        valid_conn = lowest_points_gdf[lowest_points_gdf['basin_id_str'].isin(valid_basin_ids) & lowest_points_gdf['to_basin_str'].isin(valid_basin_ids)]
        logging.info(f"Found {len(valid_conn)} valid connections between existing basins.")
        for _, row in valid_conn.iterrows():
            G.add_edge(row['basin_id_str'], row['to_basin_str'])
        logging.info(f"Connectivity graph built with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")

        min_area_node_init = 1e-3
        min_volume_node_init = 1e-5

        nodes_to_process = valid_basin_ids
        logging.info(f"Attempting to create DrainageNode objects for {len(nodes_to_process)} basins...")
        nodes_created_count = 0
        for b_id_str in nodes_to_process:
            bdata_rows = basins_gdf[basins_gdf['basin_id_str'] == b_id_str]
            if bdata_rows.empty:
                logging.warning(f"Basin data unexpectedly not found for basin_id '{b_id_str}'. Skipping node creation.")
                continue
            bdata = bdata_rows.iloc[0]
            
            # Get base area and effective area info
            original_area_val = float(bdata['area'])
            a = max(min_area_node_init, original_area_val)
            impA, pervA = eff_areas.get(b_id_str, (0, a if a > 1e-9 else 0.0))

            # Calculate Storage Capacity
            base_storage_volume_from_input: float
            try:
                base_storage_volume_from_input = float(bdata['water_volu'])
            except (KeyError, ValueError, TypeError) as e:
                logging.warning(f"Basin {b_id_str}: Missing or invalid 'water_volu'. Defaulting to min_volume_node_init. Error: {e}")
                base_storage_volume_from_input = min_volume_node_init

            # NEW: Calculate additional storage from impervious texture and pervious
            # interception
            # These parameters are now defined in the global USER INPUT section.

            # Volume from impervious texture storage
            volume_from_impervious_texture = impA * TEXTURE_STORAGE_DEPTH_IMPERVIOUS_M
            logging.debug(f"Basin {b_id_str}: Impervious texture volume (from {impA:.2f} m²): {volume_from_impervious_texture:.4e} m³")

            # Volume from pervious interception storage
            volume_from_pervious_interception = pervA * INTERCEPTION_STORAGE_DEPTH_PERVIOUS_M
            logging.debug(f"Basin {b_id_str}: Pervious interception volume (from {pervA:.2f} m²): {volume_from_pervious_interception:.4e} m³")

            # Combine base storage with the two new additional storage components
            combined_initial_storage_volume = (base_storage_volume_from_input +
                                               volume_from_impervious_texture +
                                               volume_from_pervious_interception)
            # End Of New Logic
            
            # Apply adjustments and final enforcement
            pre_factor_storage_volume = max(min_volume_node_init, combined_initial_storage_volume)
            storage_after_factor_application = pre_factor_storage_volume * DEPRESSION_STORAGE_ADJUSTMENT_FACTOR
            final_storage_capacity_for_node = max(min_volume_node_init, storage_after_factor_application)
            
            eff_depth = final_storage_capacity_for_node / a if a > 1e-9 else 0.0
            
            # Create the node with all calculated values
            node = DrainageNode(
                basin_id=b_id_str, 
                area=a,
                impervious_area=impA, 
                pervious_area=pervA, 
                water_volume=final_storage_capacity_for_node,
                effective_depth=eff_depth,
                storage_capacity=final_storage_capacity_for_node,
                children=[],
                is_inlet=(b_id_str in inlet_basins), 
                runoff_data=None
            )
            self.all_nodes[b_id_str] = node
            nodes_created_count += 1
        logging.info(f"Created {nodes_created_count} DrainageNode objects with potentially adjusted parameters.")

        valid_inlet_roots = {}
        for b_id_str in inlet_basins:
            if b_id_str in self.all_nodes:
                valid_inlet_roots[b_id_str] = self.all_nodes[b_id_str]
            else:
                logging.warning(f"Identified inlet basin {b_id_str} not found in created nodes. Skipping as root.")
        self.inlet_roots = valid_inlet_roots
        logging.info(f"Identified {len(self.inlet_roots)} valid inlet roots: {list(self.inlet_roots.keys())}")

        processed_for_tree = set()
        for i_id in self.inlet_roots:
            if i_id in G:
                self.build_upstream_tree(G, i_id, processed_for_tree)
            else:
                logging.warning(f"Inlet root {i_id} not found in the connectivity graph G. Tree might be isolated.")
                processed_for_tree.add(i_id) 

        init_count = len(self.all_nodes)
        self.remove_disconnected_nodes()
        final_count = len(self.all_nodes)
        if init_count != final_count:
            logging.info(f"Removed {init_count - final_count} disconnected nodes (not upstream of any inlet).")
        logging.info(f"Final forest has {len(self.all_nodes)} nodes across {len(self.inlet_roots)} trees.")

        for i_id, root in self.inlet_roots.items():
            try:
                ar, cap, dpth = root.get_subtree_stats()
                n_count = self.count_tree_nodes(root)
                logging.info(f"Tree {i_id}: {n_count} nodes, Area={ar:.2f} m², Total Capacity={cap:.2f} m³, Avg Depth={dpth:.3f} m")
            except Exception as stat_err:
                logging.error(f"Error getting stats for tree {i_id}: {stat_err}")
                
    def build_upstream_tree(self, G: nx.DiGraph, current_id: str, processed: Set[str]):
        if current_id in processed or current_id not in self.all_nodes:
            return
        processed.add(current_id)
        c_node = self.all_nodes[current_id]
        try:
            preds = list(G.predecessors(current_id))
        except nx.NetworkXError:
            preds = []
        for p_id in preds:
            if p_id in self.all_nodes:
                p_node = self.all_nodes[p_id]
                if p_node not in c_node.children and c_node != p_node.parent: # Check c_node != p_node.parent as well
                    if p_node.parent is None:
                        c_node.add_child(p_node)
                        self.build_upstream_tree(G, p_id, processed)
                    elif p_node.parent == c_node: # Already correctly parented by a previous (opposite direction) link in data?
                        self.build_upstream_tree(G, p_id, processed) # Still need to process its children
                    else: # p_node already has a DIFFERENT parent
                        logging.warning(f"Node {p_id} already has parent {p_node.parent.basin_id if p_node.parent else 'Unknown'}, cannot add as child to {current_id}. Check connectivity data for potential conflicts or multiple downstream paths not handled by tree structure.")
                elif p_node in c_node.children: # Already a child, just recurse
                    self.build_upstream_tree(G, p_id, processed)
            else:
                logging.warning(f"Predecessor node {p_id} for {current_id} not found in all_nodes. Connection ignored.")


    def remove_disconnected_nodes(self):
        if not self.inlet_roots:
            logging.warning("No inlet roots found, cannot determine connected nodes. Skipping removal.")
            return
        connected = set()
        for root in self.inlet_roots.values():
            self._mark_connected(root, connected)
        all_node_ids = set(self.all_nodes.keys())
        disconnected_ids = all_node_ids - connected
        if disconnected_ids:
            logging.info(f"Removing {len(disconnected_ids)} nodes not connected upstream from any inlet root.")
            for d_id in disconnected_ids:
                if d_id in self.all_nodes:
                    node_to_remove = self.all_nodes[d_id]
                    if node_to_remove.parent and node_to_remove in node_to_remove.parent.children:
                        try:
                            node_to_remove.parent.children.remove(node_to_remove)
                        except ValueError: pass 
                    del self.all_nodes[d_id]
        else:
            logging.info("No disconnected nodes found.")

    def _mark_connected(self, node: Optional[DrainageNode], connected: Set[str]):
        if node is None or node.basin_id in connected:
            return
        connected.add(node.basin_id)
        for child in node.children:
            self._mark_connected(child, connected)

    def count_tree_nodes(self, node: Optional[DrainageNode]) -> int:
        if node is None: return 0
        count = 1
        for child in node.children:
            count += self.count_tree_nodes(child)
        return count
# 3) Runoff Response Functions
def calculate_average_intensity(rainfall_csv, rain_unit):
    try:
        df = pd.read_csv(rainfall_csv)
        if 'time' not in df.columns or 'Rain' not in df.columns:
            logging.error(f"Rainfall CSV {rainfall_csv} must contain 'time' and 'Rain' columns.")
            return 0.0
        df['time'] = pd.to_datetime(df['time'])
        df = df.sort_values('time')
        duration_sec = (df['time'].max() - df['time'].min()).total_seconds()
        if duration_sec <= 0:
            logging.warning("Invalid duration (<= 0) in rainfall data. Cannot calculate average intensity.")
            return 0.0
        avg_rate = df['Rain'].mean(skipna=True)
        if pd.isna(avg_rate):
            logging.warning("Could not calculate average rain rate (all NaN?).")
            return 0.0
        if rain_unit == 'cm/hr':    avg_intensity_m_s = (avg_rate * 0.01) / 3600.0
        elif rain_unit == 'mm/hr': avg_intensity_m_s = (avg_rate * 0.001) / 3600.0
        elif rain_unit == 'in/hr': avg_intensity_m_s = (avg_rate * 0.0254) / 3600.0
        else:
            logging.warning(f"Unsupported rain unit: {rain_unit}. Cannot calculate average intensity.")
            avg_intensity_m_s = 0.0
        logging.info(f"Calculated average rainfall intensity: {avg_intensity_m_s:.6e} m/s (from {avg_rate:.2f} {rain_unit}) - FOR INFO ONLY")
        return avg_intensity_m_s
    except FileNotFoundError:
        logging.error(f"Rainfall CSV file not found: {rainfall_csv}.")
        return 0.0
    except Exception as e:
        logging.error(f"Error reading or processing rainfall CSV {rainfall_csv}: {e}.")
        return 0.0

def build_SCS_UH(T_p, D, dt_sec):
    time_base_factor = 5.0
    T_base_sec = time_base_factor * T_p
    x_tab = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0, 3.2, 3.4, 3.6, 3.8, 4.0, 4.5, 5.0])
    Q_tab = np.array([0.000, 0.030, 0.100, 0.190, 0.310, 0.470, 0.660, 0.820, 0.930, 0.990, 1.000, 0.990, 0.930, 0.860, 0.780, 0.680, 0.560, 0.460, 0.390, 0.330, 0.280, 0.207, 0.147, 0.107, 0.077, 0.055, 0.040, 0.029, 0.021, 0.015, 0.011, 0.005, 0.000])
    t_tab = x_tab * T_p
    n_steps_uh = int(np.ceil(T_base_sec / dt_sec)) + 1
    t_array_uh = np.linspace(0, T_base_sec, n_steps_uh)
    UH_ordinates = np.interp(t_array_uh, t_tab, Q_tab, left=0, right=0)
    total_area = np.sum(UH_ordinates) * dt_sec
    if total_area > 1e-9:
        UH_normalized = UH_ordinates / total_area
    else:
        logging.warning(f"Unit hydrograph area near zero (Tp={T_p:.2f}s, dt={dt_sec:.2f}s). Creating single pulse UH.")
        UH_normalized = np.zeros_like(t_array_uh)
        if len(UH_normalized) > 0: UH_normalized[0] = 1.0 / dt_sec
    return t_array_uh, UH_normalized

def read_rainfall_csv(rainfall_csv, sim_time_min, timestep_min, rain_unit='cm/hr', start_time_offset=0.0):
    try:
        df = pd.read_csv(rainfall_csv)
        if 'time' not in df.columns or 'Rain' not in df.columns:
            logging.error(f"Rainfall CSV {rainfall_csv} must contain 'time' and 'Rain' columns.")
            dt_sec = timestep_min * 60.0; timeline_min = np.arange(start_time_offset, start_time_offset + sim_time_min + timestep_min, timestep_min); rain_depth_m = np.zeros_like(timeline_min)
            return timeline_min, rain_depth_m, dt_sec
        df['time'] = pd.to_datetime(df['time']); df = df.sort_values('time')
        start_t = df['time'].min(); df['time_min_from_start'] = (df['time'] - start_t).dt.total_seconds() / 60.0
        end_time = start_time_offset + sim_time_min;
        timeline_min = np.arange(start_time_offset, end_time + timestep_min, timestep_min)
        df = df.set_index('time_min_from_start')
        rain_series_orig_units = df['Rain'].reindex(timeline_min, method='ffill').fillna(0)
        if rain_unit == 'cm/hr': rain_m_s = (rain_series_orig_units * 0.01) / 3600.0
        elif rain_unit == 'mm/hr': rain_m_s = (rain_series_orig_units * 0.001) / 3600.0
        elif rain_unit == 'in/hr': rain_m_s = (rain_series_orig_units * 0.0254) / 3600.0
        else:
            logging.error(f"Unsupported rain unit: {rain_unit}. Returning zero rainfall.");
            dt_sec = timestep_min * 60.0; rain_depth_m = np.zeros_like(timeline_min)
            return timeline_min, rain_depth_m, dt_sec
        dt_sec = timestep_min * 60.0;
        rain_depth_m = rain_m_s.to_numpy() * dt_sec
        return timeline_min, rain_depth_m, dt_sec
    except FileNotFoundError:
        logging.error(f"Rainfall CSV file not found: {rainfall_csv}.");
        dt_sec = timestep_min * 60.0; timeline_min = np.arange(start_time_offset, start_time_offset + sim_time_min + timestep_min, timestep_min); rain_depth_m = np.zeros_like(timeline_min)
        return timeline_min, rain_depth_m, dt_sec
    except Exception as e:
        logging.error(f"Error reading rainfall CSV {rainfall_csv}: {e}.");
        dt_sec = timestep_min * 60.0; timeline_min = np.arange(start_time_offset, start_time_offset + sim_time_min + timestep_min, timestep_min); rain_depth_m = np.zeros_like(timeline_min)
        return timeline_min, rain_depth_m, dt_sec

def calculate_infiltration(runoff_vol_step, perv_area, total_area, dt_sec):
    if total_area <= 1e-9 or perv_area <= 1e-9: return 0.0
    infiltration_rate_m_s = INFILTRATION_RATE_CM_S * 0.01
    max_potential_infiltration_vol = perv_area * infiltration_rate_m_s * dt_sec
    perv_ratio = perv_area / total_area
    water_on_pervious = runoff_vol_step * perv_ratio
    infiltration_vol = np.minimum(water_on_pervious, max_potential_infiltration_vol)
    infiltration_vol = max(0.0, infiltration_vol)
    return infiltration_vol

# 4) Runoff Processor Class - Incorporates Sheet Flow Tc
class RunoffProcessor:
    """Handles calculation of runoff hydrographs for individual and merged basins."""
    def __init__(self,
                 forest: 'DrainageForest',
                 basins_gdf: gpd.GeoDataFrame,
                 sinks_gdf: gpd.GeoDataFrame,
                 lowest_points_file: str,
                 dem_file: str,
                 P2_in: float):
        self.forest = forest
        self.basins_gdf = basins_gdf
        self.sinks_gdf = sinks_gdf
        self.lowest_points_file = lowest_points_file
        self.dem_file = dem_file
        self.P2_in = P2_in # This is P2_USER_VALUE from global scope
        self.basin_runoff_data: Dict[str, Dict] = {}
        self.lowest_points_gdf: Optional[gpd.GeoDataFrame] = None
        self.dem_crs = None
        self.rainfall_csv: Optional[str] = None
        self.rain_unit: Optional[str] = None

        logging.info("--- Initializing RunoffProcessor ---")
        logging.info(f"Using P2 = {self.P2_in} inches for Tc/Lag calculations.")
        try:
            with rasterio.open(self.dem_file) as src:
                self.dem_crs = src.crs
                if self.dem_crs is None: raise ValueError("DEM CRS could not be read.")
                logging.info(f"DEM CRS read: {self.dem_crs.to_string()}")
        except Exception as e:
            logging.error(f"Failed opening DEM '{self.dem_file}': {e}"); raise
        try:
            logging.info(f"Loading lowest points file: {self.lowest_points_file}")
            lp_gdf = gpd.read_file(self.lowest_points_file)
            self.lowest_points_gdf = lp_gdf.to_crs(self.dem_crs)
            logging.info(f"Lowest points reprojected to {self.dem_crs}")
            required_lp_cols = ['basin_id', 'to_basin', 'geometry']
            for col in required_lp_cols:
                if col not in self.lowest_points_gdf.columns: raise KeyError(f"Lowest points missing required column: '{col}'")
            if 'basin_id_str' not in self.lowest_points_gdf.columns: self.lowest_points_gdf['basin_id_str'] = self.lowest_points_gdf['basin_id'].astype(str)
            if 'to_basin_str' not in self.lowest_points_gdf.columns: self.lowest_points_gdf['to_basin_str'] = self.lowest_points_gdf['to_basin'].astype(str)
            logging.info("Processed lowest points file successfully.")
        except Exception as e:
            logging.error(f"Failed loading/processing lowest points file {self.lowest_points_file}: {e}"); raise
        try:
            logging.info("Checking/Reprojecting sinks GDF...")
            if self.sinks_gdf is None or not isinstance(self.sinks_gdf, gpd.GeoDataFrame): raise ValueError("Sinks GDF invalid.")
            if self.sinks_gdf.crs != self.dem_crs: self.sinks_gdf = self.sinks_gdf.to_crs(self.dem_crs)
            required_sink_cols = ['basin_id', 'geometry']
            for col in required_sink_cols:
                if col not in self.sinks_gdf.columns: raise KeyError(f"Sinks file missing required column: '{col}'")
            if 'basin_id_str' not in self.sinks_gdf.columns:
                if 'basin_id' in self.sinks_gdf.columns: self.sinks_gdf['basin_id_str'] = self.sinks_gdf['basin_id'].astype(str)
                else: raise KeyError("Sinks GDF missing required 'basin_id' column.")
            logging.info("Sinks GDF processed successfully.")
        except Exception as e:
            logging.error(f"Failed processing/reprojecting sinks GDF: {e}"); raise
        logging.info("--- RunoffProcessor Initialization Complete ---")


    def _get_geometry_for_basin(self, basin_id_str: str, get_sink: bool = False) -> Optional[Union[Polygon, MultiPolygon, Point]]:
        """Dynamically retrieves and projects basin or sink geometry from the input GDFs."""
        if not self.dem_crs:
            logging.error("DEM CRS not set in RunoffProcessor, cannot get geometry.")
            return None
        original_id = basin_id_str.split('+')[0]
        gdf_to_use = self.sinks_gdf if get_sink else self.basins_gdf
        lookup_id = original_id
        if gdf_to_use is None:
            logging.error(f"Source GeoDataFrame for {'sinks' if get_sink else 'basins'} is None.")
            return None
        if 'basin_id_str' not in gdf_to_use.columns:
            if 'basin_id' in gdf_to_use.columns:
                gdf_to_use['basin_id_str'] = gdf_to_use['basin_id'].astype(str)
            else:
                logging.error(f"'basin_id_str' column missing and cannot be created in {'sinks' if get_sink else 'basins'} GDF.")
                return None
        row = gdf_to_use[gdf_to_use['basin_id_str'] == lookup_id]
        if row.empty:
            logging.warning(f"{'Sink' if get_sink else 'Basin'} geometry not found for original ID '{lookup_id}' in source GDF.")
            return None
        try:
            geom = row.iloc[0]['geometry']
            if gdf_to_use.crs != self.dem_crs:
                projected_geom = gpd.GeoSeries([geom], crs=gdf_to_use.crs).to_crs(self.dem_crs).iloc[0]
            else:
                projected_geom = geom
            if not projected_geom.is_valid:
                projected_geom = projected_geom.buffer(0)
                if not projected_geom.is_valid:
                    logging.warning(f"Invalid {'sink' if get_sink else 'basin'} geometry for {lookup_id} after buffer(0)..")
                    return None
            if get_sink and not isinstance(projected_geom, Point):
                logging.warning(f"Expected Point for sink {lookup_id}, found {type(projected_geom)}. Returning None.")
                return None
            if not get_sink and not isinstance(projected_geom, (Polygon, MultiPolygon)):
                logging.warning(f"Expected Polygon/MultiPolygon for basin {lookup_id}, found {type(projected_geom)}. Returning None.")
                return None
            return projected_geom
        except Exception as e:
            logging.error(f"Error getting/projecting/validating geometry for {'sink' if get_sink else 'basin'} {lookup_id}: {e}")
            traceback.print_exc()
            return None

    def _collect_basin_ids(self, node: Optional['DrainageNode'], basin_ids_set: set):
        if not node or node.basin_id in basin_ids_set: return
        basin_ids_set.add(node.basin_id)
        for child in node.children: self._collect_basin_ids(child, basin_ids_set)

    def calculate_runoff_for_all_basins(self, rainfall_csv: str, sim_time_min: float,
                                         timestep_min: float, rain_unit: str = 'cm/hr',
                                         inlet_ids: Optional[List[str]] = None):
        logging.info("Calculating runoff responses for relevant basins...")
        if not self.dem_crs: logging.error("DEM CRS not set."); return
        self.rainfall_csv = rainfall_csv; self.rain_unit = rain_unit
        relevant_basins = set()
        if inlet_ids:
            logging.info(f"Calculating runoff for trees associated with inlets: {inlet_ids}")
            for inlet_id in inlet_ids:
                inlet_id_str = str(inlet_id)
                if inlet_id_str in self.forest.inlet_roots: self._collect_basin_ids(self.forest.inlet_roots[inlet_id_str], relevant_basins)
                else: logging.warning(f"Specified inlet ID {inlet_id_str} not found in forest roots.")
            logging.info(f"Processing {len(relevant_basins)} basins in selected tree(s).")
        else: relevant_basins = set(self.forest.all_nodes.keys()); logging.info(f"Processing all {len(relevant_basins)} basins.")
        if not relevant_basins: logging.error("No relevant basins found."); return
        processed_count = 0; failed_basins = set(); skipped_basins = set()
        total_basins_to_process = len(relevant_basins)
        logging.info(f"Starting runoff calculation loop for {total_basins_to_process} basins.")
        for i, basin_id in enumerate(relevant_basins):
            node = self.forest.all_nodes.get(basin_id)
            if not node:
                logging.warning(f"Basin ID {basin_id} node not found in forest. Skipping runoff calc.")
                skipped_basins.add(basin_id); continue
            if basin_id in self.basin_runoff_data:
                logging.debug(f"Runoff data already exists for basin {basin_id}. Skipping."); processed_count += 1; continue
            logging.debug(f"Attempting runoff calculation for basin {basin_id} ({i+1}/{total_basins_to_process})...")
            runoff_data = self.calculate_basin_runoff(
                basin_id=basin_id, basin_area=node.area, imperv_area=node.impervious_area, perv_area=node.pervious_area,
                rainfall_csv=rainfall_csv, sim_time_min=sim_time_min, timestep_min=timestep_min,
                rain_unit=rain_unit, P2_in=self.P2_in
            )
            if runoff_data:
                required_keys_for_merge = [
                    'convolution_timeline_min', 'direct_flow_m3_s', 'infiltration_vol_step',
                    'max_potential_infiltration_vol_step', 'timeline_min',
                    'actual_direct_runoff_vol_step', 'impervious_area', 'pervious_area',
                    'basin_area', 'dt_sec' ]
                missing_keys = [k for k in required_keys_for_merge if k not in runoff_data]
                if missing_keys:
                    logging.error(f"Runoff calculation for {basin_id} SUCCEEDED but is MISSING keys needed later: {missing_keys}. Discarding.")
                    failed_basins.add(basin_id)
                else:
                    self.basin_runoff_data[basin_id] = runoff_data; processed_count += 1
                    logging.debug(f"Successfully calculated runoff for basin {basin_id}.")
            else:
                logging.warning(f"Runoff calculation recorded as FAILED for basin {basin_id}.")
                failed_basins.add(basin_id)
            if (i + 1) % 100 == 0: logging.info(f"Processed runoff calculations for {i+1}/{total_basins_to_process} relevant basins...")
        logging.info("Runoff calculation loop finished.")
        logging.info(f"Successfully generated runoff data for {processed_count} basins.")
        if skipped_basins: logging.warning(f"Skipped runoff calculation for {len(skipped_basins)} basins (missing node): {sorted(list(skipped_basins))}")
        if failed_basins: logging.error(f"Runoff calculation FAILED for {len(failed_basins)} basins: {sorted(list(failed_basins))}")
        if failed_basins or skipped_basins: logging.warning("Some relevant basins were skipped or failed runoff calculation. Simulation may halt.")

    def calculate_basin_runoff(self, basin_id: str, basin_area: float, imperv_area: float, perv_area: float,
                                 rainfall_csv: str, sim_time_min: float, timestep_min: float, rain_unit: str,
                                 P2_in: float) -> Optional[dict]:
        logging.debug(f"--- Starting Runoff Calculation for Basin: {basin_id} ---")
        try:
            # 1. Get Geometry
            basin_poly = self._get_geometry_for_basin(basin_id, get_sink=False)
            sink_pt = self._get_geometry_for_basin(basin_id, get_sink=True)
            if basin_poly is None or sink_pt is None:
                raise ValueError("Failed to retrieve valid basin or sink geometry.")

            # 2. Calculate Manning's n
            total_effective_area_for_n = imperv_area + perv_area
            n_total = N_PERV
            if total_effective_area_for_n > 1e-6:
                n_total = compute_mannings_n(imperv_area, total_effective_area_for_n)
            logging.debug(f"Basin {basin_id}: n_total={n_total:.4f} used for Tc calc.")

            # 3. Compute Time of Concentration (Tc)
            tc_sec = 300.0
            L_tc_m_final, S_tc_dim_final = None, None
            try:
                L_tc_m, S_tc_dim_original = calculate_tc_parameters_farthest_path(basin_poly, sink_pt, self.dem_file)
                L_tc_m_final = L_tc_m
                S_tc_dim_adjusted = S_tc_dim_original
                if S_tc_dim_original is not None:
                    S_tc_dim_adjusted = S_tc_dim_original * SLOPE_ADJUSTMENT_FACTOR
                    S_tc_dim_adjusted = max(S_tc_dim_adjusted, 1e-5)
                    S_tc_dim_final = S_tc_dim_adjusted
                if L_tc_m is not None and S_tc_dim_adjusted is not None:
                    L_tc_ft = L_tc_m * METERS_TO_FEET
                    S_tc_ftft = S_tc_dim_adjusted
                    calculated_tc_min = compute_tc_sheet_flow(n_total, L_tc_ft, P2_in, S_tc_ftft)
                    if calculated_tc_min is not None and calculated_tc_min > 0:
                        tc_sec = calculated_tc_min * 60.0
                    else:
                        logging.warning(f"[Tc Calc {basin_id}] Sheet flow calc failed. Using default Tc.")
                else:
                    logging.warning(f"[Tc Calc {basin_id}] L_tc_m or S_tc_dim_adjusted is None. Using default Tc.")
            except Exception as tc_err:
                logging.warning(f"Error during Tc calculation for {basin_id}: {tc_err}. Using default Tc.")
            
            tc_sec = max(tc_sec, 60.0)
            T_p = tc_sec
            T_base_sec = 5.0 * T_p
            logging.debug(f"Basin {basin_id}: Tc={tc_sec:.1f}s, Tp={T_p:.1f}s, Tbase={T_base_sec:.1f}s")

            # 4. Read Rainfall Data
            extended_sim_time_min = sim_time_min + (T_base_sec / 60.0)
            timeline_min, rain_depth_m_step, dt_sec_from_rain = read_rainfall_csv(
                rainfall_csv, extended_sim_time_min, timestep_min, rain_unit
            )
            if dt_sec_from_rain <= 0 or len(timeline_min) == 0:
                raise ValueError("Rainfall reading failed or returned empty/invalid data.")
            dt_sec = dt_sec_from_rain

            # 5. Build SCS Unit Hydrograph
            t_uh_sec, UH_norm = build_SCS_UH(T_p, T_base_sec, dt_sec)
            if len(UH_norm) == 0:
                raise ValueError("Unit Hydrograph generation failed.")

            # 6. Calculate Simplified Excess & Actual Infiltration
            input_rate_m_s = INFILTRATION_RATE_CM_S * 0.01
            threshold_rate_m_s = 0.02 * input_rate_m_s
            initial_effective_rate_m_s = 0.0
            if basin_area > 1e-9 and input_rate_m_s > 1e-12:
                initial_effective_rate_m_s = (input_rate_m_s * perv_area / basin_area)
            perv_area_for_infiltration = perv_area
            adjusted_infiltration_flag = False
            if initial_effective_rate_m_s < threshold_rate_m_s and threshold_rate_m_s > 1e-12:
                adjusted_infiltration_flag = True
                if input_rate_m_s > 1e-12:
                    equivalent_perv_area = (threshold_rate_m_s * basin_area) / input_rate_m_s
                    perv_area_for_infiltration = min(basin_area, max(0.0, equivalent_perv_area))
                else: # Should not happen if threshold_rate_m_s > 1e-12 unless input_rate_m_s is zero
                    adjusted_infiltration_flag = False
            
            rain_rate_m_s = np.divide(rain_depth_m_step, dt_sec, out=np.zeros_like(rain_depth_m_step), where=dt_sec!=0)
            simplified_excess_rate_m_s = np.maximum(0, rain_rate_m_s - effective_max_infil_rate_m_s)
            simplified_excess_depth_m_step = simplified_excess_rate_m_s * dt_sec
            total_rainfall_vol_step = rain_depth_m_step * basin_area
            infiltration_vol_step = np.array([
                calculate_infiltration(rv, perv_area_for_infiltration, basin_area, dt_sec)
                for rv in total_rainfall_vol_step
            ])
            actual_direct_runoff_vol_step = np.maximum(0, total_rainfall_vol_step - infiltration_vol_step)
            max_potential_infiltration_vol_step_val = perv_area_for_infiltration * input_rate_m_s * dt_sec if perv_area_for_infiltration > 1e-6 else 0.0
            max_potential_infiltration_vol_step = np.full_like(total_rainfall_vol_step, max_potential_infiltration_vol_step_val)
            deficit_rate_m_s = np.maximum(0, effective_max_infil_rate_m_s - rain_rate_m_s)
            infiltration_deficit_vol_step = deficit_rate_m_s * basin_area * dt_sec

            # 7. Convolve with UH using SIMPLIFIED Excess Depth
            len_rain = len(simplified_excess_depth_m_step); len_uh = len(UH_norm)
            direct_flow_m3_s = np.array([]); direct_flow_vol_step_conv = np.array([]); convolution_timeline_min = np.array([])
            if len_rain == 0 or len_uh == 0:
                num_steps_out = len(timeline_min) + len_uh - 1 if len(timeline_min) > 0 and len_uh > 0 else len(timeline_min)
                direct_flow_m3_s = np.zeros(num_steps_out); direct_flow_vol_step_conv = np.zeros(num_steps_out)
                if num_steps_out > 0 and len(timeline_min) > 0:
                    convolution_timeline_min = timeline_min[0] + np.arange(num_steps_out) * timestep_min
                else:
                    convolution_timeline_min = np.arange(num_steps_out) * timestep_min
            else:
                convolved_discharge_depth = np.convolve(simplified_excess_depth_m_step, UH_norm * dt_sec)
                num_conv_steps = len(convolved_discharge_depth)
                if len(timeline_min) > 0:
                    convolution_timeline_min = timeline_min[0] + np.arange(num_conv_steps) * timestep_min
                else:
                    convolution_timeline_min = np.arange(num_conv_steps) * timestep_min
                direct_flow_vol_step_conv = convolved_discharge_depth * basin_area
                direct_flow_m3_s = np.divide(direct_flow_vol_step_conv, dt_sec, out=np.zeros_like(direct_flow_vol_step_conv), where=dt_sec!=0)
            
            # MODIFICATION START: Inlet basin self-abstraction
            node_obj = self.forest.all_nodes.get(basin_id)
            if node_obj and node_obj.is_inlet: # Check if it's an inlet node
                node_storage_capacity_for_abstraction = node_obj.storage_capacity
                logging.debug(f"Basin {basin_id} is an inlet node. Applying self-storage abstraction (Vmax={node_storage_capacity_for_abstraction:.3f} m³).")
                if node_storage_capacity_for_abstraction > 1e-9:
                    abstracted_direct_flow_vol_step = np.copy(direct_flow_vol_step_conv)
                    remaining_storage_to_fill = node_storage_capacity_for_abstraction
                    for k_step in range(len(abstracted_direct_flow_vol_step)):
                        if remaining_storage_to_fill <= 1e-9:
                            break 
                        volume_in_this_step = abstracted_direct_flow_vol_step[k_step]
                        can_be_stored = min(volume_in_this_step, remaining_storage_to_fill)
                        abstracted_direct_flow_vol_step[k_step] = volume_in_this_step - can_be_stored
                        remaining_storage_to_fill -= can_be_stored
                    
                    original_total_vol_before_abs = np.sum(direct_flow_vol_step_conv)
                    direct_flow_vol_step_conv = abstracted_direct_flow_vol_step # Update with abstracted values
                    direct_flow_m3_s = np.divide(direct_flow_vol_step_conv, dt_sec, out=np.zeros_like(direct_flow_vol_step_conv), where=dt_sec!=0) # Recalculate rate
                    new_total_vol_after_abs = np.sum(direct_flow_vol_step_conv)
                    logging.debug(f"Basin {basin_id} (Inlet): Self-storage abstraction reduced hydrograph volume from {original_total_vol_before_abs:.3f} m³ to {new_total_vol_after_abs:.3f} m³.")
            # Modification End

            # 8. Store Results
            len_conv_time = len(convolution_timeline_min); len_input_time = len(timeline_min)
            basin_runoff_data = {
                'basin_id': basin_id, 'basin_area': basin_area,
                'impervious_area': imperv_area, 'pervious_area': perv_area,
                'perv_area_used_for_infiltration': perv_area_for_infiltration,
                'adjusted_infiltration_flag': adjusted_infiltration_flag,
                'L_tc_m': L_tc_m_final, 'S_tc_dim': S_tc_dim_final, 'n_total': n_total, 'tc_sec': tc_sec, 'P2_in_used': P2_in,
                'timestep_min': timestep_min, 'dt_sec': dt_sec,
                'timeline_min': timeline_min[:len_input_time].tolist(),
                'rain_depth_m_step': rain_depth_m_step[:len_input_time].tolist(),
                'total_rainfall_vol_step': total_rainfall_vol_step[:len_input_time].tolist(),
                'cum_rainfall_vol': np.cumsum(total_rainfall_vol_step[:len_input_time]).tolist(),
                'total_rainfall_vol': np.sum(total_rainfall_vol_step[:len_input_time]),
                'max_potential_infiltration_vol_step': max_potential_infiltration_vol_step[:len_input_time].tolist(),
                'infiltration_vol_step': infiltration_vol_step[:len_input_time].tolist(),
                'cum_infiltration_vol': np.cumsum(infiltration_vol_step[:len_input_time]).tolist(),
                'total_infiltration_vol': np.sum(infiltration_vol_step[:len_input_time]),
                'actual_direct_runoff_vol_step': actual_direct_runoff_vol_step[:len_input_time].tolist(),
                'simplified_excess_depth_m_step': simplified_excess_depth_m_step[:len_input_time].tolist(),
                'infiltration_deficit_vol_step': infiltration_deficit_vol_step[:len_input_time].tolist(),
                'convolution_timeline_min': convolution_timeline_min[:len_conv_time].tolist(),
                'direct_flow_m3_s': direct_flow_m3_s[:len_conv_time].tolist(), # Uses potentially modified hydrograph
                'direct_flow_vol_step': direct_flow_vol_step_conv[:len_conv_time].tolist(), # Uses potentially modified hydrograph
                'cum_direct_runoff_vol': np.cumsum(direct_flow_vol_step_conv[:len_conv_time]).tolist(), # Uses potentially modified hydrograph
                'total_direct_runoff_vol': np.sum(direct_flow_vol_step_conv[:len_conv_time]), # Uses potentially modified hydrograph
                't_uh_sec': t_uh_sec.tolist(), 'UH_norm': UH_norm.tolist()
            }
            logging.debug(f"--- Finished Runoff Calculation for Basin: {basin_id} ---")
            return basin_runoff_data

        except Exception as e:
            logging.error(f"!!! Failed Runoff Calculation for Basin {basin_id}: {e}")
            traceback.print_exc()
            return None

    def calculate_merged_basin_runoff(self,
                                      upstream_id: str,
                                      downstream_id: str,
                                      current_time_sec: float,
                                      timestep_min: float,
                                      sim_time_min: float,
                                      rainfall_csv: str, # Though not directly used, kept for signature consistency if planned
                                      rain_unit: str,    # Though not directly used, kept for signature consistency if planned
                                      P2_in: float) -> Optional[dict]:
        """
        Calculates the runoff hydrograph for a newly merged basin resulting from
        an upstream basin spilling into a downstream basin.

        This method lags and routes the upstream hydrograph, calculates run-on
        infiltration into the downstream area's available deficit, and combines
        the hydrographs. The results are stored in the processor's basin_runoff_data.

        Args:
            upstream_id: The ID of the basin spilling (source of additional flow).
            downstream_id: The ID of the basin receiving the spill (base basin).
            current_time_sec: The simulation time (in seconds) at which the merge occurs.
            timestep_min: The simulation timestep in minutes.
            sim_time_min: The total simulation duration in minutes.
            rainfall_csv: Path to the rainfall CSV (may be used by helpers or future enhancements).
            rain_unit: Unit of rainfall (may be used by helpers or future enhancements).
            P2_in: The 2-year, 24-hour rainfall depth in inches, used for Tc/lag calculations.

        Returns:
            A dictionary containing the merged runoff data, or None if the calculation fails.
        """
        logging.info(f"MERGE EVENT: Basin '{upstream_id}' merging into '{downstream_id}' at t={current_time_sec/60.0:.2f} min.")
        merged_id = f"{downstream_id}+{upstream_id}" # Define merged ID early for logging

        # Determine dt_sec (simulation timestep in seconds)
        dt_sec = None
        if timestep_min is not None and timestep_min > 0:
            dt_sec = timestep_min * 60.0
        else: # Fallback if timestep_min is not directly provided
            dn_data_temp = self.basin_runoff_data.get(downstream_id)
            up_data_temp = self.basin_runoff_data.get(upstream_id)
            if dn_data_temp and 'dt_sec' in dn_data_temp and dn_data_temp['dt_sec'] > 0:
                dt_sec = dn_data_temp['dt_sec']
            elif up_data_temp and 'dt_sec' in up_data_temp and up_data_temp['dt_sec'] > 0:
                dt_sec = up_data_temp['dt_sec']

        if dt_sec is None or dt_sec <= 0:
            logging.error(f"Cannot determine valid dt_sec for merge {upstream_id}->{downstream_id}. Aborting merge calculation.")
            return None
            
        merge_time_min = current_time_sec / 60.0

        try:
            # 1. Retrieve runoff data for component basins
            up_data = self.basin_runoff_data.get(upstream_id)
            dn_data = self.basin_runoff_data.get(downstream_id)
            if up_data is None or dn_data is None:
                missing_basin = 'upstream' if up_data is None else 'downstream'
                raise ValueError(f"Merge failed: Missing base runoff data for {missing_basin} basin ('{upstream_id if missing_basin=='upstream' else downstream_id}').")

            # 2. Check for essential keys in component basin data
            required_keys = [
                'convolution_timeline_min', 'direct_flow_m3_s', 'timeline_min',
                'basin_area', 'pervious_area', 'impervious_area',
                'max_potential_infiltration_vol_step', 'infiltration_vol_step',
                'infiltration_deficit_vol_step', 'dt_sec'
            ]
            missing_up_keys = [key for key in required_keys if key not in up_data]
            missing_dn_keys = [key for key in required_keys if key not in dn_data]
            if missing_up_keys:
                logging.warning(f"Upstream basin ({upstream_id}) runoff data missing required keys for merge: {missing_up_keys}. Proceeding with defaults where possible.")
            if missing_dn_keys:
                logging.warning(f"Downstream basin ({downstream_id}) runoff data missing required keys for merge: {missing_dn_keys}. Proceeding with defaults where possible.")

            # 3. Calculate Lag Time for routing upstream flow to downstream sink
            lag_sec = 5.0 * 60.0  # Default lag time (e.g., 5 minutes) if detailed calculation fails
            try:
                # Determine the actual outlet point of the upstream (potentially merged)
                # basin
                # The upstream_id might be a compound ID like "X+Y", so we need the
                # outlet of "Y"
                outlet_component_id_of_upstream = upstream_id.split('+')[-1]
                
                # Ensure basin_id_str column exists, create if not (should be handled
                # during GDF loading)
                if 'basin_id_str' not in self.lowest_points_gdf.columns and 'basin_id' in self.lowest_points_gdf.columns:
                    self.lowest_points_gdf['basin_id_str'] = self.lowest_points_gdf['basin_id'].astype(str)

                out_row_df = self.lowest_points_gdf[self.lowest_points_gdf['basin_id_str'] == outlet_component_id_of_upstream]
                if out_row_df.empty:
                    raise ValueError(f"Lowest point data not found for upstream component outlet: '{outlet_component_id_of_upstream}'.")
                
                outlet_pt = out_row_df.iloc[0].geometry
                if not isinstance(outlet_pt, Point) or not outlet_pt.is_valid:
                    raise ValueError(f"Invalid outlet geometry for upstream component: '{outlet_component_id_of_upstream}'.")

                # Get sink point of the downstream basin (its ultimate outlet for Tc
                # purposes)
                sink_pt_downstream = self._get_geometry_for_basin(downstream_id, get_sink=True)
                if sink_pt_downstream is None:
                    raise ValueError(f"Sink point geometry not found or invalid for downstream basin: '{downstream_id}'.")

                # Sample elevations from DEM
                with rasterio.open(self.dem_file) as src:
                    nodata_val = src.nodata
                    elev_out_list = list(src.sample([(outlet_pt.x, outlet_pt.y)]))
                    elev_snk_list = list(src.sample([(sink_pt_downstream.x, sink_pt_downstream.y)]))

                    elev_out = float(elev_out_list[0][0]) if elev_out_list and len(elev_out_list[0]) > 0 and (nodata_val is None or abs(float(elev_out_list[0][0]) - nodata_val) > 1e-6) else None
                    elev_snk = float(elev_snk_list[0][0]) if elev_snk_list and len(elev_snk_list[0]) > 0 and (nodata_val is None or abs(float(elev_snk_list[0][0]) - nodata_val) > 1e-6) else None
                    
                    if elev_out is None: raise ValueError(f"Outlet elevation is invalid or NoData for '{outlet_component_id_of_upstream}'.")
                    if elev_snk is None: raise ValueError(f"Sink elevation is invalid or NoData for '{downstream_id}'.")

                L_m = outlet_pt.distance(sink_pt_downstream)
                slope_val = abs(elev_out - elev_snk) / max(L_m, 1e-6) # Avoid division by zero for L_m
                slope_val = max(slope_val, 1e-5) # Ensure a minimum slope

                # Manning's n for the downstream basin (or the path between upstream
                # outlet and downstream sink)
                # This uses the properties of the downstream basin for simplicity here.
                # Global N_IMPERV, N_PERV are already adjusted by their factors.
                dn_imperv_area = dn_data.get('impervious_area', 0.0)
                dn_perv_area = dn_data.get('pervious_area', 0.0)
                dn_total_effective_area = dn_imperv_area + dn_perv_area
                n_downstream_effective = compute_mannings_n(dn_imperv_area, dn_total_effective_area)
                
                # Global SLOPE_ADJUSTMENT_FACTOR is applied here
                adjusted_slope_for_tc = slope_val * SLOPE_ADJUSTMENT_FACTOR
                adjusted_slope_for_tc = max(adjusted_slope_for_tc, 1e-5)


                # Calculate Tc for the reach, which serves as the lag time
                # METERS_TO_FEET is a global constant
                calculated_lag_tc_min = compute_tc_sheet_flow(n_downstream_effective, L_m * METERS_TO_FEET, P2_in, adjusted_slope_for_tc)
                
                if calculated_lag_tc_min is not None and calculated_lag_tc_min > 0:
                    lag_sec = max(5.0, calculated_lag_tc_min * 60.0) # Ensure a minimum lag of 5 seconds
                else:
                    logging.warning(f"Lag calculation using sheet flow returned invalid Tc ({calculated_lag_tc_min}). Using default lag {lag_sec/60.0:.2f} min.")
            except Exception as e_lag:
                logging.error(f"Lag calculation failed for {upstream_id}->{downstream_id}: {e_lag}. Using default lag {lag_sec/60.0:.2f} min.")
                traceback.print_exc()

            # 4. Prepare Timelines and Unified Timeline for Merged Hydrograph
            t_up_conv = np.array(up_data.get('convolution_timeline_min', []), dtype=float)
            t_dn_conv = np.array(dn_data.get('convolution_timeline_min', []), dtype=float)
            t_up_in = np.array(up_data.get('timeline_min', []), dtype=float) # For deficits/infiltration
            t_dn_in = np.array(dn_data.get('timeline_min', []), dtype=float) # For deficits/infiltration

            if not all(arr.size > 0 for arr in [t_up_conv, t_dn_conv, t_up_in, t_dn_in]):
                raise ValueError(f"Merge failed: One or more essential timelines are empty for '{upstream_id}' or '{downstream_id}'.")

            # Unified timeline extending to cover the full simulation plus any lag
            # effects
            start_time_unified = min(t_dn_conv[0], t_up_conv[0]) # Should ideally be 0 or simulation start
            max_conv_time_components = max(t_dn_conv[-1] if t_dn_conv.size > 0 else 0, t_up_conv[-1] if t_up_conv.size > 0 else 0)
            end_time_unified = max(sim_time_min, max_conv_time_components + (lag_sec / 60.0))
            
            unified_timeline = np.arange(start_time_unified, end_time_unified + timestep_min, timestep_min)
            if unified_timeline.size == 0:
                raise ValueError("Merge failed: Unified timeline generation resulted in an empty array.")
            
            merge_idx_on_unified = np.searchsorted(unified_timeline, merge_time_min, side='left')

            # 5. Interpolate and Prepare Key Hydrograph Arrays onto Unified Timeline
            q_up_orig_m3s = np.array(up_data.get('direct_flow_m3_s', []), dtype=float)
            q_dn_orig_m3s = np.array(dn_data.get('direct_flow_m3_s', []), dtype=float)

            # Downstream flow interpolated on unified timeline
            dn_q_interpolated_m3s = safe_interp(unified_timeline, t_dn_conv, q_dn_orig_m3s, label=f"Downstream Q {downstream_id}")

            # Upstream flow that will be lagged and potentially spill
            spill_start_idx_on_up_timeline = np.searchsorted(t_up_conv, merge_time_min, side='left')
            
            # Prepare lagged upstream flow
            if q_up_orig_m3s.size == 0 or spill_start_idx_on_up_timeline >= len(t_up_conv):
                spill_times_lagged, spill_q_lagged = np.array([]), np.array([]) # No flow to lag
            else:
                spill_times_lagged = t_up_conv[spill_start_idx_on_up_timeline:] + (lag_sec / 60.0) # Lagged time
                spill_q_lagged = q_up_orig_m3s[spill_start_idx_on_up_timeline:]
            
            up_q_lagged_interpolated_m3s = safe_interp(unified_timeline, spill_times_lagged, spill_q_lagged, label=f"Lagged Upstream Q {upstream_id}")
            up_vol_step_lagged_m3 = up_q_lagged_interpolated_m3s * dt_sec

            # Prepare downstream deficit for run-on calculation
            dn_deficit_vol_step_orig = np.array(dn_data.get('infiltration_deficit_vol_step', []), dtype=float)
            aligned_dn_deficit_vol_step = safe_interp(unified_timeline, t_dn_in, dn_deficit_vol_step_orig, label=f"Aligned Downstream Deficit {downstream_id}")

            # Other arrays for weighted averages or combined properties
            up_deficit_vol_step_orig = np.array(up_data.get('infiltration_deficit_vol_step', []), dtype=float)
            up_deficit_interpolated_for_weighting = safe_interp(unified_timeline, t_up_in, up_deficit_vol_step_orig, label=f"Upstream Deficit {upstream_id}")
            
            dn_infiltration_vol_step_orig = np.array(dn_data.get('infiltration_vol_step', []), dtype=float)
            dn_infiltration_interpolated = safe_interp(unified_timeline, t_dn_in, dn_infiltration_vol_step_orig, label=f"Downstream Infiltration {downstream_id}")

            up_max_potential_infil_orig = np.array(up_data.get('max_potential_infiltration_vol_step', []), dtype=float)
            dn_max_potential_infil_orig = np.array(dn_data.get('max_potential_infiltration_vol_step', []), dtype=float)
            up_max_potential_infil_interpolated = safe_interp(unified_timeline, t_up_in, up_max_potential_infil_orig, label=f"Upstream Max Potential Infil {upstream_id}")
            dn_max_potential_infil_interpolated = safe_interp(unified_timeline, t_dn_in, dn_max_potential_infil_orig, label=f"Downstream Max Potential Infil {downstream_id}")
            merged_max_potential_infiltration_step_m3 = dn_max_potential_infil_interpolated + up_max_potential_infil_interpolated

            # 6. Calculate Final Run-on Infiltration onto Downstream Area
            runon_inf_vol_step_merged_m3 = np.zeros_like(unified_timeline, dtype=float)
            # Arrival index on unified timeline (when lagged upstream flow starts
            # arriving)
            arrival_idx_on_unified = np.searchsorted(unified_timeline, merge_time_min + (lag_sec / 60.0), side='left')
            
            # Iterate from arrival time to the end of the relevant period
            # common_len_for_runon determines how far we can reliably calculate run-on
            common_len_for_runon = min(len(aligned_dn_deficit_vol_step), len(up_vol_step_lagged_m3))
            for k_idx in range(arrival_idx_on_unified, common_len_for_runon):
                downstream_deficit_at_k = max(0.0, aligned_dn_deficit_vol_step[k_idx])
                incoming_lagged_vol_at_k = up_vol_step_lagged_m3[k_idx]
                runon_infiltration_at_k = min(downstream_deficit_at_k, incoming_lagged_vol_at_k)
                runon_inf_vol_step_merged_m3[k_idx] = max(0.0, runon_infiltration_at_k)
            
            total_calculated_runon_m3 = np.sum(runon_inf_vol_step_merged_m3)

            # 7. Calculate Adjusted Upstream Flow (after run-on infiltration) and Combined Flow
            adj_up_vol_step_m3 = np.maximum(0.0, up_vol_step_lagged_m3 - runon_inf_vol_step_merged_m3)
            adj_up_q_m3s = np.divide(adj_up_vol_step_m3, dt_sec, out=np.zeros_like(adj_up_vol_step_m3), where=dt_sec != 0)
            
            # Combined flow: downstream's own flow plus the adjusted (post-infiltration)
            # lagged upstream flow
            # Before merge time, it's just downstream. After, it's downstream + adjusted
            # lagged upstream.
            # Note: This logic might need refinement if downstream flow should also be
            # affected by the merge timing itself.
            # Current assumption: downstream hydrograph continues, and lagged upstream
            # is added on top *after* its arrival.
            combined_q_m3s = np.where(unified_timeline < merge_time_min, # More accurately, < arrival_time_min for adding upstream
                                      dn_q_interpolated_m3s,
                                      dn_q_interpolated_m3s + adj_up_q_m3s)
            combined_vol_step_m3 = combined_q_m3s * dt_sec

            # 8. Calculate NEW Weighted Average Deficit for the Merged Basin (for storage/later use if needed)
            # Deficit in downstream area is reduced by the run-on infiltration it
            # received.
            dn_deficit_post_runon_m3 = np.maximum(0.0, aligned_dn_deficit_vol_step - runon_inf_vol_step_merged_m3)
            
            area_up = up_data.get('basin_area', 0.0)
            area_dn = dn_data.get('basin_area', 0.0)
            total_area_merged = area_up + area_dn
            
            merged_weighted_deficit_vol_step_m3 = np.zeros_like(unified_timeline, dtype=float)
            if total_area_merged > 1e-9:
                weight_dn = area_dn / total_area_merged
                weight_up = area_up / total_area_merged
                
                # Before merge, use a weighted average of individual deficits
                merged_weighted_deficit_vol_step_m3[:merge_idx_on_unified] = (
                    safe_interp(unified_timeline[:merge_idx_on_unified], t_dn_in, dn_deficit_vol_step_orig, label="PreMerge Deficit DN") * weight_dn +
                    up_deficit_interpolated_for_weighting[:merge_idx_on_unified] * weight_up
                )
                # After merge, the "downstream component" deficit is the adjusted one,
                # weighted by its area.
                # Upstream component's deficit is conceptually "used up" or transferred.
                merged_weighted_deficit_vol_step_m3[merge_idx_on_unified:] = dn_deficit_post_runon_m3[merge_idx_on_unified:] * weight_dn # Or simply dn_deficit_post_runon if it represents the whole new entity
            else:
                logging.warning(f"Merged area is zero for {upstream_id}->{downstream_id}. Deficit will be zero.")
            merged_weighted_deficit_vol_step_m3 = np.maximum(0.0, merged_weighted_deficit_vol_step_m3)

            # 9. Store Results for the Merged Basin
            cum_combined_vol_m3 = np.cumsum(combined_vol_step_m3) if combined_vol_step_m3.size > 0 else np.array([0.0])
            total_combined_vol_final_m3 = cum_combined_vol_m3[-1] if cum_combined_vol_m3.size > 0 else 0.0

            # Total actual infiltration for the merged entity per step
            # This is downstream's original infiltration plus the new run-on
            # infiltration it received
            min_len_actual_infiltration = min(len(dn_infiltration_interpolated), len(runon_inf_vol_step_merged_m3))
            total_actual_infiltration_step_m3 = np.zeros_like(unified_timeline, dtype=float)
            total_actual_infiltration_step_m3[:min_len_actual_infiltration] = (
                dn_infiltration_interpolated[:min_len_actual_infiltration] +
                runon_inf_vol_step_merged_m3[:min_len_actual_infiltration]
            )
            # If lengths differ, pad the shorter one for remaining steps if necessary
            # (though safe_interp should align to unified_timeline)

            final_imperv_area_merged = up_data.get('impervious_area', 0.0) + dn_data.get('impervious_area', 0.0)
            final_perv_area_merged = up_data.get('pervious_area', 0.0) + dn_data.get('pervious_area', 0.0)
            final_timeline_len = len(unified_timeline)

            merged_data = {
                'basin_id': merged_id,
                'merged_from_ids': [downstream_id, upstream_id], # Convention: [receiving, spilling]
                'basin_area': total_area_merged,
                'impervious_area': final_imperv_area_merged,
                'pervious_area': final_perv_area_merged,
                'timestep_min': timestep_min,
                'dt_sec': dt_sec,
                'merge_time_min': merge_time_min,
                'lag_used_sec': lag_sec,
                'convolution_timeline_min': unified_timeline[:final_timeline_len],
                'direct_flow_m3_s': combined_q_m3s[:final_timeline_len],
                'direct_flow_vol_step': combined_vol_step_m3[:final_timeline_len],
                'cum_direct_runoff_vol': cum_combined_vol_m3[:final_timeline_len],
                'total_direct_runoff_vol': total_combined_vol_final_m3,
                'max_potential_infiltration_vol_step': merged_max_potential_infiltration_step_m3[:final_timeline_len],
                'infiltration_deficit_vol_step': merged_weighted_deficit_vol_step_m3[:final_timeline_len],
                'infiltration_vol_step': total_actual_infiltration_step_m3[:final_timeline_len], # Total infiltration for the merged entity
                'runon_infiltration_vol_step': runon_inf_vol_step_merged_m3[:final_timeline_len], # Specifically the run-on part
                'total_runon_infiltration_calculated': total_calculated_runon_m3,
                'timeline_min': unified_timeline[:final_timeline_len], # Input timeline reference
                'component_down_flow_m3s': dn_q_interpolated_m3s[:final_timeline_len],
                'original_lagged_flow_m3s': up_q_lagged_interpolated_m3s[:final_timeline_len], # Before run-on infiltration
                'adjusted_upstream_lagged_flow_m3s': adj_up_q_m3s[:final_timeline_len], # After run-on infiltration
            }

            # 10. Clean up NaN/Inf and Convert arrays to lists for storage
            final_merged_data_cleaned = {}
            for key, value in merged_data.items():
                if isinstance(value, np.ndarray):
                    if np.issubdtype(value.dtype, np.number): # Check if numeric array
                        value_clean = np.nan_to_num(value, nan=0.0, posinf=0.0, neginf=0.0)
                        final_merged_data_cleaned[key] = value_clean.tolist()
                    else: # For non-numeric arrays (e.g., object arrays, though not expected here for timeseries)
                        final_merged_data_cleaned[key] = value.tolist()
                elif isinstance(value, (int, float, np.number)): # Handles Python scalars and NumPy scalars
                    if np.isnan(value) or np.isinf(value):
                        final_merged_data_cleaned[key] = 0.0
                    else:
                        final_merged_data_cleaned[key] = float(value)
                elif isinstance(value, list) and key == 'merged_from_ids': # Ensure IDs are strings
                    final_merged_data_cleaned[key] = [str(item) for item in value]
                else:
                    final_merged_data_cleaned[key] = value
            
            # 11. Store the processed data for the new merged ID
            self.basin_runoff_data[merged_id] = final_merged_data_cleaned
            logging.info(f"Successfully calculated and stored merged runoff data for '{merged_id}'.")
            return final_merged_data_cleaned

        except KeyError as ke:
            logging.error(f"Merge calculation failed for {merged_id} due to missing key: {ke}. Upstream: {upstream_id}, Downstream: {downstream_id}.")
            traceback.print_exc()
            return None
        except ValueError as ve:
            logging.error(f"Merge calculation failed for {merged_id} due to value error: {ve}. Upstream: {upstream_id}, Downstream: {downstream_id}.")
            traceback.print_exc()
            return None
        except Exception as e:
            logging.error(f"Unexpected error during merge calculation for {merged_id}: {e}. Upstream: {upstream_id}, Downstream: {downstream_id}.")
            traceback.print_exc()
            return None

    def get_volume_in_timespan(self, basin_id: str, start_time_min: float, end_time_min: float) -> float:
        """ Calculates volume using interpolation on the cumulative hydrograph. """
        if basin_id not in self.basin_runoff_data: return 0.0
        data = self.basin_runoff_data[basin_id]
        timeline = data.get('convolution_timeline_min'); cum_vol = data.get('cum_direct_runoff_vol')
        if timeline is None or cum_vol is None: logging.warning(f"Missing timeline/cum_vol for {basin_id}"); return 0.0
        timeline = np.array(timeline); cum_vol = np.array(cum_vol)
        min_len = min(len(timeline), len(cum_vol))
        if min_len == 0: return 0.0
        timeline = timeline[:min_len]; cum_vol = cum_vol[:min_len]
        try:
            vol_at_start = np.interp(start_time_min, timeline, cum_vol, left=0.0, right=cum_vol[-1] if len(cum_vol)>0 else 0.0)
            vol_at_end = np.interp(end_time_min, timeline, cum_vol, left=0.0, right=cum_vol[-1] if len(cum_vol)>0 else 0.0)
            interval_volume = vol_at_end - vol_at_start
            return float(max(0.0, interval_volume))
        except Exception as e: logging.error(f"Error interpolating volume for {basin_id}: {e}"); return 0.0


    # get_flow_at_time Method (Remains the same)
    def get_flow_at_time(self, basin_id: str, time_min: float) -> float:
        """ Gets the calculated flow rate (m3/s) at a specific time using interpolation. """
        # (Same as original/previous versions)
        if basin_id not in self.basin_runoff_data: return 0.0
        data = self.basin_runoff_data[basin_id]
        timeline = data.get('convolution_timeline_min'); flow = data.get('direct_flow_m3_s')
        if timeline is None or flow is None: return 0.0
        timeline = np.array(timeline); flow = np.array(flow)
        min_len = min(len(timeline), len(flow))
        if min_len == 0: return 0.0
        timeline = timeline[:min_len]; flow = flow[:min_len]
        try:
            flow_at_time = np.interp(time_min, timeline, flow, left=0.0, right=0.0)
            return float(flow_at_time)
        except Exception as e: logging.error(f"Interpolation error in get_flow_at_time for {basin_id} @ {time_min}: {e}"); return 0.0

# 5) Integrated Simulation Class

@dataclass
class IntegratedSimulation:
    forest: 'DrainageForest'
    runoff_processor: 'RunoffProcessor'
    selected_inlet_ids: List[str]
    pending_runon: Dict[str, Tuple[float, int, float]] = field(default_factory=dict)
    active_states: Dict[str, BasinState] = field(default_factory=dict)
    merged_states_history: List[Dict] = field(default_factory=list)
    time_series: List[Dict] = field(default_factory=list)
    water_balance: List[Dict] = field(default_factory=list)
    tree_snapshots: List[Tuple[float, Dict]] = field(default_factory=list)
    original_basin_ids_in_sim: Set[str] = field(default_factory=set)
    simulation_time_min: Optional[float] = None
    timestep_min: Optional[float] = None
    dt_sec: Optional[float] = None
    halted_prematurely: bool = False
    halt_reason: str = ""
    previous_total_storage: float = 0.0
    initialization_successful: bool = False
    connected_original_basin_ids: Set[str] = field(default_factory=set)
    connected_total_area: float = 0.0
    connected_impervious_area: float = 0.0
    connected_pervious_area: float = 0.0
    
    RUNON_DISTRIBUTION_STEPS: int = 5

    def __post_init__(self):
        logging.info("--- Initializing Integrated Simulator ---")
        self.selected_inlet_ids = [str(iid) for iid in self.selected_inlet_ids]
        self.initialization_successful = self._initialize_basin_states()
        if self.initialization_successful:
            logging.info(f"Simulation initialized with {len(self.active_states)} active states.")
            logging.info(f"Tracking {len(self.original_basin_ids_in_sim)} original basins.")
            if self.simulation_time_min is not None and self.timestep_min is not None:
                logging.info(f"Derived Sim Time: {self.simulation_time_min} min, Timestep: {self.timestep_min} min ({self.dt_sec} sec)")
                logging.info(f"Run-on infiltration will be distributed over {self.RUNON_DISTRIBUTION_STEPS} steps starting after lag time.")
            else: logging.warning("Could not derive simulation time/timestep during init.")
            logging.info(f"Selected Inlet Trees Rooted At: {self.selected_inlet_ids}")
            logging.info("--- Integrated Simulation Initialization Complete ---")
        else: logging.error("--- Integrated Simulation Initialization FAILED ---")

    def _initialize_basin_states(self) -> bool:
        # Clear all state from previous runs
        self.active_states.clear(); self.time_series.clear(); self.water_balance.clear()
        self.tree_snapshots.clear(); self.previous_total_storage = 0.0; self.original_basin_ids_in_sim.clear()
        self.pending_runon.clear()
        # NEW: Reset connected area trackers
        self.connected_original_basin_ids.clear()
        self.connected_total_area = 0.0
        self.connected_impervious_area = 0.0
        self.connected_pervious_area = 0.0

        nodes_to_initialize = set()
        for inlet_id in self.selected_inlet_ids:
            if inlet_id in self.forest.inlet_roots:
                root_node = self.forest.inlet_roots[inlet_id]; queue = [root_node]; visited_init = set()
                while queue:
                    current_node = queue.pop(0)
                    if current_node and current_node.basin_id not in visited_init:
                        visited_init.add(current_node.basin_id)
                        if current_node.basin_id in self.forest.all_nodes:
                            nodes_to_initialize.add(current_node.basin_id); self.original_basin_ids_in_sim.add(current_node.basin_id)
                            for child in current_node.children:
                                if child: queue.append(child)
                        else: logging.warning(f"Node {current_node.basin_id} not in forest.all_nodes during init.")
            else: logging.warning(f"Selected inlet {inlet_id} not found in forest roots.")
        if not nodes_to_initialize: logging.error("No valid nodes found for selected inlets."); return False
        logging.info(f"Initializing states for {len(nodes_to_initialize)} nodes (from {len(self.original_basin_ids_in_sim)} unique basins).")
        first_valid_runoff_data = None
        for node_id in self.original_basin_ids_in_sim:
                if node_id in self.runoff_processor.basin_runoff_data: first_valid_runoff_data = self.runoff_processor.basin_runoff_data[node_id]; break
        if not first_valid_runoff_data: logging.error("Cannot initialize: No runoff data found."); return False
        global simulation_time_min; self.simulation_time_min = simulation_time_min
        self.timestep_min = first_valid_runoff_data.get('timestep_min'); self.dt_sec = first_valid_runoff_data.get('dt_sec')
        if self.timestep_min is None or self.dt_sec is None or self.dt_sec <= 0 or self.simulation_time_min is None: logging.error(f"Invalid simulation parameters derived."); return False
        logging.info(f"Using Timestep: {self.timestep_min:.4f} min ({self.dt_sec:.2f} sec), Sim Duration: {self.simulation_time_min} min.")
        initialized_count = 0
        for node_id in nodes_to_initialize:
            node = self.forest.all_nodes.get(node_id);
            if not node: logging.warning(f"Node object {node_id} not found."); continue
            runoff_data = self.runoff_processor.basin_runoff_data.get(node_id)
            if not runoff_data:
                logging.warning(f"Runoff data missing for {node_id}. Skipping & removing.");
                if node_id in self.original_basin_ids_in_sim: self.original_basin_ids_in_sim.remove(node_id)
                continue
            if abs(runoff_data.get('dt_sec', self.dt_sec) - self.dt_sec) > 1e-6: logging.warning(f"Inconsistent dt_sec for {node_id}.")
            original_tc_sec = runoff_data.get('tc_sec', 300.0)
            parent_id_in_sim = node.parent.basin_id if node.parent and node.parent.basin_id in nodes_to_initialize else None
            initial_state = BasinState(
                basin_id=node.basin_id, current_area=node.area, impervious_area=node.impervious_area, pervious_area=node.pervious_area,
                max_volume=node.storage_capacity, current_volume=0.0, effective_depth=node.effective_depth, alpha=1.0,
                parent_id=parent_id_in_sim, merged_from=[node.basin_id], is_merged=False, spilled_volume=0.0,
                runoff_state=runoff_data, lag_time=original_tc_sec, response_start_time=0.0, infiltrated_volume=0.0, initial_transit_water=0.0
            )
            self.active_states[node.basin_id] = initial_state; initialized_count += 1
        if initialized_count == 0: logging.error("Initialization failed: Zero states created."); return False
        self._log_state_and_balance(0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
        self.tree_snapshots.append((0.0, self._capture_tree_state()))
        return True

    def _add_newly_connected_areas(self, original_ids_to_connect: List[str]):
        """
        Adds the area of original basins to the running total if they haven't been added before.
        """
        for orig_id in original_ids_to_connect:
            if orig_id not in self.connected_original_basin_ids:
                # Find the node in the original forest to get its properties
                node = self.forest.all_nodes.get(orig_id)
                if node:
                    self.connected_total_area += node.area
                    self.connected_impervious_area += node.impervious_area
                    self.connected_pervious_area += node.pervious_area
                    self.connected_original_basin_ids.add(orig_id)
                    logging.info(f"Area of basin {orig_id} is now hydrologically connected.")
                else:
                    logging.warning(f"Could not find original node {orig_id} in forest to add its area.")
                    
    # _capture_tree_state method
    def _capture_tree_state(self) -> Dict:
        snapshot: Dict[str, Dict] = {};
        for b_id, st in self.active_states.items():
            data: Dict[str, Union[float, bool, List[str], List[float]]] = {'area': st.current_area, 'impervious_area': st.impervious_area, 'pervious_area': st.pervious_area, 'current_volume': st.current_volume, 'max_volume': st.max_volume, 'infiltrated_volume': st.infiltrated_volume, 'is_merged': st.is_merged, 'effective_depth': st.effective_depth, 'parent_id': st.parent_id, 'spilled_volume': st.spilled_volume, 'merged_from': list(st.merged_from), 'is_inlet': any(orig in st.merged_from for orig in self.selected_inlet_ids), 'lag_time': st.lag_time}
            rs = st.runoff_state or {}; data['convolution_timeline_min'] = rs.get('convolution_timeline_min', []); data['runon_infiltration_vol_step'] = rs.get('runon_infiltration_vol_step', [])
            snapshot[b_id] = data
        return snapshot

    # _calculate_total_storage method
    def _calculate_total_storage(self) -> float:
        total_storage = sum(st.current_volume for st in self.active_states.values() if st and not st.is_merged)
        return total_storage

    # _log_state_and_balance method
    def _log_state_and_balance(self, current_time_min: float, step_rainfall: float, step_initial_infiltration: float,
                                 step_runon_infiltration: float, step_runoff_gen: float, step_discharge_vol: float):
        storage_end_of_step = self._calculate_total_storage()
        storage_change = storage_end_of_step - self.previous_total_storage
        total_step_infiltration = step_initial_infiltration + step_runon_infiltration
        expected_storage_change = step_runoff_gen - (step_discharge_vol + step_runon_infiltration)
        internal_mass_balance_error = storage_change - expected_storage_change
        overall_mass_balance_error = step_rainfall - total_step_infiltration - step_discharge_vol - storage_change
        current_outlet_flow_m3_s = step_discharge_vol / self.dt_sec if self.dt_sec is not None and self.dt_sec > 0 else 0.0
        
        self.time_series.append({
            'time': current_time_min, 
            'inlet_total_area': self.connected_total_area, 
            'inlet_impervious_area': self.connected_impervious_area, 
            'inlet_pervious_area': self.connected_pervious_area, 
            'inlet_effective_area': self.connected_impervious_area + self.connected_pervious_area, 
            'inlet_volume_rate': current_outlet_flow_m3_s, 
            'current_storage': storage_end_of_step
        })
        
        self.water_balance.append({'time': current_time_min, 'total_rainfall': step_rainfall, 'initial_infiltration': step_initial_infiltration, 'runon_infiltration': step_runon_infiltration, 'total_infiltration': total_step_infiltration, 'active_runoff_generated': step_runoff_gen, 'inlet_discharge': step_discharge_vol, 'total_storage': storage_end_of_step, 'storage_change': storage_change, 'internal_mass_balance_error': internal_mass_balance_error, 'overall_mass_balance_error': overall_mass_balance_error})
        self.previous_total_storage = storage_end_of_step
        rel_error_threshold = 0.01; abs_error_threshold = 1e-4; check_value = abs(internal_mass_balance_error)
        input_check = abs(step_runoff_gen) if abs(step_runoff_gen) > 1e-9 else abs(storage_change)
        if check_value > abs_error_threshold:
            is_relative_error_significant = (input_check > 1e-9 and (check_value / input_check) > rel_error_threshold)
            if check_value > abs_error_threshold * 5 or is_relative_error_significant : logging.warning(f"Internal Mass Balance Error at t={current_time_min:.2f} min: {internal_mass_balance_error: .6f} m³ (RunoffGen={step_runoff_gen:.4f}, RunonInf={step_runon_infiltration:.4f}, Disch={step_discharge_vol:.4f}, dS={storage_change:.4f})")

    def calculate_eia_by_arrival_at_inlet(self) -> Optional[pd.DataFrame]:
            """
            Recalculates timeseries for Total, Impervious, and Pervious connected areas
            based on the time runoff from each source basin physically arrives at the
            final inlet basin. This is a post-simulation analysis.
            CORRECTED: Now correctly handles no-spill scenarios, resulting in zero
            area if no runoff is ever generated.
            """
            logging.info("Recalculating all connected areas based on time of arrival at the final inlet basin...")
            if not self.time_series or not self.original_basin_ids_in_sim:
                logging.error("Cannot calculate area by arrival: initial simulation data is missing.")
                return None

            # This dictionary will store: {original_basin_id: arrival_time_at_inlet_min}
            arrival_times = {}

            # --- Step 1: For every original basin, find when its own runoff hydrograph
            # starts ---
            first_runoff_times = {}
            for basin_id in self.original_basin_ids_in_sim:
                runoff_data = self.runoff_processor.basin_runoff_data.get(basin_id)
                if not runoff_data: continue

                flow = np.array(runoff_data.get('direct_flow_m3_s', []))
                times = np.array(runoff_data.get('convolution_timeline_min', []))

                # Only consider basins that actually produced a meaningful amount of
                # runoff
                if flow.size > 0 and times.size > 0 and np.sum(flow) > 1e-9:
                    first_flow_idx_list = np.where(flow > 1e-9)[0]
                    if first_flow_idx_list.size > 0:
                        first_runoff_times[basin_id] = times[first_flow_idx_list[0]]

            # If no basins produced any runoff at all, we can stop and return zeros.
            if not first_runoff_times:
                logging.warning("No runoff was generated in any basin. Connected area will be zero for the entire simulation.")
                timeline = np.array([d['time'] for d in self.time_series])
                return pd.DataFrame({
                    'time': timeline,
                    'total_area_by_arrival': np.zeros_like(timeline),
                    'impervious_area_by_arrival': np.zeros_like(timeline),
                    'pervious_area_by_arrival': np.zeros_like(timeline)
                })

            # --- Step 2: For each basin that DID produce runoff, find its total travel
            # time to the inlet ---
            for basin_id, start_runoff_time in first_runoff_times.items():
                total_lag_min = 0.0
                node = self.forest.all_nodes.get(basin_id)
                
                # Trace path from this basin up to the root of its tree
                path_node = node
                while path_node and path_node.parent:
                    parent_id = path_node.parent.basin_id
                    
                    # Find the merge event where this path segment was the upstream
                    # component
                    for merge in self.merged_states_history:
                        up_id = merge.get('upstream_id')
                        down_id = merge.get('downstream_id')
                        
                        if not (up_id and down_id): continue

                        # Find the original components of the upstream and downstream
                        # parts of the merge
                        up_data = self.runoff_processor.basin_runoff_data.get(up_id, {})
                        down_data = self.runoff_processor.basin_runoff_data.get(down_id, {})
                        up_components = up_data.get('merged_from', [up_id])
                        down_components = down_data.get('merged_from', [down_id])

                        # If our current node is in the upstream part and its parent is
                        # in the downstream part...
                        if path_node.basin_id in up_components and parent_id in down_components:
                            # ...add the lag time for this step and break from the merge
                            # history loop.
                            total_lag_min += merge.get('lag_calculated_sec', 0.0) / 60.0
                            break
                            
                    path_node = path_node.parent

                # The final arrival time is when its own runoff starts, plus all travel
                # delays.
                arrival_times[basin_id] = start_runoff_time + total_lag_min

            # --- Step 3: Build the new area timeseries using the calculated arrival
            # times ---
            timeline = np.array([d['time'] for d in self.time_series])
            total_area_series = np.zeros_like(timeline, dtype=float)
            imperv_area_series = np.zeros_like(timeline, dtype=float)
            perv_area_series = np.zeros_like(timeline, dtype=float)

            for i, t in enumerate(timeline):
                current_total = 0.0
                current_imperv = 0.0
                current_perv = 0.0
                for basin_id, arrival_time in arrival_times.items():
                    if t >= arrival_time:
                        node = self.forest.all_nodes.get(basin_id)
                        if node:
                            current_total += node.area
                            current_imperv += node.impervious_area
                            current_perv += node.pervious_area
                total_area_series[i] = current_total
                imperv_area_series[i] = current_imperv
                perv_area_series[i] = current_perv

            return pd.DataFrame({
                'time': timeline, 
                'total_area_by_arrival': total_area_series,
                'impervious_area_by_arrival': imperv_area_series,
                'pervious_area_by_arrival': perv_area_series
            })

    # _merge_basins_lag_route method (MODIFIED TO STORE PENDING RUNON INFO)
    def _merge_basins_lag_route(self, up_id: str, down_id: str) -> bool:
        """Merges basins and UPDATES the dynamically tracked connected area."""
        up_state = self.active_states.get(up_id); down_state = self.active_states.get(down_id)
        if not up_state or not down_state: logging.error(f"Merge Error: State not found for {up_id} or {down_id}."); return False
        if up_state.is_merged or down_state.is_merged: logging.warning(f"Merge attempt skipped: {up_id} or {down_id} already merged."); return False
        
        # NEW: Update connected areas BEFORE the merge calculation
        # The upstream state's original basins are now connected.
        self._add_newly_connected_areas(up_state.merged_from)
        # The downstream state's original basins are also now officially connected.
        self._add_newly_connected_areas(down_state.merged_from)
        # End New

        current_time_sec = self.time_series[-1]['time'] * 60.0 if self.time_series else 0.0
        current_merge_time_min = current_time_sec / 60.0

        upstream_volume_to_transfer = min(up_state.current_volume, up_state.max_volume)
        initial_merged_volume = upstream_volume_to_transfer + down_state.current_volume

        rainfall_csv_path = getattr(self.runoff_processor, 'rainfall_csv', None)
        rain_unit_val = getattr(self.runoff_processor, 'rain_unit', None)
        p2_val = getattr(self.runoff_processor, 'P2_in', None)
        if rainfall_csv_path is None or rain_unit_val is None or p2_val is None:
            logging.error("Merge Error: Missing params in RunoffProcessor."); self.halt_reason = "Missing params for merge."; return False

        merged_runoff_data = self.runoff_processor.calculate_merged_basin_runoff(
            upstream_id=up_id, downstream_id=down_id, current_time_sec=current_time_sec, timestep_min=self.timestep_min,
            sim_time_min=self.simulation_time_min, rainfall_csv=rainfall_csv_path, rain_unit=rain_unit_val, P2_in=p2_val
        )

        if merged_runoff_data is None: self.halt_reason = f"Merge hydrograph calc failed for {up_id}->{down_id}"; logging.critical(f"CRITICAL HALT: {self.halt_reason}"); return False

        merged_id = merged_runoff_data['basin_id']

        new_area = up_state.current_area + down_state.current_area
        new_max_volume = up_state.max_volume + down_state.max_volume
        new_imperv_area = up_state.impervious_area + down_state.impervious_area
        new_perv_area = up_state.pervious_area + down_state.pervious_area
        new_merged_state = BasinState(
            basin_id=merged_id, current_area=new_area, impervious_area=new_imperv_area, pervious_area=new_perv_area,
            max_volume=new_max_volume, current_volume=initial_merged_volume, effective_depth=new_max_volume / new_area if new_area > 1e-6 else 0,
            alpha=1.0, parent_id=down_state.parent_id, merged_from=sorted(list(set(down_state.merged_from + up_state.merged_from))),
            is_merged=False, spilled_volume=0.0, runoff_state=merged_runoff_data, lag_time=merged_runoff_data.get('lag_used_sec', 0.0),
            response_start_time=current_time_sec, infiltrated_volume=(up_state.infiltrated_volume + down_state.infiltrated_volume), initial_transit_water=0.0
        )

        up_state.is_merged = True; down_state.is_merged = True
        self.active_states[merged_id] = new_merged_state
        if up_id in self.active_states: del self.active_states[up_id]
        if down_id in self.active_states: del self.active_states[down_id]

        children_to_update = [cid for cid, cstate in self.active_states.items() if cstate.parent_id == up_id or cstate.parent_id == down_id]
        for child_id in children_to_update:
                if child_id in self.active_states: self.active_states[child_id].parent_id = merged_id

        calculated_lag_sec = merged_runoff_data.get('lag_used_sec', 0.0)
        self.merged_states_history.append({'time_min': current_merge_time_min, 'upstream_id': up_id, 'downstream_id': down_id, 'merged_id': merged_id, 'lag_calculated_sec': calculated_lag_sec})

        total_runon_for_this_merge = merged_runoff_data.get('total_runon_infiltration_calculated', 0.0)
        if isinstance(total_runon_for_this_merge, (int, float)) and total_runon_for_this_merge > 1e-9:
            arrival_time_min = current_merge_time_min + (calculated_lag_sec / 60.0)
            self.pending_runon[merged_id] = (total_runon_for_this_merge, self.RUNON_DISTRIBUTION_STEPS, arrival_time_min)
            logging.debug(f"  Stored {total_runon_for_this_merge:.6f} m³ run-on from {merged_id} to be distributed over {self.RUNON_DISTRIBUTION_STEPS} steps starting after t={arrival_time_min:.2f} min.")
        return True

    # RECOMMENDED VERSION for IntegratedSimulation
    def run_simulation(self):
        """Runs the fill-and-spill simulation. Uses TOTAL run-on stored from merge events,
            logged immediately in the merge timestep for mass balance.
            (No diagnostic print/debug lines)."""
        if not self.initialization_successful: logging.error("Sim init failed."); return self.time_series, self.water_balance, self.tree_snapshots
        logging.info("--- Starting Integrated Simulation Run ---")
        if (self.simulation_time_min is None or self.timestep_min is None or self.timestep_min <= 0 or self.dt_sec is None or self.dt_sec <= 0):
            logging.error("Invalid sim time/step."); self.halted_prematurely = True; self.halt_reason = "Invalid time/step."; return self.time_series, self.water_balance, self.tree_snapshots
        timeline_mins = np.arange(0, self.simulation_time_min + self.timestep_min, self.timestep_min); timeline_mins = timeline_mins[timeline_mins <= self.simulation_time_min + 1e-9]; num_steps = len(timeline_mins) - 1
        if num_steps <= 0: logging.error("Zero steps."); self.halted_prematurely = True; self.halt_reason = "Zero steps."; return self.time_series, self.water_balance, self.tree_snapshots
        logging.info(f"Total simulation steps: {num_steps}")

        # Main loop
        for i in range(num_steps):
            current_time_min = timeline_mins[i]; next_time_min = timeline_mins[i+1]; current_time_sec = current_time_min * 60.0
            logging.info(f"--- Sim Step {i+1}/{num_steps} (Time: {current_time_min:.2f} -> {next_time_min:.2f} min) ---")

            # 1: Rainfall and initial infiltration
            step_rain_vol = 0.0; step_init_inf = 0.0
            for oid in self.original_basin_ids_in_sim:
                rd = self.runoff_processor.basin_runoff_data.get(oid);
                if not rd: continue
                times = np.array(rd.get('timeline_min', [])); rain_arr = np.array(rd.get('total_rainfall_vol_step', [])); inf_arr = np.array(rd.get('infiltration_vol_step', []))
                if times.size == 0 or rain_arr.size != times.size or inf_arr.size != times.size: continue
                si = int(np.searchsorted(times, current_time_min - 1e-9, side='left')); ei = int(np.searchsorted(times, next_time_min - 1e-9, side='left'))
                if si < ei and ei <= len(rain_arr): step_rain_vol += np.sum(rain_arr[si:ei]); step_init_inf += np.sum(inf_arr[si:ei])

            # 2: Runoff inflow to storage
            step_runoff_gen = 0.0
            active_ids_before_inflow = list(self.active_states.keys())
            for bid in active_ids_before_inflow:
                if bid not in self.active_states: continue
                state = self.active_states[bid]; vol_in = self.runoff_processor.get_volume_in_timespan(bid, current_time_min, next_time_min)
                state.current_volume += vol_in; step_runoff_gen += vol_in

            # 3: Merges at current time
            merged_ids_this_step = [] # Track NEWLY created merge IDs this step
            merge_occurred = False; check_for_merges = True; merge_iteration = 0; max_merge_iterations = len(self.active_states) + 5
            while check_for_merges and merge_iteration < max_merge_iterations:
                    merge_iteration += 1; found_merge_in_iteration = False
                    active_ids_before_iteration = list(self.active_states.keys())
                    for bid in active_ids_before_iteration:
                        if bid not in self.active_states: continue
                        state = self.active_states[bid]
                        if state.current_volume > state.max_volume + 1e-9 and state.parent_id:
                            parent_id = state.parent_id
                            if parent_id in self.active_states:
                                success = self._merge_basins_lag_route(bid, parent_id)
                                if success:
                                    merge_occurred = True; found_merge_in_iteration = True; merged_id_created = f"{parent_id}+{bid}"
                                    if merged_id_created in self.active_states:
                                        if merged_id_created not in merged_ids_this_step: merged_ids_this_step.append(merged_id_created)
                                    else: self.halted_prematurely = True; self.halt_reason = f"Merged state {merged_id_created} missing"; break
                                else: self.halted_prematurely = True; self.halt_reason = f"Merge failed {bid}->{parent_id}"; break
                    if self.halted_prematurely: break
                    if not found_merge_in_iteration: check_for_merges = False
            if merge_iteration >= max_merge_iterations: self.halted_prematurely = True; self.halt_reason = "Max merge iterations reached."
            if self.halted_prematurely: break
            if merge_occurred: self.tree_snapshots.append((current_time_sec, self._capture_tree_state()))

            # 4: Run-on infiltration calculation (Summing Stored TOTALS)
            step_runon_inf = 0.0
            if merged_ids_this_step: # Only process NEW merges created in this step
                for mid in merged_ids_this_step:
                    ms = self.active_states.get(mid)
                    if not ms or not ms.runoff_state:
                        logging.warning(f"State/runoff_state missing for {mid} when summing total runon.")
                        continue
                    # Get the TOTAL run-on calculated and stored during the merge
                    total_runon_for_this_merge = ms.runoff_state.get('total_runon_infiltration_calculated', 0.0) # Default to 0
                    if not isinstance(total_runon_for_this_merge, (int, float)):
                        logging.warning(f"Invalid type stored total runon for {mid}.")
                        total_runon_for_this_merge = 0.0
                    step_runon_inf += total_runon_for_this_merge # Add the total lump sum

            # 5: Discharge from outlet basins
            step_discharge = 0.0
            active_ids_after_merge = list(self.active_states.keys())
            for bid in active_ids_after_merge:
                    if bid not in self.active_states: continue
                    state = self.active_states[bid]; is_outlet_path = any(orig_inlet in state.merged_from for orig_inlet in self.selected_inlet_ids)
                    if is_outlet_path:
                        discharge_potential_vol = self.runoff_processor.get_volume_in_timespan(bid, current_time_min, next_time_min)
                        actual_discharge_vol = min(discharge_potential_vol, state.current_volume); actual_discharge_vol = max(0.0, actual_discharge_vol)
                        state.current_volume -= actual_discharge_vol; step_discharge += actual_discharge_vol

            # 6: Log state and mass balance
            self._log_state_and_balance(
                next_time_min, step_rain_vol, step_init_inf,
                step_runon_inf, # Log the total "lump sum" calculated this step
                step_runoff_gen, step_discharge
            )

            if self.halted_prematurely:
                logging.warning(f"--- Halting simulation loop at step {i+1} due to: {self.halt_reason} ---")
                break

        # Final Snapshot Logic
        if not self.halted_prematurely:
            final_sim_time_min = timeline_mins[-1]; final_sec = final_sim_time_min * 60.0
            if not self.tree_snapshots or abs(self.tree_snapshots[-1][0] - final_sec) > 1e-6:
                    self.tree_snapshots.append((final_sec, self._capture_tree_state()))

        logging.info(f"--- Simulation Run Finished. Halted Prematurely: {self.halted_prematurely} ---")
        return self.time_series, self.water_balance, self.tree_snapshots
# 6) Visualization Functions

def plot_basin_runoff_stages(basin_runoff_data: Optional[Dict],
                             simulation_time_min: float, # Required Parameter
                             rain_unit_pref: str = 'cm/hr'):
    """
    Generates a 5-panel plot visualizing rainfall-runoff stages for a single basin.
    Applies formatting and COLOR scheme consistent with other plots.
    """
    if not basin_runoff_data:
        logging.warning("No basin runoff data provided for plotting stages.")
        print("No basin runoff data provided for plotting stages.")
        return

    basin_id = basin_runoff_data.get('basin_id', 'N/A')
    logging.info(f"Generating formatted runoff stage plots for Basin ID: {basin_id}")

    # Define Font Sizes
    base_fontsize = 10
    fontsize_title = base_fontsize + 4
    fontsize_subtitle = base_fontsize + 2
    fontsize_label = base_fontsize + 1
    fontsize_legend = base_fontsize
    fontsize_tick = base_fontsize

    # Define Color Palette (Consistent with other plots)
    colors = {
        'Rainfall': '#a6cee3',
        'Excess Rainfall': '#e31a1c',
        'Runoff Hydrograph': '#2ca02c',
        'Available Capacity': '#fdbf6f',
        'Infiltration Volume': '#ff7f0e',
        'Max Infil Line': '#e31a1c',
        'Effective Max Line': '#ff7f0e'
    }
    # Define alphas separately for fills
    alphas = {
        'Rainfall': 0.6,
        'Excess Rainfall': 0.6,
        'Runoff Hydrograph': 0.6,
        'Available Capacity': 0.6,
        'Infiltration Volume': 0.5
    }

    # Extract Data & Validate
    try:
        timeline_min = np.array(basin_runoff_data['timeline_min']); conv_timeline_min = np.array(basin_runoff_data['convolution_timeline_min'])
        rain_depth_m_step = np.array(basin_runoff_data['rain_depth_m_step']); simplified_excess_depth_m_step = np.array(basin_runoff_data.get('simplified_excess_depth_m_step', np.zeros_like(timeline_min)))
        direct_flow_m3_s = np.array(basin_runoff_data['direct_flow_m3_s']); deficit_volume_step = np.array(basin_runoff_data.get('infiltration_deficit_vol_step', np.zeros_like(timeline_min)))
        infiltration_vol_step = np.array(basin_runoff_data.get('infiltration_vol_step', np.zeros_like(timeline_min))) # Get actual infiltration
        dt_sec = basin_runoff_data['dt_sec']; basin_area = basin_runoff_data['basin_area']
        perv_area_eff = basin_runoff_data.get('perv_area_used_for_infiltration', basin_runoff_data.get('pervious_area', 0)) # Safer get
        total_runoff_volume = basin_runoff_data.get('total_direct_runoff_vol', 0.0)
        total_actual_infiltration = basin_runoff_data.get('total_infiltration_vol', 0.0)
        if 'INFILTRATION_RATE_CM_S' not in globals(): raise NameError("Global INFILTRATION_RATE_CM_S not defined.")
        input_infiltration_rate_cm_s = globals()['INFILTRATION_RATE_CM_S']
        if dt_sec <= 0: raise ValueError("dt_sec must be positive.")
        if basin_area <= 1e-9: basin_area = 1e-9 # Avoid division by zero but proceed
        if len(timeline_min) == 0: raise ValueError("Input timeline empty.")
        # Ensure lengths match or handle gracefully
        if len(rain_depth_m_step) != len(timeline_min): rain_depth_m_step = np.zeros_like(timeline_min); logging.warning(f"{basin_id}: Rain depth length mismatch.")
        if len(simplified_excess_depth_m_step) != len(timeline_min): simplified_excess_depth_m_step = np.zeros_like(timeline_min); logging.warning(f"{basin_id}: Simplified excess length mismatch.")
        if len(deficit_volume_step) != len(timeline_min): deficit_volume_step = np.zeros_like(timeline_min); logging.warning(f"{basin_id}: Deficit volume length mismatch.")
        if len(infiltration_vol_step) != len(timeline_min): infiltration_vol_step = np.zeros_like(timeline_min); logging.warning(f"{basin_id}: Infiltration volume length mismatch.")
        if len(conv_timeline_min) != len(direct_flow_m3_s): logging.warning(f"{basin_id}: Convolution timeline/flow mismatch."); conv_timeline_min=np.array([]); direct_flow_m3_s=np.array([])
    except KeyError as ke: logging.error(f"Plot Stages {basin_id}: Missing key {ke}"); traceback.print_exc(); return
    except Exception as e: logging.error(f"Plot Stages {basin_id}: Data extraction error: {e}"); traceback.print_exc(); return

    # Calculations for Plotting
    try:
        rain_rate_m_s = np.divide(rain_depth_m_step, dt_sec, out=np.zeros_like(rain_depth_m_step), where=dt_sec!=0)
        simplified_excess_rate_m_s = np.divide(simplified_excess_depth_m_step, dt_sec, out=np.zeros_like(simplified_excess_depth_m_step), where=dt_sec!=0)
        max_infil_rate_on_perv_m_s = input_infiltration_rate_cm_s * 0.01
        effective_max_infil_rate_m_s = (max_infil_rate_on_perv_m_s * perv_area_eff / basin_area) if basin_area > 1e-9 else 0.0
        deficit_rate_m_s = np.divide(deficit_volume_step, (basin_area * dt_sec), out=np.zeros_like(deficit_volume_step), where=(basin_area*dt_sec)!=0)
        unit_factor=1.0; rate_unit_label='m/s'
        if rain_unit_pref=='cm/hr': unit_factor=360000; rate_unit_label='cm/hr' # 100cm/m * 3600s/hr
        elif rain_unit_pref=='mm/hr': unit_factor=3600000; rate_unit_label='mm/hr' # 1000mm/m * 3600s/hr
        elif rain_unit_pref=='in/hr': unit_factor=(1/0.0254)*3600; rate_unit_label='in/hr'
        else: rate_unit_label='m/s'
        rain_rate_plot = rain_rate_m_s * unit_factor
        simplified_excess_rate_plot = simplified_excess_rate_m_s * unit_factor
        orig_infil_rate_plot = max_infil_rate_on_perv_m_s * unit_factor
        eff_infil_rate_plot = effective_max_infil_rate_m_s * unit_factor
        deficit_rate_plot = deficit_rate_m_s * unit_factor
        runoff_response_m3_min = direct_flow_m3_s * 60.0
    except Exception as e: logging.error(f"Plot Stages {basin_id}: Calculation error: {e}"); traceback.print_exc(); return

    # Create Figure & Plot
    try:
        fig, axes = plt.subplots(5, 1, figsize=(10, 15), sharex=True)
        fig.suptitle(f'Hydrologic Processes for Basin: {basin_id}', fontsize=fontsize_title, y=0.99)

        legend_props = {'loc': 'upper right', 'fontsize': fontsize_legend, 'frameon': True, 'facecolor': 'white', 'edgecolor': 'black'}
        legend_frame_lw = 0.8
        fill_edge_color = 'black'; fill_edge_lw = 0.5

        # Plot 1 (Rainfall & Capacity) - Use Light Blue for Rainfall
        ax1 = axes[0]; handles1 = []
        line_rain, = ax1.plot(timeline_min, rain_rate_plot, color=colors['Rainfall'], drawstyle='steps-post', lw=1.0) # Thinner line
        fill_rain = ax1.fill_between(timeline_min, 0, rain_rate_plot, color=colors['Rainfall'], alpha=alphas['Rainfall'], step='post', edgecolor=fill_edge_color, lw=fill_edge_lw)
        line_max_infil = ax1.axhline(orig_infil_rate_plot, color=colors['Max Infil Line'], linestyle='--', label=f'Max Infil. Rate ({orig_infil_rate_plot:.2f} {rate_unit_label})')
        line_eff_infil = ax1.axhline(eff_infil_rate_plot, color=colors['Effective Max Line'], linestyle=':', linewidth=2, label=f'Effective Max Rate ({eff_infil_rate_plot:.2f} {rate_unit_label})')
        handles1.append(mpatches.Patch(facecolor=colors['Rainfall'], alpha=alphas['Rainfall'], edgecolor=fill_edge_color, label='Rainfall Rate'))
        handles1.append(line_max_infil); handles1.append(line_eff_infil)
        ax1.set_ylabel(f'Rate ({rate_unit_label})', fontsize=fontsize_label); ax1.set_title('1. Rainfall & Infiltration Capacity', fontsize=fontsize_subtitle)
        ax1.legend(handles=handles1, **legend_props).get_frame().set_linewidth(legend_frame_lw); ax1.grid(False)
        ax1.tick_params(axis='both', labelsize=fontsize_tick); ax1.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))

        # Plot 2 (Excess Rate) - Use Red
        ax2 = axes[1]; handles2 = []
        line_excess, = ax2.plot(timeline_min, simplified_excess_rate_plot, color=colors['Excess Rainfall'], drawstyle='steps-post', lw=1.0)
        fill_excess = ax2.fill_between(timeline_min, 0, simplified_excess_rate_plot, color=colors['Excess Rainfall'], alpha=alphas['Excess Rainfall'], step='post', edgecolor=fill_edge_color, lw=fill_edge_lw)
        handles2.append(mpatches.Patch(facecolor=colors['Excess Rainfall'], alpha=alphas['Excess Rainfall'], edgecolor=fill_edge_color, label='Excess Rate (Simplified)'))
        ax2.set_ylabel(f'Rate ({rate_unit_label})', fontsize=fontsize_label); ax2.set_title('2. Excess Rainfall Rate', fontsize=fontsize_subtitle)
        ax2.legend(handles=handles2, **legend_props).get_frame().set_linewidth(legend_frame_lw); ax2.grid(False)
        ax2.tick_params(axis='both', labelsize=fontsize_tick); ax2.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))

        # Plot 3 (Runoff Response) - Use Green
        ax3 = axes[2]; handles3 = []
        if len(conv_timeline_min) > 0 and len(runoff_response_m3_min) > 0:
            line_runoff, = ax3.plot(conv_timeline_min, runoff_response_m3_min, color=colors['Runoff Hydrograph'], drawstyle='steps-post', lw=1.5)
            fill_runoff = ax3.fill_between(conv_timeline_min, 0, runoff_response_m3_min, color=colors['Runoff Hydrograph'], alpha=alphas['Runoff Hydrograph'], step='post', edgecolor=fill_edge_color, lw=fill_edge_lw)
            handles3.append(mpatches.Patch(facecolor=colors['Runoff Hydrograph'], alpha=alphas['Runoff Hydrograph'], edgecolor=fill_edge_color, label=f'Surface Runoff (Vol: {total_runoff_volume:.3f} m³)'))
        else: ax3.text(0.5, 0.5, "No runoff data", ha='center', va='center', transform=ax3.transAxes, fontsize=fontsize_tick)
        ax3.set_ylabel('Flow Rate (m³/min)', fontsize=fontsize_label); ax3.set_title('3. Surface Runoff Hydrograph', fontsize=fontsize_subtitle)
        ax3.legend(handles=handles3, **legend_props).get_frame().set_linewidth(legend_frame_lw); ax3.grid(False)
        ax3.tick_params(axis='both', labelsize=fontsize_tick); ax3.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

        # Plot 4 (Available Capacity Rate) - Use Yellow/Light Orange
        ax4 = axes[3]; handles4 = []
        line_def_rate, = ax4.plot(timeline_min, deficit_rate_plot, color=colors['Available Capacity'], drawstyle='steps-post', lw=1.0)
        fill_def_rate = ax4.fill_between(timeline_min, 0, deficit_rate_plot, color=colors['Available Capacity'], alpha=alphas['Available Capacity'], step='post', edgecolor=fill_edge_color, lw=fill_edge_lw)
        handles4.append(mpatches.Patch(facecolor=colors['Available Capacity'], alpha=alphas['Available Capacity'], edgecolor=fill_edge_color, label='Available Capacity Rate'))
        line_eff_infil4 = ax4.axhline(eff_infil_rate_plot, color=colors['Effective Max Line'], linestyle=':', linewidth=2, label=f'Effective Max Rate ({eff_infil_rate_plot:.2f} {rate_unit_label})')
        handles4.append(line_eff_infil4)
        ax4.set_ylabel(f'Rate ({rate_unit_label})', fontsize=fontsize_label); ax4.set_title('4. Available Infiltration Capacity Rate', fontsize=fontsize_subtitle)
        ax4.legend(handles=handles4, **legend_props).get_frame().set_linewidth(legend_frame_lw); ax4.grid(False)
        ax4.tick_params(axis='both', labelsize=fontsize_tick); ax4.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))

        # Plot 5 (Actual Infiltration Volume) - Use Orange
        ax5 = axes[4]; handles5 = []
        line_inf_vol, = ax5.plot(timeline_min, infiltration_vol_step, color=colors['Infiltration Volume'], drawstyle='steps-post', lw=1.0)
        fill_inf_vol = ax5.fill_between(timeline_min, 0, infiltration_vol_step, color=colors['Infiltration Volume'], alpha=alphas['Infiltration Volume'], step='post', edgecolor=fill_edge_color, lw=fill_edge_lw)
        handles5.append(mpatches.Patch(facecolor=colors['Infiltration Volume'], alpha=alphas['Infiltration Volume'], edgecolor=fill_edge_color, label=f'Infiltration Volume (Total: {total_actual_infiltration:.3f} m³)'))
        ax5.set_xlabel(f'Time (min)', fontsize=fontsize_label); ax5.set_ylabel('Volume (m³ / step)', fontsize=fontsize_label)
        ax5.set_title('5. Actual Infiltration Volume per Timestep', fontsize=fontsize_subtitle)
        ax5.legend(handles=handles5, **legend_props).get_frame().set_linewidth(legend_frame_lw); ax5.grid(False)
        ax5.tick_params(axis='both', labelsize=fontsize_tick); ax5.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.4f'))

        # Final Formatting
        axes[-1].set_xlim(0, simulation_time_min) # Set limit on last axis

        # Synchronize Rate Axes (1, 2, 4)
        max_y_rate = 0.01 # Min value
        if handles1: max_y_rate = max(max_y_rate, ax1.get_ylim()[1])
        if handles2: max_y_rate = max(max_y_rate, ax2.get_ylim()[1])
        if handles4: max_y_rate = max(max_y_rate, ax4.get_ylim()[1])
        ax1.set_ylim(0, max_y_rate); ax2.set_ylim(0, max_y_rate); ax4.set_ylim(0, max_y_rate)

        # Auto-scale flow axis (3) and volume axis (5)
        if len(runoff_response_m3_min) > 0: ax3.set_ylim(0, np.max(runoff_response_m3_min) * 1.1 if np.max(runoff_response_m3_min) > 0 else 0.001)
        else: ax3.set_ylim(0, 0.001)
        if len(infiltration_vol_step) > 0: ax5.set_ylim(0, np.max(infiltration_vol_step) * 1.1 if np.max(infiltration_vol_step) > 0 else 1e-5)
        else: ax5.set_ylim(0, 1e-5)

        plt.tight_layout(rect=[0, 0.03, 1, 0.96])
        plt.show()

    except Exception as e:
        logging.error(f"Failed during plotting setup/execution for {basin_id}: {e}"); traceback.print_exc()
        if 'fig' in locals() and fig is not None and plt.fignum_exists(fig.number): plt.close(fig)

def _style_legend_and_place_text(ax, fig, text_str, loc='upper right'):
    leg = ax.legend(loc=loc, facecolor='white', edgecolor='black', fontsize=9, frameon=True)
    if leg:
        leg.get_frame().set_linewidth(0.8); leg.get_frame().set_edgecolor('black')
        try:
            fig.canvas.draw_idle()
            bbox = leg.get_window_extent(renderer=fig.canvas.get_renderer())
            bbox_data = bbox.transformed(ax.transAxes.inverted())
            x_text = bbox_data.x0; y_text = bbox_data.y0 - 0.02
            ha = 'left'; va = 'top'
        except Exception:
            logging.debug("Could not get legend extent, placing text at default position.")
            x_text, y_text, ha, va = 0.98, 0.02, 'right', 'bottom'
        ax.text(x_text, y_text, text_str, transform=ax.transAxes, ha=ha, va=va, fontsize=9, bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=0.2))
    else:
        ax.text(0.98, 0.02, text_str, transform=ax.transAxes, ha='right', va='bottom', fontsize=9, bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=0.2))

def build_hierarchical_layout(current_state: Dict[str, Dict], original_inlet_ids: List[str]) -> Dict[str, Tuple[float, float]]:
    """
    Computes node positions for a hierarchical tree-like layout.
    Args:
        current_state: A dictionary where keys are node IDs and values are dictionaries
                        containing at least a 'parent_id' key.
        original_inlet_ids: List of original inlet IDs (used if needed, though current logic focuses on parent_id links).
    Returns:
        A dictionary mapping node IDs to (x, y) coordinates.
    """
    active_nodes = list(current_state.keys())
    logging.debug(f"[Layout Build] Attempting layout for {len(active_nodes)} active nodes: {active_nodes}")
    if not active_nodes:
        logging.warning("[Layout Build] No active nodes provided to build_hierarchical_layout. Returning empty positions.")
        return {}

    adj = {nid: [] for nid in active_nodes}
    parent_map = {nid: current_state[nid].get('parent_id') for nid in active_nodes}
    logging.debug(f"[Layout Build] Parent map generated: {parent_map}")

    roots = []
    node_has_parent_in_graph = set()
    for nid in active_nodes:
        pid = parent_map.get(nid)
        if pid and pid in active_nodes: # Parent is also in the current snapshot's active set
            adj[pid].append(nid)
            node_has_parent_in_graph.add(nid)

    for nid in active_nodes:
        # A root is a node whose parent is None, or whose parent is not in the current
        # set of active_nodes
        if parent_map.get(nid) is None or parent_map.get(nid) not in active_nodes :
            roots.append(nid)
            
    # If no roots were found by the above logic (e.g. a cycle involving all nodes, or
    # all have parents within the set)
    # and there are active_nodes, pick the first active node as a fallback root.
    # Also, ensure that any node identified as an original inlet that doesn't have a
    # parent in the graph is treated as a root.
    explicit_roots_from_inlets = [in_id for in_id in original_inlet_ids if in_id in active_nodes and (parent_map.get(in_id) is None or parent_map.get(in_id) not in active_nodes)]
    if explicit_roots_from_inlets:
        roots = sorted(list(set(roots + explicit_roots_from_inlets))) # Combine and unique

    logging.debug(f"[Layout Build] Initial roots based on parent_map: {roots}")

    if not roots and active_nodes:
        roots = [active_nodes[0]] # Fallback if no other roots found
        logging.warning(f"[Layout Build] No distinct roots found. Using fallback root: {roots[0]}")
    elif not roots and not active_nodes: # Should be caught by the first check
        return {}
        
    # Ensure HORIZONTAL_SPACING and VERTICAL_SPACING are accessible
    # These should be global variables defined in your USER INPUT section
    node_size = HORIZONTAL_SPACING
    subtree_widths: Dict[str, float] = {}
    depth: Dict[str, int] = {}
    visited_post_order: Set[str] = set()

    def post_order_dfs(u_node_id: str, current_depth_level: int) -> Tuple[float, int]:
        visited_post_order.add(u_node_id)
        depth[u_node_id] = current_depth_level
        max_child_depth = current_depth_level
        
        # Children are sorted for deterministic layout, if desired
        children_of_u = sorted(adj.get(u_node_id, []))
        
        if not children_of_u:
            subtree_widths[u_node_id] = node_size
            return node_size, current_depth_level
            
        current_total_width = 0.0
        valid_children_count = 0
        for v_node_id in children_of_u:
            if v_node_id in visited_post_order: # Avoid cycles
                logging.warning(f"[Layout Build] Cycle detected or re-visiting node {v_node_id} from {u_node_id} in post_order_dfs. Skipping.")
                continue
            if v_node_id not in active_nodes: # Should not happen if adj list is built from active_nodes
                logging.warning(f"[Layout Build] Child {v_node_id} of {u_node_id} not in active_nodes list. Skipping.")
                continue

            child_width, child_max_d = post_order_dfs(v_node_id, current_depth_level + 1)
            current_total_width += child_width
            max_child_depth = max(max_child_depth, child_max_d)
            valid_children_count += 1
            
        if valid_children_count > 1:
            current_total_width += (valid_children_count - 1) * (HORIZONTAL_SPACING / 2.0) # Add spacing between subtrees
            
        final_width = max(node_size, current_total_width)
        subtree_widths[u_node_id] = final_width
        return final_width, max_child_depth

    overall_max_depth_calculated = 0
    total_graph_width_estimate = 0.0
    processed_roots_for_width: Set[str] = set()

    sorted_roots_list = sorted(list(set(roots))) # Ensure unique roots and sort for deterministic behavior
    logging.debug(f"[Layout Build] Final sorted roots for DFS: {sorted_roots_list}")

    for i, root_id in enumerate(sorted_roots_list):
        if root_id not in visited_post_order: # Process each tree/component once
            root_subtree_width, root_max_d = post_order_dfs(root_id, 0)
            overall_max_depth_calculated = max(overall_max_depth_calculated, root_max_d)
            total_graph_width_estimate += root_subtree_width
            if i > 0: # Add spacing between separate trees
                total_graph_width_estimate += HORIZONTAL_SPACING
            processed_roots_for_width.add(root_id)

    # Handle any nodes missed if roots were not exhaustive (e.g., disconnected
    # components not in original roots)
    for node_id_check in active_nodes:
        if node_id_check not in visited_post_order:
            logging.warning(f"[Layout Build] Node {node_id_check} was not visited in post_order_dfs from main roots. Processing as a new component.")
            comp_width, comp_max_d = post_order_dfs(node_id_check, 0) # Start its depth from 0
            overall_max_depth_calculated = max(overall_max_depth_calculated, comp_max_d)
            total_graph_width_estimate += comp_width + HORIZONTAL_SPACING # Add its width and spacing

    pos: Dict[str, Tuple[float, float]] = {}
    visited_pre_order: Set[str] = set()
    
    def pre_order_dfs(u_node_id: str, current_x_center: float, current_y_level: float):
        if u_node_id in visited_pre_order:
            return
        visited_pre_order.add(u_node_id)
        
        # Y coordinate is based on depth, making roots at top (y=0 or near y=0)
        # Multiplying by VERTICAL_SPACING and negating for typical top-down tree drawing
        pos[u_node_id] = (current_x_center, -depth.get(u_node_id, 0) * VERTICAL_SPACING)
        
        children_of_u = sorted(adj.get(u_node_id, []))
        valid_children_for_pos = [v for v in children_of_u if v in active_nodes and v in subtree_widths and v not in visited_pre_order]

        if not valid_children_for_pos:
            return

        # Calculate total width needed for children subtrees to center them under parent
        total_children_width_span = sum(subtree_widths[v_id] for v_id in valid_children_for_pos)
        if len(valid_children_for_pos) > 1:
            total_children_width_span += (len(valid_children_for_pos) - 1) * (HORIZONTAL_SPACING / 2.0)
            
        child_start_x = current_x_center - (total_children_width_span / 2.0)
        
        for v_node_id in valid_children_for_pos:
            child_subtree_w = subtree_widths[v_node_id]
            # Center of the child's subtree block
            child_x_center = child_start_x + (child_subtree_w / 2.0)
            pre_order_dfs(v_node_id, child_x_center, current_y_level - VERTICAL_SPACING) # Y decreases for lower levels
            child_start_x += child_subtree_w + (HORIZONTAL_SPACING / 2.0) # Move to start of next child block

    current_root_x_offset = -total_graph_width_estimate / 2.0
    for root_id in sorted_roots_list:
        if root_id not in visited_pre_order:
            root_subtree_w = subtree_widths.get(root_id, node_size)
            root_center_x = current_root_x_offset + (root_subtree_w / 2.0)
            pre_order_dfs(root_id, root_center_x, 0) # Roots start at y-level 0 (or close to it)
            current_root_x_offset += root_subtree_w + HORIZONTAL_SPACING
            
    # Place any remaining unvisited nodes (should ideally not happen if graph is
    # connected and roots are correct)
    for node_id_check in active_nodes:
        if node_id_check not in visited_pre_order:
            logging.warning(f"[Layout Build] Node {node_id_check} was not visited in pre_order_dfs. Placing.")
            comp_width = subtree_widths.get(node_id_check, node_size)
            comp_center_x = current_root_x_offset + (comp_width / 2.0)
            # Determine y based on its calculated depth, or a fallback depth
            y_pos = -depth.get(node_id_check, overall_max_depth_calculated + 1) * VERTICAL_SPACING
            pre_order_dfs(node_id_check, comp_center_x, y_pos) # Recursive call to place its children too
            current_root_x_offset += comp_width + HORIZONTAL_SPACING

    if not pos and active_nodes:
        logging.error(f"[Layout Build] Failed to generate any positions for {len(active_nodes)} active_nodes. Roots were: {sorted_roots_list}")
    elif pos and len(pos) != len(active_nodes):
        logging.warning(f"[Layout Build] Generated positions for {len(pos)} nodes, but expected {len(active_nodes)}. Missing: {set(active_nodes) - set(pos.keys())}")
        
    return pos

def plot_drainage_snapshot_with_ids(ax: plt.Axes,
                                     forest: 'DrainageForest', # Unused in current drawing, but kept for signature
                                     snapshot: Tuple[float, Dict],
                                     inlet_id: str, # Main inlet ID for titling
                                     time_series: List[Dict], # Unused in current drawing
                                     global_cap_min: float,
                                     global_cap_max: float,
                                     original_inlet_ids: List[str]): # All selected inlets for highlighting paths
    """
    Plots a single drainage network snapshot using a hierarchical layout.
    """
    RADIUS_MIN = 0.25
    RADIUS_MAX = 0.50
    INLET_NODE_RADIUS = 0.50 # Slightly larger or distinct for inlet path nodes
    RING_THICKNESS_FACTOR = 0.2
    WHITE_GAP_FACTOR = 0.05 # Gap between fill circle and impervious/pervious ring

    fill_cmap = cm.Blues
    fill_norm = Normalize(vmin=0.0, vmax=100.0) # Normalize fill percentage 0-100

    time_sec, current_snapshot_state = snapshot # current_snapshot_state is 'st'
    time_min = time_sec / 60.0

    logging.debug(f"[Snapshot Plot {inlet_id} t={time_min:.2f}] Received state with keys: {current_snapshot_state.keys() if current_snapshot_state else 'None'}")
    if not current_snapshot_state:
        ax.text(0.5, 0.5, f"No active basins at t={time_min:.2f} min", ha='center', va='center', transform=ax.transAxes)
        ax.set_title(f"Drainage Network State - Inlet {inlet_id} (t={time_min:.2f} min) - No State Data")
        ax.axis('off')
        return

    # Internal Helper Functions for Plotting
    def is_on_selected_inlet_path(node_id_to_check: str, state_data_dict: Dict, all_orig_inlets: List[str]) -> bool:
        node_info = state_data_dict.get(node_id_to_check)
        if not node_info:
            return False
        merged_from_lineage = node_info.get('merged_from', [node_id_to_check]) # Default to self if not merged
        return any(str(orig_inlet) in merged_from_lineage for orig_inlet in all_orig_inlets)

    def capacity_to_radius(node_capacity_val: float) -> float:
        if node_capacity_val <= 0: return RADIUS_MIN
        # Use provided global_cap_min and global_cap_max for scaling
        cap_min_log = max(global_cap_min, 1e-9) # Avoid log(0)
        cap_max_log = max(global_cap_max, 1e-9)

        if cap_min_log >= cap_max_log: # If min equals max or invalid range
            return (RADIUS_MIN + RADIUS_MAX) * 0.5
        try:
            log_cap = math.log10(max(node_capacity_val, 1e-9))
            log_min_val = math.log10(cap_min_log)
            log_max_val = math.log10(cap_max_log)

            if abs(log_max_val - log_min_val) < 1e-9: # Avoid division by zero if range is tiny
                frac = 0.5
            else:
                frac = (log_cap - log_min_val) / (log_max_val - log_min_val)
            
            frac = max(0.0, min(1.0, frac)) # Clamp fraction
            return RADIUS_MIN + (RADIUS_MAX - RADIUS_MIN) * frac
        except ValueError: # Should be rare with max(cap, 1e-9)
            logging.warning(f"ValueError in capacity_to_radius for capacity={node_capacity_val}. Using RADIUS_MIN.")
            return RADIUS_MIN

    # Layout Calculation
    node_positions = None
    try:
        # Ensure HORIZONTAL_SPACING and VERTICAL_SPACING are globally defined
        node_positions = build_hierarchical_layout(current_snapshot_state, original_inlet_ids)
    except NameError as ne:
        if 'build_hierarchical_layout' in str(ne):
            logging.error(f"[Snapshot Plot] LAYOUT FAILED: 'build_hierarchical_layout' function not found!")
        elif 'HORIZONTAL_SPACING' in str(ne) or 'VERTICAL_SPACING' in str(ne):
            logging.error("[Snapshot Plot] LAYOUT FAILED: HORIZONTAL_SPACING or VERTICAL_SPACING global variables not defined.")
        else:
            logging.error(f"[Snapshot Plot] Layout failed due to NameError: {ne}")
        node_positions = None
    except Exception as layout_err:
        logging.error(f"[Snapshot Plot] Layout generation failed for t={time_min:.2f} min: {layout_err}", exc_info=True)
        node_positions = None

    if not node_positions: # Checks if pos is None or empty
        ax.text(0.5, 0.5, "Layout Error or No Nodes", ha='center', va='center', transform=ax.transAxes)
        ax.set_title(f"Drainage Network State - Inlet {inlet_id} (t={time_min:.2f} min) - LAYOUT FAILED")
        ax.axis('off')
        return

    # Plot Edges (Arrows from Parent to Child)
    edges_to_plot = []
    for node_id, node_data in current_snapshot_state.items():
        parent_id = node_data.get('parent_id')
        if parent_id and parent_id in node_positions and node_id in node_positions:
            # Ensure arrow goes from parent position to child position
            edges_to_plot.append((node_positions[parent_id], node_positions[node_id]))
    
    if edges_to_plot:
        for parent_pos, child_pos in edges_to_plot:
            arrow = FancyArrowPatch(parent_pos, child_pos, arrowstyle='-|>', mutation_scale=15, color='gray', lw=0.8, alpha=0.7, zorder=1)
            ax.add_patch(arrow)

    # Plot Nodes
    for node_id, node_data in current_snapshot_state.items():
        if node_id not in node_positions:
            logging.warning(f"[Snapshot Plot] Node {node_id} has no position data. Skipping drawing.")
            continue
            
        x_coord, y_coord = node_positions[node_id]
        is_primary_inlet_path_node = is_on_selected_inlet_path(node_id, current_snapshot_state, original_inlet_ids)
        
        node_cap = node_data.get('max_volume', 0.0)
        current_vol = node_data.get('current_volume', 0.0)
        
        # Determine radius based on capacity, with special handling for inlet path nodes
        # if desired
        outer_radius = INLET_NODE_RADIUS if is_primary_inlet_path_node else capacity_to_radius(node_cap)
        
        ring_actual_thickness = outer_radius * RING_THICKNESS_FACTOR
        white_gap_actual = outer_radius * WHITE_GAP_FACTOR
        fill_circle_radius = max(0.01, outer_radius - ring_actual_thickness - white_gap_actual)

        # Background for the ring (optional, or make it transparent if fill_circle is
        # smaller)
        # ring_background = Wedge((x_coord, y_coord), outer_radius, 0, 360,
        # width=ring_actual_thickness, facecolor='lightgrey', edgecolor='none',
        # alpha=0.3, zorder=2)
        # ax.add_patch(ring_background)

        # Impervious/Pervious Ring
        imperv_a = node_data.get('impervious_area', 0.0)
        perv_a = node_data.get('pervious_area', 0.0)
        total_effective_area = imperv_a + perv_a
        imperv_fraction = imperv_a / total_effective_area if total_effective_area > 1e-9 else 0.0
        imperv_angle_deg = 360.0 * imperv_fraction

        # Impervious part of the ring
        arc_impervious = Wedge((x_coord, y_coord), outer_radius, 0, imperv_angle_deg, 
                                 width=ring_actual_thickness, facecolor='red', edgecolor='darkred', lw=0.5, zorder=3)
        ax.add_patch(arc_impervious)
        # Pervious part of the ring
        arc_pervious = Wedge((x_coord, y_coord), outer_radius, imperv_angle_deg, 360.0, 
                             width=ring_actual_thickness, facecolor='green', edgecolor='darkgreen', lw=0.5, zorder=3)
        ax.add_patch(arc_pervious)
        
        # Central Fill Circle
        fill_percent_val = (current_vol / node_cap) * 100.0 if node_cap > 1e-9 else 0.0
        fill_percent_val = max(0.0, min(fill_percent_val, 100.0)) # Clamp
        
        fill_color_val = fill_cmap(fill_norm(fill_percent_val))
        center_circle = Circle((x_coord, y_coord), radius=fill_circle_radius, 
                               facecolor=fill_color_val, edgecolor='darkgray', linewidth=0.7, zorder=4)
        ax.add_patch(center_circle)

        # Basin ID Label
        display_id = f"ID: {node_id}"
        max_id_display_len = 15
        if len(node_id) > max_id_display_len:
            display_id = f"ID: ..{node_id[-(max_id_display_len-3):]}"
        id_label_bbox = dict(boxstyle="round,pad=0.15", fc="ivory", ec="black", lw=0.5, alpha=0.85)
        ax.text(x_coord, y_coord, display_id, ha='center', va='center', fontsize=7, fontweight='normal', color='black', bbox=id_label_bbox, zorder=6)

        # Stats Text Label below node
        node_total_area = node_data.get('area', 0.0)
        stats_text_content = (f"A:{node_total_area:.1f} Vmx:{node_cap:.2f} Vol:{current_vol:.2f}\n"
                              f"Imp:{imperv_fraction*100:.0f}% Fill:{fill_percent_val:.1f}%")
        if is_primary_inlet_path_node: stats_text_content += "\n(INLET PATH)"
        stats_label_bbox = dict(boxstyle="round,pad=0.2", fc="white", ec="grey", lw=0.5, alpha=0.75)
        ax.text(x_coord, y_coord - outer_radius - 0.15, stats_text_content, ha='center', va='top', fontsize=6, color='black', bbox=stats_label_bbox, zorder=5)

    # Legend, Title, Axis Limits, Colorbar
    legend_handles = [
        mpatches.Patch(color='red', label='Impervious Area (Ring %)'),
        mpatches.Patch(color='green', label='Pervious Area (Ring %)'),
        mpatches.Patch(color=fill_cmap(0.5), label='Fill % (Center Circle)', ec='darkgray'),
        mpatches.Patch(facecolor='ivory', edgecolor='black', label='Basin ID Label')
    ]
    ax.legend(handles=legend_handles, loc='upper left', fontsize=8, bbox_to_anchor=(0.01, 0.99), frameon=True, facecolor='white', edgecolor='black').set_zorder(10)
    
    ax.set_title(f"Drainage Network State - Inlet {inlet_id} (t={time_min:.2f} min)", fontsize=14)
    ax.axis('off') # Turn off axis lines and ticks
    
    # Auto-adjust plot limits
    if node_positions: # Ensure node_positions is not empty
        all_x_coords = [p[0] for p in node_positions.values()]
        all_y_coords = [p[1] for p in node_positions.values()]
        if all_x_coords and all_y_coords: # Check lists are not empty
            # Use HORIZONTAL_SPACING and VERTICAL_SPACING for padding, if available
            # globally
            padding_x = HORIZONTAL_SPACING * 0.75 if 'HORIZONTAL_SPACING' in globals() else 1.0
            padding_y = VERTICAL_SPACING * 0.75 if 'VERTICAL_SPACING' in globals() else 1.0
            ax.set_xlim(min(all_x_coords) - padding_x, max(all_x_coords) + padding_x)
            ax.set_ylim(min(all_y_coords) - padding_y, max(all_y_coords) + padding_y)
    
    ax.set_aspect('equal', adjustable='box') # Ensure aspect ratio is equal

    # Colorbar for fill percentage
    scalar_mappable = cm.ScalarMappable(norm=fill_norm, cmap=fill_cmap)
    scalar_mappable.set_array([]) # Needed for colorbar to work with patches
    plot_divider = make_axes_locatable(ax)
    colorbar_axis = plot_divider.append_axes("right", size="3%", pad=0.1)
    colorbar = plt.colorbar(scalar_mappable, cax=colorbar_axis)
    colorbar.set_label("Fill Percentage (%)", fontsize=10)
    colorbar.ax.tick_params(labelsize=8)

def plot_simulation_results(time_series: List[dict],
                            outlet_runoff_data: Optional[Dict],
                            primary_inlet_runoff_data: Optional[Dict],
                            rain_unit: str,
                            simulation_time_min: float,
                            selected_inlets: List[str],
                            recalculated_eia_df: Optional[pd.DataFrame] = None,
                            export_discharge_csv: bool = True,
                            export_interval_sec: int = 10,
                           ):
    """
    Plots simulation results.
    MODIFIED: Uses a recalculated DataFrame for Total, Impervious, and Pervious Area
    based on arrival time, if provided.
    """
    if not time_series:
        logging.warning("Time series data is empty. Cannot plot simulation results.")
        return
    logging.info("Plotting simulation results (Area, Rainfall, Discharge)...")

    base_fontsize = 10; fontsize_title = base_fontsize + 3; fontsize_subtitle = base_fontsize + 2
    fontsize_label = base_fontsize + 1; fontsize_legend = base_fontsize; fontsize_tick = base_fontsize
    colors = { 'Total Area': 'black', 'Impervious Area': '#e31a1c', 'Pervious Area': '#2ca02c',
               'Rainfall': '#a6cee3', 'Outlet Discharge': '#1f77b4' }
    fig, axes = plt.subplots(3, 1, figsize=(9, 10), sharex=True, dpi = 450)
    legend_props = {'loc': 'upper right', 'fontsize': fontsize_legend, 'frameon': True, 'facecolor': 'white', 'edgecolor': 'black'}

    t = np.array([d['time'] for d in time_series])

    # Plot Areas (ax1)
    ax1 = axes[0]
    if recalculated_eia_df is not None and not recalculated_eia_df.empty:
        logging.info("Using recalculated areas (by arrival time) for plotting.")
        # Use the new, corrected data for all three area types
        ax1.step(recalculated_eia_df['time'], recalculated_eia_df['total_area_by_arrival'], color=colors['Total Area'], linestyle='-', label='Total Area', where='post')
        ax1.step(recalculated_eia_df['time'], recalculated_eia_df['impervious_area_by_arrival'], color=colors['Impervious Area'], linestyle='-', label='Impervious Area', where='post')
        ax1.step(recalculated_eia_df['time'], recalculated_eia_df['pervious_area_by_arrival'], color=colors['Pervious Area'], linestyle='-', label='Pervious Area', where='post')
        ax1.set_title("Outlet Basin Connected Area (by Arrival Time)", fontsize=fontsize_subtitle)
    else:
        # Fallback to the old method if new data isn't provided or is empty
        logging.info("Recalculated area data not available, using spill-time connected area for plotting.")
        tot_area=np.array([d.get('inlet_total_area',0) for d in time_series])
        imp_area=np.array([d.get('inlet_impervious_area',0) for d in time_series])
        perv_area=np.array([d.get('inlet_pervious_area',0) for d in time_series])
        ax1.step(t, tot_area, color=colors['Total Area'], linestyle='-', label='Total Area', where='post')
        ax1.step(t, imp_area, color=colors['Impervious Area'], linestyle='-', label='Impervious Area', where='post')
        ax1.step(t, perv_area, color=colors['Pervious Area'], linestyle='-', label='Pervious Area', where='post')
        ax1.set_title("Outlet Basin Area Composition (by Spill Time)", fontsize=fontsize_subtitle)

    ax1.set_ylabel("Area (m²)", fontsize=fontsize_label)
    ax1.legend(**legend_props).get_frame().set_linewidth(0.8); ax1.grid(False); ax1.set_ylim(bottom=0)
    ax1.tick_params(axis='both', labelsize=fontsize_tick); ax1.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.0f'))

    # Plot Rainfall (ax2)
    ax2 = axes[1]; rain_timeline_min=None; rain_rates=None; ylabel_rain=f"Rain Rate ({rain_unit})"
    if primary_inlet_runoff_data:
        input_timeline=primary_inlet_runoff_data.get('timeline_min'); rain_depth_step=primary_inlet_runoff_data.get('rain_depth_m_step'); dt_sec=primary_inlet_runoff_data.get('dt_sec')
        if input_timeline is not None and rain_depth_step is not None and dt_sec is not None and dt_sec > 1e-6:
            input_timeline=np.array(input_timeline); rain_depth_step=np.array(rain_depth_step); min_len_rain=min(len(input_timeline),len(rain_depth_step))
            if min_len_rain > 0:
                rain_timeline_min=input_timeline[:min_len_rain]; rain_m_s=rain_depth_step[:min_len_rain]/dt_sec
                if rain_unit=='cm/hr': rain_rates=(rain_m_s*3600.0)/0.01
                elif rain_unit=='mm/hr': rain_rates=(rain_m_s*3600.0)/0.001
                elif rain_unit=='in/hr': rain_rates=(rain_m_s*3600.0)/0.0254
                else: rain_rates=rain_m_s; ylabel_rain="Rain Rate (m/s)"
    if rain_timeline_min is not None and rain_rates is not None:
        plot_mask = rain_timeline_min <= simulation_time_min + 1e-9
        ax2.plot(rain_timeline_min[plot_mask], rain_rates[plot_mask], color=colors['Rainfall'], drawstyle='steps-post')
    else: ax2.text(0.5, 0.5, "Rainfall data not available", transform=ax2.transAxes, ha='center', fontsize=fontsize_tick)
    ax2.set_ylabel(ylabel_rain, fontsize=fontsize_label); ax2.set_title("Rainfall Input", fontsize=fontsize_subtitle)
    ax2.grid(False); ax2.set_ylim(bottom=0); ax2.tick_params(axis='both', labelsize=fontsize_tick); ax2.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))

    # Plot Hydrographs (ax3)
    ax3 = axes[2]
    system_discharge_m3_s = np.array([d.get('inlet_volume_rate', 0.0) for d in time_series]); system_discharge_m3_min = system_discharge_m3_s * 60.0
    ax3.plot(t, system_discharge_m3_min, color=colors['Outlet Discharge'], linewidth=2, label='Outlet Discharge', drawstyle='steps-post')
    ax3.set_xlabel("Time (min)", fontsize=fontsize_label); ax3.set_ylabel("Flow Rate (m³/min)", fontsize=fontsize_label)
    ax3.set_title("Outlet Discharge Hydrograph", fontsize=fontsize_subtitle)
    ax3.grid(False); ax3.set_ylim(bottom=0); ax3.tick_params(axis='both', labelsize=fontsize_tick); ax3.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

    # Final Formatting
    ax3.set_xlim(0, simulation_time_min)
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    plt.show()

    # Export System Discharge to CSV
    if export_discharge_csv:
        try:
            primary_inlet_id = str(selected_inlets[0]) if selected_inlets else "unknown_inlet"
            dynamic_export_filename = f"outlet_discharge_{primary_inlet_id}.csv"

            logging.info(f"Exporting outlet discharge to '{dynamic_export_filename}' (interval: {export_interval_sec} sec)...")
            if not time_series or len(t) == 0: logging.warning("Cannot export discharge: time_series empty."); return

            original_times_sec = t * 60.0; original_flow_m3s = system_discharge_m3_s
            if original_times_sec.size == 0: logging.warning("Cannot export discharge: Time array empty."); return

            max_time_sec = simulation_time_min * 60.0
            export_times_sec = np.arange(0, max_time_sec + export_interval_sec, export_interval_sec)
            max_data_time_sec = original_times_sec[-1] if original_times_sec.size > 0 else 0
            export_times_sec = export_times_sec[export_times_sec <= max_data_time_sec + 1e-6]
            if export_times_sec.size == 0: logging.warning("Cannot export discharge: Export time array empty."); return

            if 'safe_interp' not in globals(): raise NameError("'safe_interp' function is needed for CSV export but not found.")
            interp_flow_m3s = safe_interp(export_times_sec, original_times_sec, original_flow_m3s, label="DischargeExport")

            df_export = pd.DataFrame({'Time (s)': export_times_sec, 'Outlet Discharge (m3/s)': interp_flow_m3s})
            export_path = os.path.abspath(dynamic_export_filename)
            df_export.to_csv(export_path, index=False, float_format='%.6f')
            logging.info(f"Successfully exported outlet discharge to {export_path}")

        except Exception as e: logging.error(f"Failed to export discharge data to CSV: {e}", exc_info=True)

def create_corrected_basin_contributions_plot(snapshots, simulator: 'IntegratedSimulation', runoff_processor: 'RunoffProcessor') -> dict:
    logging.info("Creating corrected basin contributions plot...")
    if not simulator or not runoff_processor or not snapshots or not simulator.time_series:
        logging.error("Missing required data (simulator, processor, snapshots, or time_series) for contributions plot.")
        return {}
    inlet_ids = simulator.selected_inlet_ids
    if not inlet_ids:
        logging.error("No inlet IDs found in simulator.")
        return {}
    primary_inlet_id = inlet_ids[0]
    reference_timeline = np.array([d['time'] for d in simulator.time_series])
    if len(reference_timeline) == 0:
        logging.error("Reference timeline from simulator is empty.")
        return {}

    all_basin_ids_in_data = list(runoff_processor.basin_runoff_data.keys())
    if not all_basin_ids_in_data:
        logging.warning("Runoff processor contains no basin data. Cannot plot contributions.")
        return {}
    distinct_colors = plt.cm.viridis(np.linspace(0, 1, len(all_basin_ids_in_data)))
    color_map = {bid: distinct_colors[i % len(distinct_colors)] for i, bid in enumerate(all_basin_ids_in_data)}

    final_time, final_state = snapshots[-1]
    final_outlet_ids = []
    if final_state:
        for basin_id, state_data in final_state.items():
            if any(orig_inlet in state_data.get('merged_from', []) for orig_inlet in inlet_ids):
                final_outlet_ids.append(basin_id)
    else:
        logging.warning("Final snapshot state is empty. Attempting fallback to find outlet.")
        final_outlet_ids = [bid for bid in inlet_ids if bid in runoff_processor.basin_runoff_data]
    if not final_outlet_ids:
        logging.error("Could not identify final outlet node(s) containing original inlets. Cannot trace contributions.")
        return {}
    logging.info(f"Final outlet node(s) identified for contribution tracing: {final_outlet_ids}")

    processed_for_contrib = set(); contribution_details = {}; queue = list(final_outlet_ids); traced_ids = set()
    while queue:
        current_id = queue.pop(0)
        if current_id in processed_for_contrib: continue
        processed_for_contrib.add(current_id)
        if current_id not in runoff_processor.basin_runoff_data:
            logging.warning(f"Runoff data missing for basin {current_id} during contribution trace.")
            continue
        data = runoff_processor.basin_runoff_data[current_id]
        timeline = data.get('convolution_timeline_min'); flow_m3_s = data.get('direct_flow_m3_s'); merge_time = data.get('merge_time_min', 0.0)
        if timeline is None or flow_m3_s is None or len(timeline) == 0:
            logging.warning(f"Missing timeline/flow for basin {current_id}. Skipping contribution.")
            continue
        min_len = min(len(timeline), len(flow_m3_s)); timeline = timeline[:min_len]; flow_m3_s = flow_m3_s[:min_len]
        interp_flow_m3_min = np.interp(reference_timeline, timeline, flow_m3_s * 60.0, left=0, right=0)
        label = f"Basin {current_id}"; max_label_len = 40
        if current_id in inlet_ids: label = f"Inlet {current_id}"
        if len(label) > max_label_len: label = label[:max_label_len-3] + "..."
        if merge_time > 0: label += f" (M @ {merge_time:.1f}m)"
        contribution_details[current_id] = {'flow': interp_flow_m3_min, 'label': label, 'color': color_map.get(current_id, 'gray'), 'merge_time': merge_time}
        if 'merged_from_ids' in data:
            for component_id in data['merged_from_ids']:
                if component_id not in processed_for_contrib and component_id not in traced_ids:
                    queue.append(component_id); traced_ids.add(component_id)
    logging.info(f"Identified {len(contribution_details)} contributing basins/hydrographs for plot.")
    if not contribution_details: return {}

    fig, ax = plt.subplots(figsize=(14, 8))
    sorted_basin_ids = sorted(contribution_details.keys(), key=lambda bid: contribution_details[bid]['merge_time'])
    stack_flows = [contribution_details[bid]['flow'] for bid in sorted_basin_ids]
    stack_labels = [contribution_details[bid]['label'] for bid in sorted_basin_ids]
    stack_colors = [contribution_details[bid]['color'] for bid in sorted_basin_ids]
    valid_indices = [i for i, flow in enumerate(stack_flows) if np.max(flow) > 1e-6]
    if not valid_indices:
        logging.warning("No significant contributing flows found to plot.")
        plt.close(fig); return {}
    stack_flows = [stack_flows[i] for i in valid_indices]; stack_labels = [stack_labels[i] for i in valid_indices]; stack_colors = [stack_colors[i] for i in valid_indices]
    try:
        ax.stackplot(reference_timeline, stack_flows, labels=stack_labels, colors=stack_colors, alpha=0.7, step='post')
    except Exception as stack_err:
        logging.error(f"Error during stackplot generation: {stack_err}")
        plt.close(fig); return {}
    total_system_response_m3_min = np.array([d['inlet_volume_rate'] * 60 for d in simulator.time_series])
    ax.plot(reference_timeline, total_system_response_m3_min, 'k-', linewidth=2.5, label='Total System Discharge', drawstyle='steps-post')
    ax.set_xlabel('Time (min)', fontsize=12); ax.set_ylabel('Flow Rate (m³/min)', fontsize=12)
    ax.set_title(f"Basin Contributions to Discharge Hydrograph (Outlet Tree: {primary_inlet_id})", fontsize=14)
    ax.set_xlim(0, reference_timeline[-1] if len(reference_timeline)>0 else 1)
    max_y_stack = np.max(np.sum(stack_flows, axis=0)) if stack_flows else 0
    max_y_total = np.max(total_system_response_m3_min) if len(total_system_response_m3_min)>0 else 0
    max_y = max(max_y_stack, max_y_total) * 1.1; ax.set_ylim(0, max(max_y, 0.1))
    num_labels = len(stack_labels); legend_cols = max(1, num_labels // 15 + (1 if num_labels % 15 > 0 else 0))
    ax.legend(loc='upper right', fontsize=8, ncol=legend_cols); ax.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout(); figname = "basin_contributions_plot.png"; plt.savefig(figname); plt.close(fig)
    logging.info(f"Basin contributions plot saved as {figname}")
    return contribution_details

def plot_timestep_water_balance(
    water_balance_log: List[dict],
    runon_time_series_data: List[Dict],
    simulation_time_min_param: float,
    figsize=(8, 8),
    dpi=450
):
    """
    Plots key water balance components PER 1-MINUTE TIMESTEP as LINES (NON-CUMULATIVE).
    "Surface Runoff Transfer" and "Run-on Infiltration" are combined into a single
    orange line: "Surface Runoff Loss (Transfer & Run-on Infiltration)".
    Total Rainfall line is EXCLUDED.
    Y-axis starts at 0 unless data is negative.
    """
    if not water_balance_log:
        logging.warning("plot_timestep_water_balance (1-min lines): Water balance data empty; skipping plot.")
        return None, None

    logging.info("Plotting NON-CUMULATIVE 1-min aggregated water balance components as LINES...")

    original_times_log = np.array([d['time'] for d in water_balance_log])
    if original_times_log.size == 0:
        logging.error("plot_timestep_water_balance (1-min lines): original_times_log is empty.")
        return None, None

    dt_approx_for_label = 1.0

    # 1. Create a new 1-minute timeline for interpolation/aggregation
    aggregation_timeline_max = simulation_time_min_param
    bin_edges = np.arange(0, aggregation_timeline_max + 1, 1.0)
    plot_timeline_1min = bin_edges[:-1] if len(bin_edges) > 1 else np.array([0.0])

    if plot_timeline_1min.size == 0 and aggregation_timeline_max > 0:
        logging.warning(f"plot_timestep_water_balance (1-min lines): plot_timeline_1min is unexpectedly empty for duration {aggregation_timeline_max}. Defaulting to [0.0].")
        plot_timeline_1min = np.array([0.0])
    elif plot_timeline_1min.size == 0 and aggregation_timeline_max == 0:
        logging.info("plot_timestep_water_balance: simulation duration is 0, plot will be empty.")
        fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.set_title("Water Balance Components per 1-min Timestep (No Data)")
        return fig, ax

    def _aggregate_per_1min_step(key: str, fill_val=0.0) -> np.ndarray:
        raw_values = np.array([float(d.get(key, fill_val)) if isinstance(d.get(key, 0.0), (int, float, np.number)) and np.isfinite(d.get(key, 0.0)) else fill_val for d in water_balance_log])
        current_original_times = original_times_log
        current_raw_values = raw_values

        if current_original_times.size == 0:
            return np.full(len(plot_timeline_1min), fill_val, dtype=float)

        if abs(current_original_times[0]) > 1e-6:
            current_original_times = np.insert(current_original_times, 0, 0.0)
            current_raw_values = np.insert(current_raw_values, 0, 0.0)

        aggregated_values = np.zeros(len(plot_timeline_1min), dtype=float)
        
        if current_original_times.size < 2:
            if current_original_times.size == 1:
                bin_index = np.searchsorted(bin_edges, current_original_times[0], side='right') - 1
                if 0 <= bin_index < len(aggregated_values):
                    aggregated_values[bin_index] = current_raw_values[0]
            return aggregated_values

        cumulative_orig = np.cumsum(current_raw_values)
        interp_cum_at_bin_edges = safe_interp(bin_edges, current_original_times, cumulative_orig, default_val=0.0, label=f"{key}_agg_interp")
        per_bin_values = np.diff(interp_cum_at_bin_edges)

        if len(per_bin_values) == len(plot_timeline_1min):
            aggregated_values = per_bin_values
        elif len(per_bin_values) > len(plot_timeline_1min):
            aggregated_values = per_bin_values[:len(plot_timeline_1min)]
            logging.warning(f"Aggregation for key '{key}' (diff method) resulted in more values than bins. Truncating.")
        elif len(per_bin_values) < len(plot_timeline_1min) and len(per_bin_values) > 0:
            logging.warning(f"Aggregation for key '{key}' (diff method) resulted in fewer values than bins. Padding with last.")
            aggregated_values[:len(per_bin_values)] = per_bin_values
            aggregated_values[len(per_bin_values):] = per_bin_values[-1]
        elif len(per_bin_values) == 0 and len(plot_timeline_1min) > 0:
            logging.warning(f"Aggregation for key '{key}' (diff method) resulted in zero values for non-zero bins. Using zeros.")
        return aggregated_values

    step_discharge_1min = _aggregate_per_1min_step('inlet_discharge')
    step_storage_change_1min = _aggregate_per_1min_step('storage_change')
    step_active_runoff_generated_1min = _aggregate_per_1min_step('active_runoff_generated')

    step_runon_inf_calculated_1min_agg = np.zeros_like(plot_timeline_1min, dtype=float)
    if runon_time_series_data:
        for entry_idx, entry in enumerate(runon_time_series_data):
            event_times = np.array(entry.get('times', []), dtype=float)
            event_vols_step = np.array(entry.get('runon_inf_vol_step', []), dtype=float)

            if event_times.size > 0 and event_times.size == event_vols_step.size:
                current_event_times = event_times
                current_event_vols_step = event_vols_step
                if abs(current_event_times[0]) > 1e-6:
                    current_event_times = np.insert(current_event_times, 0, 0.0)
                    current_event_vols_step = np.insert(current_event_vols_step, 0, 0.0)
                if len(current_event_times) < 2: continue

                cumulative_event_runon_orig_time = np.cumsum(current_event_vols_step)
                interp_cum_event_runon_at_bin_edges = safe_interp(bin_edges, current_event_times, cumulative_event_runon_orig_time)
                per_1min_bin_event_runon = np.diff(interp_cum_event_runon_at_bin_edges)
                
                target_len = len(step_runon_inf_calculated_1min_agg)
                if len(per_1min_bin_event_runon) == target_len:
                    step_runon_inf_calculated_1min_agg += per_1min_bin_event_runon
                elif len(per_1min_bin_event_runon) > target_len:
                    step_runon_inf_calculated_1min_agg += per_1min_bin_event_runon[:target_len]
                elif len(per_1min_bin_event_runon) < target_len and len(per_1min_bin_event_runon) > 0:
                    padded_runon = np.zeros(target_len)
                    padded_runon[:len(per_1min_bin_event_runon)] = per_1min_bin_event_runon
                    step_runon_inf_calculated_1min_agg += padded_runon
                else:
                    logging.warning(f"Runon series entry {entry_idx} for 1-min plot has mismatched times/vols or is empty.")

    positive_storage_change_step_1min = np.maximum(0, step_storage_change_1min)
    combined_loss_per_1min_step = step_active_runoff_generated_1min - positive_storage_change_step_1min - step_discharge_1min
    combined_loss_per_1min_step = np.maximum(0, combined_loss_per_1min_step)

    fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
    fontsize_title = 14; fontsize_labels = 12; fontsize_legend = 10
    dt_plot_label = 1.0

    colors_timestep_lines = {
        'Active Runoff Generated': '#5DBB63',
        'Surface Storage Change': '#7f7f7f',
        'Surface Runoff Loss (Transfer & Run-on Infiltration)': '#ff7f0e',
        'Outlet Discharge': '#1f77b4',
    }
    labels_timestep_lines = {
        'Active Runoff Generated': f'Active Runoff Gen.',
        'Surface Storage Change': f'Surface Storage Change',
        'Surface Runoff Loss (Transfer & Run-on Infiltration)': f'Runoff Loss (Transfer & Run-on)',
        'Outlet Discharge': f'Outlet Discharge',
    }
    line_linewidth = 1.8; ref_linewidth = 1.5

    num_plot_points = len(plot_timeline_1min)

    ax.plot(plot_timeline_1min, step_storage_change_1min[:num_plot_points],
            color=colors_timestep_lines['Surface Storage Change'], linewidth=line_linewidth,
            label=labels_timestep_lines['Surface Storage Change'], drawstyle='steps-post')
    ax.plot(plot_timeline_1min, combined_loss_per_1min_step[:num_plot_points],
            color=colors_timestep_lines['Surface Runoff Loss (Transfer & Run-on Infiltration)'], linewidth=line_linewidth,
            label=labels_timestep_lines['Surface Runoff Loss (Transfer & Run-on Infiltration)'], drawstyle='steps-post')
    ax.plot(plot_timeline_1min, step_discharge_1min[:num_plot_points],
            color=colors_timestep_lines['Outlet Discharge'], linewidth=line_linewidth,
            label=labels_timestep_lines['Outlet Discharge'], drawstyle='steps-post')
    ax.plot(plot_timeline_1min, step_active_runoff_generated_1min[:num_plot_points],
            label=labels_timestep_lines['Active Runoff Generated'],
            color=colors_timestep_lines['Active Runoff Generated'], linewidth=ref_linewidth, linestyle='--', drawstyle='steps-post', alpha=0.8)

    ax.set_xlabel("Time (min)", fontsize=fontsize_labels)
    ax.set_ylabel(f"Volume per 1-min Timestep (m³)", fontsize=fontsize_labels)
    ax.set_title("Water Balance Components per 1-min Timestep (Lines, Non-Cumulative)", fontsize=fontsize_title)
    ax.legend(loc='upper right', fontsize=fontsize_legend, frameon=True, facecolor='white', edgecolor='black', ncol=1)
    leg = ax.get_legend()
    if leg: leg.get_frame().set_linewidth(0.8)
    
    ax.grid(False)
    
    ax.tick_params(axis='both', which='major', labelsize=fontsize_labels-1)
    ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

    all_lines_for_ylim_agg = [
        step_storage_change_1min[:num_plot_points],
        combined_loss_per_1min_step[:num_plot_points],
        step_discharge_1min[:num_plot_points],
        step_active_runoff_generated_1min[:num_plot_points]
    ]
    min_y_overall = 0
    max_y_overall = 0.001

    for arr in all_lines_for_ylim_agg:
        if arr is not None and arr.size > 0:
            finite_arr = arr[np.isfinite(arr)]
            if finite_arr.size > 0:
                min_y_overall = min(min_y_overall, np.min(finite_arr))
                max_y_overall = max(max_y_overall, np.max(finite_arr))

    bottom_y_limit = min(0, min_y_overall - 0.05 * abs(min_y_overall) if min_y_overall < -1e-6 else -0.0001)
    top_y_limit = max_y_overall * 1.1 if max_y_overall > 1e-6 else 0.01
    ax.set_ylim(bottom_y_limit, top_y_limit)
    
    actual_plot_end_time = min(plot_timeline_1min[-1] + 0.5, 250.0) if plot_timeline_1min.size > 0 else 40.0
    ax.set_xlim(0, actual_plot_end_time)
    
    plt.axhline(0, color='black', linewidth=0.7)
    plt.tight_layout()
    return fig, ax

def plot_mass_balance(water_balance: List[dict],
                      runon_time_series_data: List[Dict],
                      total_potential_runoff_input: Optional[float] = None):
    """
    Plots key cumulative water balance components as a stacked area chart.
    MODIFIED: Combines Surface Runoff Transfer and Run-on Infiltration.
    Applies specific formatting, labels, colors, and stacking order.

    Stacking Order (Bottom to Top):
    1. Surface Depression Storage (Dark Gray)
    2. Surface Runoff Transfer & Run-on Infiltration (Orange)
    3. Outlet Discharge (Blue)
    """
    # Initial Checks
    if not water_balance:
        logging.warning("plot_mass_balance: Water balance data empty; skipping plot.")
        return
    if water_balance:
        required_keys = ['time', 'active_runoff_generated', 'inlet_discharge', 'total_storage']
        missing_keys = [k for k in required_keys if k not in water_balance[0]]
        if missing_keys:
            logging.error(f"plot_mass_balance: Log missing required keys: {missing_keys}. Cannot proceed.")
            return

    # 1) Calculate Cumulative Components
    times_log = np.array([d['time'] for d in water_balance])
    if times_log.size == 0:
        logging.error("plot_mass_balance: times_log is empty.")
        return
    prepend_zeros = False
    if times_log.size > 0 and times_log[0] != 0:
        times_log = np.insert(times_log, 0, 0.0); prepend_zeros = True

    # Internal helper functions (with length adjustment)
    def _get_cumulative(key: str) -> np.ndarray:
        values = [d.get(key, 0.0) for d in water_balance]
        numeric_values = [float(v) if isinstance(v, (int, float)) else 0.0 for v in values]
        data = np.cumsum(numeric_values)
        if prepend_zeros: data = np.insert(data, 0, 0.0)
        if len(data) != len(times_log):
            logging.warning(f"_get_cumulative: Len mismatch key '{key}'. Adjusting."); target_len=len(times_log); current_len=len(data)
            if current_len > target_len: data = data[:target_len]
            elif current_len < target_len: data = np.pad(data, (0, target_len - current_len), 'edge')
        return data

    def _get_timeseries(key: str) -> np.ndarray:
        values = [d.get(key, 0.0) for d in water_balance]
        numeric_values = [float(v) if isinstance(v, (int, float)) else 0.0 for v in values]
        data = np.array(numeric_values)
        if prepend_zeros: data = np.insert(data, 0, 0.0)
        if len(data) != len(times_log):
            logging.warning(f"_get_timeseries: Len mismatch key '{key}'. Adjusting."); target_len=len(times_log); current_len=len(data)
            if current_len > target_len: data = data[:target_len]
            elif current_len < target_len: data = np.pad(data, (0, target_len - current_len), 'edge')
        return data

    # Get Data Components
    cum_runoff_active = _get_cumulative('active_runoff_generated')
    cum_discharge_log = _get_cumulative('inlet_discharge')
    storage_timeseries = _get_timeseries('total_storage')

    # Calculate Recalculated Run-on
    logging.info("plot_mass_balance: Calculating cumulative run-on infiltration from detailed time series...")
    total_runon_step_interp = np.zeros_like(times_log, dtype=float)
    if runon_time_series_data:
        for i,entry in enumerate(runon_time_series_data):
            merged_id=entry.get('merged_id',f'Unknown_{i}'); event_times=np.array(entry.get('times',[]),dtype=float); event_vols=np.array(entry.get('runon_inf_vol_step',[]),dtype=float)
            if event_times.size > 0 and event_times.size==event_vols.size:
                sort_indices=np.argsort(event_times); times_sorted=event_times[sort_indices]; vols_sorted=event_vols[sort_indices]
                unique_indices=np.unique(times_sorted,return_index=True)[1]
                if len(unique_indices)>=2:
                    times_unique=times_sorted[unique_indices]; vols_unique=vols_sorted[unique_indices]
                    interp_vols=np.interp(times_log, times_unique, vols_unique, left=0.0, right=0.0); total_runon_step_interp += interp_vols
            else: logging.warning(f"Invalid data for run-on recalc '{merged_id}'.")
    else: logging.warning("No runon_time_series_data provided.")

    cum_runon_inf_recalculated = np.cumsum(total_runon_step_interp) # Recalculated Runon

    # Prepare Plot Data
    total_input_vol_summary = total_potential_runoff_input if total_potential_runoff_input is not None else float(cum_runoff_active[-1]) if len(cum_runoff_active) > 0 else 0.0
    final_storage = storage_timeseries[-1] if len(storage_timeseries) > 0 else 0.0
    total_runon_inf_final = cum_runon_inf_recalculated[-1] if len(cum_runon_inf_recalculated) > 0 else 0.0
    total_discharge_final = cum_discharge_log[-1] if len(cum_discharge_log) > 0 else 0.0

    # Scale potential runoff input line shape
    final_active_runoff = cum_runoff_active[-1] if len(cum_runoff_active) > 0 else 0.0
    scale_factor_input = (total_input_vol_summary / final_active_runoff) if final_active_runoff > 1e-9 else 1.0
    cum_potential_runoff_plot = cum_runoff_active * scale_factor_input

    # Assign plot variables
    storage_plot = storage_timeseries
    runon_inf_plot = cum_runon_inf_recalculated # This is Cumulative Run-on Infiltration
    discharge_plot = cum_discharge_log

    # Calculate Residual ("Surface Runoff Transfer")
    min_len_resid = min(len(cum_potential_runoff_plot), len(storage_plot), len(runon_inf_plot), len(discharge_plot), len(times_log))
    residual_calc = (cum_potential_runoff_plot[:min_len_resid] -
                     storage_plot[:min_len_resid] -
                     runon_inf_plot[:min_len_resid] -
                     discharge_plot[:min_len_resid])
    residual_calc = np.maximum(residual_calc, 0)
    final_residual_val = residual_calc[-1] if len(residual_calc) > 0 else 0.0
    residual_plot_padded = np.zeros_like(times_log)
    residual_plot_padded[:min_len_resid] = residual_calc
    if min_len_resid > 0 and min_len_resid < len(times_log): residual_plot_padded[min_len_resid:] = residual_calc[-1]

    # MODIFICATION START: Combine Residual and Run-on Infiltration
    combined_transfer_runon_plot = residual_plot_padded + runon_inf_plot
    final_combined_transfer_runon_val = final_residual_val + total_runon_inf_final
    # Modification End

    # Stacking Bases (Updated Order)
    y0_base = np.zeros_like(times_log)
    y1_storage_top = storage_plot
    y2_combined_top = y1_storage_top + combined_transfer_runon_plot
    y3_discharge_top = y2_combined_top + discharge_plot

    # Plot
    fig, ax = plt.subplots(figsize=(9, 9), dpi=450)
    fontsize_title = 16
    fontsize_labels = 14
    fontsize_legend = 11

    colors_mass_balance = {
        'Surface Depression Storage': '#7f7f7f',
        'Surface Runoff Transfer & Run-on Infiltration': '#ff7f0e',
        'Outlet Discharge': '#1f77b4',
        'Potential Input': 'black'
    }
    labels_mass_balance = {
        'Surface Depression Storage': 'Surface Depression Storage',
        'Surface Runoff Transfer & Run-on Infiltration': 'Surface Runoff Transfer & Run-on Infiltration',
        'Outlet Discharge': 'Outlet Discharge',
        'Potential Input': 'Cumulative Potential Runoff Input'
    }

    plot_arrays_check = [y0_base, y1_storage_top, y2_combined_top, y3_discharge_top, cum_potential_runoff_plot]
    if not all(len(arr) == len(times_log) for arr in plot_arrays_check):
        logging.error(f"plot_mass_balance: Final array length mismatch before plotting. Aborting.")
        if fig: plt.close(fig);
        return

    ax.fill_between(times_log, y0_base, y1_storage_top, step='post', color=colors_mass_balance['Surface Depression Storage'], alpha=0.8, label=labels_mass_balance['Surface Depression Storage'])
    ax.fill_between(times_log, y1_storage_top, y2_combined_top, step='post', color=colors_mass_balance['Surface Runoff Transfer & Run-on Infiltration'], alpha=0.8, label=labels_mass_balance['Surface Runoff Transfer & Run-on Infiltration'])
    ax.fill_between(times_log, y2_combined_top, y3_discharge_top, step='post', color=colors_mass_balance['Outlet Discharge'], alpha=0.8, label=labels_mass_balance['Outlet Discharge'])

    ax.step(times_log, cum_potential_runoff_plot, where='post',
            label=labels_mass_balance['Potential Input'],
            color=colors_mass_balance['Potential Input'], linewidth=2.0, linestyle='-')

    # Formatting
    ax.set_xlabel("Time (min)", fontsize=fontsize_labels)
    ax.set_ylabel("Cumulative Volume (m³)", fontsize=fontsize_labels)
    ax.set_title("Cumulative Water Balance Components", fontsize=fontsize_title)
    ax.legend(loc='upper left', fontsize=fontsize_legend, frameon=True, facecolor='white', edgecolor='black')
    leg = ax.get_legend()
    if leg: leg.get_frame().set_linewidth(1.0)
    ax.grid(False)
    ax.tick_params(axis='both', which='major', labelsize=fontsize_labels)
    ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
    ax.set_ylim(bottom=0)
    max_y_plot = np.max(cum_potential_runoff_plot) if len(cum_potential_runoff_plot) > 0 else 1.0
    ax.set_ylim(bottom=0, top=max_y_plot * 1.05)
    max_time = times_log[-1] if len(times_log) > 0 else 1.0
    ax.set_xlim(0, max_time)
    plt.tight_layout()

    # PRINT RESULTS (Water Balance Summary - Updated for Combined Category)
    print("\n--- Water Balance Summary (Recalculated Run-on, Logged Discharge) ---")
    input_is_zero = abs(total_input_vol_summary) < 1e-9
    print(f"Total Potential Runoff Input: {total_input_vol_summary:10.2f} m³ (100.00%)")
    print(f"    Accounted Outputs:")
    print(f"      Surface Runoff Transfer & Run-on Infiltration: {final_combined_transfer_runon_val:10.2f} m³ ", end="")
    if not input_is_zero: print(f"({final_combined_transfer_runon_val/total_input_vol_summary*100:6.1f}%)")
    else: print("(  N/A %)")
    print(f"      Outlet Discharge:        {total_discharge_final:10.2f} m³ ", end="")
    if not input_is_zero: print(f"({total_discharge_final/total_input_vol_summary*100:6.1f}%)")
    else: print("(  N/A %)")
    print(f"      Surf. Depress. Storage:  {final_storage:10.2f} m³ ", end="")
    if not input_is_zero: print(f"({final_storage/total_input_vol_summary*100:6.1f}%)")
    else: print("(  N/A %)")
    print(f"      ---------------------------------------------------")
    summary_residual_error = total_input_vol_summary - final_combined_transfer_runon_val - total_discharge_final - final_storage
    print(f"      Mass Balance Difference (Error): {summary_residual_error:10.2f} m³ ", end="")
    if not input_is_zero: print(f"({summary_residual_error/total_input_vol_summary*100:6.1f}%)")
    else: print("(  N/A %)")
    print("--------------------------------------------------------------------------------")

    return fig, ax

def visualize_merge_framework(snapshots: List[Tuple[float, Dict]],
                              simulator: Optional['IntegratedSimulation'],
                              runoff_processor: Optional['RunoffProcessor'],
                              selected_inlets: List[str],
                              non_merged_ids: Set[str],
                              simulation_time_min: float
                             ) -> Tuple[List[Dict], List[Dict]]:
    """
    Visualizes merge events focusing on hydrographs and run-on calculation.
    Applies specific formatting, labels, colors, and layout requested by user.
    Helper functions (prep, safe_interp) are defined internally.
    Uses recalculated run-on and heuristic discharge logic. Includes DEFENSIVE prep.
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import logging
    import matplotlib.gridspec as gridspec
    import matplotlib.patches as mpatches

    logging.info("Starting merge framework visualization (Hydrograph Focus)...")

    merge_volume_details = []
    runon_time_series = []

    if runoff_processor is None or not hasattr(runoff_processor, 'basin_runoff_data'):
        logging.warning("visualize_merge_framework: RunoffProcessor/data missing.")
        return merge_volume_details, runon_time_series
    if not runoff_processor.basin_runoff_data:
        logging.warning("visualize_merge_framework: runoff_processor.basin_runoff_data is empty.")
        return merge_volume_details, runon_time_series

    # === Internal Helper Functions === 
    def match_length(x: np.ndarray, y: np.ndarray):
        n = min(x.size, y.size); return x[:n], y[:n]

    def prep(d, tkey, qkey):
        """DEFENSIVE: Prepares timeline and flow data, converting flow to m³/min."""
        if d is None: return None, None
        try:
            t = d.get(tkey)
            if t is None:
                return None, None
        except Exception as e_get_t:
            logging.warning(f"prep: Error retrieving t ({tkey}): {e_get_t}")
            return None, None

        try:
            q = d.get(qkey)
            if q is None:
                return None, None
        except Exception as e_get_q:
            logging.warning(f"prep: Error retrieving q ({qkey}): {e_get_q}")
            return None, None

        if not isinstance(t, (list, tuple, np.ndarray)) or \
           not isinstance(q, (list, tuple, np.ndarray)):
            logging.warning(f"prep: Invalid data type for {tkey}({type(t)}) or {qkey}({type(q)}).")
            return None, None

        try:
            ta = np.asarray(t, dtype=float)
            qa = np.asarray(q, dtype=float) # m³/s
            if np.any(np.isnan(ta)) or np.any(np.isinf(ta)): logging.warning(f"prep: NaN/Inf time {tkey}."); return None, None
            if np.any(np.isnan(qa)) or np.any(np.isinf(qa)): logging.warning(f"prep: NaN/Inf quantity {qkey}."); return None, None

            n = min(ta.size, qa.size)
            if n == 0: logging.warning(f"prep: Zero length array {tkey}/{qkey}."); return None, None
            t_out, q_out = match_length(ta, qa)
            return t_out, q_out * 60.0 # m³/min

        except (ValueError, TypeError, MemoryError) as e:
            logging.warning(f"prep: Conversion/Processing error {tkey}/{qkey}: {e}")
            return None, None

    def safe_interp(x_new, x_old, y_old, default_val=0.0, label=""):
        # Keep the robust safe_interp from previous versions
        try:
            x_new=np.asarray(x_new,dtype=float); x_old=np.asarray(x_old,dtype=float); y_old=np.asarray(y_old,dtype=float)
            if x_new.size==0: return np.full_like(x_new, default_val, dtype=float)
            if x_old.size==0 or y_old.size==0: return np.full(x_new.shape, default_val, dtype=float)
            if x_old.size!=y_old.size: logging.warning(f"safe_interp ({label}): x/y size mismatch. Defaulting."); return np.full(x_new.shape, default_val, dtype=float)
            if np.any(np.isnan(x_old)) or np.any(np.isinf(x_old)): logging.warning(f"safe_interp ({label}): NaN/Inf x_old. Defaulting."); return np.full(x_new.shape, default_val, dtype=float)
            if np.any(np.isnan(y_old)) or np.any(np.isinf(y_old)): y_old=np.nan_to_num(y_old,nan=default_val,posinf=default_val,neginf=default_val); logging.warning(f"safe_interp ({label}): NaN/Inf y_old replaced.")
            if not np.all(np.diff(x_old)>=0): sort_indices=np.argsort(x_old); x_old=x_old[sort_indices]; y_old=y_old[sort_indices]
            unique_indices=np.unique(x_old,return_index=True)[1]
            if len(unique_indices)<len(x_old): x_old=x_old[unique_indices]; y_old=y_old[unique_indices]
            if x_old.size<2: logging.warning(f"safe_interp ({label}): < 2 unique points. Cannot interp."); return np.full(x_new.shape, default_val, dtype=float)
            interp_result=np.interp(x_new, x_old, y_old, left=default_val, right=default_val)
            return np.nan_to_num(interp_result.astype(float), nan=default_val, posinf=default_val, neginf=default_val)
        except Exception as e:
            logging.error(f"safe_interp ({label}) failed: {e}", exc_info=True); return np.full(np.asarray(x_new).shape, default_val, dtype=float)

    # End Internal Helper Functions

    # Gather Merge Events
    merges = []
    for bid, data in runoff_processor.basin_runoff_data.items():
        if not isinstance(data, dict) or '+' not in bid: continue
        required_keys = ('merged_from_ids', 'merge_time_min', 'convolution_timeline_min','original_lagged_flow_m3s', 'dt_sec', 'lag_used_sec','adjusted_upstream_lagged_flow_m3s')
        if not all(k in data for k in required_keys): continue
        if not isinstance(data['merged_from_ids'], list) or len(data['merged_from_ids']) < 2: continue
        try: merge_time=float(data['merge_time_min']); lag_sec=float(data['lag_used_sec']); dt_sec=float(data['dt_sec']); assert dt_sec > 1e-9
        except: continue
        up_id, down_id = data['merged_from_ids'][1], data['merged_from_ids'][0]
        merges.append({'merged_id': bid, 'up_id': str(up_id), 'down_id': str(down_id), 'time': merge_time, 'lag_sec': lag_sec, 'ut': np.array(data['convolution_timeline_min'], float), 'orig_flow_m3s': np.array(data['original_lagged_flow_m3s'], float), 'adj_flow_m3s': np.array(data['adjusted_upstream_lagged_flow_m3s'], float), 'dt': dt_sec})
    merges.sort(key=lambda e: e['time'])
    if not merges: print("No merge events found."); logging.warning("No merge events found."); return [], []

    # Define Colors
    colors = {
        'Flow Rate': '#6c757d', 'Flow Rate Post-Merge': '#b2df8a', # Upstream: Gray base, Lt Green hatch
        'Downstream Flow': '#1f77b4', # Blue for Downstream Base Flow
        'Combined Flow': '#1f77b4',    # Blue
        'Run-on Infiltration': '#ff7f0e', 'Surface Runoff Transfer': '#2ca02c', # Green
        'Capacity Limit': '#adb5bd', 'Incoming Lagged Vol': '#6a3d9a',
        'Original Lagged Flow': '#6a3d9a', 'Merge Time': '#e31a1c', 'Arrival Time': '#ff7f0e'
    }

    # Determine Axis Limits
    max_y_flow = 0.01; max_y_vol = 0.01
    for i_ev, ev in enumerate(merges):
        up_data = runoff_processor.basin_runoff_data.get(ev['up_id'])
        up_t, up_q_m3min = prep(up_data,'convolution_timeline_min','direct_flow_m3_s')
        if up_q_m3min is not None: max_y_flow=max(max_y_flow, np.max(up_q_m3min) if up_q_m3min.size > 0 else 0)
        dn_data = runoff_processor.basin_runoff_data.get(ev['down_id'])
        dn_t, dn_q_m3min = prep(dn_data,'convolution_timeline_min','direct_flow_m3_s')
        if dn_q_m3min is not None: max_y_flow=max(max_y_flow, np.max(dn_q_m3min) if dn_q_m3min.size > 0 else 0)
        mg_data = runoff_processor.basin_runoff_data.get(ev['merged_id'])
        mg_t, mg_q_m3min = prep(mg_data,'convolution_timeline_min','direct_flow_m3_s')
        if mg_q_m3min is not None: max_y_flow=max(max_y_flow, np.max(mg_q_m3min) if mg_q_m3min.size > 0 else 0)
        orig_lag_m3min=ev['orig_flow_m3s']*60.0; adj_lag_m3min=ev['adj_flow_m3s']*60.0
        max_y_flow=max(max_y_flow, np.max(orig_lag_m3min) if orig_lag_m3min.size > 0 else 0); max_y_flow=max(max_y_flow, np.max(adj_lag_m3min) if adj_lag_m3min.size > 0 else 0)
        dn_data_for_vol=runoff_processor.basin_runoff_data.get(ev['down_id'],{}); t_dn_def=np.array(dn_data_for_vol.get('timeline_min',[]),float); dn_def_vol=np.array(dn_data_for_vol.get('infiltration_deficit_vol_step',[]),float)
        aligned_def=safe_interp(ev['ut'],t_dn_def,dn_def_vol); up_vol_step=ev['orig_flow_m3s']*ev['dt']; runon_vol_step=np.minimum(aligned_def, up_vol_step)
        max_y_vol=max(max_y_vol, np.max(aligned_def) if aligned_def.size > 0 else 0); max_y_vol=max(max_y_vol, np.max(up_vol_step) if up_vol_step.size > 0 else 0); max_y_vol=max(max_y_vol, np.max(runon_vol_step) if runon_vol_step.size > 0 else 0)

    x_limit = simulation_time_min
    max_y_flow *= 1.1; max_y_vol *= 1.1

    # Define Consistent Font Size
    base_fontsize = 9; fontsize_title = base_fontsize + 2; fontsize_label = base_fontsize + 1; fontsize_legend = base_fontsize; fontsize_tick = base_fontsize

    # Plotting Loop
    for i, ev in enumerate(merges):
        up_id, down_id, merged_id = ev['up_id'], ev['down_id'], ev['merged_id']; t_merge = ev['time']; lag_min = ev['lag_sec'] / 60.0; t_arrival = t_merge + lag_min
        ut, orig_flow_m3s, adj_flow_m3s, dt = ev['ut'], ev['orig_flow_m3s'], ev['adj_flow_m3s'], ev['dt']
        dn_data=runoff_processor.basin_runoff_data.get(down_id,{}); t_dn_def=np.array(dn_data.get('timeline_min',[]),float); dn_def_vol=np.array(dn_data.get('infiltration_deficit_vol_step',[]),float)
        aligned_def_vol=safe_interp(ut,t_dn_def,dn_def_vol,label=f"AlignDef {i}")
        incoming_lagged_vol=orig_flow_m3s*dt; runon_inf_vol_step=np.minimum(aligned_def_vol, incoming_lagged_vol); transfer_vol_step=np.maximum(0, incoming_lagged_vol - runon_inf_vol_step)
        runon_time_series.append({'merged_id': merged_id, 'times': ut.tolist(), 'runon_inf_vol_step': runon_inf_vol_step.tolist()})
        total_runon_calculated = float(np.sum(runon_inf_vol_step))
        merge_volume_details.append({'merged_id': merged_id, 'up_id': up_id, 'down_id': down_id, 'infiltration_total': total_runon_calculated})

        fig = plt.figure(figsize=(10, 11), dpi = 450); gs = gridspec.GridSpec(3, 2, height_ratios=[1, 1, 1.2], width_ratios=[1, 1], hspace=0.3, wspace=0.25)
        legend_props = {'loc': 'upper right', 'fontsize': base_fontsize, 'frameon': True, 'facecolor': 'white', 'edgecolor': 'black'}
        legend_frame_lw = 0.8; fill_edge_color = 'black'; fill_edge_lw = 0.5

        ax_up = fig.add_subplot(gs[0, 0]); ax_up.set_title(f"Upstream Response ({up_id})", fontsize=fontsize_title)
        up_t, up_q_m3min = prep(runoff_processor.basin_runoff_data.get(up_id), 'convolution_timeline_min', 'direct_flow_m3_s')
        leg_handles_up = []
        if up_t is not None and up_q_m3min is not None:
            ax_up.fill_between(up_t, 0, up_q_m3min, step='post', color=colors['Flow Rate'], alpha=0.7, edgecolor=fill_edge_color, lw=fill_edge_lw)
            leg_handles_up.append(mpatches.Patch(facecolor=colors['Flow Rate'], alpha=0.7, edgecolor=fill_edge_color, label='Depression Storage'))
            idx_merge = np.searchsorted(up_t, t_merge, 'left')
            ax_up.fill_between(up_t[idx_merge:], 0, up_q_m3min[idx_merge:], step='post', color=colors['Flow Rate Post-Merge'], alpha=0.8, hatch='///', edgecolor=fill_edge_color, lw=fill_edge_lw)
            leg_handles_up.append(mpatches.Patch(facecolor=colors['Flow Rate Post-Merge'], alpha=0.8, hatch='///', edgecolor=fill_edge_color, label='Flow Post-Merge Trigger'))
        line_merge_up = ax_up.axvline(t_merge, color=colors['Merge Time'], linestyle='--', linewidth=1.5, label=f"Merge Time ({t_merge:.2f} min)")
        leg_handles_up.append(line_merge_up)
        ax_up.set_ylabel("Flow Rate (m³/min)", fontsize=fontsize_label); ax_up.legend(handles=leg_handles_up, **legend_props).get_frame().set_linewidth(legend_frame_lw)
        ax_up.grid(False); ax_up.set_xlim(0, x_limit); ax_up.set_ylim(0, max_y_flow)
        ax_up.tick_params(axis='both', labelsize=fontsize_tick); ax_up.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

        ax_dn = fig.add_subplot(gs[0, 1]); ax_dn.set_title(f"Downstream Response ({down_id})", fontsize=fontsize_title)
        dn_t, dn_q_m3min = prep(runoff_processor.basin_runoff_data.get(down_id), 'convolution_timeline_min', 'direct_flow_m3_s')
        leg_handles_dn = []
        if dn_t is not None and dn_q_m3min is not None:
            ax_dn.fill_between(dn_t, 0, dn_q_m3min, step='post', color=colors['Downstream Flow'], alpha=0.7, edgecolor=fill_edge_color, lw=fill_edge_lw)
            leg_handles_dn.append(mpatches.Patch(facecolor=colors['Downstream Flow'], alpha=0.7, edgecolor=fill_edge_color, label='Flow Rate'))
        line_merge_dn = ax_dn.axvline(t_merge, color=colors['Merge Time'], linestyle='--', linewidth=1.5, label=f"Merge Time ({t_merge:.2f} min)")
        leg_handles_dn.append(line_merge_dn)
        ax_dn.set_ylabel("Flow Rate (m³/min)", fontsize=fontsize_label); ax_dn.legend(handles=leg_handles_dn, **legend_props).get_frame().set_linewidth(legend_frame_lw)
        ax_dn.grid(False); ax_dn.set_xlim(0, x_limit); ax_dn.set_ylim(0, max_y_flow)
        ax_dn.tick_params(axis='both', labelsize=fontsize_tick); ax_dn.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

        ax_runon = fig.add_subplot(gs[1, 0]); ax_runon.set_title("Run-on Calculation Detail", fontsize=fontsize_title)
        leg_handles_runon = []
        fill_inf = ax_runon.fill_between(ut, 0, runon_inf_vol_step, step='post', color=colors['Run-on Infiltration'], alpha=0.8, edgecolor=fill_edge_color, lw=fill_edge_lw)
        fill_trans = ax_runon.fill_between(ut, runon_inf_vol_step, runon_inf_vol_step + transfer_vol_step, step='post', interpolate=False, color=colors['Surface Runoff Transfer'], alpha=0.7, edgecolor=fill_edge_color, lw=fill_edge_lw)
        leg_handles_runon.append(mpatches.Patch(facecolor=colors['Run-on Infiltration'], alpha=0.8, edgecolor=fill_edge_color, label=f'Run-on Infiltration ({total_runon_calculated:.3f} m³)'))
        leg_handles_runon.append(mpatches.Patch(facecolor=colors['Surface Runoff Transfer'], alpha=0.7, edgecolor=fill_edge_color, label='Surface Runoff Transfer'))
        line_cap, = ax_runon.step(ut, aligned_def_vol, where='post', color=colors['Capacity Limit'], linestyle='--', linewidth=1.5, label='Downstream Infil. Capacity')
        line_mrg_r = ax_runon.axvline(t_merge, color=colors['Merge Time'], linestyle='--', linewidth=1.5, label=f"Merge Time ({t_merge:.2f} min)")
        line_arr = ax_runon.axvline(t_arrival, color=colors['Arrival Time'], linestyle='-.', linewidth=1.5, label=f'Arrival Time ({t_arrival:.2f} min)')
        leg_handles_runon.extend([line_cap, line_mrg_r, line_arr])
        ax_runon.set_ylabel("Volume per Timestep (m³)", fontsize=fontsize_label); ax_runon.legend(handles=leg_handles_runon, **legend_props).get_frame().set_linewidth(legend_frame_lw)
        ax_runon.grid(False); ax_runon.set_xlim(0, x_limit); ax_runon.set_ylim(0, max_y_vol)
        ax_runon.tick_params(axis='both', labelsize=fontsize_tick); ax_runon.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.4f'))

        ax_lag = fig.add_subplot(gs[1, 1]); ax_lag.set_title("Lagged Upstream Flow (Post-Transfer)", fontsize=fontsize_title)
        orig_lag_m3min = orig_flow_m3s * 60.0; adj_lag_m3min = adj_flow_m3s * 60.0
        min_len_lag = min(len(ut), len(orig_lag_m3min), len(adj_lag_m3min)); ut_lag, orig_lag_m3min, adj_lag_m3min = ut[:min_len_lag], orig_lag_m3min[:min_len_lag], adj_lag_m3min[:min_len_lag]
        leg_handles_lag = []
        fill_adj = ax_lag.fill_between(ut_lag, 0, adj_lag_m3min, step='post', color=colors['Surface Runoff Transfer'], alpha=0.7, edgecolor=fill_edge_color, lw=fill_edge_lw)
        leg_handles_lag.append(mpatches.Patch(facecolor=colors['Surface Runoff Transfer'], alpha=0.7, edgecolor=fill_edge_color, label='Surface Runoff Transfer'))
        line_orig, = ax_lag.step(ut_lag, orig_lag_m3min, where='post', color=colors['Original Lagged Flow'], linestyle=':', linewidth=1.5, label='Original Lagged Flow')
        line_mrg_l = ax_lag.axvline(t_merge, color=colors['Merge Time'], linestyle='--', linewidth=1.5, label=f"Merge Time ({t_merge:.2f} min)")
        line_arr_l = ax_lag.axvline(t_arrival, color=colors['Arrival Time'], linestyle='-.', linewidth=1.5, label=f"Arrival Time ({t_arrival:.2f} min)")
        leg_handles_lag.extend([line_orig, line_mrg_l, line_arr_l])
        ax_lag.set_ylabel("Flow Rate (m³/min)", fontsize=fontsize_label); ax_lag.legend(handles=leg_handles_lag, **legend_props).get_frame().set_linewidth(legend_frame_lw)
        ax_lag.grid(False); ax_lag.set_xlim(0, x_limit); ax_lag.set_ylim(0, max_y_flow)
        ax_lag.tick_params(axis='both', labelsize=base_fontsize); ax_lag.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

        ax_comb = fig.add_subplot(gs[2, :]); ax_comb.set_title(f"Combined Outlet Response ({merged_id})", fontsize=fontsize_title + 1)
        mg_t, mg_q_m3min = prep(runoff_processor.basin_runoff_data.get(merged_id), 'convolution_timeline_min', 'direct_flow_m3_s')
        dn_t_comb, dn_q_m3min_comb = prep(runoff_processor.basin_runoff_data.get(down_id), 'convolution_timeline_min', 'direct_flow_m3_s')
        ut_lag_comb = ut_lag; adj_lag_m3min_comb = adj_lag_m3min
        leg_handles_comb = []
        if mg_t is not None and mg_q_m3min is not None:
            ax_comb.fill_between(mg_t, 0, mg_q_m3min, step='post', color=colors['Combined Flow'], alpha=0.7, edgecolor=fill_edge_color, lw=fill_edge_lw)
            total_comb_vol = 0.0
            if len(mg_t) > 1: total_comb_vol = float(np.trapz(mg_q_m3min, mg_t)) if hasattr(np,'trapezoid') else float(np.trapz(mg_q_m3min, mg_t))
            leg_handles_comb.append(mpatches.Patch(facecolor=colors['Combined Flow'], alpha=0.7, edgecolor=fill_edge_color, label=f"Combined Flow (Vol: {total_comb_vol:.3f} m³)"))
            if dn_t_comb is not None and dn_q_m3min_comb is not None:
                line_dn_c, = ax_comb.step(dn_t_comb, dn_q_m3min_comb, where='post', color=colors['Downstream Flow'], linestyle='--', linewidth=1.5, label='Downstream Component')
                leg_handles_comb.append(line_dn_c)
            line_trn_c, = ax_comb.step(ut_lag_comb, adj_lag_m3min_comb, where='post', color=colors['Surface Runoff Transfer'], linestyle=':', linewidth=1.5, label='Surface Runoff Transfer')
            leg_handles_comb.append(line_trn_c)
        else: ax_comb.text(0.5, 0.5, "Combined data not available", ha='center', va='center', fontsize=base_fontsize)
        ax_comb.set_xlabel(f"Time (min)", fontsize=fontsize_label + 1); ax_comb.set_ylabel("Flow Rate (m³/min)", fontsize=fontsize_label + 1)
        comb_legend_props = legend_props.copy(); comb_legend_props['fontsize'] = base_fontsize
        ax_comb.legend(handles=leg_handles_comb, **comb_legend_props).get_frame().set_linewidth(legend_frame_lw)
        ax_comb.grid(False); ax_comb.set_xlim(0, x_limit); ax_comb.set_ylim(0, max_y_flow)
        ax_comb.tick_params(axis='both', labelsize=fontsize_tick + 1); ax_comb.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))

        fig.suptitle(f"Merge Event {i+1} (t={t_merge:.2f} min)", fontsize=fontsize_title+2, y=0.98)
        plt.tight_layout(rect=[0, 0.02, 1, 0.95])
        plt.show(); plt.close(fig)

    # Summary Table
    print("\n--- Run-on Infiltration Volume per Merge Event ---")
    print(f"{'Merge Event (Upstream -> Downstream)':<45} | {'Run-on Infiltration (m³)':<18}")
    print("-" * 45 + "-+-" + "-" * 18); total_all_runon = 0.0
    for d in merge_volume_details: desc = f"{d['up_id']} -> {d['down_id']}"; vol = d['infiltration_total']; print(f"{desc:<45} | {vol:<18.4f}"); total_all_runon += vol
    print("-" * 45 + "-+-" + "-" * 18); print(f"{'TOTAL':<45} | {total_all_runon:<18.4f}\n")

    logging.info("Finished visualizing merge events.")
    return merge_volume_details, runon_time_series

def plot_runon_time_series(runon_time_series: List[Dict],
                           true_total_runon: float,
                           simulation_time_min: float,
                           figsize=(10,6)):
    """
    Plots the *cumulative* run-on infiltration for all merge events,
    and sets the x-axis limit to 150% of the simulation duration.
    MODIFIED: Scales the final cumulative value to match the provided true_total_runon.

    Parameters:
    - runon_time_series: List of dicts with keys:
        'merged_id', 'times', 'runon_inf_vol_step'
    - true_total_runon: The sum of infiltration totals from merge_volume_details.
    - simulation_time_min: total simulation duration (minutes)
    - figsize: tuple for figure size
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import logging

    if not runon_time_series:
        print("No run-on infiltration data available.")
        return

    # build a common time grid
    all_times = []
    for entry in runon_time_series:
        t = entry.get('times')
        if t is not None and len(t) > 0:
            all_times.append(np.array(t, dtype=float))

    if not all_times:
        logging.warning("No valid time data in runon_time_series for plotting.")
        return

    # Determine dt from the first valid event for simplicity, or use a default
    dt = 1.0 # Default
    for entry in runon_time_series:
        t = np.array(entry.get('times', []), dtype=float)
        if len(t) > 1:
            diffs = np.diff(t)
            valid_diffs = diffs[diffs > 1e-9]
            if valid_diffs.size > 0:
                dt = np.min(valid_diffs)
                break

    # grid runs from 0 to 1.5 * simulation_time_min
    t_start = 0.0
    t_end = simulation_time_min * 1.5
    common_t = np.arange(t_start, t_end + dt, dt)
    if common_t.size == 0: # Handle case where range is invalid
        logging.warning("Could not create valid common time grid for runon plot.")
        return

    # interpolate each event's step-vol onto common grid
    stacked_steps = []
    for entry in runon_time_series:
        times = np.array(entry.get('times', []), dtype=float)
        vols  = np.array(entry.get('runon_inf_vol_step', []), dtype=float)
        if times.size > 0 and times.size == vols.size: # Check for valid data
            # Ensure times are monotonically increasing for interpolation
            sort_indices = np.argsort(times)
            times_sorted = times[sort_indices]
            vols_sorted = vols[sort_indices]
            # Remove duplicates that might cause issues
            unique_indices = np.unique(times_sorted, return_index=True)[1]
            if len(unique_indices) >= 2: # Need at least 2 points to interpolate
                times_unique = times_sorted[unique_indices]
                vols_unique = vols_sorted[unique_indices]
                vols_i = np.interp(common_t, times_unique, vols_unique, left=0.0, right=0.0)
                stacked_steps.append(vols_i)
            else:
                # Handle case with < 2 unique points - add zeros
                logging.debug(f"Skipping interpolation for merge '{entry.get('merged_id', 'Unknown')}' due to insufficient unique time points.")
                stacked_steps.append(np.zeros_like(common_t))
        else:
            logging.warning(f"Invalid/empty time/volume data for merge '{entry.get('merged_id', 'Unknown')}' in runon plot.")
            stacked_steps.append(np.zeros_like(common_t)) # Add zeros if data invalid

    if not stacked_steps: # Check if any valid data was processed
        logging.warning("No valid runon data found to stack for plot.")
        return

    stacked_arr = np.vstack(stacked_steps)

    # sum and cumulative sum
    total_step = stacked_arr.sum(axis=0)
    cum_total_calculated = np.cumsum(total_step)

    # Scaling Step
    calculated_final_val = cum_total_calculated[-1] if cum_total_calculated.size > 0 else 0.0
    scale_factor = 1.0 # Default scale factor
    if abs(calculated_final_val) > 1e-9: # Avoid division by zero
        scale_factor = true_total_runon / calculated_final_val
    elif abs(true_total_runon) > 1e-9:
        # If calculated is zero but true total isn't, scaling is problematic.
        # Log a warning, plot will likely show zero.
        logging.warning(f"Calculated final run-on volume is near zero ({calculated_final_val:.4e}), "
                        f"but true total is {true_total_runon:.4f}. Cannot scale accurately.")
        # Keep scale_factor = 1.0, the plot will show the (near) zero calculated value.

    logging.info(f"Run-on Plot: True Total={true_total_runon:.4f}, Calculated Final={calculated_final_val:.4f}, Scale Factor={scale_factor:.4f}")
    cum_total_scaled = cum_total_calculated * scale_factor
    # End Scaling Step

    # plotting
    fig, ax = plt.subplots(figsize=figsize)
    # Plot the SCALED cumulative total
    ax.plot(common_t, cum_total_scaled, linewidth=2,
             label=f"Cumulative Run-on Volume (Final ≈ {true_total_runon:.2f} m³)")
    ax.set_xlim(0, t_end)
    ax.set_xlabel("Time (min)")
    ax.set_ylabel("Cumulative Infiltration Volume (m³)")
    ax.set_title("Cumulative Run-on Infiltration Over Time (Scaled to Match Total)")
    ax.grid(True, linestyle="--", alpha=0.5)
    ax.legend(loc="upper left")
    plt.tight_layout()
    plt.show()

def export_impervious_area_timeseries(
    all_scenario_time_series: Dict[str, List[Dict]],
    simulation_time_min_param: float,
    start_datetime: pd.Timestamp,
    export_interval_min: float = 1.0,
    output_filename: str = "all_inlets_impervious_area_timeseries.csv") -> None:
    """
    Exports a CSV file with impervious area timeseries for all inlets,
    using absolute datetime for the time column.
    """
    if not all_scenario_time_series:
        logging.warning("No scenario time series data provided for export.")
        return

    logging.info(f"Exporting combined impervious area timeseries with datetime format to '{output_filename}'...")

    # 1. Create a common output timeline in minutes
    common_timeline_min = np.arange(0, simulation_time_min_param + export_interval_min, export_interval_min)
    common_timeline_min = common_timeline_min[common_timeline_min <= simulation_time_min_param + 1e-9]

    if common_timeline_min.size == 0:
        logging.warning("Could not create a valid common timeline for CSV export. Aborting.")
        return

    # 2. NEW: Convert the relative minutes timeline to an absolute datetime timeline
    # Use the provided start_datetime and add timedelta for each interval
    absolute_datetimes = start_datetime + pd.to_timedelta(common_timeline_min, unit='m')

    # Format the datetimes into the desired string format
    formatted_datetimes = absolute_datetimes.strftime('%m/%d/%Y %I:%M:%S %p')

    # Initialize a DataFrame with the formatted, absolute datetime column
    df_export = pd.DataFrame({'Time': formatted_datetimes})


    # 3. For each inlet scenario, interpolate its impervious area onto the common timeline
    for inlet_id, scenario_ts_log in all_scenario_time_series.items():
        if not scenario_ts_log:
            logging.warning(f"Time series log for inlet '{inlet_id}' is empty. Skipping.")
            df_export[f'Inlet_{inlet_id}_Impervious_Area (m^2)'] = np.nan
            continue

        try:
            original_times_min = np.array([d['time'] for d in scenario_ts_log], dtype=float)
            impervious_areas = np.array([d.get('inlet_impervious_area', 0.0) for d in scenario_ts_log], dtype=float)

            if original_times_min.size == 0 or impervious_areas.size == 0 or original_times_min.size != impervious_areas.size:
                logging.warning(f"Invalid time or impervious area array for inlet '{inlet_id}'. Skipping.")
                df_export[f'Inlet_{inlet_id}_Impervious_Area (m^2)'] = np.nan
                continue

            # Interpolate the area data onto the common timeline of *minutes*
            interp_impervious_areas = safe_interp(
                common_timeline_min, # Interpolate onto the numeric minute timeline
                original_times_min,
                impervious_areas,
                default_val=0.0,
                label=f"ImpervAreaInlet_{inlet_id}"
            )
            df_export[f'Inlet_{inlet_id}_Impervious_Area (m^2)'] = interp_impervious_areas

        except Exception as e:
            logging.error(f"Failed to process impervious area for inlet '{inlet_id}': {e}", exc_info=True)
            df_export[f'Inlet_{inlet_id}_Impervious_Area (m^2)'] = np.nan

    # 4. Export the combined DataFrame to CSV
    try:
        export_path = os.path.abspath(output_filename)
        df_export.to_csv(export_path, index=False, float_format='%.3f')
        logging.info(f"Successfully exported combined impervious area timeseries to {export_path}")
        print(f"Successfully exported combined impervious area timeseries to {export_path}")
    except Exception as e_export:
        logging.error(f"Failed to export impervious area timeseries CSV: {e_export}", exc_info=True)
        print(f"ERROR: Failed to export impervious area timeseries CSV. Details: {traceback.format_exc()}")

def export_total_area_timeseries(
    all_scenario_time_series: Dict[str, List[Dict]],
    simulation_time_min_param: float,
    start_datetime: pd.Timestamp,
    export_interval_min: float = 1.0,
    output_filename: str = "all_inlets_total_area_timeseries.csv") -> None:
    """
    Exports a CSV file with total connected area timeseries for all inlets,
    using absolute datetime for the time column.
    """
    if not all_scenario_time_series:
        logging.warning("No scenario time series data provided for total area export.")
        return

    logging.info(f"Exporting combined total area timeseries with datetime format to '{output_filename}'...")

    # 1. Create a common output timeline in minutes
    common_timeline_min = np.arange(0, simulation_time_min_param + export_interval_min, export_interval_min)
    common_timeline_min = common_timeline_min[common_timeline_min <= simulation_time_min_param + 1e-9]

    if common_timeline_min.size == 0:
        logging.warning("Could not create a valid common timeline for total area CSV export. Aborting.")
        return

    # 2. Convert the relative minutes timeline to an absolute datetime timeline
    absolute_datetimes = start_datetime + pd.to_timedelta(common_timeline_min, unit='m')
    formatted_datetimes = absolute_datetimes.strftime('%m/%d/%Y %I:%M:%S %p')

    # Initialize a DataFrame with the formatted, absolute datetime column
    df_export = pd.DataFrame({'Time': formatted_datetimes})

    # 3. For each inlet scenario, interpolate its total area onto the common timeline
    for inlet_id, scenario_ts_log in all_scenario_time_series.items():
        if not scenario_ts_log:
            logging.warning(f"Time series log for inlet '{inlet_id}' is empty. Skipping total area export.")
            df_export[f'Inlet_{inlet_id}_Total_Area (m^2)'] = np.nan
            continue

        try:
            original_times_min = np.array([d['time'] for d in scenario_ts_log], dtype=float)
            # MODIFICATION: Target 'inlet_total_area' key
            total_areas = np.array([d.get('inlet_total_area', 0.0) for d in scenario_ts_log], dtype=float)

            if original_times_min.size == 0 or total_areas.size == 0 or original_times_min.size != total_areas.size:
                logging.warning(f"Invalid time or total area array for inlet '{inlet_id}'. Skipping.")
                df_export[f'Inlet_{inlet_id}_Total_Area (m^2)'] = np.nan
                continue

            # Interpolate the area data onto the common timeline of *minutes*
            interp_total_areas = safe_interp(
                common_timeline_min,
                original_times_min,
                total_areas,
                default_val=0.0,
                label=f"TotalAreaInlet_{inlet_id}"
            )
            # MODIFICATION: Update column header
            df_export[f'Inlet_{inlet_id}_Total_Area (m^2)'] = interp_total_areas

        except Exception as e:
            logging.error(f"Failed to process total area for inlet '{inlet_id}': {e}", exc_info=True)
            df_export[f'Inlet_{inlet_id}_Total_Area (m^2)'] = np.nan

    # 4. Export the combined DataFrame to CSV
    try:
        export_path = os.path.abspath(output_filename)
        df_export.to_csv(export_path, index=False, float_format='%.3f')
        logging.info(f"Successfully exported combined total area timeseries to {export_path}")
        print(f"Successfully exported combined total area timeseries to {export_path}")
    except Exception as e_export:
        logging.error(f"Failed to export total area timeseries CSV: {e_export}", exc_info=True)
        print(f"ERROR: Failed to export total area timeseries CSV. Details: {traceback.format_exc()}")

def get_node_color_and_size(state_data: dict, min_cap: float, max_cap: float, min_size=500, max_size=5000) -> Tuple[str, float]:
    """
    Determines node color based on fill percentage and size based on capacity.
    """
    volume = state_data.get('current_volume', 0.0)
    capacity = state_data.get('max_volume', 0.0)

    # Determine fill percentage
    fill_percentage = (volume / capacity) * 100 if capacity > 1e-9 else 0
    fill_percentage = max(0, min(fill_percentage, 100)) # Clamp between 0 and 100

    # Determine color: Blue scale for fill percentage
    norm = colors.Normalize(vmin=0, vmax=100)
    cmap = cm.get_cmap('Blues')
    color = colors.to_hex(cmap(norm(fill_percentage)))

    # Determine size: Log scale based on capacity relative to min/max capacity
    if capacity <= min_cap:
        size = min_size
    elif capacity >= max_cap:
        size = max_size
    else:
        if min_cap > 0 and max_cap > min_cap: # Avoid log(0) or division by zero
            log_min = np.log10(min_cap)
            log_max = np.log10(max_cap)
            log_cap = np.log10(max(capacity, 1e-9))
            scale_factor = (log_cap - log_min) / (log_max - log_min)
            size = min_size + (max_size - min_size) * scale_factor
            size = max(min_size, min(size, max_size)) # Clamp size
        else: # Fallback if capacities are problematic
            size = (min_size + max_size) / 2

    return color, size

def plot_drainage_snapshots(forest: 'DrainageForest',
                            snapshots: List[Tuple[float, Dict]],
                            inlet_id: str,
                            time_series: List[Dict],
                            global_cap_min: float,
                            global_cap_max: float,
                            original_inlet_ids: List[str],
                            output_dir: str = "drainage_snapshots",
                            structure_changes_only: bool = False,
                            max_plots: Optional[int] = None,
                            runoff_processor: Optional['RunoffProcessor'] = None):
    """
    Filters and plots drainage network snapshots based on specified criteria.
    Calls plot_drainage_snapshot_with_ids for each selected snapshot.
    Includes logic for an optional extra plot around the time rainfall stops.

    Args:
        forest: The DrainageForest object.
        snapshots: A list of snapshot tuples (time_sec, state_dict).
        inlet_id: The primary inlet ID for titling general snapshot plots.
        time_series: The simulation time series log (currently unused by plotting but kept for signature).
        global_cap_min: Minimum capacity found across all nodes for consistent sizing.
        global_cap_max: Maximum capacity found across all nodes for consistent sizing.
        original_inlet_ids: List of all originally selected inlet IDs for path highlighting.
        output_dir: Directory to save snapshots (currently not used for saving in this function).
        structure_changes_only: If True, plots only when network structure changes.
        max_plots: Maximum number of plots to generate if structure_changes_only is False.
                    If None and structure_changes_only is False, all snapshots are plotted.
        runoff_processor: The RunoffProcessor instance, used for the rainfall end plot.
    """
    if not snapshots:
        logging.warning("No snapshots available to plot in plot_drainage_snapshots.")
        return

    # Snapshot Filtering Logic
    filtered_snapshots_to_plot: List[Tuple[float, Dict]] = []
    logging.info(f"Filtering {len(snapshots)} snapshots for plotting. Max_plots={max_plots}, Structure_only={structure_changes_only}")

    if structure_changes_only and len(snapshots) > 1:
        filtered_snapshots_to_plot.append(snapshots[0]) # Always include the first snapshot
        # Ensure state_dict exists before accessing keys
        prev_snapshot_data = snapshots[0][1] if snapshots[0] and len(snapshots[0]) > 1 else {}
        prev_node_ids_set = set(prev_snapshot_data.keys()) if prev_snapshot_data else set()
        prev_parent_links = {k: v.get('parent_id') for k, v in prev_snapshot_data.items()} if prev_snapshot_data else {}

        for i in range(1, len(snapshots)):
            current_snap_tuple = snapshots[i]
            # Basic validation of snapshot tuple structure
            if not current_snap_tuple or len(current_snap_tuple) < 2:
                logging.warning(f"Skipping invalid snapshot at index {i} during filtering.")
                continue
            
            current_nodes_data_dict = current_snap_tuple[1]
            if current_nodes_data_dict is None: # Handle cases where state_dict might be None
                logging.warning(f"Snapshot at index {i} (time {current_snap_tuple[0]}) has None state_dict. Comparing as empty.")
                current_node_ids_set = set()
                current_parent_links = {}
            else:
                current_node_ids_set = set(current_nodes_data_dict.keys())
                current_parent_links = {k: v.get('parent_id') for k, v in current_nodes_data_dict.items()}

            # Check for changes in the set of nodes or their parent links
            has_structure_changed = (current_node_ids_set != prev_node_ids_set) or \
                                     (current_parent_links != prev_parent_links)

            if has_structure_changed:
                filtered_snapshots_to_plot.append(current_snap_tuple)
                prev_node_ids_set = current_node_ids_set
                prev_parent_links = current_parent_links
        
        # Ensure the very last snapshot is included if its structure differs from the
        # last one added,
        # or if only the first was added (meaning no structural changes were found after
        # the first).
        if snapshots and (snapshots[-1] not in filtered_snapshots_to_plot):
            should_add_final_snapshot = True
            if filtered_snapshots_to_plot: # If any snapshots were already added
                # Get data of the last snapshot ADDED to filtered_snapshots_to_plot
                last_added_snap_data = filtered_snapshots_to_plot[-1][1] if filtered_snapshots_to_plot[-1] and len(filtered_snapshots_to_plot[-1]) > 1 else {}
                last_added_ids = set(last_added_snap_data.keys()) if last_added_snap_data else set()
                last_added_parents = {k:v.get('parent_id') for k,v in last_added_snap_data.items()} if last_added_snap_data else {}
                
                # Get data of the ACTUAL last snapshot in the full list
                final_actual_snap_data = snapshots[-1][1] if snapshots[-1] and len(snapshots[-1]) > 1 else {}
                final_actual_ids = set(final_actual_snap_data.keys()) if final_actual_snap_data else set()
                final_actual_parents = {k:v.get('parent_id') for k,v in final_actual_snap_data.items()} if final_actual_snap_data else {}

                # If the actual last snapshot is structurally identical to the last one
                # already chosen, don't add it again
                if (final_actual_ids == last_added_ids) and (final_actual_parents == last_added_parents):
                    should_add_final_snapshot = False
            
            if should_add_final_snapshot:
                    filtered_snapshots_to_plot.append(snapshots[-1])
        
        # If structure_only was true, but no changes were found after the first
        # snapshot,
        # ensure at least the first and last are plotted if they are different.
        if not filtered_snapshots_to_plot and snapshots: 
            filtered_snapshots_to_plot.append(snapshots[0])
            if len(snapshots) > 1 and snapshots[0] != snapshots[-1]: # Check if first and last are actually different
                # This condition might be too strict if only content changed but not
                # structure
                # The above logic for should_add_final_snapshot is generally better
                if snapshots[-1] not in filtered_snapshots_to_plot : # Add if not already there
                    filtered_snapshots_to_plot.append(snapshots[-1])


    elif max_plots is not None and max_plots > 0 and len(snapshots) > 0:
        if len(snapshots) <= max_plots:
            # Plot all valid snapshots if total is less than or equal to max_plots
            filtered_snapshots_to_plot = [s for s in snapshots if s and len(s) > 1 and s[1] is not None]
        else:
            # Select first, last, and evenly spaced intermediate snapshots
            plot_indices = set()
            if max_plots >= 1: plot_indices.add(0)
            if max_plots >= 2: plot_indices.add(len(snapshots) - 1)
            
            num_intermediate_to_select = max_plots - len(plot_indices)
            if num_intermediate_to_select > 0 and len(snapshots) > 2: # Need at least 3 snapshots to pick intermediates
                # Generate indices for intermediate snapshots, excluding 0 and len-1
                # Linspace from index 1 to len-2 to pick from the "middle"
                intermediate_indices_options = np.linspace(1, len(snapshots) - 2, num=len(snapshots)-2, dtype=int)
                if len(intermediate_indices_options) >= num_intermediate_to_select :
                    step = len(intermediate_indices_options) // num_intermediate_to_select
                    selected_intermediate_indices = intermediate_indices_options[::step][:num_intermediate_to_select]
                    for idx in selected_intermediate_indices:
                        plot_indices.add(idx)
                else: # Not enough options to pick distinct intermediates, add what we can
                    for idx in intermediate_indices_options:
                        plot_indices.add(idx)

            sorted_indices = sorted(list(plot_indices))
            # Ensure we don't exceed max_plots due to forcing first/last
            if len(sorted_indices) > max_plots:
                # This case should ideally be handled by adjusting
                # num_intermediate_to_select more carefully
                # For now, just take the first max_plots indices from the sorted list
                sorted_indices = sorted_indices[:max_plots]

            filtered_snapshots_to_plot = [snapshots[i] for i in sorted_indices if snapshots[i] and len(snapshots[i]) > 1 and snapshots[i][1] is not None]
    else: # Default: Plot all valid snapshots if no specific filtering criteria met (max_plots is None or 0)
        filtered_snapshots_to_plot = [s for s in snapshots if s and len(s) > 1 and s[1] is not None]
    # End Snapshot Filtering

    logging.info(f"Plotting {len(filtered_snapshots_to_plot)} filtered snapshots (out of {len(snapshots)} total).")

    # Plotting Loop
    for i, snapshot_data_tuple in enumerate(filtered_snapshots_to_plot):
        time_sec_val, state_dict_val = snapshot_data_tuple # snapshot_data_tuple is (time_sec, state_dict)
        if state_dict_val is None:
            logging.warning(f"Snapshot {i} (original time {time_sec_val/60.0:.2f} min) has None state. Skipping plot.")
            continue
            
        time_min_val = time_sec_val / 60.0
        # Using original_inlet_ids[0] for the title if available, otherwise a generic
        # title.
        plot_title_inlet_id = str(original_inlet_ids[0]) if original_inlet_ids else inlet_id

        print(f"\n--- Generating Plot for Filtered Snapshot {i+1}/{len(filtered_snapshots_to_plot)} (Simulation Time t={time_min_val:.2f} min) ---")
        fig, ax = plt.subplots(figsize=(14, 10))
        try:
            # Call the function that draws a single snapshot
            plot_drainage_snapshot_with_ids(
                ax=ax, 
                forest=forest, # Pass the forest object
                snapshot=snapshot_data_tuple, 
                inlet_id=plot_title_inlet_id, # Main inlet for title
                time_series=time_series, # Pass if needed by the drawing function
                global_cap_min=global_cap_min,
                global_cap_max=global_cap_max, 
                original_inlet_ids=original_inlet_ids # Pass all selected inlets for path highlighting
            )
            plt.show()
        except Exception as plot_err:
            logging.error(f"Error plotting snapshot {i} (t={time_min_val:.2f} min): {plot_err}")
            traceback.print_exc()
        finally:
            # Ensure figure is closed to free memory, especially in loops
            if 'fig' in locals() and fig is not None and plt.fignum_exists(fig.number):
                    plt.close(fig)
    # End Plotting Loop

    # Logic for Extra Plot at Rainfall End
    time_rain_stops_sec_calc = None
    if runoff_processor and original_inlet_ids: 
        try:
            primary_inlet_for_rain_check = str(original_inlet_ids[0])
            inlet_rp_data = runoff_processor.basin_runoff_data.get(primary_inlet_for_rain_check)

            if inlet_rp_data:
                rain_timeline = np.array(inlet_rp_data.get('timeline_min', []))
                rain_depths_step = np.array(inlet_rp_data.get('rain_depth_m_step', []))
                dt_seconds = inlet_rp_data.get('dt_sec')

                if rain_timeline.size > 0 and rain_depths_step.size == rain_timeline.size and dt_seconds is not None and dt_seconds > 0:
                    significant_rain_threshold = 1e-7 
                    significant_rain_indices = np.where(rain_depths_step > significant_rain_threshold)[0]
                    
                    if significant_rain_indices.size > 0:
                        last_significant_rain_idx = significant_rain_indices[-1]
                        if last_significant_rain_idx + 1 < len(rain_timeline):
                            time_rain_stops_min_calc = rain_timeline[last_significant_rain_idx + 1]
                        else: 
                            time_rain_stops_min_calc = rain_timeline[-1] + (dt_seconds / 60.0)
                        time_rain_stops_sec_calc = time_rain_stops_min_calc * 60.0
                        logging.info(f"Rainfall for primary inlet '{primary_inlet_for_rain_check}' effectively stops around t={time_rain_stops_min_calc:.2f} min.")
                    else:
                        logging.info(f"No significant rainfall found in data for primary inlet '{primary_inlet_for_rain_check}'.")
                else:
                    logging.warning("Rainfall data for primary inlet is incomplete for determining stop time.")
            else:
                logging.warning(f"No runoff data found for primary inlet '{primary_inlet_for_rain_check}' to determine rainfall stop time.")
            if time_rain_stops_sec_calc is None:
                    logging.warning("Could not determine rainfall stop time for extra snapshot plot.")
        except IndexError:
            logging.warning("original_inlet_ids list is empty, cannot determine primary inlet for rainfall stop time.")
        except Exception as e_rain_stop:
            logging.error(f"Error determining rainfall stop time: {e_rain_stop}")
    else:
        logging.warning("RunoffProcessor or original_inlet_ids not available, cannot add rainfall end plot.")

    target_snapshot_for_rain_end = None
    if time_rain_stops_sec_calc is not None and snapshots:
        try:
            target_snapshot_for_rain_end = min(snapshots, key=lambda s_item: abs(s_item[0] - time_rain_stops_sec_calc))
            logging.info(f"Snapshot at t={target_snapshot_for_rain_end[0]/60.0:.2f} min selected as closest to rainfall end (target: {time_rain_stops_sec_calc/60.0:.2f} min).")
        except Exception as e_find_snap:
            logging.error(f"Error finding snapshot closest to rainfall end: {e_find_snap}")
            target_snapshot_for_rain_end = None

    if target_snapshot_for_rain_end:
        logging.info("Generating extra plot for snapshot around rainfall end time.")
        time_sec_extra, state_dict_extra = target_snapshot_for_rain_end
        if state_dict_extra is None:
            logging.warning("State dictionary for rainfall-end snapshot is None. Skipping extra plot.")
        else:
            time_min_extra = time_sec_extra / 60.0
            plot_title_inlet_id_extra = str(original_inlet_ids[0]) if original_inlet_ids else inlet_id
            print(f"\n--- Generating Plot for Rainfall End Snapshot (Approx. t={time_min_extra:.2f} min) ---")
            fig_extra, ax_extra = plt.subplots(figsize=(14, 10))
            try:
                plot_drainage_snapshot_with_ids(
                    ax=ax_extra, forest=forest, snapshot=target_snapshot_for_rain_end, 
                    inlet_id=plot_title_inlet_id_extra,
                    time_series=time_series, global_cap_min=global_cap_min,
                    global_cap_max=global_cap_max, 
                    original_inlet_ids=original_inlet_ids
                )
                title_str = f"Network State Approx. When Rainfall Stops (Snapshot at t={time_min_extra:.2f} min)"
                subtitle_str = f"(Rainfall stop for primary inlet '{plot_title_inlet_id_extra}' estimated at t={time_rain_stops_sec_calc/60.0:.2f} min)"
                ax_extra.set_title(title_str + "\n" + subtitle_str, fontsize=12)
                plt.show()
            except Exception as plot_err_extra:
                logging.error(f"Error plotting rainfall-end snapshot (t={time_min_extra:.2f} min): {plot_err_extra}")
                traceback.print_exc()
            finally:
                if 'fig_extra' in locals() and fig_extra is not None and plt.fignum_exists(fig_extra.number):
                    plt.close(fig_extra)
    else:
        logging.info("Skipping extra snapshot plot as rainfall end time/snapshot was not determined or no snapshots available.")
    # End Added Extra Plot Logic

    logging.info(f"Finished plotting drainage snapshots.")

def filter_snapshots_for_plotting(tree_snapshots_log: List[Tuple[float, Dict]],
                                   max_plots: Optional[int] = 5,
                                   structure_only: bool = False) -> List[Tuple[float, Dict]]:
    """
    Filters snapshots for plotting based on structure changes or a maximum number.
    (This is a more robust version of the placeholder)
    """
    if not tree_snapshots_log:
        logging.warning("filter_snapshots_for_plotting: No snapshots to filter.")
        return []

    filtered_snapshots = []
    logging.info(f"Filtering {len(tree_snapshots_log)} snapshots. Max_plots={max_plots}, Structure_only={structure_only}")


    if structure_only and len(tree_snapshots_log) > 1:
        filtered_snapshots.append(tree_snapshots_log[0]) # Always include the first
        # Ensure the state dictionary is not None before trying to access keys
        prev_nodes_data = tree_snapshots_log[0][1] if tree_snapshots_log[0] and len(tree_snapshots_log[0]) > 1 else {}
        prev_node_ids = set(prev_nodes_data.keys()) if prev_nodes_data else set()
        prev_parents = {k: v.get('parent_id') for k, v in prev_nodes_data.items()} if prev_nodes_data else {}

        for i in range(1, len(tree_snapshots_log)):
            current_snapshot_tuple = tree_snapshots_log[i]
            if not current_snapshot_tuple or len(current_snapshot_tuple) < 2:
                logging.warning(f"Skipping invalid snapshot at index {i} during filtering.")
                continue
            current_nodes_data = current_snapshot_tuple[1]
            if current_nodes_data is None: # Handle cases where state_dict might be None
                logging.warning(f"Snapshot at index {i} (time {current_snapshot_tuple[0]}) has None state_dict. Comparing as empty.")
                current_node_ids = set()
                current_parents = {}
            else:
                current_node_ids = set(current_nodes_data.keys())
                current_parents = {k: v.get('parent_id') for k, v in current_nodes_data.items()}

            structure_changed = (current_node_ids != prev_node_ids) or (current_parents != prev_parents)

            if structure_changed:
                filtered_snapshots.append(current_snapshot_tuple)
                prev_node_ids = current_node_ids
                prev_parents = current_parents
                # prev_nodes_data = current_nodes_data # Not strictly needed if only
                # comparing IDs/parents

        # Ensure the last snapshot is considered if it's different or if filtering was
        # too aggressive
        if tree_snapshots_log and tree_snapshots_log[-1] not in filtered_snapshots:
            # Check if the last snapshot's structure is different from the last one
            # added to filtered_snapshots
            should_add_last = True
            if filtered_snapshots: # if any snapshot was added other than the first
                last_added_snapshot_data = filtered_snapshots[-1][1] if filtered_snapshots[-1] and len(filtered_snapshots[-1]) > 1 else {}

                last_snapshot_data = tree_snapshots_log[-1][1] if tree_snapshots_log[-1] and len(tree_snapshots_log[-1]) > 1 else {}

                last_added_ids = set(last_added_snapshot_data.keys()) if last_added_snapshot_data else set()
                last_added_parents = {k:v.get('parent_id') for k,v in last_added_snapshot_data.items()} if last_added_snapshot_data else {}
                
                last_snapshot_ids = set(last_snapshot_data.keys()) if last_snapshot_data else set()
                last_snapshot_parents = {k:v.get('parent_id') for k,v in last_snapshot_data.items()} if last_snapshot_data else {}

                if (last_snapshot_ids == last_added_ids) and (last_snapshot_parents == last_added_parents):
                    should_add_last = False
            
            if should_add_last:
                    filtered_snapshots.append(tree_snapshots_log[-1])
        
        if not filtered_snapshots and tree_snapshots_log: # If structure_only and no changes, at least plot first and last
            filtered_snapshots.append(tree_snapshots_log[0])
            if len(tree_snapshots_log) > 1 and tree_snapshots_log[0] != tree_snapshots_log[-1]:
                    filtered_snapshots.append(tree_snapshots_log[-1])


    elif max_plots and max_plots > 0 and len(tree_snapshots_log) > 0 :
        # Ensure at least first and last are included if max_plots allows
        if max_plots == 1:
            indices = [0]
        elif max_plots == 2 and len(tree_snapshots_log) > 1:
            indices = [0, len(tree_snapshots_log) - 1]
        elif len(tree_snapshots_log) <= max_plots:
            indices = list(range(len(tree_snapshots_log)))
        else: # Select first, last, and evenly spaced in between
            indices = sorted(list(set([0, len(tree_snapshots_log) - 1] + list(np.linspace(0, len(tree_snapshots_log) - 1, max_plots, dtype=int)))))
            indices = indices[:max_plots] # Ensure we don't exceed max_plots

        filtered_snapshots = [tree_snapshots_log[i] for i in indices if tree_snapshots_log[i] and tree_snapshots_log[i][1] is not None]
    else: # Plot all if no filtering or max_plots is None/0
        filtered_snapshots = [s for s in tree_snapshots_log if s and s[1] is not None]

    logging.info(f"Selected {len(filtered_snapshots)} snapshots for plotting.")
    return filtered_snapshots


# 7) Main Functions
def create_drainage_forest(lowest_points_file: str, basins_file: str,
                           inlets_file: str, study_area_file: str,
                           impervious_tif: str) -> Optional[DrainageForest]:
    try:
        logging.info("Loading input files for forest creation...")
        lp = gpd.read_file(lowest_points_file); bas = gpd.read_file(basins_file)
        inl = gpd.read_file(inlets_file); std = gpd.read_file(study_area_file)
        if lp.empty or bas.empty or inl.empty or std.empty:
            logging.error("One or more input shapefiles are empty."); return None
        req_bas_cols = ['basin_id', 'water_volu', 'area']
        if not all(col in bas.columns for col in req_bas_cols):
            logging.error(f"Basins shapefile missing required columns: {req_bas_cols}."); return None
        req_lp_cols = ['basin_id', 'to_basin']
        if not all(col in lp.columns for col in req_lp_cols):
            logging.error(f"Lowest points shapefile missing required columns: {req_lp_cols}."); return None

        forest = DrainageForest()
        forest.build_forest(lp, bas, inl, std, impervious_tif)
        if not forest.all_nodes:
            logging.error("Drainage forest construction resulted in no nodes."); return None
        return forest
    except ImportError as e: logging.error(f"Missing required Python package: {e}."); return None
    except FileNotFoundError as e: logging.error(f"Input file not found during forest creation: {e}"); return None
    except Exception as e:
        logging.error(f"Error creating drainage forest: {e}"); import traceback; traceback.print_exc(); return None

def run_integrated_simulation(forest: DrainageForest,
                              inlet_ids: List[str],
                              basins_file: str,
                              sinks_file: str,
                              lowest_points_file: str,
                              dem_file: str,
                              rainfall_csv: str,
                              simulation_time_min: float,
                              timestep_min: float,
                              rain_unit: str = 'cm/hr',
                              P2_in: float = 4.53
                              ) -> Optional[Tuple[List[Dict], List[Dict], bool, str, Optional[IntegratedSimulation], RunoffProcessor]]:
    if forest is None:
        logging.error("No valid DrainageForest object provided."); return None
    runoff_processor = None; simulator = None
    try:
        for node in forest.all_nodes.values():
            min_cap = 1e-5
            if node.storage_capacity < min_cap:
                logging.debug(f"Adjusting storage capacity for node {node.basin_id} from {node.storage_capacity:.2e} to {min_cap:.1e}")
                node.storage_capacity = min_cap; node.water_volume = min_cap
                if node.area > 1e-9: node.effective_depth = node.storage_capacity / node.area
                else: node.effective_depth = 0.0

        basins_gdf = gpd.read_file(basins_file); sinks_gdf = gpd.read_file(sinks_file)
        if 'basin_id_str' not in basins_gdf.columns: basins_gdf['basin_id_str'] = basins_gdf['basin_id'].astype(str)
        if 'basin_id_str' not in sinks_gdf.columns: sinks_gdf['basin_id_str'] = sinks_gdf['basin_id'].astype(str)

        logging.info("Initializing Runoff Processor...")
        runoff_processor = RunoffProcessor(
            forest=forest, basins_gdf=basins_gdf, sinks_gdf=sinks_gdf,
            lowest_points_file=lowest_points_file, dem_file=dem_file,
            P2_in=P2_in
        )
        runoff_processor.rainfall_csv = rainfall_csv
        runoff_processor.rain_unit = rain_unit

        runoff_processor.calculate_runoff_for_all_basins(
            rainfall_csv=rainfall_csv, sim_time_min=simulation_time_min,
            timestep_min=timestep_min, rain_unit=rain_unit, inlet_ids=inlet_ids
        )
        if not runoff_processor.basin_runoff_data:
            logging.error("Runoff processor failed to generate data for any relevant basins.")
            return None, None, True, "Runoff calculation failed", None, runoff_processor

        logging.info("Initializing Integrated Simulator...")
        simulator = IntegratedSimulation(
            forest=forest, runoff_processor=runoff_processor, selected_inlet_ids=inlet_ids
        )
        if not simulator.initialization_successful:
            logging.error("IntegratedSimulation failed to initialize (e.g., could not derive sim time/step from runoff data).")
            return None, None, True, "Simulator initialization failed", simulator, runoff_processor

        time_series_log, water_balance_log, tree_snapshots_log = simulator.run_simulation()

        return time_series_log, water_balance_log, simulator.halted_prematurely, simulator.halt_reason, simulator, runoff_processor
    except Exception as e:
        logging.error(f"An error occurred during the integrated simulation run: {e}")
        import traceback; traceback.print_exc()
        rp_instance = locals().get('runoff_processor', None); sim_instance = locals().get('simulator', None)
        return None, None, True, f"Unexpected Error: {e}", sim_instance, rp_instance


# 8) Main Execution Block (using main function)

def main(P2_value: float):
    """
    Main function to execute the drainage simulation and visualization.
    MODIFIED: Performs a post-simulation recalculation of EIA based on arrival
    time at the inlet and passes this corrected data for plotting.
    NOW CALLS EXPORT FUNCTIONS FOR BOTH IMPERVIOUS AND TOTAL AREA.
    """
    logging.info("--- Main Execution Started (Sequential Inlet Processing) ---")
    overall_start_time_main = time.time()

    # Get the absolute start time for the entire simulation
    simulation_start_datetime = get_rainfall_start_time(rainfall_csv)
    if simulation_start_datetime is None:
        logging.error("Could not determine simulation start time from rainfall CSV. Datetime export will not be possible.")
    else:
        logging.info(f"Simulation absolute start time determined: {simulation_start_datetime}")

    # Step 1 (Overall): Build the Master Drainage Forest (once for all scenarios)
    master_forest: Optional[DrainageForest] = None
    try:
        logging.info("Building Master Drainage Forest (once for all scenarios)...")
        master_forest = create_drainage_forest(
            lowest_points_file, basins_file, inlets_file,
            study_area_file, impervious_tif
        )
        if master_forest is None or not master_forest.all_nodes:
            raise RuntimeError("Failed to build a non-empty master drainage forest. Aborting all scenarios.")
        logging.info("Master drainage forest built successfully.")
    except Exception as e_forest:
        logging.critical(f"Failed to create master drainage forest: {e_forest}", exc_info=True)
        print(f"CRITICAL ERROR: Failed to create master drainage forest. Details: {traceback.format_exc()}", flush=True)
        return # Exit main if forest creation fails

    # Step 2: Loop through each selected inlet and run the scenario
    if not selected_inlets or not isinstance(selected_inlets, list):
        logging.warning("No inlets specified in 'selected_inlets' list or it's not a list. Nothing to simulate.")
        return

    all_inlets_scenario_time_series_data: Dict[str, List[Dict]] = {}

    for i, current_inlet_id_str in enumerate(selected_inlets):
        scenario_start_time = time.time()
        logging.info(f"--- Starting Scenario {i+1}/{len(selected_inlets)} for Inlet: '{current_inlet_id_str}' ---")
        print(f"\n\n======================================================================")
        print(f"PROCESSING SCENARIO FOR INLET: {current_inlet_id_str}")
        print(f"======================================================================\n")

        current_selected_inlets_for_scenario = [str(current_inlet_id_str)]

        simulator: Optional[IntegratedSimulation] = None
        runoff_processor: Optional[RunoffProcessor] = None
        time_series_log: Optional[List[Dict]] = None
        water_balance_log: Optional[List[Dict]] = None
        tree_snapshots_log: Optional[List[Tuple[float, Dict]]] = None
        halted: bool = True
        halt_reason: str = "Initialization or simulation did not complete for scenario"
        merge_details: List[Dict] = []
        runon_series: List[Dict] = []
        final_outlet_id_scenario: Optional[str] = None
        originals_scenario: Set[str] = set()
        scenario_cap_min: float = 1e-3
        scenario_cap_max: float = 1e-2

        try:
            sim_results = run_integrated_simulation(
                forest=master_forest,
                inlet_ids=current_selected_inlets_for_scenario,
                basins_file=basins_file,
                sinks_file=sinks_file,
                lowest_points_file=lowest_points_file,
                dem_file=dem_file,
                rainfall_csv=rainfall_csv,
                simulation_time_min=simulation_time_min,
                timestep_min=timestep_min,
                rain_unit=rain_unit,
                P2_in=P2_value
            )

            if sim_results is None:
                raise RuntimeError(f"Integrated simulation returned None for inlet(s) {current_selected_inlets_for_scenario}.")

            time_series_log, water_balance_log, halted, halt_reason, simulator, runoff_processor = sim_results

            if time_series_log:
                all_inlets_scenario_time_series_data[current_inlet_id_str] = time_series_log
            else:
                logging.warning(f"Time series log for inlet {current_inlet_id_str} is None. It will not be included in the combined CSV.")

            if simulator:
                tree_snapshots_log = getattr(simulator, 'tree_snapshots', [])
                originals_scenario = getattr(simulator, 'original_basin_ids_in_sim', set())
            else:
                tree_snapshots_log = []
                originals_scenario = set()
                logging.warning(f"Simulator object is None after simulation run for {current_selected_inlets_for_scenario}.")

            if halted:
                logging.warning(f"Simulation for {current_selected_inlets_for_scenario} halted prematurely: {halt_reason}")
            else:
                logging.info(f"Simulation for {current_selected_inlets_for_scenario} completed.")

            if master_forest and master_forest.all_nodes:
                node_capacities = [node.storage_capacity for node in master_forest.all_nodes.values() if node.storage_capacity > 1e-9]
                if node_capacities:
                    scenario_cap_min = max(min(node_capacities), 1e-6)
                    scenario_cap_max = max(max(node_capacities), scenario_cap_min * 10.0 if scenario_cap_min > 0 else scenario_cap_min + 1e-2)
            logging.info(f"Node capacity range for snapshot sizing: min={scenario_cap_min:.3e}, max={scenario_cap_max:.3e}")

            merged_components_scenario = set()
            if simulator:
                history = getattr(simulator, 'merged_states_history', [])
                for ev in history:
                    if isinstance(ev, dict):
                        merged_components_scenario.add(ev.get('upstream_id'))
                        merged_components_scenario.add(ev.get('downstream_id'))
            non_merged_ids_scenario = originals_scenario - merged_components_scenario

            primary_inlet_for_scenario = str(current_selected_inlets_for_scenario[0]) if current_selected_inlets_for_scenario else None
            final_outlet_id_scenario = primary_inlet_for_scenario
            if simulator and runoff_processor and runoff_processor.basin_runoff_data and primary_inlet_for_scenario:
                candidate_outlet = primary_inlet_for_scenario
                latest_merge_time = -1.0
                for basin_id_rp, data_rp in runoff_processor.basin_runoff_data.items():
                    merged_from_ids_rp = [str(m_id) for m_id in data_rp.get('merged_from_ids', [])]
                    if primary_inlet_for_scenario in merged_from_ids_rp:
                        current_merge_time = data_rp.get('merge_time_min', 0.0)
                        if current_merge_time > latest_merge_time:
                            latest_merge_time = current_merge_time
                            candidate_outlet = basin_id_rp
                final_outlet_id_scenario = candidate_outlet
            logging.info(f"Final outlet ID for scenario {current_selected_inlets_for_scenario} plots: {final_outlet_id_scenario}")

            if simulator and runoff_processor and time_series_log and water_balance_log:
                logging.info(f"Starting plotting section for {current_selected_inlets_for_scenario}...")

                # Recalculate EIA based on arrival time
                corrected_eia_df = simulator.calculate_eia_by_arrival_at_inlet()
                if corrected_eia_df is not None:
                    logging.info("Successfully recalculated EIA timeseries.")
                else:
                    logging.warning("Failed to recalculate EIA timeseries.")
                
                if 'visualize_merge_framework' in globals() and callable(visualize_merge_framework):
                    merge_details, runon_series = visualize_merge_framework(
                        tree_snapshots_log if tree_snapshots_log else [],
                        simulator, runoff_processor,
                        current_selected_inlets_for_scenario,
                        non_merged_ids_scenario,
                        simulation_time_min=simulation_time_min
                    )
                else: logging.warning("visualize_merge_framework function not found."); merge_details, runon_series = [], []
                
                if 'plot_runon_time_series' in globals() and callable(plot_runon_time_series) and runon_series:
                    true_total_runon_for_plot = sum(d.get('infiltration_total', 0.0) for d in merge_details)
                    plot_runon_time_series(runon_series, true_total_runon_for_plot, simulation_time_min)
                else: logging.warning("plot_runon_time_series function not found or no runon_series data.")
                
                if tree_snapshots_log:
                    if 'plot_drainage_snapshots' in globals() and callable(plot_drainage_snapshots):
                        plot_drainage_snapshots(
                            forest=master_forest,
                            snapshots=tree_snapshots_log,
                            inlet_id=str(current_selected_inlets_for_scenario[0]) if current_selected_inlets_for_scenario else "N/A",
                            time_series=time_series_log,
                            global_cap_min=scenario_cap_min,
                            global_cap_max=scenario_cap_max,
                            original_inlet_ids=current_selected_inlets_for_scenario,
                            structure_changes_only=True,
                            max_plots=None,
                            runoff_processor=runoff_processor
                        )
                    else: logging.warning(f"No tree snapshots for {current_selected_inlets_for_scenario} to plot.")
                
                if 'plot_timestep_water_balance' in globals() and callable(plot_timestep_water_balance) and water_balance_log:
                    fig_ts_wb, ax_ts_wb = plot_timestep_water_balance(
                        water_balance_log=water_balance_log,
                        runon_time_series_data=runon_series,
                        simulation_time_min_param=simulation_time_min
                    )
                    if fig_ts_wb: plt.show()
                else: logging.warning("'plot_timestep_water_balance' not found or log missing.")
                
                if 'plot_mass_balance' in globals() and callable(plot_mass_balance) and water_balance_log:
                    total_potential_runoff_for_mb = 0.0
                    if originals_scenario and runoff_processor and runoff_processor.basin_runoff_data:
                        potential_volumes = [
                            float(runoff_processor.basin_runoff_data[b_id].get('total_direct_runoff_vol', 0.0))
                            for b_id in originals_scenario if b_id in runoff_processor.basin_runoff_data and
                            isinstance(runoff_processor.basin_runoff_data[b_id].get('total_direct_runoff_vol', 0.0), (int, float, np.number)) and
                            np.isfinite(runoff_processor.basin_runoff_data[b_id].get('total_direct_runoff_vol', 0.0))
                        ]
                        total_potential_runoff_for_mb = sum(potential_volumes)
                    fig_stacked_mb, ax_stacked_mb = plot_mass_balance(
                        water_balance=water_balance_log,
                        runon_time_series_data=runon_series,
                        total_potential_runoff_input=total_potential_runoff_for_mb
                    )
                    if fig_stacked_mb: plt.show()
                else: logging.warning("'plot_mass_balance' not found or log missing.")
                
                if 'plot_simulation_results' in globals() and callable(plot_simulation_results):
                    primary_inlet_rp_data_scenario = runoff_processor.basin_runoff_data.get(str(current_selected_inlets_for_scenario[0])) if current_selected_inlets_for_scenario and runoff_processor else None
                    final_outlet_rp_data_scenario = runoff_processor.basin_runoff_data.get(final_outlet_id_scenario) if final_outlet_id_scenario and runoff_processor else None
                    plot_simulation_results(
                        time_series=time_series_log,
                        outlet_runoff_data=final_outlet_rp_data_scenario,
                        primary_inlet_runoff_data=primary_inlet_rp_data_scenario,
                        rain_unit=rain_unit,
                        simulation_time_min=simulation_time_min,
                        selected_inlets=current_selected_inlets_for_scenario,
                        recalculated_eia_df=corrected_eia_df,
                        export_discharge_csv=True
                    )
                else: logging.error("'plot_simulation_results' function not found.")
                
                if 'plot_basin_runoff_stages' in globals() and callable(plot_basin_runoff_stages):
                    if originals_scenario and runoff_processor and runoff_processor.basin_runoff_data:
                        for b_id_orig in sorted(list(originals_scenario)):
                            basin_data_for_stage = runoff_processor.basin_runoff_data.get(b_id_orig)
                            if basin_data_for_stage:
                                plot_basin_runoff_stages(
                                    basin_runoff_data=basin_data_for_stage,
                                    simulation_time_min=simulation_time_min,
                                    rain_unit_pref=rain_unit
                                )
                else: logging.error("'plot_basin_runoff_stages' function not found.")
                logging.info(f"Plotting section for {current_selected_inlets_for_scenario} finished.")
            else:
                reason_skip = "Simulation halted for scenario" if halted else "Simulator/RunoffProcessor/Logs missing for scenario"
                logging.warning(f"Skipping plotting for {current_selected_inlets_for_scenario} because: {reason_skip}.")

        except RuntimeError as re_scenario:
            logging.critical(f"A RuntimeError occurred during processing for inlet(s) {current_selected_inlets_for_scenario}: {re_scenario}", exc_info=False)
            print(f"CRITICAL RuntimeError for inlet(s) {current_selected_inlets_for_scenario}: {re_scenario}. Check logs for details.", flush=True)
        except Exception as e_scenario:
            logging.critical(f"An unexpected error occurred during processing for inlet(s) {current_selected_inlets_for_scenario}: {e_scenario}", exc_info=True)
            print(f"ERROR for inlet(s) {current_selected_inlets_for_scenario}: {e_scenario}. Details: {traceback.format_exc()}", flush=True)
        finally:
            elapsed_scenario = time.time() - scenario_start_time
            logging.info(f"--- Scenario for Inlet(s): {current_selected_inlets_for_scenario} finished in {elapsed_scenario:.2f}s ---")
            print(f"--- Scenario for Inlet(s): {current_selected_inlets_for_scenario} finished in {elapsed_scenario:.2f}s ---", flush=True)

    # After all scenarios, call the export functions
    if simulation_start_datetime:
        if all_inlets_scenario_time_series_data:
            # Export Impervious Area
            if 'export_impervious_area_timeseries' in globals() and callable(export_impervious_area_timeseries):
                export_impervious_area_timeseries(
                    all_scenario_time_series=all_inlets_scenario_time_series_data,
                    simulation_time_min_param=simulation_time_min,
                    start_datetime=simulation_start_datetime,
                    export_interval_min=0.5,
                    output_filename="combined_inlet_impervious_areas.csv"
                )
            else:
                logging.error("'export_impervious_area_timeseries' function not defined. Cannot export impervious area CSV.")

            # Export Total Area
            if 'export_total_area_timeseries' in globals() and callable(export_total_area_timeseries):
                export_total_area_timeseries(
                    all_scenario_time_series=all_inlets_scenario_time_series_data,
                    simulation_time_min_param=simulation_time_min,
                    start_datetime=simulation_start_datetime,
                    export_interval_min=0.5,
                    output_filename="combined_inlet_total_areas.csv"
                )
            else:
                logging.error("'export_total_area_timeseries' function not defined. Cannot export total area CSV.")
        else:
            logging.warning("No data collected from any scenario; skipping CSV exports.")
    else:
        logging.error("Cannot export with datetime format because simulation start time could not be determined.")

    overall_elapsed_main = time.time() - overall_start_time_main
    logging.info(f"--- Main execution (all scenarios) finished in {overall_elapsed_main:.2f}s ---")
    print(f"--- Main execution (all scenarios) finished in {overall_elapsed_main:.2f}s ---", flush=True)

# 9) Main Execution Block (if __name__ == "__main__":)
if __name__ == "__main__":
    print(f"--- PYTHON SCRIPT ENTRY POINT REACHED (Version: {sys.version}) ---", flush=True)
    script_start_time_overall = time.time()
    
    try:
        # P2_USER_VALUE is a global from your USER INPUT SECTION
        logging.info(f"Global P2 value for Tc/Lag calculation being passed to main: {P2_USER_VALUE} inches")
        print(f"Script using P2 value: {P2_USER_VALUE} inches", flush=True)

        # Call the main function which now handles the loop internally
        main(P2_USER_VALUE)

    except NameError as ne_global:
        logging.critical(f"A NameError occurred at the global/entry point level: {ne_global}. Ensure all required global variables are defined.", exc_info=True)
        print(f"CRITICAL NameError (Global Scope): {ne_global}. Check script's USER INPUT SECTION. Details: {traceback.format_exc()}", flush=True)
    except RuntimeError as rt_global:
        logging.critical(f"A RuntimeError occurred: {rt_global}", exc_info=True)
        print(f"CRITICAL RuntimeError: {rt_global}. Details: {traceback.format_exc()}", flush=True)
    except Exception as e_global:
        logging.critical(f"A critical error occurred at the top level of script execution: {e_global}", exc_info=True)
        print(f"CRITICAL ERROR (Global Scope): {e_global}. Details: {traceback.format_exc()}", flush=True)
    finally:
        script_end_time_overall = time.time()
        elapsed_time_overall = script_end_time_overall - script_start_time_overall
        logging.info(f"--- Total Script Execution Time: {elapsed_time_overall:.2f} seconds ---")
        print(f"--- PYTHON SCRIPT FINISHING --- Total Elapsed Time: {elapsed_time_overall:.2f} seconds", flush=True)
        try:
            if logging.getLogger().hasHandlers():
                logging.shutdown()
        except Exception:
            pass