##### Creating a parameter map

This notebook is meant to show the user how to create a parameter map to see what the spatial distribution of learned parameters looks like

In [None]:
import logging
from pathlib import Path

import contextily as cx
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import yaml
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1 import make_axes_locatable

from ddr._version import __version__
from ddr.dataset import LargeScaleDataset
from ddr.dataset import StreamflowReader as streamflow
from ddr.dataset import utils as ds_utils
from ddr.nn import kan
from ddr.routing.torch_mc import dmc
from ddr.validation import Config, Metrics, plot_time_series, utils, validate_config

log = logging.getLogger(__name__)

In [None]:
config_path = "../"
with open("./example_config.yaml") as f:
    config = Config(**yaml.safe_load(f))

In [None]:
device = torch.device("cuda:0")
nn = kan(
    input_var_names=config.kan.input_var_names,
    learnable_parameters=config.kan.learnable_parameters,
    hidden_size=config.kan.hidden_size,
    num_hidden_layers=config.kan.num_hidden_layers,
    grid=config.kan.grid,
    k=config.kan.k,
    seed=config.seed,
    device=config.device,
)
routing_model = dmc(cfg=config, device=device)
flow = streamflow(config)
dataset = LargeScaleDataset(cfg=config)

In [None]:
model_states = Path("./ddr_v0.1.0a2_trained_model_weights.pt")

log.info(f"Loading spatial_nn from checkpoint: {model_states.stem}")
state = torch.load(model_states, map_location=device)
state_dict = state["model_state_dict"]
for key in state_dict.keys():
    state_dict[key] = state_dict[key].to(device)
nn.load_state_dict(state_dict)

In [None]:
nn = nn.eval()
with torch.no_grad():
    spatial_params = nn(inputs=dataset.hydrofabric.normalized_spatial_attributes.to(device))

In [None]:
gdf = gpd.read_file(config.data_sources.hydrofabric_gpkg, layer="divides").set_index("divide_id")
divide_ids = np.array([f"cat-{_id}" for _id in dataset.hf_ids])
gdf = gdf.reindex(divide_ids)
gdf["n"] = spatial_params["n"].cpu().numpy()
gdf["q_spatial"] = spatial_params["q_spatial"].cpu().numpy()
gdf = gdf.to_crs(epsg=4326)

In [None]:
def param_plot(
    gdf: gpd.GeoDataFrame,
    var: str,
    save_name: Path,
    cmap: str = "plasma",
    unit_label: str | None = None,
    title: str | None = None,
    vmin: float | None = None,
    vmax: float | None = None,
    ascending: bool = False,
    dpi: int = 100,
) -> tuple[Figure, Axes]:
    """
    Create a parameter plot for geospatial data with a basemap and colorbar.

    Parameters
    ----------
    gdf : gpd.GeoDataFrame
        GeoDataFrame containing the data to plot.
    var : str
        Column name to visualize.
    save_name : str
        Filename for saving the plot.
    cmap : str, default 'plasma'
        Colormap name for the plot.
    unit_label : str, optional
        Unit label for the colorbar.
    title : str, optional
        Title for the plot.
    vmin : float, optional
        Minimum value for color scaling. If None, uses data minimum.
    vmax : float, optional
        Maximum value for color scaling. If None, uses data maximum.
    ascending : bool, default False
        Whether to sort data in ascending order.
    dpi : int, default 100
        DPI for the figure display.

    Returns
    -------
    tuple of (matplotlib.figure.Figure, matplotlib.axes.Axes)
        Figure and axes objects for further customization.

    Raises
    ------
    KeyError
        If the specified variable column doesn't exist in the GeoDataFrame.
    ValueError
        If the GeoDataFrame is empty after dropping NaN values.

    """
    # Validate inputs
    if var not in gdf.columns:
        raise KeyError(f"Column '{var}' not found in GeoDataFrame")

    # Create figure
    fig, ax = plt.subplots(figsize=(7, 4), dpi=dpi)

    # Drop NaNs and validate data
    gdf_clean = gdf.dropna(subset=[var])
    if gdf_clean.empty:
        raise ValueError(f"No valid data found for variable '{var}' after dropping NaN values")

    # Sort data for visualization
    gdf_clean = gdf_clean.sort_values(by=var, ascending=ascending)
    data = gdf_clean[var].values

    # Set vmin and vmax if not provided
    if vmin is None:
        vmin = np.min(data)
    if vmax is None:
        vmax = np.nanmax(data)

    # Create the plot with direct vmin/vmax limits
    gdf_clean.plot(
        ax=ax,
        column=var,
        cmap=cmap,
        linewidth=0.3,
        vmin=vmin,
        vmax=vmax,
        zorder=1,
    )

    # Add basemap
    cx.add_basemap(
        ax,
        crs=gdf_clean.crs,
        source=cx.providers.CartoDB.Positron,
        alpha=0.6,
        zorder=0,
        attribution=False,
    )

    # Set bounds for CONUS
    ax.set_xlim(-125, -66)
    ax.set_ylim(24, 53)

    # Remove axis ticks
    ax.set_xticks([])
    ax.set_yticks([])

    # Set plot title
    if title is not None:
        ax.set_title(title, fontsize=14)

    # Add colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="3%", pad=0.1)
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_array([])
    sm.set_clim(vmin, vmax)
    cbar = fig.colorbar(sm, cax=cax)

    # Set colorbar label
    label_text = var
    if unit_label:
        label_text = f"{var} ({unit_label})"
    cbar.set_label(label_text)

    # Format tick values to show appropriate precision
    cbar.formatter.set_powerlimits((-2, 2))
    cbar.update_ticks()

    # Save figure
    plt.tight_layout()
    # Note: config.params.save_path reference removed - you'll need to pass the full path
    # or import your config module
    plt.savefig(save_name, dpi=600, bbox_inches="tight")

    return fig, ax

In [None]:
param_plot(
    gdf,
    "n",
    "n_train.png",
    vmax=0.2,
    cmap="plasma_r",
    title="Manning's Roughness (m⁻¹/³s)",
    ascending=True,
    dpi=200,
)

In [None]:
param_plot(
    gdf,
    "n",
    Path(config.params.save_path) / "n_train.png",
    vmax=0.4,
    cmap="plasma_r",
    title="Manning's Roughness (m⁻¹/³s)",
    ascending=True,
    dpi=200,
)