# Calculate channel form index

The following code is directed to a given local path containing 2-D water mask rasters. The code takes the water mask, and start by creating a "skeleton" of the mask. It then dilates the tips of the skeleton to improve connection of the channel network, reskeletonizes, and reduces the skeleton to only the identifiable river channels. From the final skeletion, channel links and nodes are created. The links are filtered according to criteria, and a shortest path line or "main channel" is extracted, along with a simplified main channel which acts as a valley center line, enabling sinuosity calculations. Finally, cross sections of the river are created and channel count index is calculated across the cross-sections. With sinuosity and channel-count index, the chanel form index can be calculated. These river metrics are provided and exported to a .csv. The processed skeleton, nodes, channel links, main channel, valley center-line, and channel-belt cross-sections are output as shapefiles, and the network is plotted for the user's convenience.

Author: James (Huck) Rees; PhD Student, UCSB Geography

Date: June 7th, 2024

## Import packages

In [1]:
import os
import re
import logging
from glob import glob
from itertools import combinations

import numpy as np
import pandas as pd
import geopandas as gpd

import rasterio
from rasterio.plot import show
from rasterio.transform import xy

import matplotlib.pyplot as plt

from skimage.morphology import skeletonize, label
from skimage.measure import regionprops

from shapefile import Reader, Writer
from shapely.geometry import LineString, Point, MultiLineString, MultiPoint
from shapely.ops import split, linemerge, snap, nearest_points
from shapely.strtree import STRtree

from rtree import index as rtree_index

import networkx as nx

## Initialize functions

In [2]:
# Function to load raster data
def load_raster(file_path):
    """
    Loads a raster file and returns the data of the first band along with its metadata.

    Parameters:
    file_path (str): The path to the raster file.

    Returns:
    tuple: A tuple containing:
        - data (numpy.ndarray): The data of the first band of the raster.
        - metadata (dict): The metadata of the raster file.
    """
    with rasterio.open(file_path) as dataset:
        data = dataset.read(1)  # Read the first band
        metadata = dataset.meta
    return data, metadata

# Function to save a raster file
def save_raster(output_path, data, metadata):
    """
    Saves a raster file with the given data and metadata.

    Parameters:
    output_path (str): The path to save the output raster file.
    data (numpy.ndarray): The data to be written to the raster file.
    metadata (dict): The metadata of the raster file, including CRS and transform information.

    Returns:
    None
    """
    with rasterio.open(
        output_path, 
        'w', 
        driver='GTiff', 
        height=data.shape[0], 
        width=data.shape[1], 
        count=1, 
        dtype='uint8', 
        crs=metadata['crs'], 
        transform=metadata['transform']
    ) as dst:
        dst.write(data.astype('uint8'), 1)

# Function to perform conditional dilation
def conditional_dilation(image, radius=5):
    """
    Performs a conditional dilation on a binary image. Pixels with a value of 1 that have 
    two or fewer neighbors with the same value will cause a dilation within a given radius.

    Parameters:
    image (numpy.ndarray): The input binary image (2D array) to be processed.
    radius (int, optional): The radius for the dilation operation. Default is 5.

    Returns:
    numpy.ndarray: The dilated image.
    """
    dilated_image = np.copy(image)
    for row in range(1, image.shape[0] - 1):
        for col in range(1, image.shape[1] - 1):
            if image[row, col] == 1:
                neighbors = image[row-1:row+2, col-1:col+2]
                if np.sum(neighbors) <= 2:  # Include the pixel itself in the count
                    dilated_image[max(0, row-radius):min(row+radius+1, image.shape[0]), 
                                  max(0, col-radius):min(col+radius+1, image.shape[1])] = 1
    return dilated_image

# Function to keep only the largest connected component
def keep_largest_component(image):
    """
    Identifies and retains the largest connected component in a binary image. All other components are removed.

    Parameters:
    image (numpy.ndarray): The input binary image (2D array).

    Returns:
    numpy.ndarray: A binary image containing only the largest connected component.
    """
    labeled_image, num_features = label(image, connectivity=2, return_num=True)
    if num_features == 0:
        return image
    regions = regionprops(labeled_image)
    largest_region = max(regions, key=lambda r: r.area)
    largest_component = (labeled_image == largest_region.label)
    return largest_component

# Function to create node shapefile and return node points
def create_nodes(image, metadata):
    """
    Identifies and creates nodes in a binary image. Nodes are classified as 'endpoint' or 'junction' based on their connectivity.

    Parameters:
    image (numpy.ndarray): The input binary image (2D array) to be processed.
    metadata (dict): The metadata of the raster file, including transform information.

    Returns:
    geopandas.GeoDataFrame: A GeoDataFrame containing the nodes with their types and geometries.
    """
    node_points = []
    transform = metadata['transform']
    node_id = 1

    adjacent_positions = [
        (-1, -1), (-1, 0), (-1, 1),
        (0, -1),          (0, 1),
        (1, -1),  (1, 0),  (1, 1)
    ]

    # Iterate over each pixel in the image, excluding the borders
    for row in range(1, image.shape[0] - 1):
        for col in range(1, image.shape[1] - 1):
            if image[row, col] == 1:
                # Find neighbors that are part of the segment
                neighbors = [
                    (row + dr, col + dc)
                    for dr, dc in adjacent_positions
                    if image[row + dr, col + dc] == 1
                ]
                count = len(neighbors)
                if count == 1:  # Endpoints
                    x, y = xy(transform, row, col)
                    node_points.append((node_id, 'endpoint', Point(x, y)))
                    node_id += 1
                elif count == 3:  # Potential junctions
                    is_not_on_one_side = not (
                        (image[row - 1, col - 1] == 1 and image[row - 1, col] == 1 and image[row - 1, col + 1] == 1) or  # above
                        (image[row + 1, col - 1] == 1 and image[row + 1, col] == 1 and image[row + 1, col + 1] == 1) or  # below
                        (image[row - 1, col - 1] == 1 and image[row, col - 1] == 1 and image[row + 1, col - 1] == 1) or  # left
                        (image[row - 1, col + 1] == 1 and image[row, col + 1] == 1 and image[row + 1, col + 1] == 1)     # right
                    )

                    if is_not_on_one_side:
                        is_junction = True
                        for i in range(len(neighbors)):
                            for j in range(i + 1, len(neighbors)):
                                if (abs(neighbors[i][0] - neighbors[j][0]), abs(neighbors[i][1] - neighbors[j][1])) in [(0, 1), (1, 0)]:
                                    is_junction = False
                                    break
                            if not is_junction:
                                break
                        if is_junction:
                            x, y = xy(transform, row, col)
                            node_points.append((node_id, 'junction', Point(x, y)))
                            node_id += 1
                elif count >= 4:  # Nodes with 4 or more adjacent pixels
                    direct_pairs = sum(
                        1 for i in range(len(neighbors))
                        for j in range(i + 1, len(neighbors))
                        if (abs(neighbors[i][0] - neighbors[j][0]), abs(neighbors[i][1] - neighbors[j][1])) in [(0, 1), (1, 0)]
                    )

                    if direct_pairs < 2:
                        x, y = xy(transform, row, col)
                        node_points.append((node_id, 'junction', Point(x, y)))
                        node_id += 1

    # Create a GeoDataFrame
    gdf = gpd.GeoDataFrame(node_points, columns=['node_id', 'type', 'geometry'])
    gdf.set_crs(epsg=4326, inplace=True)
    
    return gdf

# Function to create links shapefile
def create_links(image, metadata):
    """
    Identifies and creates links between adjacent pixels in a binary image. Links are represented as LineStrings.

    Parameters:
    image (numpy.ndarray): The input binary image (2D array) to be processed.
    metadata (dict): The metadata of the raster file, including transform information.

    Returns:
    geopandas.GeoDataFrame: A GeoDataFrame containing the links as LineStrings.
    """
    links = []
    transform = metadata['transform']
    link_id = 1

    # Iterate over each pixel in the image
    for row in range(image.shape[0]):
        for col in range(image.shape[1]):
            if image[row, col] == 1:  # Check if the pixel is part of a segment
                # Identify neighboring pixels that are also part of the segment
                neighbors = [
                    (row + dr, col + dc) 
                    for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)] 
                    if 0 <= row + dr < image.shape[0] and 0 <= col + dc < image.shape[1] and image[row + dr, col + dc] == 1
                ]
                # Create LineString for each neighbor
                for nr, nc in neighbors:
                    x1, y1 = xy(transform, row, col)  # Convert pixel coordinates to spatial coordinates
                    x2, y2 = xy(transform, nr, nc)
                    line = LineString([(x1, y1), (x2, y2)])
                    links.append((link_id, line))  # Append link to the list
                    link_id += 1

    # Remove duplicate links by sorting the coordinates of each LineString
    unique_links = []
    seen = set()
    for link in links:
        coords = tuple(sorted(link[1].coords))
        if coords not in seen:
            seen.add(coords)
            unique_links.append(link)

    # Create a GeoDataFrame from the unique links
    gdf = gpd.GeoDataFrame(unique_links, columns=['id', 'geometry'])
    
    # Set the coordinate reference system (CRS)
    gdf.set_crs(epsg=4326, inplace=True)

    return gdf

# Function to filter links
def filter_links(gdf):
    """
    Filters out diagonal links from a GeoDataFrame of line segments, retaining only those
    that are not part of an intersection with horizontal and vertical links.

    Parameters:
    gdf (geopandas.GeoDataFrame): The input GeoDataFrame containing line segments.

    Returns:
    geopandas.GeoDataFrame: A filtered GeoDataFrame with certain diagonal links removed.
    """
    # Function to categorize the line segments
    def categorize_line(row):
        if row['start_point'][1] == row['end_point'][1]:
            return 'horizontal'
        elif row['start_point'][0] == row['end_point'][0]:
            return 'vertical'
        else:
            return 'diagonal'
    
    # Function to extract start and end coordinates of each line segment
    def get_coordinates(geometry):
        start_point = geometry.coords[0]
        end_point = geometry.coords[1]
        return start_point, end_point
    
    # Apply the function to get coordinates and categorize each segment
    gdf[['start_point', 'end_point']] = gdf.apply(lambda row: get_coordinates(row.geometry), axis=1, result_type='expand')
    gdf['category'] = gdf.apply(categorize_line, axis=1)
    
    # Initialize spatial indexes for horizontal and vertical links
    idx_horizontal = rtree_index.Index()
    idx_vertical = rtree_index.Index()
    
    for idx, row in gdf.iterrows():
        if row['category'] == 'horizontal':
            idx_horizontal.insert(idx, row['geometry'].bounds)
        elif row['category'] == 'vertical':
            idx_vertical.insert(idx, row['geometry'].bounds)
    
    diagonals_to_remove = set()
    
    # Loop through each diagonal link
    for index, diag_row in gdf[gdf['category'] == 'diagonal'].iterrows():
        diag_start = diag_row['start_point']
        diag_end = diag_row['end_point']
        diag_bounds = diag_row['geometry'].bounds
        x_coords = {diag_start[0], diag_end[0]}
        y_coords = {diag_start[1], diag_end[1]}
        hor = ver = False
        
        # Find horizontal links intersecting with the diagonal link using spatial index
        for hor_idx in idx_horizontal.intersection(diag_bounds):
            hor_row = gdf.loc[hor_idx]
            hor_start = hor_row['start_point']
            hor_end = hor_row['end_point']
            if (hor_start[1] in y_coords or hor_end[1] in y_coords) and (hor_start[0] in x_coords and hor_end[0] in x_coords):
                hor = True
                break
        
        # Find vertical links intersecting with the diagonal link using spatial index
        for ver_idx in idx_vertical.intersection(diag_bounds):
            ver_row = gdf.loc[ver_idx]
            ver_start = ver_row['start_point']
            ver_end = ver_row['end_point']
            if (ver_start[0] in x_coords or ver_end[0] in x_coords) and (ver_start[1] in y_coords and ver_end[1] in y_coords):
                ver = True
                break
        
        # Mark the diagonal for removal if it satisfies both conditions
        if hor and ver:
            diagonals_to_remove.add(index)
    
    # Drop the identified diagonal links
    filtered_links = gdf.drop(index=diagonals_to_remove)
    
    # Drop the unnecessary columns before returning
    filtered_links = filtered_links.drop(columns=['start_point', 'end_point', 'category'])
    
    return filtered_links

# Function to find furthest endpoints
def find_furthest_endpoints(gdf_points):
    """
    Finds the two furthest nodes in the geodataframe, which may be of type 'endpoint' or 'junction'.

    Parameters:
    gdf_points (geopandas.GeoDataFrame): The geodataframe of points (nodes).

    Returns:
    geopandas.GeoDataFrame: A GeoDataFrame containing the two furthest points.
    """
    if len(gdf_points) < 2:
        raise ValueError("Not enough points to find the furthest pair.")
    
    max_distance = 0
    furthest_pair = None
    for (idx1, point1), (idx2, point2) in combinations(gdf_points.iterrows(), 2):
        distance = point1.geometry.distance(point2.geometry)
        if distance > max_distance:
            max_distance = distance
            furthest_pair = (point1, point2)
    
    furthest_geometries = [furthest_pair[0].geometry, furthest_pair[1].geometry]
    start_end_pts = gpd.GeoDataFrame(geometry=furthest_geometries, crs=gdf_points.crs)
    return start_end_pts

# Function to prune network
def prune_network(nodes, filtered_links, start_end_pts):
    """
    Prunes a network by removing spur links that are not part of the main network.

    Parameters:
    nodes (geopandas.GeoDataFrame): The GeoDataFrame containing nodes with 'endpoint' types.
    filtered_links (geopandas.GeoDataFrame): The GeoDataFrame containing filtered links (line segments).
    start_end_pts (geopandas.GeoDataFrame): The GeoDataFrame containing start and end points to retain.

    Returns:
    geopandas.GeoDataFrame: A pruned GeoDataFrame with spur links removed.
    """
    endpoints = nodes[nodes['type'] == 'endpoint']
    small_ends = endpoints[~endpoints.geometry.apply(lambda x: any(start_end_pts.geometry.intersects(x)))]
    spurs = gpd.GeoDataFrame(columns=['geometry'], geometry='geometry', crs=nodes.crs)
    G = nx.Graph()
    for idx, row in filtered_links.iterrows():
        coords = list(row.geometry.coords)
        G.add_edge(coords[0], coords[1], index=idx, geometry=row.geometry)

    for idx, p1 in small_ends.iterrows():
        pn1 = nodes[nodes.geometry == p1.geometry].iloc[0]
        nodes_excluding_pn1 = nodes[nodes.geometry != pn1.geometry]
        nearest_node_geom = nearest_points(pn1.geometry, nodes_excluding_pn1.unary_union)[1]
        nearest_node = nodes[nodes.geometry == nearest_node_geom].iloc[0]
        try:
            path = nx.shortest_path(G, source=tuple(pn1.geometry.coords[0]), target=tuple(nearest_node.geometry.coords[0]))
            path_geometries = [G.edges[path[i], path[i+1]]['geometry'] for i in range(len(path)-1)]
            if len(path_geometries) <= 20:
                spur_geometry = LineString([point for line in path_geometries for point in line.coords])
                new_spur = gpd.GeoDataFrame([{'geometry': spur_geometry}], geometry='geometry', crs=nodes.crs)
                spurs = pd.concat([spurs, new_spur], ignore_index=True)
        except nx.NetworkXNoPath:
            continue
            
    pruned_links = filtered_links.overlay(spurs, how='difference')
    return pruned_links

# Function to find shortest path
def find_shortest_path(start_end_pts, filtered_links):
    """
    Finds the shortest path between two points in a network of filtered links.

    Parameters:
    start_end_pts (geopandas.GeoDataFrame): The GeoDataFrame containing the start and end points.
    filtered_links (geopandas.GeoDataFrame): The GeoDataFrame containing the network of links (line segments).

    Returns:
    geopandas.GeoDataFrame: A GeoDataFrame containing the shortest path as a LineString.
    """
    G = nx.Graph()
    for idx, row in filtered_links.iterrows():
        line = row.geometry
        for i in range(len(line.coords) - 1):
            start = Point(line.coords[i])
            end = Point(line.coords[i + 1])
            distance = start.distance(end)
            G.add_edge(tuple(start.coords[0]), tuple(end.coords[0]), weight=distance)
    
    start_point = tuple(start_end_pts.geometry.iloc[0].coords[0])
    end_point = tuple(start_end_pts.geometry.iloc[1].coords[0])
    shortest_path = nx.shortest_path(G, source=start_point, target=end_point, weight='weight')
    shortest_path_coords = [Point(coord) for coord in shortest_path]
    shortest_path_line = LineString(shortest_path_coords)
    shortest_path_length = shortest_path_line.length
    shortest_path_gdf = gpd.GeoDataFrame({'geometry': [shortest_path_line]}, crs=filtered_links.crs)
    return shortest_path_gdf

# Function to classify channels
def classify_channels(filtered_links, shortest_path):
    """
    Classifies links in a network as 'main_channel' or 'other' based on their relationship to the shortest path.

    Parameters:
    filtered_links (geopandas.GeoDataFrame): The GeoDataFrame containing the network of links (line segments).
    shortest_path (geopandas.GeoDataFrame): The GeoDataFrame containing the shortest path as a LineString.

    Returns:
    geopandas.GeoDataFrame: A GeoDataFrame with an additional column 'chnl_cat' classifying each link.
    """
    classified_links = filtered_links.copy()
    classified_links['chnl_cat'] = 'other'
    shortest_path_line = shortest_path.geometry.iloc[0]
    
    def is_main_channel(link):
        return link.within(shortest_path_line)
    
    classified_links['chnl_cat'] = classified_links.apply(
        lambda row: 'main_channel' if is_main_channel(row.geometry) else 'other',
        axis=1
    )
    
    return classified_links

# Function to simplify shortest path
def simplify_shortest_path(shortest_path, num_vertices=10):
    """
    Simplifies the shortest path to a specified number of vertices.

    Parameters:
    shortest_path (geopandas.GeoDataFrame): The GeoDataFrame containing the shortest path as a LineString.
    num_vertices (int, optional): The number of vertices for the simplified path. Default is 10.

    Returns:
    geopandas.GeoDataFrame: A GeoDataFrame containing the simplified shortest path as a LineString.
    """
    original_line = shortest_path.geometry.iloc[0]
    simplified_coords = [
        original_line.interpolate(i / (num_vertices - 1), normalized=True).coords[0] 
        for i in range(num_vertices)
    ]
    simplified_line = LineString(simplified_coords)
    simplified_path_gdf = gpd.GeoDataFrame({'geometry': [simplified_line]}, crs=shortest_path.crs)
    return simplified_path_gdf

# Function to create perpendicular lines
def create_perpendicular_lines(simplified_path, num_lines=10, fraction_length=1/5):
    """
    Creates perpendicular lines along the simplified path at equal intervals.

    Parameters:
    simplified_path (geopandas.GeoDataFrame): A GeoDataFrame containing the simplified path as a LineString.
    num_lines (int): Number of perpendicular lines to create.
    fraction_length (float): Fraction of the total path length for the length of each perpendicular line.

    Returns:
    geopandas.GeoDataFrame: A GeoDataFrame containing the perpendicular lines.
    """
    # Extract the LineString from the GeoDataFrame
    line = simplified_path.geometry.iloc[0]
    line_length = line.length
    
    # Calculate spacing between perpendicular lines and half the length of each perpendicular line
    spacing = line_length / num_lines
    half_length = (line_length * fraction_length) / 2
    
    # Generate points at equal intervals along the line
    points = [line.interpolate(i * spacing, normalized=False) for i in range(num_lines)]
    
    perpendicular_lines = []
    
    coords = list(line.coords)
    
    for idx, point in enumerate(points):
        # Find the segment that the point falls on
        segment = None
        for i in range(len(coords) - 1):
            segment_line = LineString([coords[i], coords[i+1]])
            if segment_line.project(point) < segment_line.length:
                segment = segment_line
                break
        
        if segment is None:
            print(f"No segment found for point {idx}: {point}")
            continue
        
        # Calculate the perpendicular direction to the segment
        dx = segment.coords[1][0] - segment.coords[0][0]
        dy = segment.coords[1][1] - segment.coords[0][1]
        length = np.sqrt(dx**2 + dy**2)
        perpendicular_direction = (-dy / length, dx / length)
        
        # Calculate the start and end points of the perpendicular line
        start_point = Point(point.x + half_length * perpendicular_direction[0],
                            point.y + half_length * perpendicular_direction[1])
        end_point = Point(point.x - half_length * perpendicular_direction[0],
                          point.y - half_length * perpendicular_direction[1])
        
        # Create the perpendicular line and add it to the list
        perpendicular_line = LineString([start_point, end_point])
        perpendicular_lines.append(perpendicular_line)
    
    # Create a GeoDataFrame from the perpendicular lines
    channel_belt_cross_sections = gpd.GeoDataFrame({'geometry': perpendicular_lines}, crs=simplified_path.crs)
    
    return channel_belt_cross_sections

# Function to calculate channel count index
def calc_channel_count_index(filtered_links, cross_sections):
    """
    Calculates the Channel Count Index (CCI) for a network of links intersecting with cross sections.

    Parameters:
    filtered_links (geopandas.GeoDataFrame): The GeoDataFrame containing the network of links (line segments) with a 'chnl_cat' classification.
    cross_sections (geopandas.GeoDataFrame): The GeoDataFrame containing the cross sections.

    Returns:
    tuple: A tuple containing:
        - cci (float): The Channel Count Index.
        - cross_sections (geopandas.GeoDataFrame): The cross sections GeoDataFrame with an additional 'channel_count' column.
    """
    channel_counts = []
    for idx, cross_section in cross_sections.iterrows():
        cross_section_geom = cross_section.geometry
        main_channel_count = filtered_links[(filtered_links['chnl_cat'] == 'main_channel') & (filtered_links.intersects(cross_section_geom))].shape[0]
        total_count = filtered_links[filtered_links.intersects(cross_section_geom)].shape[0]
        if main_channel_count > 1:
            total_count -= (main_channel_count - 1)
        channel_counts.append(total_count)
    cross_sections['channel_count'] = channel_counts
    cci = sum(channel_counts) / len(channel_counts)
    print(f"Channel Count Index (CCI): {cci}")
    return cci, cross_sections

# Function to calculate sinuosity
def calc_sinuosity(shortest_path, simplified_path):
    """
    Calculates the sinuosity of a path by comparing the lengths of the shortest path and the simplified path.

    Parameters:
    shortest_path (geopandas.GeoDataFrame): The GeoDataFrame containing the shortest path as a LineString.
    simplified_path (geopandas.GeoDataFrame): The GeoDataFrame containing the simplified path as a LineString.

    Returns:
    float: The sinuosity value, which is the ratio of the shortest path length to the simplified path length.
    """
    shortest_path_line = shortest_path.geometry.iloc[0]
    simplified_path_line = simplified_path.geometry.iloc[0]
    shortest_path_length = shortest_path_line.length
    simplified_path_length = simplified_path_line.length
    sinuosity = shortest_path_length / simplified_path_length
    print(f"Sinuosity: {sinuosity}")
    return sinuosity

# Function to calculate channel form index
def calculate_channel_form_index(sinuosity, cci):
    """
    Calculates the Channel Form Index (CFI) based on sinuosity and Channel Count Index (CCI).

    Parameters:
    sinuosity (float): The sinuosity of the channel.
    cci (float): The Channel Count Index.

    Returns:
    float: The Channel Form Index (CFI).
    """
    cfi = sinuosity / cci
    print(f"Channel Form Index (CFI): {cfi}")
    return cfi

# Main function to process network
def process_network_folder(river, 
                           radius, 
                           year_range="All", 
                           reach_range="All", 
                           num_lines=10, 
                           num_vertices=10, 
                           fraction_length=1/5, 
                           root_input="C:/Users/huckr/Desktop/UCSB/Dissertation/Data/RiverMapping/RiverMasks", 
                           root_output="C:/Users/huckr/Desktop/UCSB/Dissertation/Data/RiverMapping/Channels"):
    """
    Processes a folder containing water mask rasters to extract river channel networks and calculate metrics.

    Parameters:
    river (str): Name of the river.
    radius (int): Radius for conditional dilation.
    year_range (tuple or str): Year range for processing (default is "All").
    reach_range (tuple or str): Reach range for processing (default is "All").
    num_lines (int): Number of perpendicular lines (cross-sections) (default is 10).
    num_vertices (int): Number of vertices for simplifying the shortest path (default is 10).
    fraction_length (float): Fraction length for creating cross-sections (default is 1/5).
    root_input (str): Root input directory (default is the specified path).
    root_output (str): Root output directory (default is the specified path).

    Returns:
    None
    """
    input_folder = os.path.join(root_input, river)
    output_folder_base = os.path.join(root_output, river)
    
    if year_range == "All":
        year_range = (1000, 9999)  # Arbitrary wide range to include all years
    if reach_range == "All":
        reach_range = (1, 9999)  # Arbitrary wide range to include all reaches
    
    year_start, year_end = year_range
    reach_start, reach_end = reach_range
    
    # Initialize a dictionary to store metrics
    metrics = {}

    # Process each reach folder
    for reach_folder in glob(os.path.join(input_folder, 'reach_*')):
        reach_folder_name = os.path.basename(reach_folder)
        match_reach = re.match(r"reach_(\d+)", reach_folder_name)
        if match_reach:
            reach = int(match_reach.group(1))
            if reach_start <= reach <= reach_end:
                for file_path in glob(os.path.join(reach_folder, '*.tif')):
                    file_name = os.path.basename(file_path)
                    match_year = re.match(rf"{river}_reach_{reach}_(\d{{4}})_.*\.tif", file_name)
                    if match_year:
                        year = int(match_year.group(1))
                        if year_start <= year <= year_end:
                            output_folder = os.path.join(output_folder_base, f"reach_{reach}", str(year))
                            os.makedirs(output_folder, exist_ok=True)
                            
                            try:
                                water_mask, metadata = load_raster(file_path)
                                skeleton = skeletonize(water_mask > 0)
                                dilated_skeleton = conditional_dilation(skeleton, radius)
                                reskeletonized = skeletonize(dilated_skeleton > 0)
                                largest_component = keep_largest_component(reskeletonized)
                                
                                largest_component_output_path = os.path.join(output_folder, 'largest_component.tif')
                                save_raster(largest_component_output_path, largest_component, metadata)

                                node_points = create_nodes(largest_component, metadata)
                                initial_links = create_links(largest_component, metadata)
                                filtered_links = filter_links(initial_links)
                                start_end_pts = find_furthest_endpoints(node_points)
                                pruned_links = prune_network(node_points, filtered_links, start_end_pts)
                                shortest_path_gdf = find_shortest_path(start_end_pts, pruned_links)
                                classified_links = classify_channels(pruned_links, shortest_path_gdf)
                                valley_center_line = simplify_shortest_path(shortest_path_gdf, num_vertices)
                                channel_belt_cross_sections = create_perpendicular_lines(valley_center_line, num_lines, fraction_length)
                                
                                sinuosity_value = calc_sinuosity(shortest_path_gdf, valley_center_line)
                                cci, updated_cross_sections = calc_channel_count_index(classified_links, channel_belt_cross_sections)
                                cfi_value = calculate_channel_form_index(sinuosity_value, cci)
                                
                                classified_links.to_file(os.path.join(output_folder, 'channel_links.shp'))
                                channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))
                                node_points.to_file(os.path.join(output_folder, 'nodes.shp'))
                                shortest_path_gdf.to_file(os.path.join(output_folder, 'main_channel.shp'))
                                valley_center_line.to_file(os.path.join(output_folder, 'valley_center_line.shp'))
                                
                                # Store metrics
                                reach_key = f"reach_{reach}"
                                if reach_key not in metrics:
                                    metrics[reach_key] = {}
                                metrics[reach_key][year] = {
                                    'Sinuosity': sinuosity_value,
                                    'CCI': cci,
                                    'CFI': cfi_value
                                }
                            except Exception as e:
                                logging.error(f"Error processing file {file_path}: {e}")
                                continue

    # Save metrics to an Excel workbook
    metrics_output_path = os.path.join(root_output, f'{river}_metrics.xlsx')
    with pd.ExcelWriter(metrics_output_path) as writer:
        for reach, reach_metrics in metrics.items():
            df = pd.DataFrame.from_dict(reach_metrics, orient='index')
            df.to_excel(writer, sheet_name=reach)

## Execute code for a river, a reach, or specific years

In [3]:
# Required user inputs

# river is the name of the river. Needs to match the name of the folder in which the mask resides
river = "Brahmaputra"

# Specify the range in years to process. The user can enter either a range of years, a specific year, or "All" to process
# all years for the given reach/reaches. Example inputs are (1997, 2017) for a range of years, 2017 for a single year, or 
# "All" for all years
year_range = "All"

# Specify reach/reaches to process. The user can enter either a range of reaches, a single reach, or "All" to process all
# reaches for the given river. Example inputs are (1, 40) for a range of reaches, 8 for a single reach, or 
# "All" for all reaches
reach_range = (25, 28)

# Optional user inputs
root_input = "C:/Users/huckr/Desktop/UCSB/Dissertation/Data/RiverMapping/RiverMasks"
root_output = "C:/Users/huckr/Desktop/UCSB/Dissertation/Data/RiverMapping/Channels"
radius = 3
num_lines=10
num_vertices=5
fraction_length=1/3

# Process network
process_network_folder(river, 
                           radius, 
                           year_range, 
                           reach_range, 
                           num_lines, 
                           num_vertices, 
                           fraction_length, 
                           root_input, 
                           root_output)

Sinuosity: 1.2372219337513015
Channel Count Index (CCI): 10.6
Channel Form Index (CFI): 0.11671905035389638


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2335158798187458
Channel Count Index (CCI): 12.6
Channel Form Index (CFI): 0.09789808569990047


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2271631251943798
Channel Count Index (CCI): 7.4
Channel Form Index (CFI): 0.16583285475599727


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.215190930048138
Channel Count Index (CCI): 11.9
Channel Form Index (CFI): 0.10211688487799479


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2134020148285742
Channel Count Index (CCI): 14.8
Channel Form Index (CFI): 0.08198662262355232


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.1976217188575813
Channel Count Index (CCI): 9.9
Channel Form Index (CFI): 0.12097189079369508


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2085707840791324
Channel Count Index (CCI): 13.8
Channel Form Index (CFI): 0.08757759304921249


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2223609987470696
Channel Count Index (CCI): 12.1
Channel Form Index (CFI): 0.10102157014438592


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2033042420572442
Channel Count Index (CCI): 7.3
Channel Form Index (CFI): 0.16483619754208825


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2290428228541121
Channel Count Index (CCI): 10.9
Channel Form Index (CFI): 0.11275622228019377


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2276142973623139
Channel Count Index (CCI): 6.7
Channel Form Index (CFI): 0.18322601453168863


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.24888923546072
Channel Count Index (CCI): 7.7
Channel Form Index (CFI): 0.1621934072026909


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.22747542293799
Channel Count Index (CCI): 12.1
Channel Form Index (CFI): 0.10144424982958596


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.1961433827704657
Channel Count Index (CCI): 11.3
Channel Form Index (CFI): 0.10585339670535095


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.2311470043903034
Channel Count Index (CCI): 8.3
Channel Form Index (CFI): 0.1483309643843739


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.210467265652506
Channel Count Index (CCI): 12.0
Channel Form Index (CFI): 0.10087227213770883


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


Sinuosity: 1.236727570460706
Channel Count Index (CCI): 12.4
Channel Form Index (CFI): 0.09973609439199241


  channel_belt_cross_sections.to_file(os.path.join(output_folder, 'channel_belt_cross_sections.shp'))


KeyboardInterrupt: 