In [7]:
import xarray as xr
import numpy as np
import rioxarray
import geopandas as gpd
from ipyleaflet import Map, ImageOverlay, GeoData, WidgetControl, LayersControl, basemaps
from PIL import Image
from io import BytesIO
import base64
import ipywidgets as widgets
from matplotlib.colors import Normalize
import matplotlib.cm as cm


In [8]:
# -----------------
# Helper functions: 
# -----------------

# Convert bounds convention from rioxarray to ipyleaflet
def rioxarray_to_leaflet_bounds(da):
    """
    Convert rioxarray bounds to ipyleaflet bounds format.
    
    Parameters:
    - da: rioxarray DataArray
        The input raster data.
    
    Returns:
    - bounds: list
        Bounds in ipyleaflet format [[min_lat, min_lon], [max_lat, max_lon]].
    """
    rio_bounds = da.rio.bounds()  # (left, bottom, right, top)
    return [[rio_bounds[1], rio_bounds[0]], [rio_bounds[3], rio_bounds[2]]]

# Convert RGBA DataArray to base64 image

def transpose_da(da):
    """Transpose DataArray from any order to (..., y, x)"""
    dims = da.dims
    other_dims = [d for d in dims if d not in ('y', 'x')]
    new_order = other_dims + ['y', 'x']
    return da.transpose(*new_order)

def rgba_to_base64_image(da_rgba):
    """Convert 4-band RGBA DataArray to base64 PNG"""
    da_rgba = transpose_da(da_rgba)
    rgba_arr = da_rgba.values  # shape: (4, y, x)
    rgba_arr = np.transpose(rgba_arr, (1, 2, 0))
    
    if rgba_arr.dtype != np.uint8:
        rgba_arr = rgba_arr.astype(np.uint8)
    
    img = Image.fromarray(rgba_arr, mode='RGBA')
    
    buffer = BytesIO()
    img.save(buffer, format="PNG")
    img_str = "data:image/png;base64," + base64.b64encode(buffer.getvalue()).decode()
    return img_str

# -------------------------------
# Helper: convert scalar DataArray to colormap overlay
# -------------------------------

def scalar_to_base64_image(da, cmap='viridis', vmin=None, vmax=None):
    """Convert scalar DataArray to colored base64 PNG with transparency for masked values"""
    da = transpose_da(da)
    arr = da.values
    
    if vmin is None:
        vmin = np.nanmin(arr)
    if vmax is None:
        vmax = np.nanmax(arr)

    print(f"DEBUG:{vmax=}, {vmin=}")
    
    norm = Normalize(vmin=vmin, vmax=vmax)
    mapper = cm.ScalarMappable(norm=norm, cmap=cmap)
    
    rgba = mapper.to_rgba(arr, bytes=False)  # shape: (y, x, 4)
    
    mask = np.isnan(arr)
    rgba[mask, 3] = 0
    
    rgba_uint8 = (rgba * 255).astype(np.uint8)
    
    img = Image.fromarray(rgba_uint8, mode='RGBA')
    
    buffer = BytesIO()
    img.save(buffer, format="PNG")
    img_str = "data:image/png;base64," + base64.b64encode(buffer.getvalue()).decode()
    return img_str

In [None]:
#| label: nb:map-one
# -------------------------------
# Data Paths
# -------------------------------
terrain_path = "/Users/juliusbusecke/Library/Caches/em_recharge/data_download/data/cv_terrain.tiff"
vector_path = "/Users/juliusbusecke/Library/Caches/em_recharge/data_download/data/shp/cv_rivers.geojson"
metric_zarr_path = "consolidated_metric_output.zarr"

# -------------------------------
# Data Loading
# -------------------------------

# Load vector layer with GeoPandas
gdf = gpd.read_file(vector_path)
gdf = gdf.to_crs(epsg=4326)

# Load RGBA terrain raster
da_terrain = rioxarray.open_rasterio(terrain_path)  # shape: (band=4, y, x)
da_terrain = da_terrain.rio.reproject("EPSG:4326")

# Load consolidated metric dataset
ds = xr.open_zarr(metric_zarr_path)

# -------------------------------
# Until crs/transform issue is resolved, interpolate terrain to match ds dimensions
# just to get a rough overlay working.
# -------------------------------

# Interpolate the terrain coordinates to match the number of elements in ds coordinates
lat_dim = 'y'
lon_dim = 'x'

# Interpolating the terrain coordinates to match the ds coordinate sizes
# Todo: This is a temporary workaround. Proper reprojection should be implemented.
#Todo: More robust handling of increasing vs decreasing coordinates
interp_lat = np.linspace(da_terrain[lat_dim].min(), da_terrain[lat_dim].values.max(), ds[lat_dim].size)
interp_lon = np.linspace(da_terrain[lon_dim].min(), da_terrain[lon_dim].values.max(), ds[lon_dim].size)

# Assign the interpolated coordinates to the dataset
ds = ds.assign_coords({
    lat_dim: interp_lat,
    lon_dim: interp_lon
})

ds = ds.sortby(lat_dim, ascending=False)
ds = ds.sortby(lon_dim, ascending=True)


# make up the bounds manually for now (Todo: Replace with correct bounds once crs/transform issue is resolved)
lat_dim, lon_dim = 'y', 'x'
center = [float(ds[lat_dim].mean()), float(ds[lon_dim].mean())]
# mock_bounds = [[float(ds[lat_dim].min()), float(ds[lon_dim].min())],
#           [float(ds[lat_dim].max()), float(ds[lon_dim].max())]]
mock_bounds = [
    [float(ds[lat_dim].isel({lat_dim: -1})), float(ds[lon_dim].isel({lon_dim: 0}))],
    [float(ds[lat_dim].isel({lat_dim: 0})), float(ds[lon_dim].isel({lon_dim: -1}))]
]

# -------------------------------
# Create map
# -------------------------------
m = Map(
    center=[37.66335291403956, -120.69523554193438], 
    zoom=5, 
    basemap=basemaps.CartoDB.Positron, 
    # scroll_wheel_zoom=True
)

# Create a GeoData layer from the GeoDataFrame
geo_data_layer = GeoData(
    geo_dataframe=gdf[['geometry']], 
    style={'color': 'blue', 'weight': 1, 'opacity': 0.7},
    name="Rivers"
)

# Add RGBA terrain as base layer (static)
terrain_overlay = ImageOverlay(
    url=rgba_to_base64_image(da_terrain), 
    bounds=rioxarray_to_leaflet_bounds(da_terrain), 
    opacity=0.8, 
    name="Terrain (RGBA)"
)
m.add_layer(terrain_overlay)

# Start with first dataset
# current_dataset_name = "fraction_coarse"
current_dataset_name = "path_length_norm"
da_scalar = ds[current_dataset_name].sel(fraction=0.2).load()
# Add scalar data overlay (will be updated by threshold and dataset selection)
initial_scalar = scalar_to_base64_image(da_scalar, cmap='plasma')
scalar_overlay = ImageOverlay(
    url=initial_scalar, 
    bounds=mock_bounds, 
    opacity=1.0, 
    name="Scalar Data"
)
m.add_layer(scalar_overlay)

# Add vector layer
m.add_layer(geo_data_layer)

# -------------------------------
# Controls: Dataset dropdown and threshold slider
# -------------------------------
def update_scalar_overlay():
    """Update the scalar overlay based on current dataset and threshold"""
    threshold = slider.value
    # da_masked = da_scalar.where(da_scalar >= threshold)
    if 'fraction' in da_scalar.dims:
        da_overlay = da_scalar.sel(fraction=threshold)
    else:
        da_overlay = da_scalar

    scalar_overlay.url = scalar_to_base64_image(
        da_overlay, 
        cmap='plasma',
        vmin=float(np.nanmin(da_scalar.values)),
        vmax=float(np.nanmax(da_scalar.values))
    )

def on_dataset_change(change):
    """Handle dataset selection change"""
    global da_scalar, current_dataset_name
    current_dataset_name = change['new']
    da_scalar = ds[current_dataset_name]
    
    # Update the overlay
    update_scalar_overlay()

def on_threshold_change(change):
    """Handle threshold slider change"""
    update_scalar_overlay()

# Dataset dropdown
dropdown = widgets.Dropdown(
    options=list(ds.data_vars),
    value=current_dataset_name,
    description='Dataset:',
    style={'description_width': 'initial'}
)
dropdown.observe(on_dataset_change, names='value')

slider = widgets.SelectionSlider(
    options=ds.fraction.values,
    value=0.1,  # Must be one of the values in options
    description='FCD Threshold',
    style={'description_width': 'initial'}
)
slider.observe(on_threshold_change, names='value')

# Combine controls in a VBox
controls = widgets.VBox([dropdown, slider])
widget_control = WidgetControl(widget=controls, position='topright')
m.add_control(widget_control)

# -------------------------------
# 9. Layer toggle control
# -------------------------------
m.add_control(LayersControl(position='topright'))

m

  img = Image.fromarray(rgba_arr, mode='RGBA')


DEBUG:vmax=415.1632665954467, vmin=1.4959182908308493


  img = Image.fromarray(rgba_uint8, mode='RGBA')


Map(center=[37.66335291403956, -120.69523554193438], controls=(ZoomControl(options=['position', 'zoom_in_text'â€¦