In [1]:
import xarray as xr
import numpy as np
import rioxarray
import geopandas as gpd
from ipyleaflet import Map, ImageOverlay, WidgetControl, LayersControl, basemaps, GeoJSON, TileLayer, Heatmap, Marker, Popup
import ipywidgets as widgets
from utils import leaflet_bounds, scalar_to_base64_image, find_intersections
import pandas as pd
from shapely.geometry import Point

In [28]:
# -------------------------------
# Data Paths
# -------------------------------
download_path = "./data_download"
intermediate_path = "./data_intermediate"
# terrain_path = f"{download_path}/data/cv_terrain.tiff"
vector_rivers = f"{download_path}/data/shp/cv_rivers.geojson"
vector_subbasin = f"{download_path}/data/shp/subbasins_cv_clip.geojson"
points_resistivity = f"{download_path}/data/aem/em_resistivity.csv"
metric_zarr_path = f"{intermediate_path}/consolidated_metric_output.zarr"


target_epsg = 4326
center = [37.66335291403956, -120.69523554193438]
zoom = 7
# basemap = basemaps.CartoDB.Positron 
basemap = None
map_width = '500px'
map_height = '800px'

In [29]:
# -------------------------------
# Data Loading
# -------------------------------

# Load vector layers with GeoPandas
rivers = gpd.read_file(vector_rivers)
rivers = rivers.to_crs(epsg=target_epsg)

def combine_rivers_gdf(river_gdf, name_column='GNIS_Name'):
    """
    Combine river fragments with the same name into single features.
    
    Parameters:
    -----------
    river_gdf : GeoDataFrame with river LineStrings
    name_column : str, column containing river names
    
    Returns:
    --------
    GeoDataFrame with dissolved geometries, 'name' column
    """
    rivers_combined = river_gdf.dissolve(by=name_column).reset_index()
    # rivers_combined = rivers_combined.rename(columns={name_column: 'name'})
    return rivers_combined[[name_column, 'geometry']]

rivers = combine_rivers_gdf(rivers)

subbasins = gpd.read_file(vector_subbasin)
subbasins = subbasins.to_crs(epsg=target_epsg)

# find outer basin intersections with rivers
river_intersections = find_intersections(rivers, subbasins)

df = pd.read_csv(points_resistivity)
# Create GeoDataFrame from the UTM coordinates
resistivity_profiles = gpd.GeoDataFrame(
    df,
    geometry=[Point(x, y) for x, y in zip(df['UTMX'], df['UTMY'])],
    crs='EPSG:3310'  # TODO: Confirm EPSG!!!
)
resistivity_profiles = resistivity_profiles.to_crs(f'EPSG:{target_epsg}')
# for now resample all points 
# TODO: Discuss better performance options
resistivity_profiles = resistivity_profiles.sample(10000)

# Load consolidated metric dataset
ds = xr.open_zarr(metric_zarr_path)
ds = ds.transpose('fraction', 'y', 'x')
ds = ds.sortby('y', ascending=False)
ds = ds.sortby('x', ascending=True)
ds.rio.write_crs(3310, inplace=True)
ds_reprojected = ds.rio.reproject(f"EPSG:{target_epsg}")

## TESTING: Simplify geometries before building layers to improve performance

# Simplify geometries (tolerance in degrees, ~0.001 = ~100m)
subbasins['geometry'] = subbasins.geometry.simplify(tolerance=0.001, preserve_topology=True)
rivers['geometry'] = rivers.geometry.simplify(tolerance=0.001, preserve_topology=True)


In [30]:
# Create layers
## Terrain context
# Ocean basemap (includes bathymetry)
l_ocean = TileLayer(
    url='https://server.arcgisonline.com/ArcGIS/rest/services/Ocean/World_Ocean_Base/MapServer/tile/{z}/{y}/{x}',
    name='Ocean/Water',
    opacity=1.0
)
# Hillshade with water areas
l_elevation = TileLayer(
    url='https://server.arcgisonline.com/ArcGIS/rest/services/Elevation/World_Hillshade/MapServer/tile/{z}/{y}/{x}',
    name='Hillshade',
    opacity=1.0
)

# Vector layers
l_rivers = GeoJSON(
    data=rivers.__geo_interface__,
    style={'color': 'blue', 'weight': 1, 'opacity': 0.7},
    name="Rivers"
)

l_subbasins = GeoJSON(
    data=subbasins.__geo_interface__,
    style={'color': 'black', 'weight': 1, 'fill':False},
    name="Subbasins"
)

# Marker layers
# Add intersection markers with tooltips
river_inflow_layers = []
for idx, row in river_intersections.iterrows():
    marker = Marker(
        location=(row.geometry.y, row.geometry.x),
        title=row['river_name'],
        draggable=False
    )
    popup = Popup(
        child=widgets.HTML(value=f"<b>{row['river_name']}</b>"),
        close_button=False,
        auto_close=False
    )
    marker.popup = popup
    river_inflow_layers.append(marker)

# intersections_layer = GeoJSON(
#     data=river_intersections.__geo_interface__,
#     point_style={'radius': 8, 'color': 'red', 'fillColor': 'orange', 'fillOpacity': 0.8},
#     name="River Exits"
# )

# Point layers
l_resistivity = GeoJSON(
    data=resistivity_profiles.__geo_interface__,
    point_style={
        'radius': 0.01,
        'color': 'red',
        'fillColor': 'red',
        'fillOpacity': 0.6,
        'weight': 0.1
    },
    name='Resistivity Profiles (subsampled)',
)

l_sediment = GeoJSON(
    data=resistivity_profiles.__geo_interface__,
    point_style={
        'radius': 0.01,
        'color': 'brown',
        'fillColor': 'brown',
        'fillOpacity': 0.6,
        'weight': 0.1
    },
    name='Sediment Type Profiles (subsampled)',
)

# Heatmap layers
resistivity_locations = [[point.y, point.x] for point in resistivity_profiles.geometry]
l_resistivity_heatmap = Heatmap(
    locations=resistivity_locations,
    radius=5,
    blur=2,
    name='Resistivity Heatmap (subsampled)'
)

l_sediment_heatmap = Heatmap(
    locations=resistivity_locations,  # Using same data for now
    radius=10,
    blur=2,
    gradient={0.4: 'blue', 0.6: 'cyan', 0.7: 'lime', 0.8: 'yellow', 1.0: 'red'},
    name='Sediment Heatmap (subsampled)'
)

In [31]:
#| label: interactive:fig-1

##########
# Figure 1
##########
m = Map(
    center=center, 
    zoom=zoom, 
    # basemap=basemap, #TODO fix
    scroll_wheel_zoom=True,
    layout=widgets.Layout(width=map_width, height=map_height)
)

# Dictionary mapping dropdown options to layer objects
layer_map = {layer.name: layer for layer in [
    l_resistivity,
    l_resistivity_heatmap,
    l_sediment,
    l_sediment_heatmap
]}
init_key = list(layer_map.keys())[0]

m.add_layer(l_ocean)
m.add_layer(l_elevation)
m.add_layer(l_rivers)
m.add_layer(l_subbasins)
m.add_layer(l_resistivity)

# Create dropdown to switch between layers
layer_dropdown = widgets.Dropdown(
    options=list(layer_map.keys()),
    value=init_key,
    description='Data Layer:',
    style={'description_width': 'initial'}
)

# Current active layer
current_layer = layer_map[init_key]

def on_layer_change(change):
    """Handle layer selection change"""
    global current_layer
    new_layer_name = change['new']
    new_layer = layer_map[new_layer_name]
    
    # Remove current layer and add new one
    m.remove_layer(current_layer)
    m.add_layer(new_layer)
    current_layer = new_layer

layer_dropdown.observe(on_layer_change, names='value')

# Add dropdown control to map
widget_control = WidgetControl(widget=layer_dropdown, position='topright')
m.add_control(widget_control)
m.add_control(LayersControl(position='topleft'))
m

Map(center=[37.66335291403956, -120.69523554193438], controls=(ZoomControl(options=['position', 'zoom_in_text'…

In [33]:
from ipyleaflet import Map, GeoJSON, WidgetControl, LayersControl, basemaps
import ipywidgets as widgets
import geopandas as gpd
from IPython.display import display
from ipywidgets import jslink

class DualMapController:
    """Controller for synchronized dual maps with interactive controls"""
    
    def __init__(self, width, height, center, zoom, subbasins, subbasin_column):
        self.subbasins = subbasins
        self.subbasin_column = subbasin_column
        
        # Create maps
        layout = widgets.Layout(width=width, height=height)
        self.m1 = Map(center=center, zoom=zoom, basemap=basemaps.CartoDB.Positron, layout=layout)
        self.m2 = Map(center=center, zoom=zoom, basemap=basemaps.CartoDB.Positron, layout=layout)
        
        # Sync views
        jslink((self.m1, 'center'), (self.m2, 'center'))
        jslink((self.m1, 'zoom'), (self.m2, 'zoom'))
        
        # Store references to dynamic layers
        self.highlight_style = {
            'color': 'orange', 
            'weight': 3, 
            'opacity': 0.8,
            'fillColor': 'orange', 
            'fillOpacity': 0.1,
            'dashArray': '5, 5'
        }
        
        # Initialize empty highlight - don't add to map yet
        self.current_highlight = None
        self.debug_output = widgets.Output()
        
    def add_base_layers(self, rivers, subbasins, fcd_layer, scalar_overlay):
        """Add the base layers that appear in LayersControl"""
        self.m1.add_layer(rivers)
        self.m2.add_layer(rivers)
        self.m1.add_layer(subbasins)
        self.m2.add_layer(subbasins)
        self.m1.add_layer(fcd_layer)
        self.m2.add_layer(scalar_overlay)
        
        # Store scalar overlay reference for updates
        self.scalar_overlay = scalar_overlay
        
    def add_controls(self):
        """Add LayersControl - call this AFTER base layers"""
        self.m1.add_control(LayersControl(position='bottomleft', collapsed=False))
        self.m2.add_control(LayersControl(position='bottomleft', collapsed=False))
        
    def create_dataset_selector(self, ds, initial_dataset):
        """Create dataset selection dropdown for right map"""
        self.ds = ds
        self.current_dataset_name = initial_dataset
        self.da_scalar = ds[initial_dataset]
        
        # Dataset dropdown
        dataset_dropdown = widgets.Dropdown(
            options=list([v for v in ds.data_vars if v not in ['fraction_coarse']]),
            value=initial_dataset,
            description='Dataset:',
            style={'description_width': 'initial'}
        )
        dataset_dropdown.observe(self._on_dataset_change, names='value')
        
        # Threshold slider
        self.slider = widgets.SelectionSlider(
            options=ds.fraction.values,
            value=0.1,
            description='FCD Threshold',
            style={'description_width': 'initial'}
        )
        self.slider.observe(self._on_threshold_change, names='value')
        
        # Combine and add to map
        controls = widgets.VBox([dataset_dropdown, self.slider])
        widget_control = WidgetControl(widget=controls, position='topright')
        self.m2.add_control(widget_control)
        
    def _on_dataset_change(self, change):
        """Handle dataset selection"""
        self.current_dataset_name = change['new']
        self.da_scalar = self.ds[self.current_dataset_name]
        self._update_scalar_overlay()
        
    def _on_threshold_change(self, change):
        """Handle threshold change"""
        self._update_scalar_overlay()
        
    def _update_scalar_overlay(self):
        """Update scalar overlay with new data"""
        threshold = self.slider.value
        if 'fraction' in self.da_scalar.dims:
            da_overlay = self.da_scalar.sel(fraction=threshold)
        else:
            da_overlay = self.da_scalar
            
        self.scalar_overlay.url = scalar_to_base64_image(
            da_overlay,
            cmap='Greens',
            vmin=float(np.nanpercentile(self.da_scalar.values, 0.1)),
            vmax=float(np.nanpercentile(self.da_scalar.values, 0.9))
        )
        
    def create_subbasin_selector(self, center, zoom):
        """Create subbasin selection dropdown"""
        self.default_center = center
        self.default_zoom = zoom
        
        subbasin_names = ['All Regions'] + self.subbasins[self.subbasin_column].tolist()
        dropdown = widgets.Dropdown(
            options=subbasin_names,
            description='Subbasin:',
            style={'description_width': 'initial'}
        )
        dropdown.observe(self._on_subbasin_change, names='value')
        
        control = WidgetControl(widget=dropdown, position='topright')
        self.m1.add_control(control)
        
    def _on_subbasin_change(self, change):
        """Handle subbasin selection - use transient highlight overlay"""
        selected_name = change['new']
        
        # Remove previous highlight if exists
        if self.current_highlight is not None:
            self.m1.remove_layer(self.current_highlight)
            self.m2.remove_layer(self.current_highlight)
            self.current_highlight = None
        
        if selected_name == 'All Regions':
            self.m1.center = self.default_center
            self.m1.zoom = self.default_zoom
        else:
            selected_subbasin = self.subbasins[
                self.subbasins[self.subbasin_column] == selected_name
            ]
            
            if selected_subbasin.empty:
                with self.debug_output:
                    print("No matching subbasin found!")
                return
            
            # Zoom to bounds
            bounds = selected_subbasin.total_bounds
            self.m1.fit_bounds([[bounds[1], bounds[0]], [bounds[3], bounds[2]]])
            
            # Create temporary highlight with name
            self.current_highlight = GeoJSON(
                data=selected_subbasin.__geo_interface__,
                style=self.highlight_style,
                name="Region Highlight"
            )
            self.m1.add_layer(self.current_highlight)
            self.m2.add_layer(self.current_highlight)
            
    def display(self):
        """Display the dual map setup"""
        # display(self.debug_output)
        display(widgets.HBox([self.m1, self.m2]))


# Initialize controller
controller = DualMapController(
    width=map_width,
    height=map_height,
    center=[37.66335291403956, -120.69523554193438],
    zoom=6,
    subbasins=subbasins,
    subbasin_column="Basin_Su_1"
)

# Add base layers
fcd_ave = ds_reprojected['fraction_coarse']
fcd_layer = ImageOverlay(
    url=scalar_to_base64_image(fcd_ave, cmap='RdBu_r', vmin=0, vmax=1),
    bounds=leaflet_bounds(fcd_ave),
    opacity=1.0,
    name="Fraction Coarse Dominated [%]",
)

current_dataset_name = "path_length_norm"
da_scalar = ds_reprojected[current_dataset_name].sel(fraction=0.2).load()
scalar_overlay = ImageOverlay(
    url=scalar_to_base64_image(da_scalar, cmap='Greens'),
    bounds=leaflet_bounds(da_scalar),
    opacity=1.0,
    name="Scalar Data"
)

controller.add_base_layers(l_rivers, l_subbasins, fcd_layer, scalar_overlay)
controller.add_controls()  # Add LayersControl AFTER base layers
controller.create_dataset_selector(ds_reprojected, current_dataset_name)
controller.create_subbasin_selector(
    center=[37.66335291403956, -120.69523554193438],
    zoom=6
)
controller.display()

HBox(children=(Map(center=[37.66335291403956, -120.69523554193438], controls=(ZoomControl(options=['position',…

In [34]:
controller.m1.zoom = 7

In [9]:
# DEBUG: Minimal test for LayersControl
test_m1 = Map(center=[37.66, -120.69], zoom=6, basemap=basemaps.CartoDB.Positron)
test_m2 = Map(center=[37.66, -120.69], zoom=6, basemap=basemaps.CartoDB.Positron)

# Add a simple named layer
test_layer = GeoJSON(data=rivers.__geo_interface__, name="Test Rivers")
test_m1.add_layer(test_layer)
test_m2.add_layer(test_layer)

# Add LayersControl
test_m1.add_control(LayersControl(position='topleft'))
test_m2.add_control(LayersControl(position='topleft'))

display(widgets.HBox([test_m1, test_m2]))

HBox(children=(Map(center=[37.66, -120.69], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_i…