# Extract WBMsed data

The following code is meant to take the reach shapefile for a given river, and extract sediment flux metrics from the global WBMsed database. The code looks for the maximum valued pixel within the reach, then creates a best fit line for all of the reaches for each metric. This allows for the ID of outlier pixels (such as those that belong to an adjacent larger river system). The outlier pixels are then removed. Code developed using the Rio Bermejo in Argentina and compared to field data from Repsasch et al. (2020) and WBMsed data is based on model developed and validated in Cohen et al. (2013). Outputs include mean flux of bedload, suspended bedload, washload, and total sediment flux (all in kg per s), as well as mean particle size (in m).

WBMsed model: 
Cohen, S., Kettner, A.J., Syvitski, J.P.M., Fekete, B.M., 2013. WBMsed, a distributed global-scale riverine sediment flux model: Model description and validation. Computers & Geosciences, Modeling for Environmental Change 53, 80–93. https://doi.org/10.1016/j.cageo.2011.08.011

Rio Bermejo case study with sediment flux estimates:
Repasch, M., Wittmann, H., Scheingross, J.S., Sachse, D., Szupiany, R., Orfeo, O., Fuchs, M., Hovius, N., 2020. Sediment Transit Time and Floodplain Storage Dynamics in Alluvial Rivers Revealed by Meteoric 10Be. Journal of Geophysical Research: Earth Surface 125, e2019JF005419. https://doi.org/10.1029/2019JF005419

Based on code developed by Evan Greenberg, PhD:
https://github.com/evan-greenbrg/MeanderMigration/blob/main/GetWBMdata_new_wbm.py

Code Author: James (Huck) Rees; PhD Student, UCSB Geography

Date: October 17, 2024

## Import packages

In [1]:
import os
import glob
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
from rasterio.features import geometry_mask
from shapely.geometry import mapping
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

## Initialize functions

In [2]:
def load_asc_rasters(directory_path):
    """
    Load all .asc raster files from a specified directory and assign EPSG:4326 CRS if undefined.

    Args:
        directory_path (str): Path to the directory containing the .asc files.

    Returns:
        dict: A dictionary with filenames as keys and a tuple (raster data, metadata) as values.
    """
    asc_files = glob.glob(os.path.join(directory_path, '*.asc'))
    
    if not asc_files:
        raise FileNotFoundError(f"No .asc files found in {directory_path}")
    
    rasters = {}
    
    epsg_4326_wkt = 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["degree",0.0174532925199433]]'
    
    for file_path in asc_files:
        with rasterio.open(file_path) as src:
            raster_data = src.read(1)
            metadata = src.meta

            if not metadata.get('crs'):
                print(f"No CRS found for {file_path}. Manually assigning EPSG:4326 (WGS 84).")
                metadata['crs'] = epsg_4326_wkt
            
            file_key = os.path.basename(file_path).replace('.asc', '')
            rasters[file_key] = (raster_data, metadata)
    
    return rasters

def extract_raster_values(rasters, lat, lon):
    """
    Extract values from all rasters in the dictionary at the given latitude and longitude.

    Args:
        rasters (dict): Dictionary of raster data and metadata.
        lat (float): Latitude of the point.
        lon (float): Longitude of the point.

    Returns:
        dict: A dictionary where keys are the raster names and values are the extracted values at the given lat/lon.
    """
    extracted_values = {}

    for raster_name, (raster_data, metadata) in rasters.items():
        transform = metadata['transform']
        row, col = rasterio.transform.rowcol(transform, lon, lat)

        try:
            value = raster_data[row, col]
            extracted_values[raster_name] = value
        except IndexError:
            print(f"Coordinates (lat: {lat}, lon: {lon}) are out of bounds for {raster_name}.")
            extracted_values[raster_name] = None

    return extracted_values

def extract_reach_raster_values(root_dir, river_name, rasters):
    """
    Extract maximum raster values within each polygon reach for a given river.

    Args:
        root_dir (str): Root directory where the river folders are stored.
        river_name (str): Name of the river.
        rasters (dict): Dictionary of raster data and metadata.
    
    Returns:
        pd.DataFrame: DataFrame containing the maximum raster values for each polygon reach.
    """
    shapefile_path = os.path.join(root_dir, river_name, f"{river_name}.shp")

    if not os.path.exists(shapefile_path):
        raise FileNotFoundError(f"Shapefile not found: {shapefile_path}")
    
    gdf = gpd.read_file(shapefile_path)

    if 'ds_order' not in gdf.columns:
        raise ValueError(f"'ds_order' column not found in {shapefile_path}")

    target_crs = "EPSG:4326"
    if gdf.crs != target_crs:
        print(f"Reprojecting {river_name} shapefile from {gdf.crs} to {target_crs}.")
        gdf = gdf.to_crs(target_crs)

    data = []

    for idx, reach in gdf.iterrows():
        reach_id = reach['ds_order']
        geometry = [mapping(reach['geometry'])]

        reach_data = {'ds_order': reach_id}

        for raster_name, (raster_data, metadata) in rasters.items():
            try:
                mask_array = geometry_mask(
                    geometries=geometry,
                    transform=metadata['transform'],
                    invert=True,
                    out_shape=raster_data.shape
                )

                masked_raster = np.where(mask_array, raster_data, np.nan)

                max_value = np.nanmax(masked_raster) if np.any(~np.isnan(masked_raster)) else None
                reach_data[f'max_{raster_name}'] = max_value
            except Exception as e:
                print(f"Error extracting raster values for {raster_name} in reach {reach_id}: {e}")
                reach_data[f'max_{raster_name}'] = None

        data.append(reach_data)

    return pd.DataFrame(data)

def remove_outliers_and_interpolate(df, raster_name, degree=2, std_dev_threshold=2):
    """
    Fit a polynomial curve to the raster data, identify outliers, remove them, and replace the outliers
    using linear interpolation between adjacent upstream and downstream reaches.
    """
    x = df['ds_order'].values
    y = df[raster_name].values
    
    mask = ~np.isnan(y)
    x_valid = x[mask].reshape(-1, 1)
    y_valid = y[mask]
    
    poly = PolynomialFeatures(degree=degree)
    x_poly = poly.fit_transform(x_valid)
    model = LinearRegression()
    model.fit(x_poly, y_valid)
    
    y_pred_initial = model.predict(poly.transform(x.reshape(-1, 1)))
    
    residuals = y - y_pred_initial
    std_dev = np.nanstd(residuals)
    outliers = np.abs(residuals) > std_dev_threshold * std_dev
    
    if not np.any(outliers):
        return df
    
    outlier_indices = np.where(outliers)[0]
    non_outlier_indices = np.where(~outliers)[0]
    
    for i in range(len(outlier_indices)):
        start = outlier_indices[i]
        
        end = start
        while end + 1 in outlier_indices:
            end += 1
            i += 1
        
        upstream_index = non_outlier_indices[non_outlier_indices < start].max() if np.any(non_outlier_indices < start) else None
        downstream_index = non_outlier_indices[non_outlier_indices > end].min() if np.any(non_outlier_indices > end) else None
        
        if upstream_index is None or downstream_index is None:
            continue
        
        upstream_value = y[upstream_index]
        downstream_value = y[downstream_index]
        upstream_ds_order = x[upstream_index]
        downstream_ds_order = x[downstream_index]
        
        for j in range(start, end + 1):
            fraction = (x[j] - upstream_ds_order) / (downstream_ds_order - upstream_ds_order)
            y[j] = upstream_value + fraction * (downstream_value - upstream_value)
    
    df[raster_name] = y
    
    return df

def remove_outliers_and_interpolate_all_rasters(df, degree=2, std_dev_threshold=2):
    """
    Perform outlier detection and linear interpolation for all raster columns in the DataFrame.
    """
    raster_columns = [col for col in df.columns if col != 'ds_order']
    
    for raster_name in raster_columns:
        df = remove_outliers_and_interpolate(df, raster_name, degree, std_dev_threshold)
    
    return df

def extract_wbmsed(root_dir, river_name, raster_dir, degree=2, std_dev_threshold=2):
    """
    Wrapper function to load rasters, extract values for reaches, remove outliers using interpolation,
    rename columns, and export the DataFrame as a CSV file in the root directory.

    Args:
        root_dir (str): Root directory for river shapefiles (also used for saving the CSV output).
        river_name (str): Name of the river.
        raster_dir (str): Directory containing the .asc raster files.
        degree (int): Degree of the polynomial for curve fitting.
        std_dev_threshold (float): Threshold for identifying outliers (in standard deviations).

    Returns:
        pd.DataFrame: DataFrame with raster values, outliers removed, interpolated, and renamed columns.
    """
    # Step 1: Load rasters
    rasters = load_asc_rasters(raster_dir)
    
    # Step 2: Extract raster values for each reach
    df = extract_reach_raster_values(root_dir, river_name, rasters)
    
    # Step 3: Remove outliers and interpolate for all raster columns
    df_corrected = remove_outliers_and_interpolate_all_rasters(df, degree, std_dev_threshold)
    
    # Step 4: Rename columns (except 'ds_order')
    for col in df_corrected.columns:
        if col != 'ds_order':
            # Extract everything between the second and third underscores
            parts = col.split("_")
            if len(parts) >= 3:
                metric = parts[2]  # Take the part after the second underscore
            if len(parts) > 3:
                # If there's a third underscore, ignore everything after it
                metric = parts[2]  # Only use the second part
            
            # Determine the appropriate unit based on the metric
            if "Flux" in metric:
                unit = "kg_s"  # For flux metrics
            elif "Size" in metric:
                unit = "m"  # For particle size metrics
            elif "Discharge" in metric:
                unit = "cms" # For discharge
            else:
                unit = ""

            # Rename the column with the new format: mean_<metric>_<unit>
            new_col_name = f"mean_{metric}_{unit}".strip("_")
            df_corrected.rename(columns={col: new_col_name}, inplace=True)
    
    # Step 5: Export the DataFrame as a CSV file to the root_directory
    csv_output_path = os.path.join(root_dir, f"{river_name}_wbmsed.csv")
    df_corrected.to_csv(csv_output_path, index=False)
    
    print(f"DataFrame exported to {csv_output_path}")
    
    return df_corrected


## Enter input arguments and run

In [3]:
reach_directory = r"D:\Dissertation\Data\RiverMapping\Reaches"
river_name = 'Beni'
raster_directory = r"D:\Dissertation\Data\WBMsed\RivMapperASCs"

df_final = extract_wbmsed(reach_directory, river_name, raster_directory)

No CRS found for D:\Dissertation\Data\WBMsed\RivMapperASCs\Global_BedloadFlux_4p4p1+Dist_06min_aTS1990-2019.asc. Manually assigning EPSG:4326 (WGS 84).
No CRS found for D:\Dissertation\Data\WBMsed\RivMapperASCs\Global_Discharge_4p4p1+Dist_06min_aTS1990-2019.asc. Manually assigning EPSG:4326 (WGS 84).
No CRS found for D:\Dissertation\Data\WBMsed\RivMapperASCs\Global_ParticleSize_4p4p1+Dist_06min_aTS1990-2019.asc. Manually assigning EPSG:4326 (WGS 84).
No CRS found for D:\Dissertation\Data\WBMsed\RivMapperASCs\Global_SedimentFlux_4p4p1+Dist_06min_aTS1990-2019.asc. Manually assigning EPSG:4326 (WGS 84).
No CRS found for D:\Dissertation\Data\WBMsed\RivMapperASCs\Global_SuspendedBedFlux_4p4p1+Dist_06min_aTS1990-2019.asc. Manually assigning EPSG:4326 (WGS 84).
No CRS found for D:\Dissertation\Data\WBMsed\RivMapperASCs\Global_WashloadFlux_4p4p1+Dist_06min_aTS1990-2019.asc. Manually assigning EPSG:4326 (WGS 84).
Reprojecting Beni shapefile from PROJCS["WGS_84_World_Mercator",GEOGCS["WGS 84",DA