# GoogleHydrology Finetuning Tutorial Notebook

This notebook loads test results from a **base model** and a corresponding **fine-tuned model** run. It follows a specific workflow to compare their performance:

1. **Configuration:** Set all necessary local paths.
2. **Base Model**
    * **Choose Base Model:** Interactively choose a base model from available model runs in your run directory.
    * **Visualize Experiment:** Read the base model's `config.yml` to find and plot the train/test basins on a map.
    * **Analyze Base Model:** Load `test_results.p` from the basemodel run directory. Calculate metrics for the base model and create a map color-coded by skill metric.
3.  **Fine-Tuning**
    * **Choose Basin for Fine-Tuning:** From basemodel results, choose a test basin to finetune for. Generate the config file and run command for finetuning.
    * **Analyze Fine-Tuned Model:**  Interactively choose a fine-tuned model from available experiments. Load `test_results.p` from both the fine-tuned run directories. Calculate metrics.
4. **Compare Models (Skill):** Use an interactive plot to compare a metric vs. lead time for the base and fine-tuned models on the target basin.
5. **Compare Models (Hydrograph):** Use an interactive plot to compare the hydrographs (Observed vs. Base vs. Fine-Tuned) for the target basin.

## 0. Imports

In [2]:
# Standard Library Imports
import glob
import os
import pickle
import re
import sys
import yaml
from typing import Any, Dict, List, Optional, Tuple, Set

# Third-Party Library Imports
import geopandas as gpd
import ipywidgets as widgets
from IPython.display import display, Markdown
from ipywidgets import interactive, VBox
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

# Local tutorial Module Imports
# import tutorial.metrics
import metrics

# Set plot style
plt.style.use('seaborn-v0_8-notebook')

ModuleNotFoundError: No module named 'geopandas'

## 0. User-Defined Paths
1.  **Set Shapefile Path:** Set the variables `SHAPEFILE_PATH` to your local Caravan shapefile directory.
2.  **Set Model Run Directory:** Set the variable `MODEL_RUN_DIR` to your local model-run directory.

In [None]:
# Define the path to the shapefile containing basin geometries.
# This shapefile is used for plotting maps of basin locations and model skill.
SHAPEFILE_PATH = '/home/gsnearing/data/Caravan-nc/shapefiles/camels/camels_basin_shapes.shx'

# Path to a base directory containing one or more model run directories.
# The code below will allow you to select which model run in this directory
# you want to evalutate.
MODEL_RUN_DIR = '/home/gsnearing/tutorial/model-runs/'

# Path to the streamflow gauge data to use for model evaluation.
# This directory must be in the Caravan format.
# If None, will use the observation data contained in the model output files.
EVALUATION_DATA_DIR = '/home/gsnearing/data/camels-updated-2025'

## 1. Base Model

### Choose Base Model

**Select Base Model:** The interactive widget will ...


In [None]:
# --- Configuration: Paths & Model Names ---

# This cell handles finding available model run directories and allows the user
# to interactively select the base model run directory using a dropdown widget.

# The base directory to search for model runs is defined in the cell above (MODEL_RUN_DIR).

def find_model_run_dirs(base_dir: str) -> Dict[str, str]:
    """
    Finds directories containing a 'test/model_epoch*/test_results.p' file.
    These directories are considered valid model run directories.

    Args:
        base_dir: The base directory to start searching from.

    Returns:
        A dictionary mapping display names (relative paths from base_dir)
        to the absolute paths of the valid model run directories.
    """
    run_dirs = {}
    print(f"Searching for model run directories in: {base_dir}")
    # Use os.walk to traverse directories efficiently
    for root, dirs, files in os.walk(base_dir):
        # Check if a 'test' subdirectory exists within the current root
        if 'test' in dirs:
            test_dir = os.path.join(root, 'test')
            # Check for subdirectories named 'model_epoch*' within the 'test' directory
            epoch_dirs = glob.glob(os.path.join(test_dir, 'model_epoch*'))
            for epoch_dir in epoch_dirs:
                # Check if 'test_results.p' file exists within the 'model_epoch*' directory
                if os.path.isfile(os.path.join(epoch_dir, 'test_results.p')):
                    # Create a user-friendly display name for the run
                    # This is the path relative to the initial search base_dir
                    display_name = os.path.relpath(root, base_dir)
                    run_dirs[display_name] = root
                    # Since we found a test_results.p for this run, we don't need
                    # to search deeper into its subdirectories.
                    break # Found a result file in an epoch dir within this run
    return run_dirs

# Find available run directories using the path set in the cell above
available_run_dirs = find_model_run_dirs(MODEL_RUN_DIR)

# Check if any run directories were found
if not available_run_dirs:
    print("\nError: No model run directories containing test results found.")
    print(f"Please ensure MODEL_RUN_DIR ({MODEL_RUN_DIR}) is set correctly and contains valid runs.")
    # You might want to exit or disable further processing here if no runs are found
else:
    print(f"\nFound {len(available_run_dirs)} potential model run directories.")
    # print("Available runs:", list(available_run_dirs.keys())) # Uncomment for debugging if needed

    # --- Interactive Selection ---
    # Define the function that will be called when the dropdown value changes
    def select_base_model(base_model_selection):
        """
        Sets the global variables for the selected base model run directory and name.
        Performs basic validation.
        """
        global base_model_run_dir, base_model_name

        # Get the absolute path for the selected run directory
        base_model_run_dir = available_run_dirs.get(base_model_selection)
        # Create a display name for the selected base model
        base_model_name = f'Base Model ({base_model_selection})'

        # --- Validation Checks ---
        # Check if the selected base model directory is valid
        if not base_model_run_dir or not os.path.isdir(base_model_run_dir):
            print(f"\nError: Base model directory not found or invalid: {base_model_selection}")

        # Print the configuration that has been set
        print("\nConfiguration set:")
        print(f"  Base Model: {base_model_name} ({base_model_run_dir})")


    # Create a dropdown widget for selecting the base model
    run_dir_options = sorted(available_run_dirs.keys()) # Sort options alphabetically
    base_dropdown = widgets.Dropdown(
        options=run_dir_options, # Use the found run directories as options
        description='Select Base Model:', # Label for the dropdown
        disabled=False, # The widget should be enabled
        style = {'description_width': 'initial'} # Adjust description width for better display
    )

    # Create an interactive widget linking the dropdown to the selection function
    # The output of the function (print statements) will appear below the widget
    interactive_selection = interactive(
        select_base_model,
        base_model_selection=base_dropdown, # Link the dropdown widget to the function argument
    )

    # Display the interactive widget
    display(interactive_selection)

### Visualize Experiment

In [None]:
def read_basin_list(file_path: str) -> Set[str]:
    """Reads a basin list file and returns a set of normalized basin IDs."""
    if not os.path.isfile(file_path):
        print(f"Warning: Basin file not found: {file_path}")
        return set()
    with open(file_path, 'r') as f:
        # Read lines, strip whitespace, and normalize the ID
        # Assumes IDs might be like 'camels_012345' and we just want '012345'
        ids = [normalize_id(line.strip()) for line in f if line.strip()]
    return set(ids)

def normalize_id(basin_id: str) -> str:
    """Converts basin IDs to a standard string format for comparison."""
    # Example: 'camels_012345' -> '012345'
    # Example: 12345 -> '012345' (if it's a CAMELS ID, they are often 8-digit strings)
    # This function MUST match the format in your shapefile ID column
    str_id = str(basin_id).split('_')[-1]
    # Assuming CAMELS-style 8-digit zero-padding
    return str_id.zfill(8)

base_config_path = os.path.join(base_model_run_dir, 'config.yml')
train_basin_ids = set()
test_basin_ids = set()

try:
    with open(base_config_path, 'r') as f:
        base_config = yaml.safe_load(f)

    train_basin_file = base_config.get('train_basin_file')
    test_basin_file = base_config.get('test_basin_file')

    if train_basin_file:
        print(f"Found train basin file: {train_basin_file}")
        train_basin_ids = read_basin_list(train_basin_file)
        print(f"Loaded {len(train_basin_ids)} training basin IDs.")
    else:
        print("Warning: 'train_basin_file' not found in config.yml")

    if test_basin_file:
        print(f"Found test basin file: {test_basin_file}")
        test_basin_ids = read_basin_list(test_basin_file)
        print(f"Loaded {len(test_basin_ids)} test basin IDs.")
    else:
        print("Warning: 'test_basin_file' not found in config.yml")

except FileNotFoundError:
    print(f"Error: config.yml not found in {base_model_run_dir}")
except Exception as e:
    print(f"Error reading config.yml: {e}")

In [None]:
gdf_all_basins = gpd.read_file(SHAPEFILE_PATH)


### Simplify Basin Geometries (Optional)

If plotting the shapefile is causing performance issues or crashes, you can try simplifying the basin geometries. This reduces the number of vertices in each polygon, making them less complex to render. The `simplify` method from GeoPandas is used here with a `tolerance` parameter that controls the degree of simplification. A larger tolerance results in more aggressive simplification. You may need to adjust the `tolerance` value depending on your data and desired level of detail.

In [None]:
# Add a simplification step before plotting if needed
# Adjust the tolerance value to control the degree of simplification
# A larger tolerance means more simplification.
try:
    # Check if gdf_all_basins exists and is not None
    if 'gdf_all_basins' in locals() and gdf_all_basins is not None:
        # Simplify the geometries
        gdf_all_basins['geometry'] = gdf_all_basins['geometry'].simplify(tolerance=0.01, preserve_topology=True)
        print("Basin geometries simplified for plotting.")
    else:
        print("GeoDataFrame not loaded. Skipping simplification.")
except Exception as e:
    print(f"Error simplifying geometries: {e}")
    print("Proceeding with original geometries.")

# Now execute the plotting cell (new_cell_basin_map_code)
# You might need to manually run the plotting cell after this one.

In [None]:
# Plot the geometry of the first basin
gdf_all_basins.iloc[[0]].plot()

In [None]:
asdf

### Plot Train & Test Basins

In [None]:
# Ensure geopandas is imported
import geopandas as gpd

shapefile_basin_id_column='gauge_id'

def plot_colored_shapefile(
    gdf: gpd.GeoDataFrame,
    column: str,
    title: str,
    legend_title: str,
    cmap: Optional[str] = None,
    colors: Optional[Dict[Any, str]] = None,
    # add_basemap: bool = True, # Removed contextily dependency
    figsize: Tuple[int, int] = (12, 12),
    missing_kwds: Optional[Dict[str, Any]] = None
):
    """
    Plots a GeoDataFrame with polygons colored based on a specified column.

    Args:
        gdf: The GeoDataFrame to plot.
        column: The name of the column in the GeoDataFrame to use for coloring.
        title: The title of the plot.
        legend_title: The title for the legend.
        cmap: Colormap name (if using continuous data).
        colors: Dictionary mapping unique values in 'column' to specific colors.
                Overrides cmap if provided.
        # add_basemap: Whether to add a contextily basemap. # Removed contextily dependency
        figsize: Figure size.
        missing_kwds: Dictionary of keyword arguments for handling missing values
                      (e.g., {'color': 'lightgrey', 'label': 'No Data'}).
    """
    fig, ax = plt.subplots(figsize=figsize)

    if colors:
        # Plot using specific colors for categories
        for category, color in colors.items():
            subset = gdf_to_plot[gdf_to_plot[column] == category]
            if not subset.empty:
                 subset.plot(
                    ax=ax,
                    color=color,
                    label=category,
                    alpha=0.7,
                    edgecolor='black',
                    linewidth=0.5
                )
            return

    else:
        # Plot using colormap
        gdf_to_plot.plot(
            ax=ax,
            column=column,
            legend=True,
            legend_kwds={'title': legend_title},
            cmap=cmap,
            alpha=0.7,
            edgecolor='black',
            linewidth=0.5,
            missing_kwds=missing_kwds
        )

    ax.set_title(title)
    ax.set_axis_off()
    plt.show()


# --- Apply the function to plot train/test basins ---

# Read the shapefile
gdf_all_basins = gpd.read_file(SHAPEFILE_PATH)

# Ensure the comparison column exists
if shapefile_basin_id_column not in gdf_all_basins.columns:
    raise KeyError(f"ID column '{shapefile_basin_id_column}' not found in shapefile. Available columns: {gdf_all_basins.columns}")

# Normalize the ID column in the shapefile for comparison
gdf_all_basins['normalized_id'] = gdf_all_basins[shapefile_basin_id_column].apply(normalize_id)

is_train = gdf_all_basins['normalized_id'].isin(train_basin_ids)
is_test = gdf_all_basins['normalized_id'].isin(test_basin_ids)

# Assign categories
gdf_all_basins['dataset'] = 'Not Used'
gdf_all_basins.loc[is_train, 'dataset'] = 'Train Only'
gdf_all_basins.loc[is_test, 'dataset'] = 'Test Only'
gdf_all_basins.loc[is_train & is_test, 'dataset'] = 'Train & Test'

# Define colors for each category
dataset_colors = {
    'Train Only': 'blue',
    'Test Only': 'red',
    'Train & Test': 'purple',
    'Not Used': 'lightgrey'
}

# Plot using the reusable function
plot_colored_shapefile(
    gdf=gdf_all_basins,
    column='dataset',
    title=f"Train & Test Basin Sets from {base_model_name if 'base_model_name' in globals() else 'Base Model'}",
    legend_title='Basin Set',
    colors=dataset_colors,
)

In [None]:
asdfadsfasdf

## 5. Load Model Results

This section finds and loads the `test_results.p` file from both your base model and fine-tuned model directories.

In [None]:
def load_test_results(run_dir: str) -> Tuple[Optional[Dict[str, Any]], Optional[int]]:
    """
    Finds and loads the 'test_results.p' file from a specific model run directory.

    Args:
        run_dir: The specific model run directory (e.g., .../camels-2014-run).

    Returns:
        A tuple of (data, epoch_number). 'data' is the loaded pickle file (a dict).
    """
    search_path = os.path.join(run_dir, 'test', 'model_epoch*', 'test_results.p')
    result_files = glob.glob(search_path)

    if not result_files:
        print(f"Warning: No 'test_results.p' file found in {run_dir}/test/model_epoch*/")
        return None, None

    if len(result_files) > 1:
        print(f"Warning: Multiple 'test_results.p' files found. Using the first one: {result_files[0]}")

    file_path = result_files[0]

    # Try to extract epoch number
    epoch_number = 0
    match = re.search(r'model_epoch(\d+)', file_path)
    if match:
        epoch_number = int(match.group(1))

    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        print(f"Successfully loaded results from: {file_path}")
        return data, epoch_number
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None, None

# --- Execute Data Loading ---
print(f"Loading base model results for: {base_model_name}")
base_model_data, base_model_epoch = load_test_results(base_model_run_dir)

print(f"\nLoading fine-tuned model results for: {finetune_model_name}")
finetune_model_data, finetune_model_epoch = load_test_results(finetune_model_run_dir)

if base_model_data is None or finetune_model_data is None:
    print("\nError: Failed to load one or both model results. Please check paths.")
else:
    print("\nAll model results loaded successfully.")

## 6. Analysis & Plotting Functions

Below are the core helper functions we will use for our analysis. We define them all here in one place so they can be called later.

In [None]:
def calculate_metrics_by_lead_time(gauge_data: Dict[str, Any]) -> Optional[pd.DataFrame]:
    """
    Calculates metrics for a single gauge for each lead time.
    Assumes 'metrics.py' is imported and available.
    """
    try:
        obs_dataarray = gauge_data['1D']['xr']['streamflow_obs']
        sim_dataarray = gauge_data['1D']['xr']['streamflow_sim']
    except KeyError as e:
        print(f"Warning: Data structure unexpected. Missing key {e}. Skipping metric calculation.")
        return None
    except TypeError:
        print(f"Warning: gauge_data is not in the expected format. Skipping.")
        return None

    try:
        metrics_list = metrics.get_available_metrics()
    except Exception as e:
        print(f"Error: Could not get metrics list from metrics.py: {e}")
        print("Defaulting to KGE and NSE.")
        metrics_list = ['KGE', 'NSE']

    lead_times = obs_dataarray['time_step'].values
    metrics_results = {}

    for lt in lead_times:
        obs_at_lead_time = obs_dataarray.sel(time_step=lt)
        sim_at_lead_time = sim_dataarray.sel(time_step=lt)

        try:
            calculated_metrics = metrics.calculate_metrics(
                obs_at_lead_time,
                sim_at_lead_time,
                metrics=metrics_list,
                resolution="1D",
                datetime_coord="date"
            )
            metrics_results[lt] = calculated_metrics
        except Exception as e:
            # print(f"Could not calculate metrics for lead time {lt}: {e}")
            metrics_results[lt] = {m: np.nan for m in metrics_list}

    metrics_df = pd.DataFrame.from_dict(metrics_results, orient='index')
    metrics_df.index.name = 'lead_time'
    return metrics_df

def calculate_metrics_for_run(run_data: Dict[str, Any]) -> Dict[str, pd.DataFrame]:
    """
    Iterates through all gauges in a loaded run and calculates metrics for each.

    Args:
        run_data: The loaded test_results.p dictionary {gauge_id: data}.

    Returns:
        A nested dictionary: {gauge_id: pd.DataFrame_of_metrics}
    """
    all_metrics_results = {}
    if not run_data:
        print("No run data provided. Skipping metric calculation.")
        return all_metrics_results

    for gauge_id, gauge_data in run_data.items():
        metrics_df = calculate_metrics_by_lead_time(gauge_data)
        if metrics_df is not None:
            all_metrics_results[gauge_id] = metrics_df

    print(f"Metric calculation complete for {len(all_metrics_results)} gauges.")
    return all_metrics_results

## 7. Pre-calculate All Metrics

To make the plots fast and responsive, we will pre-calculate the metrics for *every* loaded model and basin.

In [None]:
if base_model_data and finetune_model_data:
    print(f"Calculating metrics for: {base_model_name}")
    base_model_metrics = calculate_metrics_for_run(base_model_data)

    print(f"\nCalculating metrics for: {finetune_model_name}")
    finetune_model_metrics = calculate_metrics_for_run(finetune_model_data)

    # --- Identify the fine-tuned basin ---
    if len(finetune_model_metrics) != 1:
        print(f"Warning: Fine-tuned model has {len(finetune_model_metrics)} basins. Expected 1.")

    finetune_basin_id = list(finetune_model_metrics.keys())[0]
    print(f"\nFine-tuning basin identified as: {finetune_basin_id}")

    # --- Create a combined DataFrame for all base model basins ---
    try:
        base_model_all_basins_df = pd.concat(
            base_model_metrics.values(),
            keys=base_model_metrics.keys(),
            names=['basin_id', 'lead_time']
        )
        print("Created combined DataFrame for all base model basins.")
    except Exception as e:
        print(f"Could not create combined DataFrame: {e}")
        base_model_all_basins_df = None

else:
    print("Model data not loaded. Skipping metric calculation.")
    base_model_metrics = {}
    finetune_model_metrics = {}
    base_model_all_basins_df = None
    finetune_basin_id = None

## 8. Base Model Skill Map

This map shows the performance of your base model across all test basins. Basins are color-coded by their KGE skill at lead time 0.

In [None]:
if not gpd_available:
    print("Geospatial libraries not found. Skipping skill map.")
elif base_model_all_basins_df is None:
    print("Base model metrics not calculated. Skipping skill map.")
else:
    print("Plotting base model skill map...")
    try:
        SKILL_METRIC = 'KGE'
        SKILL_LEAD_TIME = 0

        # Get skills for the specified lead time
        skills = base_model_all_basins_df.xs(SKILL_LEAD_TIME, level='lead_time')[SKILL_METRIC]
        skills.name = f'{SKILL_METRIC} (Lead Time {SKILL_LEAD_TIME})'
        skills_df = skills.reset_index()

        # Normalize basin IDs for merging
        skills_df['normalized_id'] = skills_df['basin_id'].apply(normalize_id)

        # Load shapefile again (or reuse if available)
        if 'gdf_all_basins' not in locals():
             gdf_all_basins = gpd.read_file(shapefile_path)
             gdf_all_basins['normalized_id'] = gdf_all_basins[shapefile_basin_id_column].apply(normalize_id)

        # Merge skills with geometry
        gdf_with_skills = gdf_all_basins.merge(skills_df, on='normalized_id')

        if gdf_with_skills.empty:
            print("No matching basins found between metrics and shapefile. Check ID formats.")
        else:
            # Create the plot
            fig, ax = plt.subplots(figsize=(12, 12))
            gdf_with_skills_web = gdf_with_skills.to_crs(epsg=3857) # Reproject

            gdf_with_skills_web.plot(
                ax=ax,
                column=skills.name,
                legend=True,
                legend_kwds={'label': f'{SKILL_METRIC} Score', 'orientation': 'horizontal'},
                cmap='viridis',
                alpha=0.8,
                edgecolor='black',
                linewidth=0.5,
                missing_kwds={'color': 'lightgrey', 'label': 'No Data'}
            )

            ctx.add_basemap(ax, source=ctx.providers.Stamen.Terrain, zoom=6)
            ax.set_title(f"{base_model_name} Performance: {skills.name}")
            ax.set_axis_off()
            plt.show()

    except FileNotFoundError:
        print(f"Error: Shapefile not found at: {shapefile_path}")
    except Exception as e:
        print(f"An unexpected error occurred during skill map plotting: {e}")

## 9. Base Model Skills (Table)

This table shows the basin IDs and their corresponding skill scores for the metric and lead time plotted above, sorted from best to worst.

In [None]:
if 'skills' in locals():
    print(f"Displaying skills for: {skills.name}")
    display(skills.reset_index().sort_values(by=skills.name, ascending=False).set_index('basin_id'))
else:
    print("Skills data not available to display.")

## 10. Fine-Tuning Comparison: Skill vs. Lead Time

This is the key comparison plot. Use the dropdown to select a metric.

-   **Lines:** Show the performance of the **Base Model** (blue) and the **Fine-Tuned Model** (orange) for the target basin.
-   **Box Plots (Background):** Show the distribution (median, quartiles) of that metric for *all other basins* in the base model's test set. This tells you if the base model was already performing well or poorly for this basin compared to others.

In [None]:
if not finetune_basin_id:
    display(Markdown("**Error: Could not identify fine-tuning basin. Cannot create comparison plot.**"))
elif finetune_basin_id not in base_model_metrics:
    display(Markdown(f"**Error: Fine-tuning basin '{finetune_basin_id}' not found in base model results. Cannot compare.**"))
else:
    # Get metrics for the specific fine-tuned basin
    ft_basin_base_metrics = base_model_metrics[finetune_basin_id]
    ft_basin_finetune_metrics = finetune_model_metrics[finetune_basin_id]

    # Get metrics for all OTHER basins from the base model
    other_basins_base_metrics = base_model_all_basins_df.drop(finetune_basin_id, level='basin_id')

    def plot_finetune_comparison(metric_name):
        plt.figure(figsize=(14, 7))
        ax = plt.gca()

        # 1. Plot background boxplots for all OTHER basins
        try:
            # Unstack to get lead_time x basin_id
            boxplot_data_all = other_basins_base_metrics.unstack(level='basin_id')[metric_name]

            # Create a list of arrays, one for each lead time, dropping NaNs
            boxplot_data = [col.dropna().values for _, col in boxplot_data_all.items()]
            lead_times = boxplot_data_all.index

            ax.boxplot(
                boxplot_data,
                positions=lead_times,
                patch_artist=True,
                showfliers=False,
                widths=0.6,
                boxprops=dict(facecolor='lightgray', alpha=0.7),
                whiskerprops=dict(color='gray'),
                capprops=dict(color='gray'),
                medianprops=dict(color='black')
            )
        except Exception as e:
            print(f"Could not plot boxplot background: {e}")

        # 2. Plot lines for the target basin
        base_line = ft_basin_base_metrics[metric_name]
        ft_line = ft_basin_finetune_metrics[metric_name]

        ax.plot(base_line.index, base_line.values, label=base_model_name, marker='o', lw=2, zorder=10)
        ax.plot(ft_line.index, ft_line.values, label=finetune_model_name, marker='s', lw=2, zorder=10)

        # Add a dummy artist for the boxplot legend
        ax.add_patch(plt.Rectangle((0,0), 1, 1, fc='lightgray', alpha=0.7, label='Base Model (All Other Basins)'))

        ax.set_title(f"{metric_name} vs. Lead Time for Basin {finetune_basin_id}")
        ax.set_xlabel("Lead Time (days)")
        ax.set_ylabel(f"{metric_name} Score")
        ax.legend()
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        ax.set_xticks(lead_times)
        plt.show()

    # --- Create Widget ---
    available_metrics = list(ft_basin_base_metrics.columns)
    metric_widget = widgets.Dropdown(
        options=available_metrics,
        value='KGE' if 'KGE' in available_metrics else available_metrics[0],
        description='Metric:',
        disabled=False,
    )

    display(interactive(plot_finetune_comparison, metric_name=metric_widget))

## 11. Fine-Tuning Comparison: Hydrographs

This plot lets you visually inspect the hydrographs for the fine-tuned basin. Use the sliders and text boxes to select the lead time and date range you want to see.

In [None]:
if not finetune_basin_id:
    display(Markdown("**Error: Could not identify fine-tuning basin. Cannot create hydrograph plot.**"))
elif (finetune_basin_id not in base_model_data) or (finetune_basin_id not in finetune_model_data):
    display(Markdown(f"**Error: Data for basin '{finetune_basin_id}' not found in both models. Cannot compare.**"))
else:
    # Get max lead time
    try:
        max_lt = int(base_model_data[finetune_basin_id]['1D']['xr']['time_step'].max())
    except:
        max_lt = 10 # default

    def plot_comparison_hydrograph(lead_time, start_date_str, end_date_str):
        try:
            # Slice by date
            data_slice = slice(start_date_str, end_date_str)
        except Exception as e:
            print(f"Invalid date format. Using full range. Error: {e}")
            data_slice = slice(None, None)

        try:
            # Get data arrays for the fine-tuned basin from both models
            obs_da = base_model_data[finetune_basin_id]['1D']['xr']['streamflow_obs']
            base_sim_da = base_model_data[finetune_basin_id]['1D']['xr']['streamflow_sim']
            ft_sim_da = finetune_model_data[finetune_basin_id]['1D']['xr']['streamflow_sim']

            # Select lead time and date range, drop NaNs from obs
            obs_lt = obs_da.sel(time_step=lead_time, date=data_slice).dropna(dim='date')
            if obs_lt.size == 0:
                print(f"No observed data for basin {finetune_basin_id} in this date range.")
                return

            # Align sim data to the valid observation dates
            base_sim_lt = base_sim_da.sel(time_step=lead_time).reindex(date=obs_lt['date'])
            ft_sim_lt = ft_sim_da.sel(time_step=lead_time).reindex(date=obs_lt['date'])

            # --- Plotting ---
            plt.figure(figsize=(15, 7))
            plt.plot(obs_lt['date'], obs_lt.values, label='Observed', color='black', lw=2.5)
            plt.plot(base_sim_lt['date'], base_sim_lt.values, label=base_model_name, linestyle='--', alpha=0.8)
            plt.plot(ft_sim_lt['date'], ft_sim_lt.values, label=finetune_model_name, linestyle='-', alpha=0.8)

            plt.title(f'Hydrograph for {finetune_basin_id} (Lead Time {lead_time} days)')
            plt.xlabel('Date')
            plt.ylabel('Streamflow')
            plt.legend()
            plt.grid(True, linestyle='--', alpha=0.6)
            plt.show()

        except KeyError as e:
            print(f"Error: Data structure missing. Could not find key: {e}")
        except Exception as e:
            print(f"An error occurred during plotting: {e}")

    # --- Create Widgets ---
    lead_time_widget = widgets.IntSlider(
        value=0, min=0, max=max_lt, step=1, description='Lead Time (days):'
    )
    start_date_widget = widgets.Text(value='2022-01-01', description='Start Date:')
    end_date_widget = widgets.Text(value='2022-03-01', description='End Date:')

    ui = VBox([lead_time_widget, start_date_widget, end_date_widget])
    out = interactive(
        plot_comparison_hydrograph,
        lead_time=lead_time_widget,
        start_date_str=start_date_widget,
        end_date_str=end_date_widget
    )
    display(ui, out.children[-1])