In [10]:
# Skeletonise a binary channel raster (multi-pixel thickness -> 1-pixel centreline)
# Output is a georeferenced uint8 raster: 1 = skeleton, 0 = background (and preserves nodata if present)

import numpy as np
import rasterio

# ---- USER PARAMS ----
in_raster   = "/home/hector/Documents/raster_file.tif"
out_raster  = "/home/hector/Documents/skeleton.tif"

# How to binarize (edit if needed):
# - if raster is 0/1 -> (arr == 1)
# - if raster is probability 0..1 -> (arr >= 0.5)
# - if raster is 0/255 -> (arr > 0)
binarize = lambda arr: (arr > 0)
# ---------------------

try:
    from skimage.morphology import skeletonize
except ImportError as e:
    raise ImportError("This cell requires scikit-image. Install with: pip install scikit-image (or conda install scikit-image)") from e

with rasterio.open(in_raster) as src:
    prof = src.profile.copy()
    nodata = src.nodata
    arr = src.read(1, masked=True)  # masked array if nodata exists

# Build boolean mask for channels, respecting nodata
binary = np.zeros(arr.shape, dtype=bool)
valid = ~arr.mask if np.ma.isMaskedArray(arr) else np.ones(arr.shape, dtype=bool)
binary[valid] = binarize(np.asarray(arr)[valid])

# Skeletonise: 1-pixel-wide centreline
skel = skeletonize(binary).astype(np.uint8)  # 0/1

# Optionally restore nodata (so nodata isn't forced to 0)
if nodata is not None and np.any(~valid):
    skel = skel.astype(np.uint8)
    # Keep as uint8 for GIS friendliness; nodata will be set in metadata below
    skel[~valid] = 0  # keep nodata pixels as 0 in data; nodata flag handles masking in GIS

# Write output
prof.update(dtype=rasterio.uint8, count=1, compress="deflate", nodata=0 if nodata is None else nodata)
with rasterio.open(out_raster, "w", **prof) as dst:
    dst.write(skel, 1)

print(f"Done. Wrote skeleton raster: {out_raster}")


Done. Wrote skeleton raster: /home/hector/Downloads/skeleton_corridors<015.tif


In [3]:
# Connected-pixel filtering of a binary channel raster (remove small components)
# - Reads a binary raster (0/1 or 0/255 etc.)
# - Labels connected components
# - Removes components with fewer than `min_pixels` pixels
# - Writes a filtered binary raster with the same georeferencing

import numpy as np
import rasterio


# ---- USER PARAMS ----
in_raster   = "/home/hector/Documents/skeleton.tif"
out_raster  = "/home/hector/Documents/filter_skeleton_50conPix.tif"
min_pixels  = 50          # remove components smaller than this (in pixels)
connectivity = 8          # 4 or 8 (8 is usually better for channel networks)

# If your raster is not strictly 0/1, define how to binarize it:
# e.g., channels are values > 0
binarize = lambda arr: (arr > 0)

# ---------------------
try:
    from scipy.ndimage import label
except ImportError as e:
    raise ImportError("This cell requires scipy. Install it with: pip install scipy (or conda install scipy)") from e

if connectivity not in (4, 8):
    raise ValueError("connectivity must be 4 or 8")

# Structuring element defines pixel connectivity
structure = np.array([[0,1,0],
                      [1,1,1],
                      [0,1,0]], dtype=np.uint8) if connectivity == 4 else np.ones((3,3), dtype=np.uint8)

with rasterio.open(in_raster) as src:
    prof = src.profile.copy()
    nodata = src.nodata

    band = src.read(1, masked=True)  # masked array if nodata exists
    # Build boolean binary mask for channels, respecting nodata
    binary = np.zeros(band.shape, dtype=bool)
    valid = ~band.mask if np.ma.isMaskedArray(band) else np.ones(band.shape, dtype=bool)
    binary[valid] = binarize(np.asarray(band)[valid])

    # Label connected components (only on True pixels)
    labels, nlab = label(binary, structure=structure)

    if nlab == 0:
        # Nothing to filter; just write an empty binary raster (or original)
        filtered = binary.astype(np.uint8)
    else:
        # Component sizes (labels start at 1; label 0 is background)
        sizes = np.bincount(labels.ravel())
        keep = sizes >= min_pixels
        keep[0] = False  # never keep background

        filtered = keep[labels].astype(np.uint8)

    # If you want to preserve nodata (rather than forcing nodata to 0), apply it back:
    if nodata is not None and np.any(~valid):
        filtered = filtered.astype(np.uint8)
        # set nodata pixels to nodata value (commonly 0, but not always)
        filtered = filtered.astype(np.float32) if prof["dtype"] in ("float32", "float64") else filtered
        filtered[~valid] = nodata

    # Write output: keep it binary uint8 unless you explicitly need another dtype
    prof.update(dtype=rasterio.uint8, count=1, nodata=0 if nodata is None else nodata, compress="deflate")
    # If nodata exists and is not 0, you may prefer to keep prof["dtype"] instead; adjust as needed.

    with rasterio.open(out_raster, "w", **prof) as dst:
        dst.write(filtered.astype(np.uint8), 1)

print(f"Done. Wrote: {out_raster}")
print(f"Removed components with < {min_pixels} pixels (connectivity={connectivity}).")


Done. Wrote: /home/hector/Documents/Nazarij/filter_skeleton_50conPix.tif
Removed components with < 50 pixels (connectivity=8).


In [8]:
# Cell: Transform a filtered skeleton raster (0/1) into a network with:
#  - endpoint↔endpoint gap bridging (angle + distance)
#  - endpoint→vertex snapping (attach dangling endpoints to nearby line vertices)
#  - degree-2 contraction (analysis network)
#  - optional parallel/near-parallel connector creation (with proper node insertion)
#  - optional directionality from a slope-direction/aspect raster
#  - optional per-edge cost extraction from a cost raster (sum of traversed cells)


import os
import math
import numpy as np
import rasterio
import networkx as nx

# ---- USER PARAMS ------------------------------------------------------------
in_skeleton = "/home/hector/Documents/filter_skeleton_50conPix.tif"
out_gpkg    = "/home/hector/Documents/channels_50conPix_network.gpkg"

# Geometry clean-up (map units)
simplify_tol = None          # if None, auto ~ 0.5 pixel
smooth_iters = 2             # 0 disables smoothing; 1-3 typical
smooth_then_simplify = True  # re-simplify after smoothing

# ---- Gap filling: endpoint -> endpoint (done post-contraction) ----
bridge_endpoints = True
max_gap_dist = 900.0          # max distance to bridge
max_angle_deg = 65.0          # endpoints must face each other within this angle
direction_step = 3            # vertices ahead to estimate local direction
max_bridges_per_endpoint = 1
max_bridge_rounds = 2         # run bridging in multiple rounds (helps when topology changes)

# ---- Gap filling: endpoint -> vertex snapping (endpoint to an *existing vertex* on another line) ----
bridge_to_vertices = True
max_snap_dist = 600.0          # max distance endpoint->vertex
max_snap_angle_deg = 45.0      # endpoint extension must align within this angle
max_target_angle_deg = 90.0    # how "clean" the attachment is to the target line at that vertex. 90.0 = T-type attachment
max_vertex_snaps_per_endpoint = 1
vertex_candidate_stride = 1    # check every vertex (1), or subsample (e.g., 2, 3) if huge dataset

# ---- Optional directionality from a slope/aspect/flow-direction raster ----
slope_dir_raster = None        # e.g. "/path/to/aspect.tif" or None to disable
slope_dir_convention = "cw_from_north"  # "cw_from_north" or "ccw_from_east"
slope_dir_is_downhill = True   # True if raster indicates downhill direction; False if uphill

# ---- Optional costs per edge from a cost raster ----
cost_raster = None             # e.g. "/path/to/cost.tif" or None to disable
cost_nodata_to_nan = True
cost_sampling_step = None      # if None, ~0.5 pixel; else spacing in map units for sampling points
cost_unique_cells = True       # sum unique traversed cells (recommended)

# ---- Join closely running (parallel / near-parallel) lines by adding connectors ----
join_parallel_lines = False
parallel_max_dist = 40.0       # max separation between lines to consider connecting
parallel_max_angle_deg = 15.0  # max angular difference between overall directions (mod 180)
parallel_min_length = 50.0     # ignore very short edges
max_parallel_connectors = 200  # safety cap
# ----------------------------------------------------------------------------

# deps
try:
    import sknw  # pip install sknw
except ImportError as e:
    raise ImportError("Install sknw with: pip install sknw") from e

try:
    import geopandas as gpd
    from shapely.geometry import Point, LineString
    from shapely.strtree import STRtree
    from shapely.ops import substring, nearest_points
except ImportError as e:
    raise ImportError("Install geopandas + shapely with: pip install geopandas shapely") from e


# ---------------------- numeric helpers -------------------------------------
def angle_deg(v1, v2):
    v1 = np.asarray(v1, dtype=float); v2 = np.asarray(v2, dtype=float)
    n1 = np.linalg.norm(v1); n2 = np.linalg.norm(v2)
    if n1 == 0 or n2 == 0:
        return 180.0
    cosang = np.clip(np.dot(v1, v2) / (n1 * n2), -1.0, 1.0)
    return float(np.degrees(np.arccos(cosang)))


def angle_diff_180(a_deg, b_deg):
    """Smallest angular difference modulo 180 degrees (for undirected parallel comparison)."""
    d = abs((a_deg - b_deg) % 180.0)
    return min(d, 180.0 - d)


def bearing_cw_from_north(p0, p1):
    """Bearing degrees clockwise from North for (x,y) with +y = North."""
    p0 = np.asarray(p0, float); p1 = np.asarray(p1, float)
    dx = p1[0] - p0[0]
    dy = p1[1] - p0[1]
    ang = np.degrees(np.arctan2(dx, dy))  # atan2(x,y) => cw from north
    return float((ang + 360.0) % 360.0)


def convert_slope_dir_to_cw_from_north(val_deg, convention):
    """Convert raster direction to cw-from-north degrees."""
    if not np.isfinite(val_deg):
        return np.nan
    v = float(val_deg)
    if convention == "cw_from_north":
        return (v % 360.0)
    elif convention == "ccw_from_east":
        # 0=east ccw -> 90=north; convert to cw-from-north: cw = (90 - ccw_east) mod 360
        return float((90.0 - v) % 360.0)
    else:
        raise ValueError("slope_dir_convention must be 'cw_from_north' or 'ccw_from_east'")


def pixel_size_from_transform(transform):
    px = np.hypot(transform.a, transform.b)
    py = np.hypot(transform.d, transform.e)
    return float((px + py) / 2.0)


# ---------------------- raster/geometry helpers -----------------------------
def pix_to_xy(transform, rc):
    rows = rc[:, 0].astype(int)
    cols = rc[:, 1].astype(int)
    xs, ys = rasterio.transform.xy(transform, rows, cols, offset="center")
    return np.column_stack([np.asarray(xs), np.asarray(ys)])


def node_xy(G, n, transform):
    rc = np.array(G.nodes[n]["o"], dtype=float).reshape(1, 2)
    return pix_to_xy(transform, rc)[0]


def snap_edge_endpoints_to_nodes(G, u, v, k, transform):
    """
    Edge coords oriented u->v, and endpoints snapped EXACTLY to node coords.
    """
    pts_rc = G.edges[u, v, k]["pts"]
    xy = pix_to_xy(transform, pts_rc)

    xu = node_xy(G, u, transform)
    xv = node_xy(G, v, transform)

    # orient so start closer to u
    if np.linalg.norm(xy[0] - xu) <= np.linalg.norm(xy[-1] - xu):
        xy_oriented = xy.copy()
    else:
        xy_oriented = xy[::-1].copy()

    xy_oriented[0]  = xu
    xy_oriented[-1] = xv
    return xy_oriented


def chaikin_smooth(coords, n_iter=2):
    """Chaikin smoothing, keeping endpoints."""
    if n_iter <= 0 or len(coords) < 3:
        return coords
    out = coords
    for _ in range(n_iter):
        P = out
        Q = 0.75 * P[:-1] + 0.25 * P[1:]
        R = 0.25 * P[:-1] + 0.75 * P[1:]
        out2 = np.vstack([P[0], np.column_stack([Q, R]).reshape(-1, 2), P[-1]])
        _, idx = np.unique(out2, axis=0, return_index=True)
        out = out2[np.sort(idx)]
        if len(out) < 3:
            break
    return out


def simplify_coords(coords, tol):
    """Shapely simplify; preserve endpoints."""
    if tol is None or tol <= 0 or len(coords) < 3:
        return coords
    ls = LineString(coords)
    ls2 = ls.simplify(float(tol), preserve_topology=True)
    c = np.asarray(ls2.coords)
    c[0] = coords[0]
    c[-1] = coords[-1]
    return c


def split_linestring_at_point(line, pt, eps=1e-9):
    """
    Split LineString into two at projection of pt onto line.
    Returns (seg1, seg2, at_endpoint_bool).
    """
    if line.is_empty or line.length == 0:
        return line, None, True
    d = float(line.project(pt))
    L = float(line.length)
    if d <= eps or d >= L - eps:
        return line, None, True
    seg1 = substring(line, 0.0, d)
    seg2 = substring(line, d, L)
    return seg1, seg2, False


def endpoint_direction_from_linestring(endpoint_xy, line_coords, step=3):
    """
    line_coords is oriented so that coords[0] is endpoint (or very close).
    returns vector from endpoint into the line interior.
    """
    if len(line_coords) < 2:
        return np.array([0.0, 0.0])
    i1 = min(int(step), len(line_coords) - 1)
    return np.asarray(line_coords[i1], float) - np.asarray(line_coords[0], float)


def raster_value_at_xy(src, xy):
    try:
        val = next(src.sample([(float(xy[0]), float(xy[1]))]))[0]
    except Exception:
        return np.nan
    if val is None:
        return np.nan
    try:
        if np.ma.isMaskedArray(val) and val.mask:
            return np.nan
    except Exception:
        pass
    return float(val)


def cost_sum_along_line(line, cost_arr, transform, nodata, step, nodata_to_nan=True, unique_cells=True):
    """
    Approximate sum of raster cost along a line.
    - If unique_cells=True: sample points, map to cells, sum unique traversed cells.
    - Else: sum values at sample points (can double count cells).
    """
    if line.is_empty or line.length == 0:
        return 0.0

    L = float(line.length)
    n = max(2, int(np.ceil(L / float(step))) + 1)
    ds = np.linspace(ளம் := 0.0, L, n)  # keep deterministic

    pts = [line.interpolate(float(d)) for d in ds]
    xs = np.array([p.x for p in pts], dtype=float)
    ys = np.array([p.y for p in pts], dtype=float)

    rows, cols = rasterio.transform.rowcol(transform, xs, ys)
    rc = np.column_stack([rows, cols]).astype(int)

    h, w = cost_arr.shape
    ok = (rc[:, 0] >= 0) & (rc[:, 0] < h) & (rc[:, 1] >= 0) & (rc[:, 1] < w)
    rc = rc[ok]
    if rc.size == 0:
        return 0.0

    if unique_cells:
        rc = np.unique(rc, axis=0)

    vals = cost_arr[rc[:, 0], rc[:, 1]].astype(float)

    if nodata is not None:
        m = (vals == nodata)
        if m.any():
            if nodata_to_nan:
                vals[m] = np.nan
            else:
                vals[m] = 0.0

    vals = vals[np.isfinite(vals)]
    if vals.size == 0:
        return 0.0
    return float(np.sum(vals))


# ---------------------- export from sknw graph ------------------------------
def build_nodes_edges_gdfs(G, transform, crs, simplify_tol, smooth_iters):
    # nodes
    node_rows = []
    node_xy_cache = {}
    for n in G.nodes:
        xy = node_xy(G, n, transform)
        node_xy_cache[n] = xy
        node_rows.append({
            "node": int(n),
            "degree": int(G.degree[n]),
            "x": float(xy[0]),
            "y": float(xy[1]),
            "geometry": Point(xy)
        })
    nodes_gdf = gpd.GeoDataFrame(node_rows, geometry="geometry", crs=crs)

    # edges
    edge_rows = []
    for u, v, k, data in G.edges(keys=True, data=True):
        xy = snap_edge_endpoints_to_nodes(G, u, v, k, transform)

        xy = simplify_coords(xy, simplify_tol)
        if smooth_iters > 0:
            xy = chaikin_smooth(xy, n_iter=smooth_iters)
            if smooth_then_simplify:
                xy = simplify_coords(xy, simplify_tol)

        geom = LineString(xy)
        edge_rows.append({
            "u": int(u), "v": int(v), "k": int(k),
            "n_vert": int(len(xy)),
            "length": float(geom.length),
            "is_bridge": bool(data.get("is_bridge", False)),
            "bridge_type": str(data.get("bridge_type", "")) if data.get("is_bridge", False) else "",
            "geometry": geom
        })
    edges_gdf = gpd.GeoDataFrame(edge_rows, geometry="geometry", crs=crs)
    return nodes_gdf, edges_gdf


# ---------------------- degree-2 contraction --------------------------------
def contract_degree2_graph(G, nodes_gdf, edges_gdf):
    """
    Contract degree-2 nodes by merging chains into single edges between 'kept' nodes.
    """
    node_geom = {int(r.node): r.geometry for r in nodes_gdf.itertuples(index=False)}
    node_xy_  = {int(r.node): np.array([float(r.x), float(r.y)]) for r in nodes_gdf.itertuples(index=False)}
    deg = dict(G.degree)

    keep = {int(n) for n in G.nodes if deg[n] != 2}

    # pure cycles: keep one node per all-degree-2 component
    for comp in nx.connected_components(nx.Graph(G)):
        comp = {int(n) for n in comp}
        if comp and all(deg[n] == 2 for n in comp):
            keep.add(min(comp))

    coords_map = {}
    for r in edges_gdf.itertuples(index=False):
        coords_map[(int(r.u), int(r.v), int(r.k))] = np.asarray(r.geometry.coords, dtype=float)

    def norm_eid(a, b, k):
        a = int(a); b = int(b); k = int(k)
        return (a, b, k) if a <= b else (b, a, k)

    def get_seg_coords(a, b, k):
        a = int(a); b = int(b); k = int(k)
        if (a, b, k) in coords_map:
            return coords_map[(a, b, k)]
        if (b, a, k) in coords_map:
            return coords_map[(b, a, k)][::-1]
        return np.vstack([node_xy_[a], node_xy_[b]])

    visited = set()
    out_edges = []

    for s in sorted(keep):
        for nbr, keydict in G[s].items():
            for k in keydict.keys():
                eid = norm_eid(s, nbr, k)
                if eid in visited:
                    continue

                prev = int(s)
                curr = int(nbr)
                kk   = int(k)

                chain_coords = list(get_seg_coords(prev, curr, kk))
                visited.add(eid)
                chain_n_segs = 1

                while curr not in keep and deg[curr] == 2:
                    found = False
                    for nbr2, keydict2 in G[curr].items():
                        nbr2 = int(nbr2)
                        if nbr2 == prev:
                            continue
                        for k2 in keydict2.keys():
                            k2 = int(k2)
                            eid2 = norm_eid(curr, nbr2, k2)
                            if eid2 in visited:
                                continue
                            seg = list(get_seg_coords(curr, nbr2, k2))
                            chain_coords.extend(seg[1:])
                            visited.add(eid2)
                            chain_n_segs += 1
                            prev, curr = curr, nbr2
                            found = True
                            break
                        if found:
                            break
                    if not found:
                        break

                t = int(curr)
                if len(chain_coords) >= 2:
                    geom = LineString(chain_coords)
                    out_edges.append({
                        "u": int(s),
                        "v": int(t),
                        "n_segs": int(chain_n_segs),
                        "length": float(geom.length),
                        "is_bridge": False,
                        "bridge_type": "",
                        "geometry": geom
                    })

    edges2_gdf = gpd.GeoDataFrame(out_edges, geometry="geometry", crs=edges_gdf.crs)

    # degrees in contracted graph
    H = nx.Graph()
    for r in edges2_gdf.itertuples(index=False):
        H.add_edge(int(r.u), int(r.v))

    node_rows = []
    for n in sorted(keep):
        if n not in node_geom:
            continue
        xy = node_xy_[n]
        node_rows.append({
            "node": int(n),
            "degree": int(H.degree[n]) if n in H else 0,
            "x": float(xy[0]),
            "y": float(xy[1]),
            "geometry": node_geom[n]
        })
    nodes2_gdf = gpd.GeoDataFrame(node_rows, geometry="geometry", crs=nodes_gdf.crs)
    return nodes2_gdf, edges2_gdf


# ---------------------- mutable network utilities ---------------------------
def recompute_node_degrees(nodes, edges):
    deg = {int(n): 0 for n in nodes.keys()}
    for e in edges:
        deg[int(e["u"])] = deg.get(int(e["u"]), 0) + 1
        deg[int(e["v"])] = deg.get(int(e["v"]), 0) + 1
    return deg


def build_node_dict(nodes_gdf):
    nodes = {}
    for r in nodes_gdf.itertuples(index=False):
        nodes[int(r.node)] = np.array([float(r.x), float(r.y)], dtype=float)
    return nodes


def make_or_get_node_at_point(nodes, pt, tol=1e-6):
    """
    Create/reuse node at pt; reuse if within tol of an existing node.
    """
    xy = np.array([float(pt.x), float(pt.y)], dtype=float)
    # quick reuse: exact match by rounding
    key = (round(xy[0], 6), round(xy[1], 6))
    # map may not exist; build lazily
    return xy, key


def build_rounding_index(nodes, nd=6):
    idx = {}
    for nid, xy in nodes.items():
        idx[(round(float(xy[0]), nd), round(float(xy[1]), nd))] = int(nid)
    return idx


def ensure_node(nodes, idx, xy, next_id):
    key = (round(float(xy[0]), 6), round(float(xy[1]), 6))
    if key in idx:
        return idx[key], next_id
    nid = int(next_id)
    next_id += 1
    nodes[nid] = np.array([float(xy[0]), float(xy[1])], dtype=float)
    idx[key] = nid
    return nid, next_id


def edge_coords_oriented_from_node(edge_geom, node_xy):
    coords = np.asarray(edge_geom.coords, dtype=float)
    if np.linalg.norm(coords[0] - node_xy) <= np.linalg.norm(coords[-1] - node_xy):
        return coords
    return coords[::-1].copy()


def edge_bearing_mod180(edge_geom):
    coords = np.asarray(edge_geom.coords, dtype=float)
    b = bearing_cw_from_north(coords[0], coords[-1])
    return float(b % 180.0)


def split_edge_at_point(nodes, idx, edges, edge_i, pt, next_node_id,
                        cost_ctx=None, parent_cost=None, eps=1e-9):
    """
    Split edges[edge_i] at pt (projected), inserting a node if interior.
    Returns: (node_id_at_split, next_node_id, did_split_bool)
    """
    e = edges[edge_i]
    geom = e["geometry"]
    seg1, seg2, at_endpoint = split_linestring_at_point(geom, pt, eps=eps)

    if at_endpoint or seg2 is None:
        # snap to closest endpoint node id (existing)
        xy = np.array([pt.x, pt.y], float)
        uxy = nodes[int(e["u"])]
        vxy = nodes[int(e["v"])]
        if np.linalg.norm(uxy - xy) <= np.linalg.norm(vxy - xy):
            return int(e["u"]), next_node_id, False
        return int(e["v"]), next_node_id, False

    # interior split: create/reuse node at split point
    split_xy = np.array([pt.x, pt.y], dtype=float)
    nid, next_node_id = ensure_node(nodes, idx, split_xy, next_node_id)

    # replace edge with seg1 and append seg2
    u = int(e["u"]); v = int(e["v"])
    attrs = {k: e[k] for k in e.keys() if k not in ("geometry", "u", "v", "length", "cost_sum")}
    # first piece
    e1 = {
        **attrs,
        "u": u, "v": nid,
        "geometry": seg1,
        "length": float(seg1.length),
    }
    # second piece
    e2 = {
        **attrs,
        "u": nid, "v": v,
        "geometry": seg2,
        "length": float(seg2.length),
    }

    if cost_ctx is not None:
        e1["cost_sum"] = cost_sum_along_line(seg1, cost_ctx["arr"], cost_ctx["transform"], cost_ctx["nodata"],
                                             cost_ctx["step"], nodata_to_nan=cost_ctx["nodata_to_nan"],
                                             unique_cells=cost_ctx["unique_cells"])
        e2["cost_sum"] = cost_sum_along_line(seg2, cost_ctx["arr"], cost_ctx["transform"], cost_ctx["nodata"],
                                             cost_ctx["step"], nodata_to_nan=cost_ctx["nodata_to_nan"],
                                             unique_cells=cost_ctx["unique_cells"])
    elif "cost_sum" in e:
        # if you had a cost but no ctx, carry proportionally (fallback)
        pc = float(e.get("cost_sum", 0.0))
        L = float(geom.length) if geom.length else 1.0
        e1["cost_sum"] = pc * (float(seg1.length) / L)
        e2["cost_sum"] = pc * (float(seg2.length) / L)

    edges[edge_i] = e1
    edges.append(e2)
    return int(nid), next_node_id, True


# ---------------------- post-contraction: endpoint↔endpoint bridging ----------
def build_endpoint_directions(nodes, edges, direction_step=3):
    """
    Compute for degree-1 nodes:
      - outward direction vector (pointing OUT of the endpoint) = - (into-line vector)
    Returns: endpoints list, dir_map[endpoint_id] = outward_vector (not normalized)
    """
    deg = {nid: 0 for nid in nodes.keys()}
    for e in edges:
        deg[int(e["u"])] = deg.get(int(e["u"]), 0) + 1
        deg[int(e["v"])] = deg.get(int(e["v"]), 0) + 1

    endpoints = [int(n) for n, d in deg.items() if d == 1]
    dir_map = {}

    # build adjacency: endpoint -> (edge index)
    incident = {int(n): [] for n in nodes.keys()}
    for i, e in enumerate(edges):
        incident[int(e["u"])].append(i)
        incident[int(e["v"])].append(i)

    for n in endpoints:
        if len(incident[n]) != 1:
            continue
        ei = incident[n][0]
        e = edges[ei]
        nxy = nodes[n]
        coords = edge_coords_oriented_from_node(e["geometry"], nxy)
        d_in = endpoint_direction_from_linestring(nxy, coords, step=direction_step)
        d_out = -d_in
        dir_map[n] = d_out

    return endpoints, dir_map


def bridge_endpoints_round(nodes, edges, max_gap_dist, max_angle_deg, direction_step,
                           max_bridges_per_endpoint, bridge_type="endpoint_endpoint"):
    endpoints, out_dir = build_endpoint_directions(nodes, edges, direction_step=direction_step)
    used = {n: 0 for n in endpoints}
    added = 0

    # naive O(E^2) over endpoints; ok for moderate endpoint counts
    for a in endpoints:
        if used[a] >= max_bridges_per_endpoint:
            continue
        if a not in out_dir:
            continue
        xa = nodes[a]
        da = out_dir[a]

        best_b = None
        best_score = None

        for b in endpoints:
            if b == a:
                continue
            if used[b] >= max_bridges_per_endpoint:
                continue
            if b not in out_dir:
                continue

            xb = nodes[b]
            ab = xb - xa
            dist = float(np.linalg.norm(ab))
            if dist == 0 or dist > float(max_gap_dist):
                continue

            ang_a = angle_deg(da, ab)
            db = out_dir[b]
            ang_b = angle_deg(db, -ab)

            if ang_a <= float(max_angle_deg) and ang_b <= float(max_angle_deg):
                score = dist + 0.2 * (ang_a + ang_b)
                if best_score is None or score < best_score:
                    best_score = score
                    best_b = b

        if best_b is not None:
            b = best_b
            geom = LineString([tuple(nodes[a]), tuple(nodes[b])])
            edges.append({
                "u": int(a),
                "v": int(b),
                "length": float(geom.length),
                "n_segs": 1,
                "is_bridge": True,
                "bridge_type": str(bridge_type),
                "geometry": geom
            })
            used[a] += 1
            used[b] += 1
            added += 1

    return added


# ---------------------- post-contraction: endpoint→vertex snapping -----------
def build_vertex_index_for_edges(edges, stride=1):
    """
    Build STRtree over interior vertices of edges.
    Returns: (tree, meta, sig_to_i)
      meta[i] = {"edge_i": int, "vidx": int, "xy": np.array, "tangent": np.array}
      sig_to_i maps (round(x,6), round(y,6)) -> meta index (for Shapely 1.x geometry-return cases)
    """
    pts = []
    meta = []
    stride = max(1, int(stride))

    for ei, e in enumerate(edges):
        geom = e["geometry"]
        coords = np.asarray(geom.coords, dtype=float)
        if len(coords) < 3:
            continue
        for j in range(1, len(coords) - 1, stride):
            p = coords[j]
            tan = coords[j + 1] - coords[j - 1]
            g = Point(float(p[0]), float(p[1]))
            pts.append(g)
            meta.append({"edge_i": int(ei), "vidx": int(j), "xy": p.copy(), "tangent": tan.copy()})

    if not pts:
        return None, [], {}

    tree = STRtree(pts)
    sig_to_i = {(round(p.x, 6), round(p.y, 6)): i for i, p in enumerate(pts)}
    return tree, meta, sig_to_i


def snap_endpoints_to_vertices_round(nodes, edges, next_node_id, cost_ctx,
                                    max_snap_dist, max_snap_angle_deg, max_target_angle_deg,
                                    direction_step, max_snaps_per_endpoint, vertex_stride):
    """
    Endpoint -> interior vertex snapping with Shapely 1/2 STRtree compatibility.
    """
    endpoints, out_dir = build_endpoint_directions(nodes, edges, direction_step=direction_step)

    tree, meta, sig_to_i = build_vertex_index_for_edges(edges, stride=vertex_stride)
    if tree is None:
        return 0, next_node_id

    idx = build_rounding_index(nodes, nd=6)
    used = {n: 0 for n in endpoints}
    added = 0

    for a in endpoints:
        if used[a] >= int(max_snaps_per_endpoint):
            continue
        if a not in out_dir:
            continue

        xa = nodes[a]
        da = out_dir[a]

        # IMPORTANT: initialize for each endpoint
        best = None
        best_score = None

        search_geom = Point(float(xa[0]), float(xa[1])).buffer(float(max_snap_dist))
        hits = tree.query(search_geom)

        for h in hits:
            # Shapely 2 often returns integer indices; Shapely 1 returns geometries
            if isinstance(h, (int, np.integer)):
                mi = int(h)
            else:
                mi = sig_to_i.get((round(h.x, 6), round(h.y, 6)), None)

            if mi is None or mi < 0 or mi >= len(meta):
                continue

            m = meta[mi]
            ei = int(m["edge_i"])
            vxy = np.asarray(m["xy"], dtype=float)

            ab = vxy - xa
            dist = float(np.linalg.norm(ab))
            if dist == 0 or dist > float(max_snap_dist):
                continue

            ang_ext = angle_deg(da, ab)
            if ang_ext > float(max_snap_angle_deg):
                continue

            tan = np.asarray(m["tangent"], dtype=float)
            ang_target = min(angle_deg(tan, -ab), angle_deg(-tan, -ab))
            if ang_target > float(max_target_angle_deg):
                continue

            score = dist + 0.2 * ang_ext + 0.1 * ang_target
            if best_score is None or score < best_score:
                best_score = score
                best = (ei, Point(float(vxy[0]), float(vxy[1])))

        if best is None:
            continue

        target_ei, vpt = best

        parent_cost = edges[target_ei].get("cost_sum", None)
        nid, next_node_id, did_split = split_edge_at_point(
            nodes, idx, edges, target_ei, vpt, next_node_id,
            cost_ctx=cost_ctx, parent_cost=parent_cost
        )

        geom = LineString([tuple(nodes[a]), tuple(nodes[nid])])
        conn = {
            "u": int(a),
            "v": int(nid),
            "length": float(geom.length),
            "n_segs": 1,
            "is_bridge": True,
            "bridge_type": "endpoint_vertex",
            "geometry": geom
        }
        if cost_ctx is not None:
            conn["cost_sum"] = cost_sum_along_line(
                geom, cost_ctx["arr"], cost_ctx["transform"], cost_ctx["nodata"], cost_ctx["step"],
                nodata_to_nan=cost_ctx["nodata_to_nan"], unique_cells=cost_ctx["unique_cells"]
            )

        edges.append(conn)
        used[a] += 1
        added += 1

    return added, next_node_id


# ---------------------- join parallel lines by adding connectors -------------
def _edge_signature(e):
    """Stable-ish signature for STRtree geometry->edge lookup (handles copies)."""
    g = e["geometry"]
    c = np.asarray(g.coords, float)
    return (
        round(c[0,0], 6), round(c[0,1], 6),
        round(c[-1,0], 6), round(c[-1,1], 6),
        round(float(g.length), 6)
    )


def _snap_edge_endpoints_to_node_coords(nodes, e):
    """Force edge geometry endpoints to coincide with node coordinates of e['u'], e['v']."""
    g = e["geometry"]
    coords = np.asarray(g.coords, float)
    uxy = nodes[int(e["u"])]
    vxy = nodes[int(e["v"])]

    # decide which end is u by proximity
    if np.linalg.norm(coords[0] - uxy) <= np.linalg.norm(coords[-1] - uxy):
        coords[0] = uxy
        coords[-1] = vxy
    else:
        coords[0] = vxy
        coords[-1] = uxy
        # also swap u/v to keep consistent orientation with geometry (optional)
        e["u"], e["v"] = int(e["v"]), int(e["u"])

    e["geometry"] = LineString(coords)
    e["length"] = float(e["geometry"].length)


def _reroute_incident_edges(nodes, edges, old_node, new_node, exclude_edge_idx=None, cost_ctx=None):
    """
    Rewire all edges incident to old_node to use new_node instead, updating geometries so endpoints coincide.
    """
    old_node = int(old_node); new_node = int(new_node)
    old_xy = nodes[old_node]
    new_xy = nodes[new_node]

    for i, e in enumerate(edges):
        if e is None or e.get("_deleted", False):
            continue
        if exclude_edge_idx is not None and i == exclude_edge_idx:
            continue

        touched = False
        if int(e["u"]) == old_node:
            e["u"] = new_node
            touched = True
        if int(e["v"]) == old_node:
            e["v"] = new_node
            touched = True
        if not touched:
            continue

        # update geometry endpoint closest to old_xy to new_xy
        coords = np.asarray(e["geometry"].coords, float)
        d0 = np.linalg.norm(coords[0] - old_xy)
        d1 = np.linalg.norm(coords[-1] - old_xy)
        if d0 <= d1:
            coords[0] = new_xy
        else:
            coords[-1] = new_xy
        e["geometry"] = LineString(coords)
        e["length"] = float(e["geometry"].length)

        # if costs are enabled, recompute edge cost (safe; preserves correctness)
        if cost_ctx is not None:
            e["cost_sum"] = cost_sum_along_line(
                e["geometry"], cost_ctx["arr"], cost_ctx["transform"], cost_ctx["nodata"], cost_ctx["step"],
                nodata_to_nan=cost_ctx["nodata_to_nan"], unique_cells=cost_ctx["unique_cells"]
            )


def _endpoint_overlap_ratio_on_target(target_line, other_line):
    """
    How much of other_line is 'covered' along target_line by projecting other endpoints onto target.
    Used to avoid merging lines that are just locally close.
    """
    if target_line.length == 0 or other_line.length == 0:
        return 0.0
    c = np.asarray(other_line.coords, float)
    p0 = Point(float(c[0,0]), float(c[0,1]))
    p1 = Point(float(c[-1,0]), float(c[-1,1]))
    d0 = float(target_line.project(p0))
    d1 = float(target_line.project(p1))
    span = abs(d1 - d0)
    return float(span / float(other_line.length))


def merge_parallel_duplicate_edges(nodes, edges, next_node_id, cost_ctx,
                                  parallel_max_dist, parallel_max_angle_deg,
                                  parallel_min_length, max_merges,
                                  overlap_ratio_min=0.60):
    """
    True merge of near-parallel duplicate edges:
      - identify candidate edge pairs (near + parallel + sufficient overlap)
      - choose a keeper edge (longer)
      - reroute all topology incident to the duplicate's endpoints onto the keeper by:
          * splitting keeper at nearest points to duplicate endpoints (node inserted)
          * rewiring incident edges from duplicate endpoints to those new nodes
      - delete the duplicate edge and drop orphan nodes

    Costs:
      - if cost_ctx exists: duplicates have cost_sum; keeper gets averaged cost density
      - if keeper is split during merging, propagate averaged cost by segment length
    """
    # Build eligible list
    eligible = [i for i, e in enumerate(edges)
                if e is not None and not e.get("_deleted", False)
                and float(e.get("length", e["geometry"].length)) >= float(parallel_min_length)]

    if not eligible:
        return 0, next_node_id

    geoms = [edges[i]["geometry"] for i in eligible]
    tree = STRtree(geoms)

    # map signature -> edge index (works even if STRtree returns geometry copies)
    sig_to_eidx = {}
    for i in eligible:
        sig_to_eidx[_edge_signature(edges[i])] = i

    idx_round = build_rounding_index(nodes, nd=6)  # for node reuse in split_edge_at_point

    merges_done = 0
    used_pairs = set()

    for base_idx, ei in enumerate(eligible):
        if merges_done >= int(max_merges):
            break
        e1 = edges[ei]
        if e1 is None or e1.get("_deleted", False):
            continue

        g1 = e1["geometry"]
        if float(g1.length) < float(parallel_min_length):
            continue
        b1 = float(edge_bearing_mod180(g1))

        # query spatially nearby edges
        hits = tree.query(g1.buffer(float(parallel_max_dist)))

        for h in hits:
            if merges_done >= int(max_merges):
                break

            sig = _edge_signature({"geometry": h})
            ej = sig_to_eidx.get(sig, None)
            if ej is None or ej == ei:
                continue

            a, b = (ei, ej) if ei < ej else (ej, ei)
            if (a, b) in used_pairs:
                continue
            used_pairs.add((a, b))

            e2 = edges[ej]
            if e2 is None or e2.get("_deleted", False):
                continue

            g2 = e2["geometry"]
            if float(g2.length) < float(parallel_min_length):
                continue
            b2 = float(edge_bearing_mod180(g2))
            if angle_diff_180(b1, b2) > float(parallel_max_angle_deg):
                continue

            # distance check via nearest points
            p1, p2 = nearest_points(g1, g2)
            if float(p1.distance(p2)) > float(parallel_max_dist):
                continue

            # overlap check (avoid merging lines that are just briefly close)
            # project shorter onto longer
            if g1.length >= g2.length:
                overlap = _endpoint_overlap_ratio_on_target(g1, g2)
            else:
                overlap = _endpoint_overlap_ratio_on_target(g2, g1)
            if overlap < float(overlap_ratio_min):
                continue

            # Decide keeper (longer)
            if g1.length >= g2.length:
                keep_i, dup_i = ei, ej
            else:
                keep_i, dup_i = ej, ei

            keep = edges[keep_i]
            dup  = edges[dup_i]
            if keep is None or dup is None or keep.get("_deleted", False) or dup.get("_deleted", False):
                continue

            # ensure endpoints are coherent with node coords (important before splitting / rewiring)
            _snap_edge_endpoints_to_node_coords(nodes, keep)
            _snap_edge_endpoints_to_node_coords(nodes, dup)

            keep_geom = keep["geometry"]
            dup_geom  = dup["geometry"]

            # --- averaged cost density (if enabled) ---
            avg_cost_density = None
            if cost_ctx is not None:
                c_keep = float(keep.get("cost_sum", 0.0))
                c_dup  = float(dup.get("cost_sum", 0.0))
                avg_sum = 0.5 * (c_keep + c_dup)
                Lk = float(keep_geom.length) if keep_geom.length else 1.0
                avg_cost_density = avg_sum / Lk
                # set keeper cost_sum to averaged (whole edge)
                keep["cost_sum"] = float(avg_sum)

            # --- reroute duplicate endpoints onto keeper ---
            # We handle each endpoint node of the duplicate edge:
            #  - find nearest point on keeper geometry
            #  - split keeper at that point (creates node)
            #  - rewire all edges incident to that endpoint node to the new node
            #  - then delete duplicate edge
            dup_end_nodes = [int(dup["u"]), int(dup["v"])]

            for old_n in dup_end_nodes:
                # If old_n also lies on keeper already (shared node), skip
                if old_n == int(keep["u"]) or old_n == int(keep["v"]):
                    continue

                old_xy = nodes.get(old_n, None)
                if old_xy is None:
                    continue

                # nearest point on current keeper geometry
                # (keeper may have been split already; we must pick the segment that is closest)
                # Find the best segment among all edges that are part of the keeper "cluster".
                # We approximate by selecting the single closest eligible edge to old_xy with similar bearing.
                best_seg_i = None
                best_dist = None
                old_pt = Point(float(old_xy[0]), float(old_xy[1]))

                # search among all current edges: those not deleted and within buffer
                for ii, ee in enumerate(edges):
                    if ee is None or ee.get("_deleted", False):
                        continue
                    if float(ee.get("length", ee["geometry"].length)) < 1e-9:
                        continue
                    # use geometry distance as quick filter
                    d = float(ee["geometry"].distance(old_pt))
                    if d > float(parallel_max_dist) * 2.0:  # generous; we only need the keeper vicinity
                        continue
                    if best_dist is None or d < best_dist:
                        best_dist = d
                        best_seg_i = ii

                if best_seg_i is None:
                    continue

                # Split that segment at nearest point
                seg = edges[best_seg_i]
                proj_pt = seg["geometry"].interpolate(seg["geometry"].project(old_pt))
                parent_cost_before = seg.get("cost_sum", None)

                new_n, next_node_id, did_split = split_edge_at_point(
                    nodes, idx_round, edges, best_seg_i, proj_pt, next_node_id,
                    cost_ctx=cost_ctx, parent_cost=parent_cost_before
                )

                # If we have averaged cost density, enforce it on the affected segment(s)
                if avg_cost_density is not None:
                    # segment replaced in place
                    edges[best_seg_i]["cost_sum"] = float(avg_cost_density * float(edges[best_seg_i]["geometry"].length))
                    if did_split:
                        # new segment appended at the end
                        edges[-1]["cost_sum"] = float(avg_cost_density * float(edges[-1]["geometry"].length))

                # rewire incident edges (except the duplicate edge itself) from old_n -> new_n
                _reroute_incident_edges(nodes, edges, old_node=old_n, new_node=new_n,
                                        exclude_edge_idx=dup_i, cost_ctx=cost_ctx)

                # old node might become orphan later; we clean at the end

            # delete the duplicate edge itself
            dup["_deleted"] = True

            merges_done += 1

    # cleanup: remove orphan nodes (degree==0) and drop deleted edges
    deg = {}
    for e in edges:
        if e is None or e.get("_deleted", False):
            continue
        deg[int(e["u"])] = deg.get(int(e["u"]), 0) + 1
        deg[int(e["v"])] = deg.get(int(e["v"]), 0) + 1

    orphans = [nid for nid in list(nodes.keys()) if deg.get(int(nid), 0) == 0]
    for nid in orphans:
        nodes.pop(int(nid), None)

    return merges_done, next_node_id


# ---------------------- annotate directionality ------------------------------
def add_directionality(edges, slope_raster_path, convention, is_downhill):
    """
    Adds fields:
      - slope_dir_deg (converted to cw_from_north)
      - from, to (node ids)
    """
    if not slope_raster_path:
        return

    with rasterio.open(slope_raster_path) as src:
        for e in edges:
            geom = e["geometry"]
            mid = geom.interpolate(0.5, normalized=True)
            sraw = raster_value_at_xy(src, (mid.x, mid.y))
            sdir = convert_slope_dir_to_cw_from_north(sraw, convention)
            if np.isfinite(sdir) and not bool(is_downhill):
                sdir = (sdir + 180.0) % 360.0  # raster indicates uphill; convert to downhill

            e["slope_dir_deg"] = float(sdir) if np.isfinite(sdir) else np.nan

            # choose orientation closest to downhill direction
            coords = np.asarray(geom.coords, float)
            b_uv = bearing_cw_from_north(coords[0], coords[-1])
            b_vu = (b_uv + 180.0) % 360.0

            if not np.isfinite(sdir):
                e["from"] = int(e["u"]); e["to"] = int(e["v"])
                continue

            d_uv = abs(((b_uv - sdir + 180.0) % 360.0) - 180.0)
            d_vu = abs(((b_vu - sdir + 180.0) % 360.0) - 180.0)
            if d_uv <= d_vu:
                e["from"] = int(e["u"]); e["to"] = int(e["v"])
            else:
                e["from"] = int(e["v"]); e["to"] = int(e["u"])


# ============================ PIPELINE ======================================

# ---- 1) Read skeleton raster ----
with rasterio.open(in_skeleton) as src:
    skel = src.read(1, masked=True)
    transform = src.transform
    crs = src.crs

skel_bool = (np.asarray(skel.filled(0) if np.ma.isMaskedArray(skel) else skel) > 0).astype(np.uint8)

# ---- 2) Build sknw graph from skeleton ----
G = sknw.build_sknw(skel_bool, multi=True)

# ---- 3) Auto simplify tolerance if not provided ----
if simplify_tol is None:
    px = pixel_size_from_transform(transform)
    simplify_tol = 0.5 * px  # good default for staircase reduction

# ---- 4) Export nodes + edges (split at junctions) ----
nodes_gdf, edges_gdf = build_nodes_edges_gdfs(G, transform, crs, simplify_tol, smooth_iters)

# ---- 5) Contract degree-2 nodes to create analysis network ----
nodes2_gdf, edges2_gdf = contract_degree2_graph(G, nodes_gdf, edges_gdf)

# ---- 6) Convert to mutable in-memory network (dict/list) ----
nodes = build_node_dict(nodes2_gdf)
edges = []
for r in edges2_gdf.itertuples(index=False):
    edges.append({
        "u": int(r.u),
        "v": int(r.v),
        "n_segs": int(getattr(r, "n_segs", 1)),
        "length": float(r.length),
        "is_bridge": bool(getattr(r, "is_bridge", False)),
        "bridge_type": str(getattr(r, "bridge_type", "")),
        "geometry": r.geometry
    })

next_node_id = int(max(nodes.keys())) + 1 if nodes else 0

# ---- 7) Optional costs on current edges (needed for later avg-parent metrics) ----
cost_ctx = None
if cost_raster:
    with rasterio.open(cost_raster) as cs:
        carr = cs.read(1)
        cnod = cs.nodata
        ctr  = cs.transform
        if cost_sampling_step is None:
            cpx = pixel_size_from_transform(ctr)
            cstep = 0.5 * cpx
        else:
            cstep = float(cost_sampling_step)

        cost_ctx = {
            "arr": carr,
            "transform": ctr,
            "nodata": cnod,
            "step": float(cstep),
            "nodata_to_nan": bool(cost_nodata_to_nan),
            "unique_cells": bool(cost_unique_cells),
        }

    for e in edges:
        e["cost_sum"] = cost_sum_along_line(
            e["geometry"], cost_ctx["arr"], cost_ctx["transform"], cost_ctx["nodata"], cost_ctx["step"],
            nodata_to_nan=cost_ctx["nodata_to_nan"], unique_cells=cost_ctx["unique_cells"]
        )

# ---- 8) Post-contraction gap filling (multiple rounds) ----
total_ep_bridges = 0
total_vertex_snaps = 0

for round_i in range(int(max_bridge_rounds)):
    # endpoint↔endpoint bridging
    if bridge_endpoints:
        added = bridge_endpoints_round(
            nodes, edges,
            max_gap_dist=max_gap_dist,
            max_angle_deg=max_angle_deg,
            direction_step=direction_step,
            max_bridges_per_endpoint=max_bridges_per_endpoint,
            bridge_type=f"endpoint_endpoint_r{round_i+1}"
        )
        total_ep_bridges += added

        # compute costs for newly added bridge edges
        if cost_ctx is not None and added > 0:
            for e in edges[-added:]:
                e["cost_sum"] = cost_sum_along_line(
                    e["geometry"], cost_ctx["arr"], cost_ctx["transform"], cost_ctx["nodata"], cost_ctx["step"],
                    nodata_to_nan=cost_ctx["nodata_to_nan"], unique_cells=cost_ctx["unique_cells"]
                )

    # endpoint→vertex snapping
    if bridge_to_vertices:
        added_vs, next_node_id = snap_endpoints_to_vertices_round(
            nodes, edges, next_node_id, cost_ctx,
            max_snap_dist=max_snap_dist,
            max_snap_angle_deg=max_snap_angle_deg,
            max_target_angle_deg=max_target_angle_deg,
            direction_step=direction_step,
            max_snaps_per_endpoint=max_vertex_snaps_per_endpoint,
            vertex_stride=vertex_candidate_stride
        )
        total_vertex_snaps += added_vs

print(f"[Gap filling] endpoint↔endpoint bridges added: {total_ep_bridges} | endpoint→vertex snaps added: {total_vertex_snaps}")

# ---- 9) Merge close parallel duplicates (TRUE MERGE) ----
parallel_merged = 0
if join_parallel_lines:
    parallel_merged, next_node_id = merge_parallel_duplicate_edges(
        nodes, edges, next_node_id, cost_ctx,
        parallel_max_dist=parallel_max_dist,
        parallel_max_angle_deg=parallel_max_angle_deg,
        parallel_min_length=parallel_min_length,
        max_merges=max_parallel_connectors,   
        overlap_ratio_min=0.60                # internal guard; adjust if needed
    )
print(f"[Parallel merge] merged duplicates: {parallel_merged}" if join_parallel_lines else "[Parallel merge] disabled")

# ---- 10) Optional directionality on final edges ----
if slope_dir_raster:
    add_directionality(
        edges,
        slope_raster_path=slope_dir_raster,
        convention=slope_dir_convention,
        is_downhill=slope_dir_is_downhill
    )

# ---- 11) Build final GeoDataFrames and export ----
deg = recompute_node_degrees(nodes, edges)

node_rows = []
for nid, xy in nodes.items():
    node_rows.append({
        "node": int(nid),
        "degree": int(deg.get(int(nid), 0)),
        "x": float(xy[0]),
        "y": float(xy[1]),
        "geometry": Point(float(xy[0]), float(xy[1]))
    })
nodes_out = gpd.GeoDataFrame(node_rows, geometry="geometry", crs=crs)

edge_rows = []
for i, e in enumerate(edges):
    row = {
        "edge_id": int(i),
        "u": int(e["u"]),
        "v": int(e["v"]),
        "length": float(e.get("length", e["geometry"].length)),
        "n_segs": int(e.get("n_segs", 1)),
        "is_bridge": bool(e.get("is_bridge", False)),
        "bridge_type": str(e.get("bridge_type", "")),
        "geometry": e["geometry"]
    }
    if "cost_sum" in e:
        row["cost_sum"] = float(e["cost_sum"])
    if "avg_parent_cost_sum" in e:
        row["avg_parent_cost_sum"] = float(e["avg_parent_cost_sum"])
    if "slope_dir_deg" in e:
        row["slope_dir_deg"] = float(e["slope_dir_deg"]) if np.isfinite(e["slope_dir_deg"]) else np.nan
    if "from" in e and "to" in e:
        row["from"] = int(e["from"])
        row["to"] = int(e["to"])
    edge_rows.append(row)

edges_out = gpd.GeoDataFrame(edge_rows, geometry="geometry", crs=crs)

# write
if os.path.exists(out_gpkg):
    os.remove(out_gpkg)
nodes_out.to_file(out_gpkg, layer="nodes", driver="GPKG")
edges_out.to_file(out_gpkg, layer="edges", driver="GPKG")

print(f"Wrote network to: {out_gpkg} (layers: nodes, edges)")
print(f"Original sknw export: nodes={len(nodes_gdf)} edges={len(edges_gdf)}")
print(f"Contracted:          nodes={len(nodes2_gdf)} edges={len(edges2_gdf)}")
print(f"Final:              nodes={len(nodes_out)} edges={len(edges_out)}")
print(f"simplify_tol={simplify_tol:.4g} smooth_iters={smooth_iters}")
if cost_raster:
    print(f"Costs enabled: cost_sum (unique_cells={cost_unique_cells}, step={cost_ctx['step']:.4g}, nodata_to_nan={cost_nodata_to_nan})")
if slope_dir_raster:
    print(f"Directionality enabled: convention={slope_dir_convention}, slope_dir_is_downhill={slope_dir_is_downhill} -> fields: slope_dir_deg, from, to")


[Gap filling] endpoint↔endpoint bridges added: 2533 | endpoint→vertex snaps added: 2263
[Parallel merge] disabled
Wrote network to: /home/hector/Documents/Nazarij/channels_50conPix_NEW3network.gpkg (layers: nodes, edges)
Original sknw export: nodes=35890 edges=32650
Contracted:          nodes=35007 edges=31767
Final:              nodes=37143 edges=38699
simplify_tol=10.93 smooth_iters=2


In [5]:
# Add connectivity metrics directly into the existing nodes/edges tables:
#  1) Read nodes+edges from an existing GPKG
#  2) Compute metrics with NetworkX
#  3) Write new GeoPackage file with metrics,`overwrite_in_place=True` to replace the original GPKG.

import os
import re
import warnings
from pathlib import Path

import numpy as np
import geopandas as gpd
import networkx as nx

# ---- USER PARAMS ----
gpkg_in   = "/home/hector/Documents/channels_50conPix_network.gpkg"
nodes_lyr = "nodes"
edges_lyr = "edges"

# Output handling
gpkg_out = "/home/hector/Documents/channels_50conPix_network_metrics.gpkg"
overwrite_in_place = False  # if True, replaces gpkg_in with gpkg_out at the end

# Graph options
directed    = False
use_weights = True
length_col  = "length"   # created from geometry if missing

# Extra/slow metrics
compute_current_flow = False            # OFF by default (avoid warnings + heavy computation)
compute_local_edge_connectivity = False # can be expensive
# ---------------------


def _sanitize_path(p: str) -> Path:
    p = str(p).strip()
    p = re.sub(r"^[cC]/home/", "/home/", p)  # fix c/home typo
    outp = Path(p).expanduser()
    if not outp.is_absolute():
        outp = (Path.cwd() / outp).resolve()
    outp.parent.mkdir(parents=True, exist_ok=True)
    return outp


def _safe_unlink(p: Path):
    try:
        if p.exists():
            p.unlink()
    except Exception as e:
        raise OSError(f"Cannot remove existing file: {p}\nClose it in QGIS/other apps and retry.\n{e}") from e


# ---------- READ ----------
gpkg_in_p = _sanitize_path(gpkg_in)
gpkg_out_p = _sanitize_path(gpkg_out)

nodes = gpd.read_file(gpkg_in_p.as_posix(), layer=nodes_lyr)
edges = gpd.read_file(gpkg_in_p.as_posix(), layer=edges_lyr)

if "node" not in nodes.columns:
    raise ValueError(f"'{nodes_lyr}' layer must contain a 'node' column.")
for c in ("u", "v"):
    if c not in edges.columns:
        raise ValueError(f"'{edges_lyr}' layer must contain '{c}' columns.")
if "geometry" not in edges.columns:
    raise ValueError(f"'{edges_lyr}' layer must contain geometries.")

edges_out = edges.copy()
if length_col not in edges_out.columns:
    edges_out[length_col] = edges_out.geometry.length.astype(float)

has_k = "k" in edges_out.columns

# ---------- BUILD GRAPH ----------
if directed:
    G = nx.MultiDiGraph() if has_k else nx.DiGraph()
    G_simple = nx.DiGraph()
else:
    G = nx.MultiGraph() if has_k else nx.Graph()
    G_simple = nx.Graph()

G.add_nodes_from(nodes["node"].astype(int).tolist())

for r in edges_out.itertuples(index=False):
    u = int(getattr(r, "u"))
    v = int(getattr(r, "v"))
    w = float(getattr(r, length_col))
    if has_k:
        k = int(getattr(r, "k"))
        G.add_edge(u, v, key=k, **{length_col: w})
    else:
        G.add_edge(u, v, **{length_col: w})

    # Simple view: keep minimum weight between u-v
    if G_simple.has_edge(u, v):
        if w < G_simple[u][v].get(length_col, np.inf):
            G_simple[u][v][length_col] = w
    else:
        G_simple.add_edge(u, v, **{length_col: w})

node_ids = nodes["node"].astype(int)

# ---------- NODE METRICS ----------
degree = dict(G.degree())
strength = dict(G.degree(weight=length_col)) if use_weights else {n: np.nan for n in G.nodes()}

betweenness = nx.betweenness_centrality(G, weight=(length_col if use_weights else None), normalized=True)
closeness   = nx.closeness_centrality(G, distance=(length_col if use_weights else None))

try:
    eigenvector = nx.eigenvector_centrality_numpy(G, weight=None)
except Exception:
    eigenvector = nx.eigenvector_centrality(G, max_iter=2000, tol=1e-8, weight=None)

harmonic = nx.harmonic_centrality(G, distance=(length_col if use_weights else None))

try:
    pagerank = nx.pagerank(G, weight=(length_col if use_weights else None))
except Exception:
    pagerank = {n: np.nan for n in G.nodes()}

# clustering/core best on undirected view
try:
    clustering = nx.clustering(G_simple if not directed else nx.Graph(G_simple))
except Exception:
    clustering = {n: np.nan for n in G.nodes()}

try:
    core = nx.core_number(G_simple if not directed else nx.Graph(G_simple))
except Exception:
    core = {n: np.nan for n in G.nodes()}

# components
if directed:
    comps = list(nx.weakly_connected_components(G))
else:
    comps = list(nx.connected_components(G))
comp_id, comp_size = {}, {}
for i, comp in enumerate(comps):
    s = len(comp)
    for n in comp:
        comp_id[n] = i
        comp_size[n] = s

nodes2 = nodes.copy()
nodes2["degree"]      = node_ids.map(degree).astype(float)
nodes2["strength"]    = node_ids.map(strength).astype(float)
nodes2["betweenness"] = node_ids.map(betweenness).astype(float)
nodes2["closeness"]   = node_ids.map(closeness).astype(float)
nodes2["eigenvector"] = node_ids.map(eigenvector).astype(float)

nodes2["harmonic"]    = node_ids.map(harmonic).astype(float)
nodes2["pagerank"]    = node_ids.map(pagerank).astype(float)
nodes2["clustering"]  = node_ids.map(clustering).astype(float)
nodes2["core"]        = node_ids.map(core).astype(float)
nodes2["component"]   = node_ids.map(comp_id).astype(int)
nodes2["comp_size"]   = node_ids.map(comp_size).astype(int)

if directed:
    indeg  = dict(G.in_degree())
    outdeg = dict(G.out_degree())
    nodes2["in_degree"]  = node_ids.map(indeg).astype(float)
    nodes2["out_degree"] = node_ids.map(outdeg).astype(float)

# ---------- EDGE METRICS (emphasis) ----------
deg_map = nodes2.set_index("node")["degree"].to_dict()
bet_map = nodes2.set_index("node")["betweenness"].to_dict()
clo_map = nodes2.set_index("node")["closeness"].to_dict()
eig_map = nodes2.set_index("node")["eigenvector"].to_dict()

edges2 = edges_out.copy()

# Endpoint metrics
edges2["u_degree"] = edges2["u"].astype(int).map(deg_map).astype(float)
edges2["v_degree"] = edges2["v"].astype(int).map(deg_map).astype(float)
edges2["u_betweenness"] = edges2["u"].astype(int).map(bet_map).astype(float)
edges2["v_betweenness"] = edges2["v"].astype(int).map(bet_map).astype(float)
edges2["u_closeness"] = edges2["u"].astype(int).map(clo_map).astype(float)
edges2["v_closeness"] = edges2["v"].astype(int).map(clo_map).astype(float)
edges2["u_eigenvector"] = edges2["u"].astype(int).map(eig_map).astype(float)
edges2["v_eigenvector"] = edges2["v"].astype(int).map(eig_map).astype(float)

edges2["deg_sum"]  = edges2["u_degree"] + edges2["v_degree"]
edges2["deg_diff"] = (edges2["u_degree"] - edges2["v_degree"]).abs()

# Edge betweenness (kept)
edge_betw = nx.edge_betweenness_centrality(G, weight=(length_col if use_weights else None), normalized=True)

def _edge_betw_lookup(u, v, k=None):
    if has_k:
        if (u, v, k) in edge_betw: return edge_betw[(u, v, k)]
        if (v, u, k) in edge_betw and not directed: return edge_betw[(v, u, k)]
        return np.nan
    else:
        if (u, v) in edge_betw: return edge_betw[(u, v)]
        if (v, u) in edge_betw and not directed: return edge_betw[(v, u)]
        return np.nan

if has_k:
    edges2["edge_betweenness"] = [
        float(_edge_betw_lookup(int(u), int(v), int(k)))
        for u, v, k in zip(edges2["u"], edges2["v"], edges2["k"])
    ]
else:
    edges2["edge_betweenness"] = [
        float(_edge_betw_lookup(int(u), int(v)))
        for u, v in zip(edges2["u"], edges2["v"])
    ]

# Bridges (cut edges) on undirected simple graph
if not directed:
    try:
        bridge_pairs = set(tuple(sorted(e)) for e in nx.bridges(G_simple))
    except Exception:
        bridge_pairs = set()

    if has_k:
        # If parallel edges exist between u-v, it cannot be a bridge
        multiplicity = {}
        for u, v, k in G.edges(keys=True):
            a, b = (u, v) if u <= v else (v, u)
            multiplicity[(a, b)] = multiplicity.get((a, b), 0) + 1

        edges2["is_bridge"] = [
            bool((tuple(sorted((int(u), int(v)))) in bridge_pairs) and (multiplicity.get(tuple(sorted((int(u), int(v)))), 0) == 1))
            for u, v in zip(edges2["u"], edges2["v"])
        ]
    else:
        edges2["is_bridge"] = [bool(tuple(sorted((int(u), int(v)))) in bridge_pairs) for u, v in zip(edges2["u"], edges2["v"])]
else:
    edges2["is_bridge"] = False

# Geometry-based edge measures
def _straight_dist(geom):
    if geom is None or geom.is_empty:
        return np.nan
    coords = np.asarray(geom.coords)
    if coords.shape[0] < 2:
        return np.nan
    return float(np.linalg.norm(coords[-1] - coords[0]))

edges2["straight_dist"] = edges2.geometry.apply(_straight_dist).astype(float)
edges2["sinuosity"] = (edges2[length_col] / edges2["straight_dist"]).replace([np.inf, -np.inf], np.nan)
edges2["straightness"] = (edges2["straight_dist"] / edges2[length_col]).replace([np.inf, -np.inf], np.nan)

# Triangle-based edge measures (undirected)
def _common_neighbors(u, v):
    try:
        return len(list(nx.common_neighbors(G_simple, int(u), int(v))))
    except Exception:
        return 0

if not directed:
    tris = [_common_neighbors(u, v) for u, v in zip(edges2["u"], edges2["v"])]
    edges2["edge_triangles"] = np.asarray(tris, dtype=float)

    ecc = []
    for u, v, tri in zip(edges2["u"], edges2["v"], edges2["edge_triangles"]):
        du = G_simple.degree(int(u))
        dv = G_simple.degree(int(v))
        denom = min(du - 1, dv - 1)
        ecc.append(float(tri / denom) if denom > 0 else 0.0)
    edges2["edge_clustering"] = np.asarray(ecc, dtype=float)
else:
    edges2["edge_triangles"] = np.nan
    edges2["edge_clustering"] = np.nan

# Local edge connectivity (optional)
def _local_ec(u, v):
    if not compute_local_edge_connectivity:
        return np.nan
    try:
        return float(nx.local_edge_connectivity(G_simple, int(u), int(v)))
    except Exception:
        return np.nan

edges2["local_edge_connectivity"] = [_local_ec(u, v) for u, v in zip(edges2["u"], edges2["v"])]

# Current-flow edge betweenness (optional; often fragile/slow)
edges2["edge_current_flow_betweenness"] = np.nan
if (not directed) and compute_current_flow:
    edge_cf = {}
    for comp in nx.connected_components(G_simple):
        H = G_simple.subgraph(comp).copy()
        if H.number_of_nodes() < 3 or H.number_of_edges() < 2:
            continue
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=RuntimeWarning)
            try:
                part = nx.edge_current_flow_betweenness_centrality(
                    H, weight=(length_col if use_weights else None), normalized=True
                )
                for e, val in part.items():
                    edge_cf[e] = float(val) if np.isfinite(val) else np.nan
            except Exception:
                continue

    def _cf_lookup(u, v):
        if (u, v) in edge_cf: return edge_cf[(u, v)]
        if (v, u) in edge_cf: return edge_cf[(v, u)]
        return np.nan

    edges2["edge_current_flow_betweenness"] = [float(_cf_lookup(int(u), int(v))) for u, v in zip(edges2["u"], edges2["v"])]

# ---------- WRITE ----------
# This avoids Fiona "append" bugs: recreate the file and write BOTH layers in one go.
_safe_unlink(gpkg_out_p)

# Write nodes (mode="w") and edges (mode="w") to distinct layers by using separate files? No:
# For GPKG, writing a second layer usually implies append. Fiona is the flaky part.
# Workaround: use pyogrio if available; otherwise write separate temporary GPKGs then merge is messy.
#
# Practical robust approach in many conda setups: tell GeoPandas to use the "pyogrio" engine if installed.
engine = "pyogrio"
try:
    import pyogrio  # noqa: F401
except Exception:
    engine = "fiona"

print(f"Writing with engine: {engine}")
nodes2.to_file(gpkg_out_p.as_posix(), layer=nodes_lyr, driver="GPKG", engine=engine)
edges2.to_file(gpkg_out_p.as_posix(), layer=edges_lyr, driver="GPKG", engine=engine)

print(f"Written: {gpkg_out_p}")
print(f"Layers overwritten in output: {nodes_lyr}, {edges_lyr}")

# Optional: replace input file (only if everything succeeded)
if overwrite_in_place:
    # Ensure input is not open elsewhere
    tmp_backup = gpkg_in_p.with_suffix(".backup.gpkg")
    try:
        if tmp_backup.exists():
            tmp_backup.unlink()
        gpkg_in_p.rename(tmp_backup)
        gpkg_out_p.rename(gpkg_in_p)
        tmp_backup.unlink(missing_ok=True)
        print(f"Replaced original GPKG in-place: {gpkg_in_p}")
    except Exception as e:
        raise OSError(
            "Failed to replace original file. The output GPKG is still available.\n"
            f"Output: {gpkg_out_p}\nBackup (if created): {tmp_backup}\n{e}"
        ) from e


Writing with engine: fiona
Written: /home/hector/Documents/Nazarij/channels_50conPix_network_metrics.gpkg
Layers overwritten in output: nodes, edges
