In [None]:
import ee
import geopandas as gpd
from shapely.geometry import shape
import geemap
import folium

def get_basin_and_rivers(hydrobasins_col, rivers_col, coordinates):
    point = ee.Geometry.Point(coordinates)
    basin = hydrobasins_col.filterBounds(point)
    rivers = rivers_col.filterBounds(basin.geometry())
    print('Intersecting Basin:', basin.getInfo())
    return basin, rivers

def get_river_elev_slope(dem_img, basin, rivers):
    def get_min_max(basin, var, label):
        var_min_max = var.reduceRegion(
            reducer=ee.Reducer.minMax(),
            geometry=basin,  
            scale=30,
            maxPixels=1e9
        )
        print(f'{label} range:', var_min_max.getInfo())
        return
    elevation_rivers = dem_img.clip(rivers)
    slopes_rivers = ee.Terrain.slope(elevation_rivers)
    get_min_max(basin, elevation_rivers, 'Elevation')
    get_min_max(basin, slopes_rivers, 'Slopes')
    return elevation_rivers, slopes_rivers

def map_basin_river(Map, basin, rivers):
    visualization_basin = {
        'color': 'orange', 
        'strokeWidth': 1
    }
    visualization_rivers = {
        'color': 'blue', 
        'strokeWidth': 1
    }
    Map.addLayer(basin, visualization_basin, 'Basin')
    Map.addLayer(rivers, visualization_rivers, 'Rivers')
    Map.centerObject(basin, 12)
    return Map

def map_elevation_slope(elevation_rivers, slopes_rivers, Map):
    visualization_elevation = {
        'min': 0,  
        'max': 5000,  
        'palette': ['blue', 'green', 'yellow', 'brown', 'white']
    }
    visualization_slopes = {
    'min': 11,  
    'max': 28,  
    'palette': ['blue', 'green', 'yellow', 'brown', 'white']  
    }
    Map.addLayer(elevation_rivers, visualization_elevation, 'Elevation rivers')
    Map.addLayer(slopes_rivers, visualization_slopes, 'Slope rivers')
    return Map

def feature_collection_to_gdf(feature_collection, crs):
    # Step 1: Get the features from the FeatureCollection as a dictionary
    features = feature_collection.getInfo()['features']
    # Step 2: Convert the dictionary into a list of geometries and properties
    data = []
    for feature in features:
        # Safely retrieve geometry and properties
        geometry = feature.get('geometry', {})
        properties = feature.get('properties', {})  # Handle missing properties
        if geometry:
            # Convert the geometry to a shapely object
            shapely_geom = shape(geometry)
            data.append({'geometry': shapely_geom, **properties})
    # Step 3: Convert the list of dictionaries into a GeoDataFrame
    gdf = gpd.GeoDataFrame(data)
    if crs:
        gdf.set_crs(crs, inplace=True)
    return gdf

def get_length(rivers, gdf):
    # Function to calculate the straight-line distance for each reach
    def calculate_straight_line_distance(rivers, reach_id):
        # Filter the feature collection by REACH_ID
        reach_feature = rivers.filter(ee.Filter.eq('REACH_ID', reach_id)).first()
        
        # Get the coordinates of the start and end points; IS THIS RIGHT?
        coordinates = reach_feature.geometry().coordinates()
        start_point = ee.Geometry.Point(coordinates.get(0))
        end_point = ee.Geometry.Point(coordinates.get(-1))
        
        # Calculate the straight-line distance in kilometers
        straight_line_distance = start_point.distance(end_point).divide(1000)  # Convert meters to kilometers
        
        return straight_line_distance.getInfo()
    
    gdf['STRAIGHT_LINE_DISTANCE_KM'] = gdf.apply(lambda row: calculate_straight_line_distance(rivers, row['REACH_ID']), axis=1)
    return gdf

def get_mean_elevation(elevation, rivers, gdf):
    # Function to calculate the straight-line distance for each reach
    def calculate_mean_elevation(rivers, reach_id):
        reach_feature = ee.Feature(rivers.filter(ee.Filter.eq('REACH_ID', reach_id)).first())
        geometry = reach_feature.geometry()
        clipped_elevation = elevation.clip(geometry)

        mean_elevation = clipped_elevation.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=geometry,
            scale=30
        ).get('elevation')
        return mean_elevation.getInfo()
    
    gdf['ELEVATION'] = gdf.apply(lambda row: calculate_mean_elevation(rivers, row['REACH_ID']), axis=1)
    return gdf

def get_mean_slope(elevation, rivers, gdf):
    # Function to calculate the straight-line distance for each reach
    def calculate_mean_slope(rivers, slope, reach_id):
        reach_feature = ee.Feature(rivers.filter(ee.Filter.eq('REACH_ID', reach_id)).first())
        geometry = reach_feature.geometry()
        clipped_slope = slope.clip(geometry)

        mean_slope = clipped_slope.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=geometry,
            scale=30
        ).get('slope')
        return mean_slope.getInfo()
    slope = ee.Terrain.slope(elevation)
    gdf['SLOPE'] = gdf.apply(lambda row: calculate_mean_slope(rivers, slope, row['REACH_ID']), axis=1)
    return gdf

def get_sinuosity(gdf):
    gdf['SINUOSITY'] = gdf['LENGTH_KM'] / gdf['STRAIGHT_LINE_DISTANCE_KM']
    gdf['SINUOSITY'] = gdf['SINUOSITY'].apply(lambda x: max(x, 1)) #TODO: Check with Natashs if sinuosity always needs to be over 1.
    return gdf

def plot_sinuosity(gdf, threshold):
    def my_colormap(value):  
        if value >= threshold:
            return "green"
        return "red"
    return gdf.explore(color=gdf['SINUOSITY'].apply(my_colormap))

def plot_csi(gdf, threshold):
    def my_colormap(value):  
        if value >= threshold:
            return "green"
        return "red"
    return gdf.explore(color=gdf['CSI'].apply(my_colormap))

def plot_csi(gdf, threshold):
    def my_colormap(value):
        if value >= threshold:
            return "green"
        return "red"

    # Generate a color map for the FPZs column
    gdf['color'] = gdf['CSI'].apply(my_colormap)

    # Create a legend mapping
    legend_dict = {
        "CSI>=99": "green",
        "CSI<99": "red",
    }
    # Add the legend information as HTML, with proper sizing and spacing
    legend_html = """
    <div style='
        position: fixed; 
        top: 10px; 
        right: 10px; 
        width: 220px; 
        z-index: 1000; 
        background-color: white; 
        border: 1px solid black; 
        border-radius: 5px; 
        padding: 10px; 
        font-size: 14px; 
        box-shadow: 2px 2px 5px rgba(0,0,0,0.3);
        line-height: 1.5;'>
        <b>Legend:</b><br>
    """
    for name, color in legend_dict.items():
        legend_html += f"<i style='background:{color}; width:12px; height:12px; display:inline-block; margin-right:8px;'></i>{name}<br>"
    legend_html += "</div>"

    # Add the legend to the map
    m = gdf.explore(color=gdf['color'])
    m.get_root().html.add_child(folium.Element(legend_html))

    return m

def plot_fpz(gdf):
    def my_colormap(value):  
        if value == 'Lowland Alluvial':
            return "blue"
        elif value == 'Open-valley Mid-altitude':
            return "pink"
        elif value == 'Open-Valley Highland':
            return "orange"
        elif value == 'Highland High-Energy':
            return "green"
        return "red"
    return gdf.explore(color=gdf['FPZs'].apply(my_colormap))

def plot_fpz(gdf):
    def my_colormap(value):
        if value == 'Lowland Alluvial':
            return "blue"
        elif value == 'Open-valley Mid-altitude':
            return "pink"
        elif value == 'Open-Valley Highland':
            return "orange"
        elif value == 'Highland High-Energy':
            return "green"
        return "red"

    # Generate a color map for the FPZs column
    gdf['color'] = gdf['FPZs'].apply(my_colormap)

    # Create a legend mapping
    legend_dict = {
        "Lowland Alluvial": "blue",
        "Open-valley Mid-altitude": "pink",
        "Open-Valley Highland": "orange",
        "Highland High-Energy": "green",
        "No classification": "red"
    }

    # Add the legend information as HTML, with proper sizing and spacing
    legend_html = """
    <div style='
        position: fixed; 
        top: 10px; 
        right: 10px; 
        width: 220px; 
        z-index: 1000; 
        background-color: white; 
        border: 1px solid black; 
        border-radius: 5px; 
        padding: 10px; 
        font-size: 14px; 
        box-shadow: 2px 2px 5px rgba(0,0,0,0.3);
        line-height: 1.5;'>
        <b>Legend:</b><br>
    """
    for name, color in legend_dict.items():
        legend_html += f"<i style='background:{color}; width:12px; height:12px; display:inline-block; margin-right:8px;'></i>{name}<br>"
    legend_html += "</div>"

    # Add the legend to the map
    m = gdf.explore(color=gdf['color'])
    m.get_root().html.add_child(folium.Element(legend_html))

    return m

def get_classification(gdf):
    classifications = [
        {
            'name': 'Lowland Alluvial',
            'ELEVATION': lambda ele: ele < 200,
            'SLOPE': lambda slo: slo < 2,
            'SINUOSITY': lambda sin: sin > 1.4
        },
        {
            'name': 'Open-valley Mid-altitude',
            'ELEVATION': lambda ele: 200 <= ele <= 800,
            'SLOPE': lambda slo: 2 <= slo <= 4,
            'SINUOSITY': lambda sin: sin > 1.2
        },
        {
            'name': 'Open-Valley Highland',
            'ELEVATION': lambda ele: ele > 800,
            'SLOPE': lambda slo: 4 <= slo <= 10,
            'SINUOSITY': lambda sin: 1.0 <= sin <= 1.2
        },
        {
            'name': 'Highland High-Energy',
            'ELEVATION': lambda ele: ele > 800,
            'SLOPE': lambda slo: slo > 10,
            'SINUOSITY': lambda sin: 1.0 <= sin <= 1.1
        }
    ]

    def classify_row(row):
        for classification in classifications:
            if (classification['ELEVATION'](row['ELEVATION']) and
                classification['SLOPE'](row['SLOPE']) and
                classification['SINUOSITY'](row['SINUOSITY'])):
                return classification['name']
        return 'Unclassified'  

    gdf['FPZs'] = gdf.apply(classify_row, axis=1)
    return gdf

def get_fpz(hydrobasins_col, rivers_col, dem_img, coordinates):
    basin, rivers = get_basin_and_rivers(hydrobasins_col, rivers_col, coordinates)
    gdf_rivers = feature_collection_to_gdf(rivers, crs="EPSG:4326") 
    gdf_rivers = get_length(rivers, gdf_rivers)
    gdf_rivers = get_sinuosity(gdf_rivers)
    gdf_rivers = get_mean_elevation(dem_img, rivers, gdf_rivers)
    gdf_rivers = get_mean_slope(dem_img, rivers, gdf_rivers)
    gdf_rivers = get_classification(gdf_rivers)
    gdf_rivers = gdf_rivers[['REACH_ID', 'LENGTH_KM','STRAIGHT_LINE_DISTANCE_KM','SINUOSITY','ELEVATION','SLOPE', 'FPZs', 'CSI', 'CSI_FF1', 'CSI_FF2', 'CSI_FFID', 'geometry']]
    gdf_rivers.to_file("fpzs_output.geojson", driver="GeoJSON")
    return basin, rivers, gdf_rivers 

def plot_elevation_slope_river_network(dem_img, basin, rivers):
    elevation_rivers, slopes_rivers = get_river_elev_slope(dem_img, basin, rivers)
    Map = geemap.Map()
    Map = map_basin_river(Map, basin, rivers)
    Map = map_elevation_slope(elevation_rivers, slopes_rivers, Map)
    return Map


In [None]:
# ee.Authenticate()
ee.Initialize()

hydrobasins_col = ee.FeatureCollection('WWF/HydroSHEDS/v1/Basins/hybas_12')
rivers_col = ee.FeatureCollection("WWF/HydroSHEDS/v1/FreeFlowingRivers")
dem_img = ee.Image('USGS/SRTMGL1_003').select('elevation')

coord_Aracataca = [-73.9122858271122, 10.649659492838959]
coord_Fundacion = [-74.1775961943024, 10.507908590834537]

basin, rivers, gdf_rivers = get_fpz(hydrobasins_col, rivers_col, dem_img, coord_Aracataca)
# plot_elevation_slope_river_network(dem_img, basin, rivers)
# plot_cci(gdf_rivers, 99)
# plot_sinuosity(gdf_rivers,1)
# plot_fpz(gdf_rivers)