In [1]:

#!pip install seaborn
#!pip install plotly

#!pip install nbformat

#!pip install jupyter
#!pip install anywidget

In [2]:
#import arcpy
import pandas as pd
import re
from io import StringIO

from pathlib import Path
import sys
# Ensure the project root (parent of this folder) is importable
sys.path.insert(0, str(Path().resolve().parent))
from python_pipeline_scripts import utils, runner
config = utils.load_config(Path().resolve().parent / 'config' / 'config.yaml')



%load_ext autoreload
%autoreload 1

## One-Stop Function to parse .rch

In [3]:
import pandas as pd
from datetime import datetime, date, timedelta

DEFAULT_RCH_COLUMNS = [
    "object type","RCH","GIS","MON","AREAkm2","FLOW_INcms","FLOW_OUTcmscms","EVAPcms","TLOSScms",
    "SED_INtons","SED_OUTtons","SEDCONCmg/L","ORGN_INkg","ORGN_OUTkg","ORGP_INkg","ORGP_OUTkg",
    "NO3_INkg","NO3_OUTkg","NH4_INkg","NH4_OUTkg","NO2_INkg","NO2_OUTkg","MINP_INkg","MINP_OUTkg",
    "CHLA_INkg","CHLA_OUTkg","CBOD_INkg","CBOD_OUTkg","DISOX_INkg","DISOX_OUTkg","SOLPST_INmg",
    "SOLPST_OUTmg","SORPST_INmg","SORPST_OUTmg","REACTPSTmg","VOLPSTmg","SETTLPSTmg","RESUSP_PSTmg",
    "DIFFUSEPSTmg","REACBEDPSTmg","BURYPSTmg","BED_PSTmg","BACTP_OUTct","BACTLP_OUTct","CMETAL#1kg",
    "CMETAL#2kg","CMETAL#3kg","TOT_Nkg","TOT_Pkg","NO3ConcMg/l","WTMPdegc","Salt1","Salt2","Salt3",
    "Salt4","Salt5","Salt6","Salt7","Salt8","Salt9","Salt10","SAR","EC"
]

DEFAULT_DROP_COLS = ["object type","total_days","GIS","MON","AREAkm2","YEAR"]

def load_output_rch(
    file_path: str,
    cio_file: str,
    *,
    columns: list[str] = None,
    skiprows: int = 9,
    group_size: int = 17,        # rows per day in output.rch (typical)
    add_area_ha: bool = True,
    hectare_per_km2: float = 100.0,
    drop_cols: list[str] = None,
    reorder_date_cols: bool = True
) -> pd.DataFrame:
    """
    Load SWAT output.rch, attach datetime based on file.cio, compute area_ha, and tidy columns.

    Parameters
    ----------
    file_path : str
        Path to output.rch
    cio_file : str
        Path to file.cio (used to derive simulation start date)
    columns : list[str], optional
        Column names for output.rch; defaults to DEFAULT_RCH_COLUMNS
    skiprows : int, optional
        Lines to skip before header/data (default 9 for SWAT outputs)
    group_size : int, optional
        Number of rows per simulated day in output.rch (default 17)
    add_area_ha : bool, optional
        If True, adds area_ha = AREAkm2 * hectare_per_km2
    hectare_per_km2 : float, optional
        Conversion factor (default 100 ha per km²)
    drop_cols : list[str], optional
        Columns to drop at the end (default DEFAULT_DROP_COLS)
    reorder_date_cols : bool, optional
        If True, moves 'date' to col 3 and 'YEAR' to col 4 (0-based)

    Returns
    -------
    pd.DataFrame
        Tidy DataFrame with date, optional area_ha, and dropped columns.
    """

    # ---- helpers to read cio ----
    def _getModelParameter(param: str, parameterfile: str) -> str | None:
        with open(parameterfile, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                if param in line:
                    # value expected before first '|'
                    return line.partition("|")[0].strip()
        return None

    def _getStartDate(swatiofile: str) -> date:
        skip_year = int(_getModelParameter("NYSKIP", swatiofile))
        sim_year = int(_getModelParameter("NBYR", swatiofile))
        start_year = int(_getModelParameter("IYR", swatiofile))
        start_day = int(_getModelParameter("IDAF", swatiofile))
        # first day of (start_year + skip_year) plus (start_day - 1)
        return date(start_year + skip_year, 1, 1) + timedelta(days=start_day - 1)

    # ---- defaults ----
    if columns is None:
        columns = DEFAULT_RCH_COLUMNS
    if drop_cols is None:
        drop_cols = DEFAULT_DROP_COLS

    # ---- read output.rch ----
    df = pd.read_csv(file_path, sep=r"\s+", skiprows=skiprows, header=None, names=columns, engine="python")

    # ---- create total_days from row index / group_size ----
    df.index.name = "total_days"
    df.reset_index(drop=False, inplace=True)
    if group_size <= 0:
        raise ValueError("group_size must be > 0")
    df["total_days"] = df["total_days"] // group_size

    # ---- compute date & YEAR from cio ----
    start_date = _getStartDate(cio_file)
    df["date"] = df["total_days"].apply(lambda d: start_date + timedelta(days=int(d)))
    df["date"] = pd.to_datetime(df["date"])
    df["YEAR"] = df["date"].dt.year

    # ---- reorder date/YEAR columns if desired ----
    if reorder_date_cols:
        # insert 'date' at position 3 and 'YEAR' at position 4 (0-based)
        df.insert(3, "date", df.pop("date"))
        df.insert(4, "YEAR", df.pop("YEAR"))

    # ---- add area_ha if requested ----
    if add_area_ha:
        if "AREAkm2" not in df.columns:
            raise KeyError("AREAkm2 column not found; cannot compute area_ha.")
        df["area_ha"] = df["AREAkm2"] * hectare_per_km2
        # place area_ha after AREAkm2 (which is index 6 after inserts; but robustly reinsert)
        # insert at 7 like your original
        df.insert(7, "area_ha", df.pop("area_ha"))

    # ---- final tidy: drop columns ----
    to_drop = [c for c in drop_cols if c in df.columns]
    df = df.drop(columns=to_drop)

    return df


In [4]:
import os
import pandas as pd
import numpy as np


def load_multiple_rch_from_folders(base_folders: list[str], **kwargs) -> dict[str, pd.DataFrame]:
    """
    Loads output.rch and file.cio from multiple SWAT TxtInOut base folders.

    Parameters
    ----------
    base_folders : list[str]
        List of paths to TxtInOut folders.
    **kwargs :
        Additional keyword arguments to pass to load_output_rch().

    Returns
    -------
    dict[str, pd.DataFrame]
        Dictionary mapping parent folder names to loaded DataFrames.
    """
    results = {}

    for folder in base_folders:
        # Normalize path
        folder = os.path.abspath(folder)
        parent_folder = os.path.dirname(folder)
        if "txtinout" in os.path.basename(folder).lower():
            realization_name = f"rch_{os.path.basename(parent_folder)}"
        else:
            realization_name = f"rch_{os.path.basename(folder)}"
        print(f"Loading from folder: {folder} as {realization_name}")

        # Expected files
        output_rch_path = os.path.join(folder, "output.rch")
        cio_path = os.path.join(folder, "file.cio")

        # Safety check
        if not os.path.isfile(output_rch_path):
            raise FileNotFoundError(f"output.rch not found in {folder}")
        if not os.path.isfile(cio_path):
            raise FileNotFoundError(f"file.cio not found in {folder}")

        # Load using the previous function
        df = load_output_rch(file_path=output_rch_path, cio_file=cio_path, **kwargs)

        results[realization_name] = df

    return results


def compare_dfs(df1: pd.DataFrame, df2: pd.DataFrame):
    # Only compare numeric columns
    df1_numeric = df1.select_dtypes(include=[np.number])
    df2_numeric = df2.select_dtypes(include=[np.number])

    # Handle column name differences (e.g., FLOW_OUTcms vs FLOW_OUTcmscms)
    col_map = {
        "FLOW_OUTcms": "FLOW_OUTcmscms",
        "FLOW_OUTcmscms": "FLOW_OUTcms"
    }
    df1_cols = set(df1_numeric.columns)
    df2_cols = set(df2_numeric.columns)

    # Prepare lists of columns to compare
    common_cols = [c for c in df1_numeric.columns if c in df2_numeric.columns and c not in col_map]
    mapped_comparisons = []
    for c1, c2 in col_map.items():
        if c1 in df1_cols and c2 in df2_cols:
            mapped_comparisons.append((c1, c2))
        elif c2 in df1_cols and c1 in df2_cols:
            mapped_comparisons.append((c2, c1))

    # Align DataFrames by columns
    df1_common = df1_numeric[common_cols]
    df2_common = df2_numeric[common_cols]

    # Compare common columns directly
    if df1_common.shape != df2_common.shape:
        raise ValueError("DataFrames must have the same shape for direct comparison.")

    diff_mask = df1_common != df2_common
    any_diff = diff_mask.any().any()

    # Metrics for common columns
    abs_diff = (df1_common - df2_common).abs()
    mean_abs_diff = abs_diff.mean()
    max_abs_diff = abs_diff.max()
    mean_rel_diff = ((abs_diff / (df1_common.replace(0, np.nan))).mean()).replace([np.inf, -np.inf], np.nan)

    # Compare mapped columns
    mapped_diff_columns = []
    mapped_diff_counts = {}
    mapped_mean_abs_diff = {}
    mapped_max_abs_diff = {}
    mapped_mean_rel_diff = {}
    for c1, c2 in mapped_comparisons:
        col_diff_mask = df1_numeric[c1] != df2_numeric[c2]
        if col_diff_mask.any():
            mapped_diff_columns.append(f"{c1} vs {c2}")
            mapped_diff_counts[f"{c1} vs {c2}"] = int(col_diff_mask.sum())
            abs_diff_col = (df1_numeric[c1] - df2_numeric[c2]).abs()
            mapped_mean_abs_diff[f"{c1} vs {c2}"] = abs_diff_col.mean()
            mapped_max_abs_diff[f"{c1} vs {c2}"] = abs_diff_col.max()
            mapped_mean_rel_diff[f"{c1} vs {c2}"] = (abs_diff_col / df1_numeric[c1].replace(0, np.nan)).mean()
            any_diff = True

    # Columns with at least one difference
    diff_columns = diff_mask.any(axis=0)
    diff_counts = diff_mask.sum(axis=0)[diff_columns]

    # Collect metrics for columns with differences
    metrics = {}
    for col in diff_columns.index[diff_columns]:
        metrics[col] = {
            "count": int(diff_counts[col]),
            "mean_abs_diff": float(mean_abs_diff[col]),
            "max_abs_diff": float(max_abs_diff[col]),
            "mean_rel_diff": float(mean_rel_diff[col]) if col in mean_rel_diff else None
        }
    for col in mapped_diff_columns:
        metrics[col] = {
            "count": mapped_diff_counts[col],
            "mean_abs_diff": mapped_mean_abs_diff[col],
            "max_abs_diff": mapped_max_abs_diff[col],
            "mean_rel_diff": mapped_mean_rel_diff[col]
        }

    # Print easy to read stats
    print("Comparison Summary")
    print("==================")
    print(f"Any difference: {any_diff}")
    print(f"Number of numeric columns with differences: {int(diff_columns.sum()) + len(mapped_diff_columns)}")
    print("Numeric columns with differences:")
    for col in diff_columns.index[diff_columns]:
        m = metrics[col]
        print(f"  {col}: count={m['count']}, mean_abs_diff={m['mean_abs_diff']:.4g}, max_abs_diff={m['max_abs_diff']:.4g}, mean_rel_diff={m['mean_rel_diff']:.4g}")
    for col in mapped_diff_columns:
        m = metrics[col]
        print(f"  {col}: count={m['count']}, mean_abs_diff={m['mean_abs_diff']:.4g}, max_abs_diff={m['max_abs_diff']:.4g}, mean_rel_diff={m['mean_rel_diff']:.4g}")

    result = {
        "any_difference": any_diff,
        "n_diff_columns": int(diff_columns.sum()) + len(mapped_diff_columns),
        "diff_columns": diff_columns.index[diff_columns].tolist() + mapped_diff_columns,
        "diff_counts": {**diff_counts.to_dict(), **mapped_diff_counts},
        "metrics": metrics
    }
    return result


In [5]:
import os

def find_run_folders(run_number, path=r"C:\SWAT\RSWAT\cubillas\mc_results"):
    """
    Find folders in the given path that match the pattern:
    runXXXXXX_realYYYYYY_*
    
    Parameters
    ----------
    run_number : int
        The run number to match (will be zero-padded to 6 digits).
    path : str, optional
        The folder path to search in. Defaults to 
        'C:\\SWAT\\RSWAT\\cubillas\\mc_results'.
    
    Returns
    -------
    list of str
        A list of matching folder names.
    """
    # Format the number as 6 digits with leading zeros
    run_str = f"run{run_number:06d}_"
    
    try:
        # List all entries in the path
        all_entries = os.listdir(path)
    except FileNotFoundError:
        print(f"Error: Path '{path}' does not exist.")
        return []
    
    # Keep only directories that start with the correct run string
    matching_folders = [
        os.path.join(path, folder) for folder in all_entries
        if os.path.isdir(os.path.join(path, folder)) and folder.startswith(run_str)
    ]
    
    return matching_folders


In [6]:
from __future__ import annotations
from pathlib import Path
from typing import Iterable, Dict, Any
import pickle

def _coerce_to_dict(obj: Any, run: int) -> Dict[str, Any]:
    """Make sure we hand back a dict no matter what was stored."""
    if isinstance(obj, dict):
        return obj
    if isinstance(obj, list):
        # common patterns: list of (name, df) or list of dfs
        try:
            if all(isinstance(x, tuple) and len(x) == 2 for x in obj):
                return dict(obj)  # [(key, df), ...]
        except Exception:
            pass
        # fallback: enumerate
        return {f"run{run}_sim{i}": x for i, x in enumerate(obj)}
    # last resort: wrap single object
    return {f"run{run}": obj}

def load_or_build_dfs_for_runs(
    runs: int | Iterable[int],
    *,
    pickle_name_fmt: str = "all_dfs_mc_run_{run}.pkl",
    force_rebuild: bool = False,
    save_pickle: bool = True,
) -> Dict[str, Any]:
    """
    For each run:
      - Look in first folder from find_run_folders(run) for a pickle.
      - Load it unless force_rebuild; otherwise build with load_multiple_rch_from_folders.
      - Coerce result to a dict (compat with old pickles returning lists), then merge.
    Returns ONE merged dict across all runs.
    """
    run_list = [runs] if isinstance(runs, int) else list(runs)
    merged: Dict[str, Any] = {}

    for run in run_list:
        folders = find_run_folders(run)
        if not folders:
            continue
        first_folder = Path(folders[0])
        pkl_path = first_folder / pickle_name_fmt.format(run=run)

        run_obj = None
        if not force_rebuild and pkl_path.exists():
            try:
                with open(pkl_path, "rb") as f:
                    run_obj = pickle.load(f)
            except Exception:
                run_obj = None

        if run_obj is None:
            run_obj = load_multiple_rch_from_folders(folders)
            if save_pickle:
                try:
                    with open(pkl_path, "wb") as f:
                        pickle.dump(run_obj, f)
                except Exception:
                    pass

        run_dict = _coerce_to_dict(run_obj, run)
        # if keys collide across runs, later runs overwrite earlier ones
        merged.update(run_dict)

    return merged


In [7]:
available_vars = ["FLOW_INcms","FLOW_OUTcms","EVAPcms","TLOSScms", 
    "SED_INtons","SED_OUTtons","SEDCONCmg/L","ORGN_INkg","ORGN_OUTkg","ORGP_INkg","ORGP_OUTkg",
    "NO3_INkg","NO3_OUTkg","NH4_INkg","NH4_OUTkg","NO2_INkg","NO2_OUTkg","MINP_INkg","MINP_OUTkg",
    "CHLA_INkg","CHLA_OUTkg","CBOD_INkg","CBOD_OUTkg","DISOX_INkg","DISOX_OUTkg","SOLPST_INmg",
    "SOLPST_OUTmg","SORPST_INmg","SORPST_OUTmg","REACTPSTmg","VOLPSTmg","SETTLPSTmg","RESUSP_PSTmg",
    "DIFFUSEPSTmg","REACBEDPSTmg","BURYPSTmg",
    #"BED_PSTmg",
    "BACTP_OUTct","BACTLP_OUTct","CMETAL#1kg",
    "CMETAL#2kg","CMETAL#3kg","TOT_Nkg","TOT_Pkg","NO3ConcMg/l","WTMPdegc",
    #"Salt1","Salt2","Salt3","Salt4","Salt5","Salt6","Salt7","Salt8","Salt9","Salt10",
    #"SAR","EC"
]

how_map_defaults_all = {
    # Flow rates
    "FLOW_INcms": "mean",
    "FLOW_OUTcms": "mean",
    "EVAPcms": "mean",
    "TLOSScms": "mean",

    # Sediment fluxes & concentrations
    "SED_INtons": "sum",
    "SED_OUTtons": "sum",
    "SEDCONCmg/L": "flow_weighted_mean",

    # Nutrient fluxes
    "ORGN_INkg": "sum",
    "ORGN_OUTkg": "sum",
    "ORGP_INkg": "sum",
    "ORGP_OUTkg": "sum",
    "NO3_INkg": "sum",
    "NO3_OUTkg": "sum",
    "NH4_INkg": "sum",
    "NH4_OUTkg": "sum",
    "NO2_INkg": "sum",
    "NO2_OUTkg": "sum",
    "MINP_INkg": "sum",
    "MINP_OUTkg": "sum",

        # Totals
    "TOT_Nkg": "sum",
    "TOT_Pkg": "sum",

    # Algae & BOD / DO fluxes
    "CHLA_INkg": "sum",
    "CHLA_OUTkg": "sum",
    "CBOD_INkg": "sum",
    "CBOD_OUTkg": "sum",
    "DISOX_INkg": "sum",
    "DISOX_OUTkg": "sum",

    # Pesticide-related
    "SOLPST_INmg": "sum",
    "SOLPST_OUTmg": "sum",
    "SORPST_INmg": "sum",
    "SORPST_OUTmg": "sum",
    "REACTPSTmg": "sum",
    "VOLPSTmg": "sum",
    "SETTLPSTmg": "sum",
    "RESUSP_PSTmg": "sum",
    "DIFFUSEPSTmg": "sum",
    "REACBEDPSTmg": "sum",
    "BURYPSTmg": "sum",
    "BED_PSTmg": "mean",  # Bed storage → inventory

    # Bacteria counts (flux over time)
    "BACTP_OUTct": "sum",
    "BACTLP_OUTct": "sum",

    # Metals
    "CMETAL#1kg": "sum",
    "CMETAL#2kg": "sum",
    "CMETAL#3kg": "sum",



    # Concentrations & physical/chemical properties
    "NO3ConcMg/l": "flow_weighted_mean",
    "WTMPdegc": "flow_weighted_mean",
    "Salt1": "flow_weighted_mean",
    "Salt2": "flow_weighted_mean",
    "Salt3": "flow_weighted_mean",
    "Salt4": "flow_weighted_mean",
    "Salt5": "flow_weighted_mean",
    "Salt6": "flow_weighted_mean",
    "Salt7": "flow_weighted_mean",
    "Salt8": "flow_weighted_mean",
    "Salt9": "flow_weighted_mean",
    "Salt10": "flow_weighted_mean",
    "SAR": "flow_weighted_mean",
    "EC": "flow_weighted_mean"
}


In [8]:
swat_to_measured = {

    # -------- Sediment fluxes & concentrations --------
    "SEDCONCmg/L": {
        "100": ("SOLIDOS EN SUSPENSION", "# both are suspended solids concentrations (mg/L)"),
        "90": ("FOSFORO TOTAL", "# much TP is sediment-associated; strong event co-variation"),
        "70": ("DEMANDA QUIMICA DE OXIGENO", "# COD can increase with particulate/organic matter; indirect"),
    },

    # ------------------- Nutrient fluxes -------------------

    "ORGN_OUTkg": {
        "100": ("", "# organic N flux not directly measured"),
        "90": ("NITROGENO KJELDAHL", "# TKN captures most of organic N (plus NH4)"),
        "70": ("NITROGENO TOTAL", "# broader pool; weaker proxy for organic component"),
    },
    "ORGP_OUTkg": {
        "100": ("", "# organic P flux not directly measured"),
        "90": ("FOSFORO TOTAL", "# TP best available proxy for total non-dissolved P"),
        "70": ("FOSFATOS", "# inorganic reactive P; complementary rather than equivalent"),
    },
    "NO3_OUTkg": {
        "100": ("NITRATOS", "# same species; outflow flux vs measured concentration"),
        "90": ("NITROGENO TOTAL", "# TN includes NO3"),
        "70": ("NITRITOS", "# weaker association via nitrification/denitrification dynamics"),
    },

    "NH4_OUTkg": {
        "100": ("AMONIO", "# same dissolved species (NH4+)"),
        "90": ("NITROGENO KJELDAHL", "# TKN includes NH4"),
        "70": ("NITROGENO TOTAL", "# TN includes NH4; weaker proxy"),
    },

    "NO2_OUTkg": {
        "100": ("NITRITOS", "# same dissolved species (NO2-)"),
        "90": ("NITRATOS", "# related via N redox cycling"),
        "70": ("NITROGENO TOTAL", "# broad; weak proxy"),
    },
    "MINP_OUTkg": {
        "100": ("FOSFATOS", "# orthophosphate ≈ dissolved inorganic P"),
        "90": ("FOSFORO TOTAL", "# TP co-varies; not the same species"),
        "70": ("SOLIDOS EN SUSPENSION", "# indirect via sorption/desorption with sediment"),
    },

    # --------- Algae & BOD / DO fluxes ----------
    
    "CHLA_OUTkg": {
        "100": ("", "# chlorophyll-a not in measurement list"),
        "90": ("CARBONO ORGANICO TOTAL", "# biomass–carbon linkage; indirect"),
        "70": ("DEMANDA BIOQUIMICA DE OXIGENO 5 DIAS", "# decay/oxygen demand linkage; weaker"),
    },

    "CBOD_OUTkg": {
        "100": ("DEMANDA BIOQUIMICA DE OXIGENO 5 DIAS", "# CBOD ≈ BOD5"),
        "90": ("DEMANDA QUIMICA DE OXIGENO", "# COD as broader oxygen demand metric"),
        "70": ("CARBONO ORGANICO TOTAL", "# TOC–oxygen demand linkage; weaker"),
    },
   
    "DISOX_OUTkg": {
        "100": ("OXIGENO DISUELTO \"IN SITU\"", "# same analyte (DO)"),
        "90": ("SATURACION DE OXIGENO DISUELTO \"IN SITU\"", "# closely related"),
        "70": ("TEMPERATURA \"IN SITU\"", "# inverse solubility link; weak"),
    },



    # ---------------------------- Totals ----------------------------
    "TOT_Nkg": {
        "100": ("NITROGENO KJELDAHL", "# organic N + NH4 subset of TN"),
        "70": ("NITROGENO TOTAL", "# same aggregate analyte (all N forms), but almost no data"),
        "90": ("NITRATOS", "# dominant dissolved N species in many rivers"),
    },
    "TOT_Pkg": {
        "100": ("FOSFORO TOTAL", "# same aggregate analyte (all P forms)"),
        "90": ("FOSFATOS", "# DIP is a major fraction in some conditions; not total"),
        "70": ("SOLIDOS EN SUSPENSION", "# much TP is particulate-bound; indirect"),
    },

    # ---- Concentrations & physical/chemical properties ----
    "NO3ConcMg/l": {
        "100": ("NITRATOS", "# same dissolved species concentration"),
        "90": ("NITROGENO TOTAL", "# contains NO3; partial tracking"),
        "70": ("NITRITOS", "# related nitrogen species; weaker"),
    },
    "WTMPdegc": {
        "100": ("TEMPERATURA \"IN SITU\"", "# same variable (°C)"),
        "90": ("SATURACION DE OXIGENO DISUELTO \"IN SITU\"", "# temperature drives DO saturation; inverse relation"),
        "70": ("OXIGENO DISUELTO \"IN SITU\"", "# DO concentration linked to temperature; indirect"),
    },

    "EC": {
        "100": ("CONDUCTIVIDAD ELECTRICA A 20ºC \"IN SITU\"", "# same measurement (EC)"),
        "90": ("NITRATOS", "# ions contribute to EC; partial proxy"),
        "70": ("FOSFATOS", "# lesser ionic contributor; weak"),
    },
}


In [9]:
key_vars_with_comments = [
    "TOT_Nkg",         # Total nitrogen load — sum of all N forms, will increase.
    "TOT_Pkg",         # Total phosphorus load — sum of all P forms, will increase.
    "MINP_OUTkg",      # Mineral P load — soluble phosphate, highly bioavailable.
    "ORGP_OUTkg",      # Organic phosphorus load — particulate P from wastewater.
    "ORGN_OUTkg",      # Organic nitrogen load — particulate N from wastewater.
    "NO3_OUTkg",       # Nitrate load — may increase with lag; seasonal patterns important.
    "NH4_OUTkg",       # Ammonium load — direct from effluent (big spike expected).
    "NO3ConcMg/l",     # Nitrate concentration — can rise downstream after nitrification; flow-weighted mean.
    "SED_OUTtons",     # Suspended solids — large TSS load increases sediment export.
    "SEDCONCmg/L",     # Sediment concentration — turbidity impact; use flow-weighted mean.
    "CBOD_OUTkg",      # Biochemical oxygen demand — strong increase, drives DO depletion.
    "DISOX_OUTkg",     # Dissolved oxygen mass exported — expect decrease due to BOD and nitrification.
    "CHLA_OUTkg"       # Chlorophyll-a load — algal biomass; can increase if nutrients + light allow.

]

#vars_to_compare = available_vars
vars_to_compare = key_vars_with_comments
how_map_defaults = how_map_defaults_all

In [10]:
# LOAD MEASURED DATA (OLD)
df_cubillas_chem_measur_clean = pd.read_csv(r"C:\Users\Usuario\OneDrive - UNIVERSIDAD DE HUELVA\Granada\TrabajoFM\Genil_ArcGIS_Pasca\MDA_test_data\df_cubillas_chem_measur_clean.csv")

In [11]:

""" 
def clean_and_mg_L_to_kg_per_day(
    df_samples: pd.DataFrame,
    df_flow: pd.DataFrame,
    sample_date_col: str = "F_MUESTREO",
    sample_value_col: str = "RESULTADO",
    flow_date_col: str = "date",
    flow_value_col: str = "water_flow_m3_d_cubillas",
    kg_col: str = "kg_per_day",
) -> pd.DataFrame:
    
    # Add (or overwrite) a kg/day load column to df_samples.

    # Assumptions:
    #   - df_samples[sample_value_col] is in mg/L.
    #   - df_flow[flow_value_col] is the flow in m^3/day.
    #   - We join flow to samples by (day) date.

    # kg/day is computed as: (mg/L) * (m^3/day) * (1000 L/m^3) * (1 kg / 1e6 mg) = value * flow * 0.001
    
    # Work on copies to avoid changing caller's data
    s = df_samples.copy()
    f = df_flow.copy()

    # --- Force types ---
    # Dates → datetime (floored to day)
    s[sample_date_col] = pd.to_datetime(s[sample_date_col], errors="coerce").dt.floor("D")
    f[flow_date_col]   = pd.to_datetime(f[flow_date_col],   errors="coerce").dt.floor("D")

    # Numeric → float (0 if coercion fails)
    s[sample_value_col] = pd.to_numeric(s[sample_value_col], errors="coerce").astype(float).fillna(0)
    # setting negative sample_value_col to zero
    s.loc[s[sample_value_col] < 0, sample_value_col] = 0
    
    f[flow_value_col]   = pd.to_numeric(f[flow_value_col],   errors="coerce").astype(float).fillna(0)



    # If multiple flow rows per date, reduce to a single value (mean is a sensible default)
    f_reduced = (
        f[[flow_date_col, flow_value_col]]
        .groupby(flow_date_col, as_index=False)
        .mean(numeric_only=True)
    )

    # --- Join flow onto samples by date ---
    merged = s.merge(
        f_reduced,
        left_on=sample_date_col,
        right_on=flow_date_col,
        how="left",
        suffixes=("", "_flow")
    )

    # --- Compute kg/day ---
    # kg/day = (mg/L) * (m^3/day) * 0.001
    merged[kg_col] = merged[sample_value_col] * merged[flow_value_col] * 0.001

    # Drop the join key from the right side to keep original schema tidy
    merged = merged.drop(columns=[flow_date_col])

    return merged
 """

' \ndef clean_and_mg_L_to_kg_per_day(\n    df_samples: pd.DataFrame,\n    df_flow: pd.DataFrame,\n    sample_date_col: str = "F_MUESTREO",\n    sample_value_col: str = "RESULTADO",\n    flow_date_col: str = "date",\n    flow_value_col: str = "water_flow_m3_d_cubillas",\n    kg_col: str = "kg_per_day",\n) -> pd.DataFrame:\n\n    # Add (or overwrite) a kg/day load column to df_samples.\n\n    # Assumptions:\n    #   - df_samples[sample_value_col] is in mg/L.\n    #   - df_flow[flow_value_col] is the flow in m^3/day.\n    #   - We join flow to samples by (day) date.\n\n    # kg/day is computed as: (mg/L) * (m^3/day) * (1000 L/m^3) * (1 kg / 1e6 mg) = value * flow * 0.001\n\n    # Work on copies to avoid changing caller\'s data\n    s = df_samples.copy()\n    f = df_flow.copy()\n\n    # --- Force types ---\n    # Dates → datetime (floored to day)\n    s[sample_date_col] = pd.to_datetime(s[sample_date_col], errors="coerce").dt.floor("D")\n    f[flow_date_col]   = pd.to_datetime(f[flow_date_

In [12]:
#dict_81 = load_or_build_dfs_for_runs([81], force_rebuild=False)
#dict_83 = load_or_build_dfs_for_runs([83], force_rebuild=False)
#dict_91 = load_or_build_dfs_for_runs([91], force_rebuild=False)
#dict_92 = load_or_build_dfs_for_runs([92], force_rebuild=False)
#dict_106 = load_or_build_dfs_for_runs([106], force_rebuild=False)
#dict_107 = load_or_build_dfs_for_runs([107], force_rebuild=False)
dict_108 = load_or_build_dfs_for_runs([108], force_rebuild=False)
#dict_109 = load_or_build_dfs_for_runs([109], force_rebuild=False)
dict_117 = load_or_build_dfs_for_runs([117], force_rebuild=False)
dict_134 = load_or_build_dfs_for_runs([134], force_rebuild=False)
dict_136 = load_or_build_dfs_for_runs([136], force_rebuild=False)
dict_138 = load_or_build_dfs_for_runs([138], force_rebuild=False)
dict_140 = load_or_build_dfs_for_runs([140], force_rebuild=False)
dict_141 = load_or_build_dfs_for_runs([141], force_rebuild=False)
dict_142 = load_or_build_dfs_for_runs([142], force_rebuild=False)
dict_143 = load_or_build_dfs_for_runs([143], force_rebuild=False)
dict_144 = load_or_build_dfs_for_runs([144], force_rebuild=False)
dict_145 = load_or_build_dfs_for_runs([145], force_rebuild=False)
dict_155 = load_or_build_dfs_for_runs([155], force_rebuild=False)
#dict_BASE_orig = load_multiple_rch_from_folders([r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_set-219\TxtInOut_1"])
#dict_BASE_recr_default_rswat = load_multiple_rch_from_folders([r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_set-219\BASE recreated from arcswat default\TxtInOut_1"])

print(dict_136.keys(), dict_141.keys(), dict_108.keys(), dict_142.keys(), dict_143.keys(), dict_144.keys(), dict_145.keys(), dict_140.keys(), dict_117.keys(), dict_138.keys(), dict_134.keys())

dict_keys(['rch_run000136_real000492_1']) dict_keys(['rch_run000141_real000499_1']) dict_keys(['rch_run000108_real000395_1']) dict_keys(['rch_run000142_real000500_1']) dict_keys(['rch_run000143_real000501_1']) dict_keys(['rch_run000144_real000502_1']) dict_keys(['rch_run000145_real000503_1']) dict_keys(['rch_run000140_real000497_1', 'rch_run000140_real000498_2']) dict_keys(['rch_run000117_real000404_1']) dict_keys(['rch_run000138_real000495_1']) dict_keys(['rch_run000134_real000490_1'])


In [13]:
#dict_BASE_recr_default["rch_BASE recreated from arcswat default"]

In [14]:
#print(compare_dfs(dict_144['rch_run000144_real000502_1'], dict_BASE_recr_default_recr_python['rch_BASE recreated from arcswat default']))

#print(compare_dfs(dict_91["rch_run000091_real000364_1"], dict_92["rch_run000092_real000366_1"]))
#print(compare_dfs(dict_91["rch_run000091_real000365_2"], dict_92["rch_run000092_real000367_2"]))

#print(compare_dfs(dict_107["rch_run000107_real000394_2"], dict_107["rch_run000107_real000393_1"]))
#print(compare_dfs(dict_109["rch_run000109_real000396_1"], dict_BASE_POINT_recreated["rch_TxtInOut_1_work"]))


#dfs_compare_different_point_load_runs = load_multiple_rch_from_folders([r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_POINT_set-219\TxtInOut_1_work\TxtInOut", r"C:\SWAT\RSWAT\cubillas\mc_results\run000109_real000396_1"])


In [15]:

# TxtInOut base folders
OLD_base_folders = [
    #r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_set-219\TxtInOut_1",
    #r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_POINT_set-219\TxtInOut_1",
    #r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_DIFFUSE_set-219\TxtInOut_1"
    r"C:\Users\Usuario\OneDrive - UNIVERSIDAD DE HUELVA\Archivos de Cesar Ruben Fernandez De Villaran San Juan - swat_cubillas\cubillas_hru\Scenarios\Default\TxtInOut",
    r"C:\Users\Usuario\OneDrive - UNIVERSIDAD DE HUELVA\Archivos de Cesar Ruben Fernandez De Villaran San Juan - swat_cubillas\cubillas_hru\Scenarios\cubillas_original\TxtInOut",

]
folders_to_compara = OLD_base_folders + find_run_folders(77)
print(folders_to_compara)
#dfs_base_vs_point = load_multiple_rch_from_folders(OLD_base_folders)
#print(dfs_base_vs_point.keys())


['C:\\Users\\Usuario\\OneDrive - UNIVERSIDAD DE HUELVA\\Archivos de Cesar Ruben Fernandez De Villaran San Juan - swat_cubillas\\cubillas_hru\\Scenarios\\Default\\TxtInOut', 'C:\\Users\\Usuario\\OneDrive - UNIVERSIDAD DE HUELVA\\Archivos de Cesar Ruben Fernandez De Villaran San Juan - swat_cubillas\\cubillas_hru\\Scenarios\\cubillas_original\\TxtInOut']


In [16]:
import numpy as np

def tag_flow_outliers(
    df_flow: pd.DataFrame,
    *,
    date_col: str = "date",
    flow_col: str = "water_flow_m3_d_cubillas",
    col_name: str = "outliers",
    method: str = "mad_log",      # "mad_log" | "iqr" | "quantile"
    sided: str = "upper",         # "upper" | "both"
    k: float = 3.5,               # mad_log threshold (in robust-sd)
    iqr_k: float = 1.5,           # IQR multiplier for IQR method
    upper_q: float = 0.995,       # upper quantile for "quantile" method
    lower_q: float = 0.005,       # lower quantile if sided="both"
    daily_reduce: str = "sum",    # "sum" | "mean"
    inplace: bool = False,
) -> pd.DataFrame:
    """
    Add a boolean column marking extreme flow days. A sensible default is
    method="mad_log", sided="upper", k=3.5 which flags unusually large flows.

    - Reduces to daily resolution before tagging.
    - 'sided="upper"' flags only high-flow extremes; use "both" to also flag lows.
    """
    df = df_flow if inplace else df_flow.copy()

    # Normalize dates
    d = df[[date_col, flow_col]].copy()
    d[date_col] = pd.to_datetime(d[date_col], errors="coerce").dt.floor("D")
    d[flow_col] = pd.to_numeric(d[flow_col], errors="coerce").astype(float)

    # Reduce duplicates per day
    if daily_reduce == "mean":
        daily = d.groupby(date_col, as_index=False)[flow_col].mean()
    else:
        daily = d.groupby(date_col, as_index=False)[flow_col].sum(min_count=1)

    s = daily.set_index(date_col)[flow_col].sort_index()

    # Choose method
    method = str(method).lower()
    sided = "both" if str(sided).lower() == "both" else "upper"

    if method == "mad_log":
        x = np.log1p(s.values)  # robust for heavy tails
        med = np.nanmedian(x)
        mad = np.nanmedian(np.abs(x - med))
        robust_sigma = 1.4826 * mad if mad > 0 else np.nan
        if np.isnan(robust_sigma) or robust_sigma == 0:
            flags = np.zeros_like(x, dtype=bool)
        else:
            z = (x - med) / robust_sigma
            if sided == "both":
                flags = np.abs(z) > float(k)
            else:
                flags = z > float(k)
    elif method == "iqr":
        q1, q3 = np.nanpercentile(s.values, [25, 75])
        iqr = q3 - q1
        lo = q1 - float(iqr_k) * iqr
        hi = q3 + float(iqr_k) * iqr
        if sided == "both":
            flags = (s.values < lo) | (s.values > hi)
        else:
            flags = s.values > hi
    elif method == "quantile":
        hi = np.nanquantile(s.values, float(upper_q))
        if sided == "both":
            lo = np.nanquantile(s.values, float(lower_q))
            flags = (s.values < lo) | (s.values > hi)
        else:
            flags = s.values > hi
    else:
        raise ValueError("Unsupported method. Use 'mad_log', 'iqr', or 'quantile'.")

    out_days = pd.Series(flags, index=s.index, name=col_name)

    # Attach to original df by daily date
    map_days = out_days.reset_index().rename(columns={date_col: "_day"})
    df["_day"] = pd.to_datetime(df[date_col], errors="coerce").dt.floor("D")
    df = df.merge(map_days, left_on="_day", right_on="_day", how="left")
    df[col_name] = df[col_name].fillna(False).astype(bool)
    df.drop(columns=["_day"], inplace=True)
    return df


In [17]:
# Load water flow and measurement data
df_water_flow_m3_d_cubillas = pd.read_csv(r"C:\Users\Usuario\OneDrive - UNIVERSIDAD DE HUELVA\Granada\TrabajoFM\Genil GEO_INFO_POOL\Data Zip inicial Francisco\CHGxSAIH\Embalses\E45SAIHInflowQR_most_current_Francisco\E45SAIHInflowQR.csv", index_col=False)
DMA_cubillas_measurements_131415_mg_L = pd.read_csv(r"C:\Users\Usuario\OneDrive - UNIVERSIDAD DE HUELVA\Granada\TrabajoFM\Genil GEO_INFO_POOL\Input Data\Water flow\Ptos_MDA_Cubillas_131517_data_mgL.csv")


df_water_flow_m3_d_cubillas = tag_flow_outliers(
    df_water_flow_m3_d_cubillas,
    date_col="date",
    flow_col="water_flow_m3_d_cubillas",
    col_name="outliers",
    method="mad_log",
    sided="upper",
    k=1,
    daily_reduce="sum",
    inplace=True
)


In [18]:
df_water_flow_m3_d_cubillas

Unnamed: 0,date,water_flow_m3_d_cubillas,doy,smooth,outliers
0,1955-03-02,450000.0,61,364133.333333,True
1,1955-03-03,0.0,62,352937.500000,False
2,1955-03-04,500000.0,63,348294.117647,True
3,1955-03-05,324000.0,64,343388.888889,True
4,1955-03-06,649000.0,65,343894.736842,True
...,...,...,...,...,...
25486,2024-12-27,67670.0,362,74615.600000,False
25487,2024-12-28,73120.0,363,74614.105263,False
25488,2024-12-29,78610.0,364,74256.388889,False
25489,2024-12-30,78323.0,365,73739.294118,False


In [19]:
# filter DMA_cubillas_measurements_131415_mg_L for FOSFORO TOTAL NOMBRES on 2015-2-12 and estaci 30304

DMA_cubillas_measurements_131415_mg_L[(DMA_cubillas_measurements_131415_mg_L["NOMBRE"] == "FOSFORO TOTAL") & (DMA_cubillas_measurements_131415_mg_L["est_estaci"] == "30304")]

DMA_cubillas_measurements_131415_mg_L[(DMA_cubillas_measurements_131415_mg_L["RESULTADO"] == "LC")]

Unnamed: 0,Join_Count,TARGET_FID,est_estaci,COD_UE,COD_PUNTO_,F_MUESTREO,NOMBRE,RESULTADO,UNIDAD,COD_UE_1,...,Dep1,Lat,Long_,Elev,ElevMin,ElevMax,Bname,Shape_Leng,HydroID,OutletID
32,1,13832,30303,30303-O,GV10090004,09/10/97 0:00:00,AMONIO,LC,mg NH4/l,,...,0.400558,37.291096,-3.664485,706.037444,606.0,869.0,,40000.0,300015,100001
38,1,13838,30303,30303-O,GV10090004,03/04/98 0:00:00,AMONIO,LC,mg NH4/l,,...,0.400558,37.291096,-3.664485,706.037444,606.0,869.0,,40000.0,300015,100001
100,1,13900,30303,30303-O,GV10090004,03/11/04 0:00:00,AMONIO,LC,mg NH4/l,,...,0.400558,37.291096,-3.664485,706.037444,606.0,869.0,,40000.0,300015,100001
105,1,13905,30303,30303-O,GV10090004,08/11/04 0:00:00,AMONIO,LC,mg NH4/l,,...,0.400558,37.291096,-3.664485,706.037444,606.0,869.0,,40000.0,300015,100001
106,1,13906,30303,30303-O,GV10090004,09/15/04 0:00:00,AMONIO,LC,mg NH4/l,,...,0.400558,37.291096,-3.664485,706.037444,606.0,869.0,,40000.0,300015,100001
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3087,1,16887,30304,30304,GV10090007,12/01/99 0:00:00,NITROGENO KJELDAHL,LC,mg N/l,,...,0.343950,37.311018,-3.639073,749.117173,637.0,938.0,,28300.0,300013,100005
3088,1,16888,30304,30304,GV10090007,03/14/00 0:00:00,NITROGENO KJELDAHL,LC,mg N/l,,...,0.343950,37.311018,-3.639073,749.117173,637.0,938.0,,28300.0,300013,100005
3089,1,16889,30304,30304,GV10090007,06/06/00 0:00:00,NITROGENO KJELDAHL,LC,mg N/l,,...,0.343950,37.311018,-3.639073,749.117173,637.0,938.0,,28300.0,300013,100005
3090,1,16890,30304,30304,GV10090007,09/06/00 0:00:00,NITROGENO KJELDAHL,LC,mg N/l,,...,0.343950,37.311018,-3.639073,749.117173,637.0,938.0,,28300.0,300013,100005


In [20]:
#variable_for_manual_loaded_rchs_which_needs_lots_of_time_to_load = load_multiple_rch_from_folders(OLD_base_folders) # + [r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_set-219\BASE recreated from arcswat default\TxtInOut_1", ]),#r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_POINT_set-219\TxtInOut_1",]),
    #extra_dfs = load_or_build_dfs_for_runs([150], force_rebuild=False),

In [21]:
#dict_BASE_recr_default_rswat = load_multiple_rch_from_folders([r"C:\SWAT\RSWAT\cubillas\cubillas_set_219_ruben\cubillas_BASE_set-219\BASE recreated from arcswat default\TxtInOut_1"])

In [22]:
#dict_BASE_recr_default_rswat.keys()

In [23]:
#dict_BASE_recr_default_rswat.keys()
#dict_145.keys()

comparing_BASEs = {"BASE": dict_145['rch_run000145_real000503_1']} #, "219 recreated RSWAT": dict_BASE_recr_default_rswat['rch_BASE recreated from arcswat default']}

In [24]:
#variable_for_manual_loaded_rchs_which_needs_lots_of_time_to_load.keys()

In [52]:
import sys
from importlib import reload
sys.path.insert(0, r"c:\Users\Usuario\OneDrive - UNIVERSIDAD DE HUELVA\Granada\TrabajoFM\scripts\Python_Pipeline_SWAT_Pascal\swat_pipeline\trabajoFM")
from python_pipeline_scripts.provenance_report import summarize_run

# 2) Import module as alias
import python_pipeline_scripts.dashboard as dash
import python_pipeline_scripts.stats as stats

monte_carlo_run_numbmer = 157
defaults = {
    "variable": "KJELDAHL_OUTkg", # IMPORTANT: must match an item in your variables list
    "reach": 13,
    "freq": "D",
    "bin": 1,
    "compare_mode": "load", # "conc" for mg/L, "load" for kg/day
    "method": "mean",
    "autoscale_y_live": True,
    "show_names_in_tooltip": False,
    "show_diags": True,
    "lag_metric": "NSE",
    "max_lag": 2,
    "local_Ks": [1, 2],
    "log_metrics": False,
    "measured_on": True,
    "flow_on": False,
    #"cats": {
    #    1: {
    #        "enabled": True,
    #        "chem": "FOSFORO TOTAL", # Edit to your measured chem name as it appears in measured_df["NOMBRE"]
    #        "stations": ["30304"] # Edit to your station codes (strings)
    #    },
    #    2: {"enabled": False},
    #    3: {"enabled": False}
    #},
    "extra_visible": {
    # "Baseline": False,
    },
    "debug": True,
    "flow_on": True, "exclude_flow_outliers": False, "outlier_buffer_days": 3,
    "swat_flow_on": False,
    "flag_deviations": False,
    "erosion_on": False,
    "cb_ldc_sediment": False,
    "meas_negative_policy": "drop",
    "meas_nonnum_policy": "half_MDL", # "zero" or "half_detection_limit" or "drop"
    "event_view": "all", # "all" or "events" or "non_events"
    "event_source": "swat_avg", # "external" or "swat_avg"
    "event_min_days": 2, # minimum event duration in days
    "event_threshold": "p50", # multiplier of average flow
    "event_buffer_days": 2, # days to add before/after event
    "event_threshold": "p75", # threshold for event detection

    "ldc_log_scale": True,
    "flow_strat_curve": True,
    

}

reload(stats)
stats.SHIFT_AGG = 'both'  # or 'mean', "median", "both"
reload(dash)
dash.fan_compare_simulations_dashboard(
    load_or_build_dfs_for_runs([monte_carlo_run_numbmer], force_rebuild=False),#dashboard_mc_dfs,
    vars_to_compare,
    measured_df=DMA_cubillas_measurements_131415_mg_L,
    water_flow_df = df_water_flow_m3_d_cubillas,
    measured_var_map=swat_to_measured,
    reach=13,
    start="1986-01-01",
    end="2001-02-25",
    freq_options=("D","W","M","A"),
    max_bin_size=30,
    #extra_dfs = variable_for_manual_loaded_rchs_which_needs_lots_of_time_to_load,
    #extra_dfs = load_or_build_dfs_for_runs([145], force_rebuild=False), #dict_BASE_recr_default_rswat
    extra_dfs = comparing_BASEs,
    how_map_defaults=how_map_defaults,
    ui_defaults=defaults
)

print(summarize_run(monte_carlo_run_numbmer))



[dash] Init dashboard {'reach': 13, 'start': '1986-01-01', 'end': '2001-02-25', 'season_months': None}
[dash] variables {'n': 14, 'has_KJELDAHL': True}
Detected water flow series in water_flow_df - will also map water flow.


HBox(children=(VBox(children=(HTML(value='Number of initializations: 2 - run 157'), Dropdown(description='Vari…



Output()

HBox(children=(Dropdown(description='Lag by:', index=1, layout=Layout(width='140px'), options=('r', 'NSE'), va…

HBox(layout=Layout(justify_content='flex-start', width='100%'))

HBox(children=(HTML(value='', layout=Layout(border_bottom='1px solid #ddd', border_left='1px solid #ddd', bord…

[dash] measured_var_map {}
[dash] compute {'var': 'KJELDAHL_OUTkg', 'reach': 13, 'freq': '1D', 'method': 'mean', 'mode': 'load'}
[info] median sampling interval = 1.000000 days (24.000 h)
[proc] 'main': threshold resolved = 135432.0 (units same as 'Q'), etmin = 2.0 days -> min_samples = 2
[result] 'main': runs found = 51, kept (>= 2 samples) = 47
[info] median sampling interval = 1.000000 days (24.000 h)
[proc] 'main': threshold resolved = 135432.0 (units same as 'Q'), etmin = 2.0 days -> min_samples = 2
[result] 'main': runs found = 51, kept (>= 2 samples) = 47
[dash] events {'mode': 'all', 'events': 1381, 'buffer': 243, 'keep': 5535}
[dash] rch_run000157_real000529_1 raw reach=13: n=14610, date=[1981-01-01 00:00:00..2020-12-31 00:00:00]
[dash] rch_run000157_real000529_1 after time slice: n=5535
[dash] rch_run000157_real000529_1 after event-mode filter: n=5535
[dash] rch_run000157_real000529_1 series after resample: n=5535, idx=[1986-01-01 00:00:00..2001-02-25 00:00:00]
[dash] rch_run

[dash] duration context {'has_q_plot': True, 'q_plot_shape': (5535, 10), 'swat_flow_valid': False, 'ext_flow_valid': True}
[dash] duration q_plot cols ['min', 'p05', 'p10', 'p25', 'p50', 'p60', 'p75', 'p90', 'p95', 'max']
[dash] duration measured overlay {'n_points': 20}
[dash] flow_strat:aligned_df_plot_info {'valid': True, 'shape': (5535, 2)}
[dash] flow_strat:event_ctx_indices {'has_events': True, 'n_events': 1624, 'has_non_events': True, 'n_non_events': 3917}
Run 157: realizations=2 ids=[529, 530]
Time span: 2025-09-20T10:12:08.913842+00:00 → 2025-09-20T10:12:15.190879+00:00
Names:
  - run000157_real000529_1
  - run000157_real000530_2
Transforms:
  - transform_init_copy
  - fn
  - ops_choices
  - split_choices
  - transform_interpolate_years_wide
  - point_mgL_choices
Inputs (union):
  - C:\Users\Usuario\OneDrive - UNIVERSIDAD DE HUELVA\Archivos de Cesar Ruben Fernandez De Villaran San Juan - swat_cubillas\cubillas_hru\Watershed\Shapes\hru1.shp
  - C:\Users\Usuario\OneDrive - UNIVE

[dash] duration context {'has_q_plot': True, 'q_plot_shape': (5535, 10), 'swat_flow_valid': False, 'ext_flow_valid': True}
[dash] duration q_plot cols ['min', 'p05', 'p10', 'p25', 'p50', 'p60', 'p75', 'p90', 'p95', 'max']
[dash] duration measured overlay {'n_points': 12}
[dash] duration measured overlay {'n_points': 12}
[dash] flow_strat:aligned_df_plot_info {'valid': True, 'shape': (5535, 2)}
[dash] flow_strat:event_ctx_indices {'has_events': True, 'n_events': 1624, 'has_non_events': True, 'n_non_events': 3917}
[dash] flow_strat:aligned_df_plot_info {'valid': True, 'shape': (5535, 2)}
[dash] flow_strat:event_ctx_indices {'has_events': True, 'n_events': 1624, 'has_non_events': True, 'n_non_events': 3917}


In [26]:
""" monte_carlo_run_numbmer = 129
dashboard_mc_dfs = load_or_build_dfs_for_runs([monte_carlo_run_numbmer], force_rebuild=False)


defaults2 = {
    "variable": "TOT_Nkg", # IMPORTANT: must match an item in your variables list
    "reach": 13,
    "freq": "D",
    "bin": 1,
    "compare_mode": "conc", # "conc" for mg/L, "load" for kg/day
    "method": "mean",
    "autoscale_y_live": True,
    "show_names_in_tooltip": False,
    "show_diags": True,
    "lag_metric": "NSE",
    "max_lag": 2,
    "local_Ks": [1, 2],
    "log_metrics": False,
    "measured_on": True,
    "flow_on": False,
    #"cats": {
    #    1: {
    #        "enabled": True,
    #        "chem": "FOSFORO TOTAL", # Edit to your measured chem name as it appears in measured_df["NOMBRE"]
    #        "stations": ["30304"] # Edit to your station codes (strings)
    #    },
    #    2: {"enabled": False},
    #    3: {"enabled": False}
    #},
    "extra_visible": {
    # "Baseline": False,
    },
    "debug": True
}



reload(dash)
dash.fan_compare_simulations_dashboard(
    dashboard_mc_dfs,
    vars_to_compare,
    measured_df=DMA_cubillas_measurements_131415_mg_L,
    measured_var_map=swat_to_measured,
    reach=13,
    freq_options=("D","W","M","A"),
    max_bin_size=30,
    #extra_dfs = dfs_base_vs_point, 
    #extra_dfs = load_or_build_dfs_for_runs([118], force_rebuild=False),
    water_flow_df = df_water_flow_m3_d_cubillas,
    how_map_defaults=how_map_defaults,
    ui_defaults=defaults2
)

print(summarize_run(monte_carlo_run_numbmer)) """

' monte_carlo_run_numbmer = 129\ndashboard_mc_dfs = load_or_build_dfs_for_runs([monte_carlo_run_numbmer], force_rebuild=False)\n\n\ndefaults2 = {\n    "variable": "TOT_Nkg", # IMPORTANT: must match an item in your variables list\n    "reach": 13,\n    "freq": "D",\n    "bin": 1,\n    "compare_mode": "conc", # "conc" for mg/L, "load" for kg/day\n    "method": "mean",\n    "autoscale_y_live": True,\n    "show_names_in_tooltip": False,\n    "show_diags": True,\n    "lag_metric": "NSE",\n    "max_lag": 2,\n    "local_Ks": [1, 2],\n    "log_metrics": False,\n    "measured_on": True,\n    "flow_on": False,\n    #"cats": {\n    #    1: {\n    #        "enabled": True,\n    #        "chem": "FOSFORO TOTAL", # Edit to your measured chem name as it appears in measured_df["NOMBRE"]\n    #        "stations": ["30304"] # Edit to your station codes (strings)\n    #    },\n    #    2: {"enabled": False},\n    #    3: {"enabled": False}\n    #},\n    "extra_visible": {\n    # "Baseline": False,\n    }

# evaluate sedimentation effects

In [27]:
import pandas as pd
import numpy as np

def _to_datetime_if_needed(s):
    # If MON is julian, replace this with proper calendar conversion.
    return pd.to_datetime(s)

def compute_sed_timeseries(
    df_rch: pd.DataFrame,
    df_sub: pd.DataFrame,
    reach: int,
    method: str = "A",           # "A", "B", "C1", "C2"
    dr=1.0,                      # delivery ratio: scalar, or mapping {reach: dr}, or Series by reach or by (date, reach)
    cols_rch=dict(date="date", reach="RCH", sed_in="SED_IN", sed_out="SED_OUTtons",
                  flow_in="FLOW_INcms", flow_out="FLOW_OUTcms", area_km2="AREA"),
    cols_sub=dict(date="date", sub="SUB", area_km2="AREA", syld_t_ha="SYLD", wyld_mm="WYLD"),
    drop_feb29=True
) -> pd.Series:
    """
    Returns a pandas Series indexed by datetime for the chosen reach & method.
      Method A: Δ_channel = (OUT - IN) - DR * SYLD_tons
      Method B: Retained = (IN + DR * SYLD_tons) - OUT
      Method C1: Hydrology-normalized concentration-like (mg/L)
      Method C2: Hydrology-normalized per runoff (kg/mm/ha)

    Notes:
      - SYLD is tons/ha in .sub, convert to tons by multiplying subbasin area (ha).
      - If WYLD is zero on a day and method='C2', result is NaN for that day.
      - If FLOW_OUT is zero on a day and method='C1', result is NaN for that day.
    """

    # --- Extract & clean reach rows
    r = df_rch.copy()
    s = df_sub.copy()
    r = r.rename(columns={
        cols_rch["date"]:"date", cols_rch["RCH"]:"RCH",
        cols_rch["sed_in"]:"SED_IN", cols_rch["sed_out"]:"SED_OUTtons",
        cols_rch["flow_in"]:"FLOW_INcms", cols_rch["flow_out"]:"FLOW_OUTcms",
        cols_rch["area_km2"]:"RCH_AREA_KM2"
    })
    s = s.rename(columns={
        cols_sub["date"]:"date", cols_sub["sub"]:"SUB",
        cols_sub["area_km2"]:"SUB_AREA_KM2", cols_sub["syld_t_ha"]:"SYLD_T_HA",
        cols_sub["wyld_mm"]:"WYLD_MM"
    })

    # Convert date
    r["date"] = _to_datetime_if_needed(r["date"])
    s["date"] = _to_datetime_if_needed(s["date"])

    # Keep only target reach
    r = r[r["RCH"] == reach].copy()
    s = s[s["SUB"] == reach].copy()

    # Merge reach + sub on date
    m = pd.merge(
        r[["date","RCH","SED_IN","SED_OUTtons","FLOW_INcms","FLOW_OUTcms","RCH_AREA_KM2"]],
        s[["date","SUB","SUB_AREA_KM2","SYLD_T_HA","WYLD_MM"]],
        on="date", how="inner"
    )

    # Delivery ratio
    if np.isscalar(dr):
        m["DR"] = float(dr)
    elif isinstance(dr, dict):
        m["DR"] = m["RCH"].map(dr).astype(float)
    elif isinstance(dr, pd.Series):
        # Series keyed by reach or MultiIndex (date, reach)
        if dr.index.nlevels == 1:
            m["DR"] = m["RCH"].map(dr).astype(float)
        else:
            m = m.set_index(["date","RCH"])
            m["DR"] = dr.reindex(m.index).astype(float)
            m = m.reset_index()
    else:
        m["DR"] = 1.0

    # Convert SYLD from tons/ha to tons per time-step (using SUB area)
    m["SUB_AREA_HA"] = m["SUB_AREA_KM2"] * 100.0
    m["SYLD_TONS"] = m["SYLD_T_HA"] * m["SUB_AREA_HA"]

    # Core channel terms
    m["delta_raw"] = m["SED_OUTtons"] - m["SED_IN"]
    m["local_delivered"] = m["DR"] * m["SYLD_TONS"]
    m["delta_channel"] = m["delta_raw"] - m["local_delivered"]
    m["retained"] = (m["SED_IN"] + m["local_delivered"]) - m["SED_OUTtons"]  # = -delta_channel

    # Optional: remove Feb 29 to keep a consistent year-day
    if drop_feb29:
        is_feb29 = (m["date"].dt.month == 2) & (m["date"].dt.day == 29)
        m = m[~is_feb29]

    m = m.sort_values("date").set_index("date")

    # Methods
    if method.upper() == "A":
        out = m["delta_channel"]  # tons per time-step
        out.name = f"Δ_channel_tons_R{reach}"

    elif method.upper() == "B":
        out = m["retained"]  # tons per time-step (positive = deposition)
        out.name = f"Retained_tons_R{reach}"

    elif method.upper() == "C1":
        # concentration-like normalization using FLOW_OUT (m3/s)
        # tons/day -> mg/L approx: tons/day * 1e9 mg/ton / (m3/s * 86400 s/day * 1000 L/m3)
        # Simplify factor ~ 11.574074 / Qcms
        q = m["FLOW_OUTcms"].replace(0, np.nan)
        out = (m["delta_channel"] * 11.574074) / q
        out.name = f"Δ_channel_mgL_R{reach}"

    elif method.upper() == "C2":
        # kg per mm-runoff per ha (uses SUB WYLD and area)
        denom = (m["WYLD_MM"] * m["SUB_AREA_HA"]).replace(0, np.nan)  # mm * ha
        out = (m["delta_channel"] * 1000.0) / denom  # tons -> kg
        out.name = f"Δ_channel_kg_per_mm_ha_R{reach}"

    else:
        raise ValueError("method must be one of 'A', 'B', 'C1', 'C2'")

    return out


In [28]:
# Sediment Dynamics Dashboard (methods A/B/C) — Jupyter + Plotly + ipywidgets
# ----------------------------------------------------------------------------
# Methods (one time series per reach):
#  A: Δ_channel = (SED_OUT − SED_IN) − DR * SYLD_tons
#  B: Retained  = (SED_IN + DR * SYLD_tons) − SED_OUT  ( = −Δ_channel )
#  C1: Hydrology-normalized Δ_channel per discharge  [mg/L]  (uses FLOW_OUT)
#  C2: Hydrology-normalized Δ_channel per runoff     [kg/mm/ha] (uses WYLD + area)
#
# Notes:
# - .rch must have: date, RCH, SED_INtons, SED_OUTtons (optional: FLOW_OUT)
# - .sub must have: date, SUB, AREA (km²), SYLD (tons/ha), WYLD (mm)
# - Delivery ratio DR is user-controlled (scalar) in the UI.

import warnings
warnings.filterwarnings("ignore")  # stop showing warnings in dashboard

import pandas as pd
import numpy as np
import plotly.graph_objects as go
import ipywidgets as W
from IPython.display import display

# ---------------- Utility helpers ----------------
def to_datetime(s):
    return pd.to_datetime(s)

def agg_key(name: str) -> str:
    # Return pandas string aggregator -> fixes FutureWarning
    return "mean" if name == "Mean" else "median"

def rolling_apply(x, win, func):
    if win is None or win <= 1:
        return np.asarray(x)
    # func is a numpy function (np.nanmean / np.nanmedian)
    f = np.nanmean if func == "mean" else np.nanmedian
    return pd.Series(x).rolling(win, min_periods=max(1, win // 2)).apply(lambda s: f(s.values), raw=False).values

def quantiles(a, qs):
    a = np.asarray(a, dtype=float)
    a = a[~np.isnan(a)]
    if a.size == 0: return {q: np.nan for q in qs}
    qq = np.quantile(a, qs)
    return {q: float(v) for q, v in zip(qs, qq)}

def band_fill(fig, x, high, low, name, opacity=0.2):
    fig.add_trace(go.Scatter(x=x, y=high, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip"))
    fig.add_trace(go.Scatter(x=x, y=low,  mode="lines", line=dict(width=0), fill="tonexty", name=name, opacity=opacity))

# ------------- Method engine (build metric per row) -------------
def build_metric_table(
    df_rch: pd.DataFrame,
    df_sub: pd.DataFrame,
    drop_feb29: bool,
    method: str,
    DR: float,
    cols_rch=dict(date="date", reach="RCH", sed_in="SED_INtons", sed_out="SED_OUTtons", flow_out="FLOW_OUTcms"),
    cols_sub=dict(date="date", sub="SUB", area_km2="AREA", syld_t_ha="SYLD", wyld_mm="WYLD"),
) -> pd.DataFrame:
    """
    Returns a table with columns:
      date, RCH, year, doy, metric  (metric depends on method & DR)
    """
    r = df_rch.rename(columns={
        cols_rch["date"]:"date", cols_rch["reach"]:"RCH",
        cols_rch["sed_in"]:"SED_IN", cols_rch["sed_out"]:"SED_OUT",
        cols_rch.get("flow_out","FLOW_OUTcms"):"FLOW_OUTcms"
    }).copy()
    s = df_sub.rename(columns={
        cols_sub["date"]:"date", cols_sub["sub"]:"SUB",
        cols_sub["area_km2"]:"SUB_AREA_KM2",
        cols_sub["syld_t_ha"]:"SYLD_T_HA",
        cols_sub["wyld_mm"]:"WYLD_MM"
    }).copy()

    r["date"] = to_datetime(r["date"])
    s["date"] = to_datetime(s["date"])

    # merge by (date, reach/sub)
    m = pd.merge(
        r[["date","RCH","SED_IN","SED_OUT","FLOW_OUTcms"]],
        s[["date","SUB","SUB_AREA_KM2","SYLD_T_HA","WYLD_MM"]],
        left_on=["date","RCH"], right_on=["date","SUB"], how="inner"
    )
    m.drop(columns=["SUB"], inplace=True)

    # Convert SYLD tons/ha -> tons (per step)
    m["SUB_AREA_HA"] = m["SUB_AREA_KM2"] * 100.0
    m["SYLD_TONS"]   = m["SYLD_T_HA"] * m["SUB_AREA_HA"]

    # core deltas
    m["delta_raw"]       = m["SED_OUT"] - m["SED_IN"]
    m["local_delivered"] = float(DR) * m["SYLD_TONS"]
    m["delta_channel"]   = m["delta_raw"] - m["local_delivered"]          # Method A
    m["retained"]        = (m["SED_IN"] + m["local_delivered"]) - m["SED_OUT"]  # Method B

    # choose metric
    method = method.upper()
    if method == "A":
        metric = m["delta_channel"]  # tons/step
        ylab = "Δ_channel (tons)"
    elif method == "B":
        metric = m["retained"]       # tons/step (positive = deposition)
        ylab = "Retained (tons)"
    elif method == "C1":
        # per discharge mg/L (use FLOW_OUT m3/s)
        q = m["FLOW_OUTcms"].replace(0, np.nan)
        metric = (m["delta_channel"] * 11.574074) / q   # tons/day → mg/L approx
        ylab = "Δ_channel per discharge (mg/L)"
    elif method == "C2":
        # per runoff kg/mm/ha
        denom = (m["WYLD_MM"] * m["SUB_AREA_HA"]).replace(0, np.nan)  # mm * ha
        metric = (m["delta_channel"] * 1000.0) / denom                # tons → kg
        ylab = "Δ_channel per runoff (kg/mm/ha)"
    else:
        raise ValueError("method must be one of 'A', 'B', 'C1', 'C2'")

    out = m[["date","RCH"]].copy()
    out["metric"] = metric.values
    out["year"] = out["date"].dt.year
    out["month"] = out["date"].dt.month
    out["day"] = out["date"].dt.day
    if drop_feb29:
        out = out[~((out["month"] == 2) & (out["day"] == 29))]
    out["doy"] = out["date"].dt.dayofyear
    out.drop(columns=["month","day"], inplace=True)
    out.attrs["ylab"] = ylab
    return out.reset_index(drop=True)

# ------------- Aggregators (operate on 'metric') -------------
def compose_time_aggregate(subm: pd.DataFrame, reaches, years, agg_name: str, stats, roll_win: int):
    s = subm[subm["RCH"].isin(reaches)]
    if years: s = s[s["year"].isin(years)]

    # center line by date
    if len(reaches) > 1:
        grouped = s.groupby("date", as_index=False)["metric"].agg(agg_key(agg_name)).rename(columns={"metric":"center"})
    else:
        grouped = s.groupby("date", as_index=False)["metric"].agg("mean").rename(columns={"metric":"center"})

    # bands across reaches (distribution at each date)
    if len(reaches) > 1:
        tmp = s.groupby(["date","RCH"])["metric"].agg("mean").reset_index()
        rows = []
        for d, g in tmp.groupby("date"):
            vals = g["metric"].values
            row = {"date": d}
            if "p10p90" in stats: row.update(quantiles(vals,[0.10,0.90]))
            if "p25p75" in stats: 
                q = quantiles(vals,[0.25,0.75]); row.update({"p25":q[0.25],"p75":q[0.75]})
            if "minmax" in stats and len(vals):
                row.update({"vmin": float(np.nanmin(vals)), "vmax": float(np.nanmax(vals))})
            rows.append(row)
        band_df = pd.DataFrame(rows)
        grouped = grouped.merge(band_df, on="date", how="left")

    # rolling
    grouped = grouped.sort_values("date")
    grouped["center"] = rolling_apply(grouped["center"].values, roll_win, agg_key(agg_name))
    for col in ["p10","p90","p25","p75","vmin","vmax"]:
        if col in grouped.columns:
            grouped[col] = rolling_apply(grouped[col].values, roll_win, "mean")
    return grouped

def compose_climatology(subm: pd.DataFrame, reaches, years, agg_name: str, stats, roll_win: int):
    s = subm[subm["RCH"].isin(reaches)]
    if years: s = s[s["year"].isin(years)]

    center = s.groupby("doy")["metric"].agg(agg_key(agg_name)).rename("center").reset_index()

    # bands across (year×reach) samples per DOY
    tmp = s.groupby(["doy","year","RCH"])["metric"].agg("mean").reset_index()
    rows=[]
    for d, g in tmp.groupby("doy"):
        vals = g["metric"].values
        row = {"doy": d}
        if "p10p90" in stats: row.update(quantiles(vals,[0.10,0.90]))
        if "p25p75" in stats:
            q = quantiles(vals,[0.25,0.75]); row.update({"p25":q[0.25],"p75":q[0.75]})
        if "minmax" in stats and len(vals):
            row.update({"vmin": float(np.nanmin(vals)), "vmax": float(np.nanmax(vals))})
        rows.append(row)
    band_df = pd.DataFrame(rows)
    out = center.merge(band_df, on="doy", how="left").sort_values("doy")

    # rolling over DOY (simple)
    out["center"] = rolling_apply(out["center"].values, roll_win, agg_key(agg_name))
    for col in ["p10","p90","p25","p75","vmin","vmax"]:
        if col in out.columns:
            out[col] = rolling_apply(out[col].values, roll_win, "mean")
    return out

def compose_all_years_overlay(subm: pd.DataFrame, reaches, agg_name: str, roll_win: int):
    s = subm[subm["RCH"].isin(reaches)]
    years = sorted(s["year"].unique())
    lines = {}
    for y in years:
        ydf = s[s["year"]==y]
        if len(reaches)>1:
            daily = ydf.groupby("date")["metric"].agg(agg_key(agg_name)).reset_index()
        else:
            daily = ydf.groupby("date")["metric"].agg("mean").reset_index()
        daily = daily.sort_values("date")
        daily["metric"] = rolling_apply(daily["metric"].values, roll_win, agg_key(agg_name))
        lines[y] = daily
    return lines

def compose_year_reach(subm: pd.DataFrame, reaches, years):
    s = subm[subm["RCH"].isin(reaches)]
    if years: s = s[s["year"].isin(years)]
    return s.groupby(["year","RCH"])["metric"].agg("mean").reset_index().rename(columns={"metric":"delta"})

def compose_across_years_all_reaches(subm: pd.DataFrame, reaches, agg_name: str, roll_win: int):
    s = subm[subm["RCH"].isin(reaches)] if reaches else subm.copy()
    line = s.groupby("doy")["metric"].agg(agg_key(agg_name)).reset_index().sort_values("doy")
    line["center"] = rolling_apply(line["metric"].values, roll_win, agg_key(agg_name))
    return line[["doy","center"]]

def plot_time_series(df_stats, title, ylab, bands=("p10p90","p25p75","minmax"),
                     flow_stats=None, y2lab="Flow (m³/d)"):
    has_flow = flow_stats is not None and len(flow_stats)
    fig = make_subplots(specs=[[{"secondary_y": has_flow}]]) if has_flow else go.Figure()
    x = df_stats["date"]

    if "minmax" in bands and {"vmin","vmax"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["vmax"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["vmin"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="min–max", opacity=0.15), secondary_y=False if has_flow else None)

    if "p10p90" in bands and {"p10","p90"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["p90"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["p10"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="p10–p90", opacity=0.20), secondary_y=False if has_flow else None)

    if "p25p75" in bands and {"p25","p75"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["p75"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["p25"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="p25–p75", opacity=0.30), secondary_y=False if has_flow else None)

    fig.add_trace(go.Scatter(x=x, y=df_stats["center"], mode="lines", name="Center"),
                  secondary_y=False if has_flow else None)

    if has_flow:
        fig.add_trace(go.Scatter(x=flow_stats["date"], y=flow_stats["center"], mode="lines",
                                 name="Flow", line=dict(dash="dot")), secondary_y=True)
        fig.update_yaxes(title_text=y2lab, secondary_y=True)

    fig.add_hline(y=0, line_dash="dash", opacity=0.5, secondary_y=False if has_flow else None)
    fig.update_layout(title=title, xaxis_title="Date", yaxis_title=ylab, legend_title=None,
                      margin=dict(l=40,r=20,t=50,b=40))
    return fig

def plot_climatology(df_stats, title, ylab, bands=("p10p90","p25p75","minmax"),
                     flow_stats=None, y2lab="Flow (m³/d)"):
    has_flow = flow_stats is not None and len(flow_stats)
    fig = make_subplots(specs=[[{"secondary_y": has_flow}]]) if has_flow else go.Figure()
    x = df_stats["doy"]

    if "minmax" in bands and {"vmin","vmax"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["vmax"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["vmin"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="min–max", opacity=0.15), secondary_y=False if has_flow else None)
    if "p10p90" in bands and {"p10","p90"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["p90"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["p10"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="p10–p90", opacity=0.20), secondary_y=False if has_flow else None)
    if "p25p75" in bands and {"p25","p75"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["p75"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["p25"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="p25–p75", opacity=0.30), secondary_y=False if has_flow else None)

    fig.add_trace(go.Scatter(x=x, y=df_stats["center"], mode="lines", name="Center"),
                  secondary_y=False if has_flow else None)

    if has_flow:
        fig.add_trace(go.Scatter(x=flow_stats["doy"], y=flow_stats["center"], mode="lines",
                                 name="Flow", line=dict(dash="dot")), secondary_y=True)
        fig.update_yaxes(title_text=y2lab, secondary_y=True)

    fig.add_hline(y=0, line_dash="dash", opacity=0.5, secondary_y=False if has_flow else None)
    fig.update_layout(title=title, xaxis_title="Day of Year (1–365)", yaxis_title=ylab, legend_title=None,
                      margin=dict(l=40,r=20,t=50,b=40))
    return fig

def plot_all_years_overlay(lines_dict, title, ylab, flow_lines=None, y2lab="Flow (m³/d)"):
    has_flow = isinstance(flow_lines, dict) and len(flow_lines)
    fig = make_subplots(specs=[[{"secondary_y": has_flow}]]) if has_flow else go.Figure()

    for y, d in lines_dict.items():
        fig.add_trace(go.Scatter(x=d["date"], y=d["metric"], mode="lines", name=str(y)),
                      secondary_y=False if has_flow else None)

    if has_flow:
        for y, d in flow_lines.items():
            fig.add_trace(go.Scatter(x=d["date"], y=d["metric"], mode="lines",
                                     name=f"Flow {y}", line=dict(dash="dot")), secondary_y=True)
        fig.update_yaxes(title_text=y2lab, secondary_y=True)

    fig.add_hline(y=0, line_dash="dash", opacity=0.5, secondary_y=False if has_flow else None)
    fig.update_layout(title=title, xaxis_title="Date", yaxis_title=ylab, legend_title="Year",
                      margin=dict(l=40,r=20,t=50,b=40))
    return fig



# ------------- UI: per-panel with method toggle -------------
class ChartPanel:
    PLOT_TYPES = [
        "Time series",
        "DOY climatology",
        "All years overlay",
        "Year × Reach bars",
        "Across-years avg per reach (DOY)",
        "Across-years avg (all reaches, DOY)",
    ]
    BAND_OPTIONS = [("p10–p90","p10p90"), ("p25–p75","p25p75"), ("min–max","minmax")]
    METHOD_OPTIONS = [
                        ("Channel net balance — (OUT−IN) − DR×SYLD (tons)", "A"),
                        ("Retained in reach — deposition = IN + DR×SYLD − OUT (tons)", "B"),
                        ("Hydrology-normalized — Δ_channel per discharge (mg/L)", "C1"),
                        ("Runoff-normalized — Δ_channel per runoff (kg/mm/ha)", "C2"),
     ]

    def __init__(self, df_rch, df_sub,df_flow=None):
        self.df_rch = df_rch
        self.df_sub = df_sub
        self.df_flow_raw = df_flow

        # defaults
        self.method = W.Dropdown(options=self.METHOD_OPTIONS, value="A", description="Method")
        self.DR = W.FloatSlider(value=1.0, min=0.0, max=1.0, step=0.05, description="DR", readout_format=".2f", continuous_update=False)

        # we build an initial metric table to populate reach/year lists
        base = build_metric_table(self.df_rch, self.df_sub, True, self.method.value, self.DR.value)
        self.all_reaches = sorted(base["RCH"].unique().tolist())
        self.all_years   = sorted(base["year"].unique().tolist())

        self.plot_type = W.Dropdown(options=self.PLOT_TYPES, value="Time series", description="Plot")
        self.reach_sel = W.SelectMultiple(options=self.all_reaches, value=tuple(self.all_reaches[:3]), description="Reaches", rows=7)
        self.year_sel  = W.SelectMultiple(options=self.all_years, value=tuple(self.all_years), description="Years", rows=7)
        self.agg = W.ToggleButtons(options=["Mean","Median"], value="Mean", description="Aggregate")
        self.roll = W.IntSlider(value=7, min=1, max=60, step=1, description="Rolling (d)", continuous_update=False)
        self.drop_feb = W.Checkbox(value=True, description="Drop Feb 29")
        self.band_sel = W.SelectMultiple(options=[lbl for lbl,_ in self.BAND_OPTIONS],
                                         value=("p10–p90","p25–p75"), rows=4, description="Bands")
        self.out = W.Output()

        # Wire
        for w in [self.method, self.DR, self.plot_type, self.reach_sel, self.year_sel, self.agg, self.roll, self.drop_feb, self.band_sel]:
            w.observe(self.render, names="value")

        # Initial render
        self.render(None)

    def _band_keys(self):
        labels = set(self.band_sel.value)
        return tuple(dict(self.BAND_OPTIONS)[lbl] for lbl in labels if lbl in dict(self.BAND_OPTIONS))

    def widget(self):
        return W.VBox([
            W.HBox([self.plot_type, self.method, self.DR, self.agg, self.roll]),
            W.HBox([self.reach_sel, self.year_sel, W.VBox([self.drop_feb, self.band_sel])]),
            self.out
        ])

    def render(self, _):
        with self.out:
            self.out.clear_output(wait=True)

            # Build table of the chosen method (contains y-label in attrs)
            subm = build_metric_table(self.df_rch, self.df_sub, self.drop_feb.value, self.method.value, self.DR.value)
            ylab = subm.attrs.get("ylab", "Metric")
            reaches = list(self.reach_sel.value) or sorted(subm["RCH"].unique())
            years   = list(self.year_sel.value)  or sorted(subm["year"].unique())
            agg_name = self.agg.value
            roll = int(self.roll.value)
            bands = self._band_keys()

            if self.plot_type.value == "Time series":
                stats = set(bands)
                data = compose_time_aggregate(subm, reaches, years, agg_name, stats, roll)
                fig = plot_time_series(data, "Time series — aggregated over selected reaches", ylab, bands)

            elif self.plot_type.value == "DOY climatology":
                data = compose_climatology(subm, reaches, years, agg_name, set(bands), roll)
                fig = plot_climatology(data, "DOY climatology — selected reaches & years", ylab, bands)

            elif self.plot_type.value == "All years overlay":
                lines = compose_all_years_overlay(subm, reaches, agg_name, roll)
                fig = plot_all_years_overlay(lines, "All years overlay (one line per year)", ylab)

            elif self.plot_type.value == "Year × Reach bars":
                df_yr = compose_year_reach(subm, reaches, years)
                fig = plot_year_reach_bars(df_yr, "Year × Reach mean (center metric)", ylab)

            elif self.plot_type.value == "Across-years avg per reach (DOY)":
                fig = go.Figure()
                s = subm[subm["RCH"].isin(reaches)]
                if years: s = s[s["year"].isin(years)]
                for r in sorted(set(reaches)):
                    rsub = s[s["RCH"]==r]
                    line = rsub.groupby("doy")["metric"].agg(agg_key(agg_name)).reset_index().sort_values("doy")
                    line["metric"] = rolling_apply(line["metric"].values, roll, agg_key(agg_name))
                    fig.add_trace(go.Scatter(x=line["doy"], y=line["metric"], mode="lines", name=f"R{r}"))
                fig.add_hline(y=0, line_dash="dash", opacity=0.5)
                fig.update_layout(title=f"Across-years DOY average per reach ({agg_name})",
                                  xaxis_title="Day of Year (1–365)", yaxis_title=ylab,
                                  legend_title=None, margin=dict(l=40,r=20,t=50,b=40))

            elif self.plot_type.value == "Across-years avg (all reaches, DOY)":
                line = compose_across_years_all_reaches(subm[subm["year"].isin(years)], reaches, agg_name, roll)
                fig = go.Figure([go.Scatter(x=line["doy"], y=line["center"], mode="lines", name=f"{agg_name}")])
                fig.add_hline(y=0, line_dash="dash", opacity=0.5)
                fig.update_layout(title=f"Across-years DOY average (all selected reaches) — {agg_name}",
                                  xaxis_title="Day of Year (1–365)", yaxis_title=ylab,
                                  margin=dict(l=40,r=20,t=50,b=40))
            else:
                fig = go.Figure()

            fig.show()


In [None]:
from plotly.subplots import make_subplots  # <-- add at top with other imports

# ---- Flow helpers (Cubillas) ----
def prep_flow(df_flow: pd.DataFrame, drop_feb29: bool) -> pd.DataFrame:
    df = df_flow.rename(columns={"date":"date", "water_flow_m3_d_cubillas":"flow"}).copy()
    df["date"] = to_datetime(df["date"])
    if drop_feb29:
        m = df["date"].dt.month; d = df["date"].dt.day
        df = df[~((m==2) & (d==29))]
    df["year"] = df["date"].dt.year
    df["doy"]  = df["date"].dt.dayofyear
    return df[["date","year","doy","flow"]].sort_values("date").reset_index(drop=True)

def compose_time_aggregate_flow(df_flow_prep: pd.DataFrame, years, agg_name: str, roll_win: int):
    s = df_flow_prep.copy()
    if years: s = s[s["year"].isin(years)]
    g = s.groupby("date", as_index=False)["flow"].agg(agg_key(agg_name)).rename(columns={"flow":"center"})
    g = g.sort_values("date")
    g["center"] = rolling_apply(g["center"].values, roll_win, agg_key(agg_name))
    return g

def compose_climatology_flow(df_flow_prep: pd.DataFrame, years, agg_name: str, roll_win: int):
    s = df_flow_prep.copy()
    if years: s = s[s["year"].isin(years)]
    g = s.groupby("doy", as_index=False)["flow"].agg(agg_key(agg_name)).rename(columns={"flow":"center"})
    g = g.sort_values("doy")
    g["center"] = rolling_apply(g["center"].values, roll_win, agg_key(agg_name))
    return g

def compose_all_years_overlay_flow(df_flow_prep: pd.DataFrame, agg_name: str, roll_win: int):
    lines = {}
    for y in sorted(df_flow_prep["year"].unique()):
        d = df_flow_prep[df_flow_prep["year"]==y][["date","flow"]].rename(columns={"flow":"metric"}).copy()
        d = d.groupby("date", as_index=False)["metric"].agg(agg_key(agg_name))
        d = d.sort_values("date")
        d["metric"] = rolling_apply(d["metric"].values, roll_win, agg_key(agg_name))
        lines[y] = d
    return lines


In [None]:
# Sediment / TP / TN Dynamics Dashboard — methods A/B/C + Constituent toggle
# ----------------------------------------------------------------------------
# Methods (one time series per reach):
#  A: Δ_channel = (OUT − IN) − DR * local_yield
#  B: Retained  = (IN + DR * local_yield) − OUT  (= −Δ_channel)
#  C1: Hydrology-normalized Δ_channel per discharge [mg/L]  (uses FLOW_OUTcms)
#  C2: Hydrology-normalized Δ_channel per runoff    [kg/mm/ha] (uses WYLD + area)
#
# Constituents:
#  - Sediment (tons): uses SED_INtons / SED_OUTtons; local_yield = SYLD (tons/ha) × area_ha × DR_sed
#  - Total Phosphorus (kg): OUT/IN = ORGP ± MINP; local_yield = (ORGP+SEDP)*DR_partic + SOLP*DR_diss (all kg/ha × area_ha)
#  - Total Nitrogen (kg):   OUT/IN = ORGN ± NO3 ± NH4 ± NO2; local_yield = ORGN*DR_partic + NSURQ*DR_diss (kg/ha × area_ha)

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import plotly.graph_objects as go
import ipywidgets as W
from IPython.display import display

# --------------- Helpers ---------------
def to_datetime(s): return pd.to_datetime(s)

def agg_key(name: str) -> str:
    return "mean" if name == "Mean" else "median"

def rolling_apply(x, win, reducer_key):
    if win is None or win <= 1:
        return np.asarray(x)
    f = np.nanmean if reducer_key == "mean" else np.nanmedian
    return pd.Series(x).rolling(win, min_periods=max(1, win // 2)).apply(lambda s: f(s.values), raw=False).values

def quantiles(a, qs):
    a = np.asarray(a, dtype=float); a = a[~np.isnan(a)]
    if a.size == 0: return {q: np.nan for q in qs}
    qq = np.quantile(a, qs)
    return {q: float(v) for q, v in zip(qs, qq)}

def band_fill(fig, x, high, low, name, opacity=0.2):
    fig.add_trace(go.Scatter(x=x, y=high, mode="lines", line=dict(width=0), showlegend=False, hoverinfo="skip"))
    fig.add_trace(go.Scatter(x=x, y=low,  mode="lines", line=dict(width=0), fill="tonexty", name=name, opacity=opacity))

def _sum_cols(df, cols):
    cols = [c for c in cols if c in df.columns]
    if not cols: return pd.Series(0.0, index=df.index)
    return df[cols].fillna(0).sum(axis=1)

# ------------- Method engine (by constituent) -------------
def build_metric_table(
    df_rch: pd.DataFrame,
    df_sub: pd.DataFrame,
    drop_feb29: bool,
    method: str,
    constituent: str,         # "SED" | "TP" | "TN"
    DR_sed: float = 1.0,      # for sediment
    DR_partic: float = 0.7,   # for particulate (ORGN, ORGP, SEDP)
    DR_diss: float = 0.95,    # for dissolved (NO3, NH4, NO2, SOLP, NSURQ)
    # column mappings (keep your names)
    cols_rch=dict(
        date="date", reach="RCH",
        sed_in="SED_INtons", sed_out="SED_OUTtons", flow_out="FLOW_OUTcms",
        # P (kg)
        ORGP_IN="ORGP_INkg", ORGP_OUT="ORGP_OUTkg",
        MINP_IN="MINP_INkg", MINP_OUT="MINP_OUTkg",
        # N (kg)
        ORGN_IN="ORGN_INkg", ORGN_OUT="ORGN_OUTkg",
        NO3_IN="NO3_INkg",  NO3_OUT="NO3_OUTkg",
        NH4_IN="NH4_INkg",  NH4_OUT="NH4_OUTkg",
        NO2_IN="NO2_INkg",  NO2_OUT="NO2_OUTkg",
    ),
    cols_sub=dict(
        date="date", sub="SUB", area_km2="AREA", wyld_mm="WYLD",
        syld_t_ha="SYLD",       # sediment (tons/ha)
        ORGP="ORGP", SOLP="SOLP", SEDP="SEDP",  # P (kg/ha)
        ORGN="ORGN", NSURQ="NSURQ",             # N (kg/ha)
    ),
) -> pd.DataFrame:
    """
    Returns columns: date, RCH, year, doy, metric; y-label in out.attrs['ylab'].
    """
    # ---- rename minimally
    r = df_rch.rename(columns={
        cols_rch["date"]:"date", cols_rch["reach"]:"RCH",
        cols_rch["sed_in"]:"SED_IN", cols_rch["sed_out"]:"SED_OUT", cols_rch["flow_out"]:"FLOW_OUTcms",
        # P
        cols_rch.get("ORGP_IN","ORGP_INkg"):"ORGP_INkg", cols_rch.get("ORGP_OUT","ORGP_OUTkg"):"ORGP_OUTkg",
        cols_rch.get("MINP_IN","MINP_INkg"):"MINP_INkg", cols_rch.get("MINP_OUT","MINP_OUTkg"):"MINP_OUTkg",
        # N
        cols_rch.get("ORGN_IN","ORGN_INkg"):"ORGN_INkg", cols_rch.get("ORGN_OUT","ORGN_OUTkg"):"ORGN_OUTkg",
        cols_rch.get("NO3_IN","NO3_INkg"):"NO3_INkg",   cols_rch.get("NO3_OUT","NO3_OUTkg"):"NO3_OUTkg",
        cols_rch.get("NH4_IN","NH4_INkg"):"NH4_INkg",   cols_rch.get("NH4_OUT","NH4_OUTkg"):"NH4_OUTkg",
        cols_rch.get("NO2_IN","NO2_INkg"):"NO2_INkg",   cols_rch.get("NO2_OUT","NO2_OUTkg"):"NO2_OUTkg",
    }).copy()
    s = df_sub.rename(columns={
        cols_sub["date"]:"date", cols_sub["sub"]:"SUB", cols_sub["area_km2"]:"SUB_AREA_KM2",
        cols_sub["wyld_mm"]:"WYLD_MM", cols_sub["syld_t_ha"]:"SYLD_T_HA",
        cols_sub.get("ORGP","ORGP"):"ORGP", cols_sub.get("SOLP","SOLP"):"SOLP", cols_sub.get("SEDP","SEDP"):"SEDP",
        cols_sub.get("ORGN","ORGN"):"ORGN", cols_sub.get("NSURQ","NSURQ"):"NSURQ",
    }).copy()

    r["date"] = to_datetime(r["date"]); s["date"] = to_datetime(s["date"])

    # select needed rch columns by constituent
    base_cols = ["date","RCH","FLOW_OUTcms"]
    if constituent == "SED":
        need_r = base_cols + ["SED_IN","SED_OUT"]
    elif constituent == "TP":
        need_r = base_cols + ["ORGP_INkg","MINP_INkg","ORGP_OUTkg","MINP_OUTkg"]
    else:  # "TN"
        need_r = base_cols + ["ORGN_INkg","NO3_INkg","NH4_INkg","NO2_INkg",
                               "ORGN_OUTkg","NO3_OUTkg","NH4_OUTkg","NO2_OUTkg"]

    need_r = [c for c in need_r if c in r.columns]
    r_use = r[need_r].copy()

    s_use = s[["date","SUB","SUB_AREA_KM2","WYLD_MM","SYLD_T_HA","ORGP","SOLP","SEDP","ORGN","NSURQ"]].copy()

    # merge reach & sub
    m = pd.merge(r_use, s_use, left_on=["date","RCH"], right_on=["date","SUB"], how="inner").drop(columns=["SUB"])

    # area conversion (km² -> ha)
    m["SUB_AREA_HA"] = m["SUB_AREA_KM2"] * 100.0

    # build IN/OUT & local delivered per constituent
    if constituent == "SED":
        X_IN  = m["SED_IN"].astype(float)
        X_OUT = m["SED_OUT"].astype(float)
        local = DR_sed * (m["SYLD_T_HA"].astype(float) * m["SUB_AREA_HA"])  # tons
        unit = "tons"
        # Δ channel (tons)
        delta_channel = (X_OUT - X_IN) - local

        # method transforms
        if method.upper() == "A":
            metric = delta_channel; ylab = "Δ_channel (tons)"
        elif method.upper() == "B":
            metric = -delta_channel; ylab = "Retained (tons)"
        elif method.upper() == "C1":
            q = m["FLOW_OUTcms"].replace(0, np.nan).astype(float)
            metric = delta_channel * 11.574074 / q  # tons/day → mg/L approx
            ylab = "Δ_channel per discharge (mg/L)"
        elif method.upper() == "C2":
            denom = (m["WYLD_MM"] * m["SUB_AREA_HA"]).replace(0, np.nan)  # mm * ha
            metric = (delta_channel * 1000.0) / denom  # tons → kg; kg/(mm·ha)
            ylab = "Δ_channel per runoff (kg/mm/ha)"
        else:
            raise ValueError("method must be one of 'A','B','C1','C2'")

    elif constituent == "TP":
        # IN/OUT in kg
        IN_cols  = ["ORGP_INkg","MINP_INkg"]
        OUT_cols = ["ORGP_OUTkg","MINP_OUTkg"]
        X_IN  = _sum_cols(m, IN_cols).astype(float)
        X_OUT = _sum_cols(m, OUT_cols).astype(float)
        # local delivered in kg: particulate (ORGP+SEDP) × DR_partic + dissolved (SOLP) × DR_diss
        local = ((m["ORGP"].fillna(0) + m["SEDP"].fillna(0)) * DR_partic +
                  m["SOLP"].fillna(0) * DR_diss) * m["SUB_AREA_HA"]
        delta_channel = (X_OUT - X_IN) - local
        unit = "kg"

        if method.upper() == "A":
            metric = delta_channel; ylab = "Δ_channel TP (kg)"
        elif method.upper() == "B":
            metric = -delta_channel; ylab = "Retained TP (kg)"
        elif method.upper() == "C1":
            q = m["FLOW_OUTcms"].replace(0, np.nan).astype(float)
            metric = delta_channel * 0.011574074 / q   # kg/day → mg/L
            ylab = "Δ_channel TP per discharge (mg/L)"
        elif method.upper() == "C2":
            denom = (m["WYLD_MM"] * m["SUB_AREA_HA"]).replace(0, np.nan)
            metric = delta_channel / denom  # kg/(mm·ha)
            ylab = "Δ_channel TP per runoff (kg/mm/ha)"
        else:
            raise ValueError("method must be one of 'A','B','C1','C2'")

    else:  # "TN"
        IN_cols  = ["ORGN_INkg","NO3_INkg","NH4_INkg","NO2_INkg"]
        OUT_cols = ["ORGN_OUTkg","NO3_OUTkg","NH4_OUTkg","NO2_OUTkg"]
        X_IN  = _sum_cols(m, IN_cols).astype(float)
        X_OUT = _sum_cols(m, OUT_cols).astype(float)
        # local delivered in kg: ORGN × DR_partic + NSURQ × DR_diss
        local = (m["ORGN"].fillna(0) * DR_partic + m["NSURQ"].fillna(0) * DR_diss) * m["SUB_AREA_HA"]
        delta_channel = (X_OUT - X_IN) - local
        unit = "kg"

        if method.upper() == "A":
            metric = delta_channel; ylab = "Δ_channel TN (kg)"
        elif method.upper() == "B":
            metric = -delta_channel; ylab = "Retained TN (kg)"
        elif method.upper() == "C1":
            q = m["FLOW_OUTcms"].replace(0, np.nan).astype(float)
            metric = delta_channel * 0.011574074 / q   # kg/day → mg/L
            ylab = "Δ_channel TN per discharge (mg/L)"
        elif method.upper() == "C2":
            denom = (m["WYLD_MM"] * m["SUB_AREA_HA"]).replace(0, np.nan)
            metric = delta_channel / denom  # kg/(mm·ha)
            ylab = "Δ_channel TN per runoff (kg/mm/ha)"
        else:
            raise ValueError("method must be one of 'A','B','C1','C2'")

    out = m[["date","RCH"]].copy()
    out["metric"] = metric.values
    out["year"] = pd.to_datetime(out["date"]).dt.year
    out["month"] = pd.to_datetime(out["date"]).dt.month
    out["day"] = pd.to_datetime(out["date"]).dt.day
    if drop_feb29:
        out = out[~((out["month"]==2) & (out["day"]==29))]
    out["doy"] = pd.to_datetime(out["date"]).dt.dayofyear
    out.drop(columns=["month","day"], inplace=True)
    out.attrs["ylab"] = ylab
    out.attrs["unit"] = unit
    return out.reset_index(drop=True)

# ------------- Aggregators (re-used) -------------
def compose_time_aggregate(subm: pd.DataFrame, reaches, years, agg_name: str, stats, roll_win: int):
    s = subm[subm["RCH"].isin(reaches)]
    if years: s = s[s["year"].isin(years)]
    if len(reaches) > 1:
        grouped = s.groupby("date", as_index=False)["metric"].agg(agg_key(agg_name)).rename(columns={"metric":"center"})
    else:
        grouped = s.groupby("date", as_index=False)["metric"].agg("mean").rename(columns={"metric":"center"})
    if len(reaches) > 1:
        tmp = s.groupby(["date","RCH"])["metric"].agg("mean").reset_index()
        rows=[]
        for d,g in tmp.groupby("date"):
            vals=g["metric"].values; row={"date":d}
            if "p10p90" in stats: row.update(quantiles(vals,[0.10,0.90]))
            if "p25p75" in stats:
                q=quantiles(vals,[0.25,0.75]); row.update({"p25":q[0.25],"p75":q[0.75]})
            if "minmax" in stats and len(vals):
                row.update({"vmin":float(np.nanmin(vals)),"vmax":float(np.nanmax(vals))})
            rows.append(row)
        grouped = grouped.merge(pd.DataFrame(rows), on="date", how="left")
    grouped = grouped.sort_values("date")
    reducer = agg_key(agg_name)
    grouped["center"] = rolling_apply(grouped["center"].values, roll_win, reducer)
    for col in ["p10","p90","p25","p75","vmin","vmax"]:
        if col in grouped.columns:
            grouped[col] = rolling_apply(grouped[col].values, roll_win, "mean")
    return grouped

def compose_climatology(subm: pd.DataFrame, reaches, years, agg_name: str, stats, roll_win: int):
    s = subm[subm["RCH"].isin(reaches)]
    if years: s = s[s["year"].isin(years)]
    center = s.groupby("doy")["metric"].agg(agg_key(agg_name)).rename("center").reset_index()
    tmp = s.groupby(["doy","year","RCH"])["metric"].agg("mean").reset_index()
    rows=[]
    for d,g in tmp.groupby("doy"):
        vals=g["metric"].values; row={"doy":d}
        if "p10p90" in stats: row.update(quantiles(vals,[0.10,0.90]))
        if "p25p75" in stats:
            q=quantiles(vals,[0.25,0.75]); row.update({"p25":q[0.25],"p75":q[0.75]})
        if "minmax" in stats and len(vals):
            row.update({"vmin":float(np.nanmin(vals)),"vmax":float(np.nanmax(vals))})
        rows.append(row)
    out = center.merge(pd.DataFrame(rows), on="doy", how="left").sort_values("doy")
    reducer = agg_key(agg_name)
    out["center"] = rolling_apply(out["center"].values, roll_win, reducer)
    for col in ["p10","p90","p25","p75","vmin","vmax"]:
        if col in out.columns:
            out[col] = rolling_apply(out[col].values, roll_win, "mean")
    return out

def compose_all_years_overlay(subm: pd.DataFrame, reaches, agg_name: str, roll_win: int):
    s = subm[subm["RCH"].isin(reaches)]
    years = sorted(s["year"].unique()); lines={}
    for y in years:
        ydf = s[s["year"]==y]
        if len(reaches)>1:
            daily = ydf.groupby("date")["metric"].agg(agg_key(agg_name)).reset_index()
        else:
            daily = ydf.groupby("date")["metric"].agg("mean").reset_index()
        daily = daily.sort_values("date")
        daily["metric"] = rolling_apply(daily["metric"].values, roll_win, agg_key(agg_name))
        lines[y] = daily
    return lines

def compose_year_reach(subm: pd.DataFrame, reaches, years):
    s = subm[subm["RCH"].isin(reaches)]
    if years: s = s[s["year"].isin(years)]
    return s.groupby(["year","RCH"])["metric"].agg("mean").reset_index().rename(columns={"metric":"delta"})

def compose_across_years_all_reaches(subm: pd.DataFrame, reaches, agg_name: str, roll_win: int):
    s = subm[subm["RCH"].isin(reaches)] if reaches else subm.copy()
    line = s.groupby("doy")["metric"].agg(agg_key(agg_name)).reset_index().sort_values("doy")
    line["center"] = rolling_apply(line["metric"].values, roll, agg_key(agg_name))
    return line[["doy","center"]]

# ------------- Plotters -------------
def plot_time_series(df_stats, title, ylab, bands=("p10p90","p25p75","minmax"),
                     flow_stats=None, y2lab="Flow (m³/d)"):
    has_flow = flow_stats is not None and len(flow_stats)
    fig = make_subplots(specs=[[{"secondary_y": has_flow}]]) if has_flow else go.Figure()
    x = df_stats["date"]

    if "minmax" in bands and {"vmin","vmax"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["vmax"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["vmin"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="min–max", opacity=0.15), secondary_y=False if has_flow else None)

    if "p10p90" in bands and {"p10","p90"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["p90"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["p10"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="p10–p90", opacity=0.20), secondary_y=False if has_flow else None)

    if "p25p75" in bands and {"p25","p75"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["p75"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["p25"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="p25–p75", opacity=0.30), secondary_y=False if has_flow else None)

    fig.add_trace(go.Scatter(x=x, y=df_stats["center"], mode="lines", name="Center"),
                  secondary_y=False if has_flow else None)

    if has_flow:
        fig.add_trace(go.Scatter(x=flow_stats["date"], y=flow_stats["center"], mode="lines",
                                 name="Flow", line=dict(dash="dot")), secondary_y=True)
        fig.update_yaxes(title_text=y2lab, secondary_y=True)

    fig.add_hline(y=0, line_dash="dash", opacity=0.5, secondary_y=False if has_flow else None)
    fig.update_layout(title=title, xaxis_title="Date", yaxis_title=ylab, legend_title=None,
                      margin=dict(l=40,r=20,t=50,b=40))
    return fig

def plot_climatology(df_stats, title, ylab, bands=("p10p90","p25p75","minmax"),
                     flow_stats=None, y2lab="Flow (m³/d)"):
    has_flow = flow_stats is not None and len(flow_stats)
    fig = make_subplots(specs=[[{"secondary_y": has_flow}]]) if has_flow else go.Figure()
    x = df_stats["doy"]

    if "minmax" in bands and {"vmin","vmax"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["vmax"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["vmin"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="min–max", opacity=0.15), secondary_y=False if has_flow else None)
    if "p10p90" in bands and {"p10","p90"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["p90"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["p10"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="p10–p90", opacity=0.20), secondary_y=False if has_flow else None)
    if "p25p75" in bands and {"p25","p75"}.issubset(df_stats.columns):
        fig.add_trace(go.Scatter(x=x, y=df_stats["p75"], mode="lines", line=dict(width=0),
                                 showlegend=False, hoverinfo="skip"), secondary_y=False if has_flow else None)
        fig.add_trace(go.Scatter(x=x, y=df_stats["p25"], mode="lines", line=dict(width=0), fill="tonexty",
                                 name="p25–p75", opacity=0.30), secondary_y=False if has_flow else None)

    fig.add_trace(go.Scatter(x=x, y=df_stats["center"], mode="lines", name="Center"),
                  secondary_y=False if has_flow else None)

    if has_flow:
        fig.add_trace(go.Scatter(x=flow_stats["doy"], y=flow_stats["center"], mode="lines",
                                 name="Flow", line=dict(dash="dot")), secondary_y=True)
        fig.update_yaxes(title_text=y2lab, secondary_y=True)

    fig.add_hline(y=0, line_dash="dash", opacity=0.5, secondary_y=False if has_flow else None)
    fig.update_layout(title=title, xaxis_title="Day of Year (1–365)", yaxis_title=ylab, legend_title=None,
                      margin=dict(l=40,r=20,t=50,b=40))
    return fig

def plot_all_years_overlay(lines_dict, title, ylab, flow_lines=None, y2lab="Flow (m³/d)"):
    has_flow = isinstance(flow_lines, dict) and len(flow_lines)
    fig = make_subplots(specs=[[{"secondary_y": has_flow}]]) if has_flow else go.Figure()

    for y, d in lines_dict.items():
        fig.add_trace(go.Scatter(x=d["date"], y=d["metric"], mode="lines", name=str(y)),
                      secondary_y=False if has_flow else None)

    if has_flow:
        for y, d in flow_lines.items():
            fig.add_trace(go.Scatter(x=d["date"], y=d["metric"], mode="lines",
                                     name=f"Flow {y}", line=dict(dash="dot")), secondary_y=True)
        fig.update_yaxes(title_text=y2lab, secondary_y=True)

    fig.add_hline(y=0, line_dash="dash", opacity=0.5, secondary_y=False if has_flow else None)
    fig.update_layout(title=title, xaxis_title="Date", yaxis_title=ylab, legend_title="Year",
                      margin=dict(l=40,r=20,t=50,b=40))
    return fig


# ------------- UI Panel -------------
class ChartPanel:
    PLOT_TYPES = [
        "Time series",
        "DOY climatology",
        "All years overlay",
        "Year × Reach bars",
        "Across-years avg per reach (DOY)",
        "Across-years avg (all reaches, DOY)",
    ]
    BAND_OPTIONS = [("p10–p90","p10p90"), ("p25–p75","p25p75"), ("min–max","minmax")]
    METHOD_OPTIONS = [
        ("Channel net balance — (OUT−IN) − DR×local", "A"),
        ("Retained in reach — (IN + DR×local − OUT)", "B"),
        ("Hydrology-normalized — per discharge (mg/L)", "C1"),
        ("Runoff-normalized — per runoff (kg/mm/ha)", "C2"),
    ]
    CONSTITUENTS = [("Sediment", "SED"), ("Total Phosphorus (TP)", "TP"), ("Total Nitrogen (TN)", "TN")]

    def __init__(self, df_rch, df_sub, df_flow=None):
        self.df_rch = df_rch
        self.df_sub = df_sub
        self.df_flow_raw = df_flow

        #df_water_flow_m3_d_cubillas = pd.read_csv(r"C:\Users\Usuario\OneDrive - UNIVERSIDAD DE HUELVA\Granada\TrabajoFM\Genil GEO_INFO_POOL\Data Zip inicial Francisco\CHGxSAIH\Embalses\E45SAIHInflowQR_most_current_Francisco\E45SAIHInflowQR.csv", index_col=False)

        # controls
        self.constit = W.Dropdown(options=self.CONSTITUENTS, value="SED", description="Constituent")
        self.method  = W.Dropdown(options=self.METHOD_OPTIONS, value="A", description="Method")
        self.DR_sed = W.FloatSlider(value=1.0, min=0.0, max=1.0, step=0.05, description="DR_sed", readout_format=".2f", continuous_update=False)
        self.DR_partic = W.FloatSlider(value=0.7, min=0.0, max=1.0, step=0.05, description="DR_partic", readout_format=".2f", continuous_update=False)
        self.DR_diss   = W.FloatSlider(value=0.95, min=0.0, max=1.0, step=0.05, description="DR_diss", readout_format=".2f", continuous_update=False)

        # initial metric table (for reach/year options)
        base = build_metric_table(self.df_rch, self.df_sub, True, self.method.value, self.constit.value,
                                  DR_sed=self.DR_sed.value, DR_partic=self.DR_partic.value, DR_diss=self.DR_diss.value)
        self.all_reaches = sorted(base["RCH"].unique().tolist())
        self.all_years   = sorted(base["year"].unique().tolist())

        self.plot_type = W.Dropdown(options=self.PLOT_TYPES, value="Time series", description="Plot")
        self.reach_sel = W.SelectMultiple(options=self.all_reaches, value=tuple(self.all_reaches[:3]), description="Reaches", rows=7)
        self.year_sel  = W.SelectMultiple(options=self.all_years, value=tuple(self.all_years), description="Years", rows=7)
        self.agg = W.ToggleButtons(options=["Mean","Median"], value="Mean", description="Aggregate")
        self.roll = W.IntSlider(value=7, min=1, max=60, step=1, description="Rolling (d)", continuous_update=False)
        self.drop_feb = W.Checkbox(value=True, description="Drop Feb 29")
        self.band_sel = W.SelectMultiple(options=[lbl for lbl,_ in self.BAND_OPTIONS],
                                         value=("p10–p90","p25–p75"), rows=4, description="Bands")
        self.show_flow = W.Checkbox(value=False, description="Overlay flow (right y-axis)")
        self.out = W.Output()

        # wire up
        for w in [self.constit, self.method, self.DR_sed, self.DR_partic, self.DR_diss,
                  self.plot_type, self.reach_sel, self.year_sel, self.agg, self.roll, self.drop_feb, self.band_sel, self.show_flow]:
            w.observe(self.render, names="value")

        self.render(None)

    def _band_keys(self):
        labels = set(self.band_sel.value)
        return tuple(dict(self.BAND_OPTIONS)[lbl] for lbl in labels if lbl in dict(self.BAND_OPTIONS))

    def widget(self):
        return W.VBox([
            W.HBox([self.plot_type, self.constit, self.method]),
            W.HBox([self.DR_sed, self.DR_partic, self.DR_diss, self.agg, self.roll]),
            W.HBox([self.reach_sel, self.year_sel, W.VBox([self.drop_feb, self.band_sel, self.show_flow])]),
            self.out
        ])

    def render(self, _):
        with self.out:
            self.out.clear_output(wait=True)
            subm = build_metric_table(
                self.df_rch, self.df_sub,
                self.drop_feb.value, self.method.value, self.constit.value,
                DR_sed=self.DR_sed.value, DR_partic=self.DR_partic.value, DR_diss=self.DR_diss.value
            )
            ylab = subm.attrs.get("ylab", "Metric")
            reaches = list(self.reach_sel.value) or sorted(subm["RCH"].unique())
            years   = list(self.year_sel.value)  or sorted(subm["year"].unique())
            agg_name = self.agg.value
            roll = int(self.roll.value)
            bands = self._band_keys()

            # --- flow overlay prep (optional)
            flow_stats = None
            flow_lines = None
            if self.show_flow.value and self.df_flow_raw is not None:
                flow_prep = prep_flow(self.df_flow_raw, self.drop_feb.value)
                if self.plot_type.value == "Time series":
                    flow_stats = compose_time_aggregate_flow(flow_prep, years, agg_name, roll)
                elif self.plot_type.value == "DOY climatology":
                    flow_stats = compose_climatology_flow(flow_prep, years, agg_name, roll)
                elif self.plot_type.value == "All years overlay":
                    flow_lines = compose_all_years_overlay_flow(flow_prep, agg_name, roll)

            # --- plots
            if self.plot_type.value == "Time series":
                data = compose_time_aggregate(subm, reaches, years, agg_name, set(bands), roll)
                fig = plot_time_series(data, "Time series — aggregated over selected reaches",
                                       ylab, bands, flow_stats=flow_stats)

            elif self.plot_type.value == "DOY climatology":
                data = compose_climatology(subm, reaches, years, agg_name, set(bands), roll)
                fig = plot_climatology(data, "DOY climatology — selected reaches & years",
                                       ylab, bands, flow_stats=flow_stats)

            elif self.plot_type.value == "All years overlay":
                lines = compose_all_years_overlay(subm, reaches, agg_name, roll)
                fig = plot_all_years_overlay(lines, "All years overlay (one line per year)",
                                             ylab, flow_lines=flow_lines)

            elif self.plot_type.value == "Year × Reach bars":
                df_yr = compose_year_reach(subm, reaches, years)
                fig = plot_year_reach_bars(df_yr, "Year × Reach mean (center metric)", ylab)

            elif self.plot_type.value == "Across-years avg per reach (DOY)":
                fig = go.Figure()
                s = subm[subm["RCH"].isin(reaches)]
                if years: s = s[s["year"].isin(years)]
                for rch in sorted(set(reaches)):
                    rsub = s[s["RCH"]==rch]
                    line = rsub.groupby("doy")["metric"].agg(agg_key(agg_name)).reset_index().sort_values("doy")
                    line["metric"] = rolling_apply(line["metric"].values, roll, agg_key(agg_name))
                    fig.add_trace(go.Scatter(x=line["doy"], y=line["metric"], mode="lines", name=f"R{rch}"))
                fig.add_hline(y=0, line_dash="dash", opacity=0.5)
                fig.update_layout(title=f"Across-years DOY average per reach ({agg_name})",
                                  xaxis_title="Day of Year (1–365)", yaxis_title=ylab,
                                  legend_title=None, margin=dict(l=40,r=20,t=50,b=40))

            elif self.plot_type.value == "Across-years avg (all reaches, DOY)":
                line = compose_across_years_all_reaches(subm[subm["year"].isin(years)], reaches, agg_name, roll)
                fig = go.Figure([go.Scatter(x=line["doy"], y=line["center"], mode="lines", name=f"{agg_name}")])
                fig.add_hline(y=0, line_dash="dash", opacity=0.5)
                fig.update_layout(title=f"Across-years DOY average (all selected reaches) — {agg_name}",
                                  xaxis_title="Day of Year (1–365)", yaxis_title=ylab,
                                  margin=dict(l=40,r=20,t=50,b=40))
            else:
                fig = go.Figure()

            fig.show()




In [None]:
print(dict_155.keys())

In [None]:
from python_pipeline_scripts.sub_parser import parse_swat_sub_to_df
df_sed_rch = dict_155['rch_run000155_real000526_1']
#df_sed_rch = dict_145['rch_run000145_real000503_1']
df_sub_yld = parse_swat_sub_to_df(r"C:\SWAT\RSWAT\cubillas\mc_results\run000117_real000404_1\output.sub")


In [None]:
# Net sediment difference calculations (added cell)
# This cell derives two helper series from df_sed_rch if the needed columns exist:
#   SED_IN_MINUS_OUT = SED_IN - SED_OUT
#   SED_OUT_MINUS_IN = SED_OUT - SED_IN
# It supports either column naming convention: SED_INtons/SED_OUTtons or SED_IN/SED_OUT.
import pandas as _pd

# Identify candidate column names
_sed_in_cols = [c for c in df_sed_rch.columns if c.upper() in {"SED_INTONS", "SED_IN"}]
_sed_out_cols = [c for c in df_sed_rch.columns if c.upper() in {"SED_OUTTONS", "SED_OUT"}]

if _sed_in_cols and _sed_out_cols:
    sed_in_col = _sed_in_cols[0]
    sed_out_col = _sed_out_cols[0]
    # Ensure datetime index
    if not isinstance(df_sed_rch.index, _pd.DatetimeIndex):
        # Try common date column names
        for cand in ["date", "DATE", "Date", "time", "TIME"]:
            if cand in df_sed_rch.columns:
                df_sed_rch = df_sed_rch.set_index(_pd.to_datetime(df_sed_rch[cand]))
                break
    df_sed_rch["SED_IN_MINUS_OUT"] = df_sed_rch[sed_in_col] - df_sed_rch[sed_out_col]
    df_sed_rch["SED_OUT_MINUS_IN"] = df_sed_rch[sed_out_col] - df_sed_rch[sed_in_col]
    print(f"Added net sediment columns using {sed_in_col} and {sed_out_col} -> SED_IN_MINUS_OUT / SED_OUT_MINUS_IN")
else:
    print("Net sediment columns NOT added (required SED_IN*/SED_OUT* columns missing)")

In [None]:
# Monthly / Annual aggregations for sediment variables (including net differences)
import pandas as _pd

# Select the working dataframe
sed_df = df_sed_rch.copy()
# Ensure datetime index
if not isinstance(sed_df.index, _pd.DatetimeIndex):
    for cand in ["date", "DATE", "Date", "time", "TIME"]:
        if cand in sed_df.columns:
            sed_df = sed_df.set_index(_pd.to_datetime(sed_df[cand]))
            break

# Identify sediment related columns (original + derived)
_sed_cols = [c for c in sed_df.columns if c.upper().startswith("SED_")]
print(f"Sediment columns considered: {_sed_cols}")

# Build aggregations
monthly_sum = sed_df[_sed_cols].resample('M').sum(min_count=1)
monthly_mean = sed_df[_sed_cols].resample('M').mean()
annual_sum = sed_df[_sed_cols].resample('Y').sum(min_count=1)
annual_mean = sed_df[_sed_cols].resample('Y').mean()

print("Monthly sum head:\n", monthly_sum.head())
print("Annual sum head:\n", annual_sum.head())

In [None]:
# Simple interactive selection for sediment aggregation view
import ipywidgets as _w
import plotly.graph_objects as _go

sed_cols_all = [c for c in df_sed_rch.columns if c.upper().startswith('SED_')]
_default_var = 'SED_IN_MINUS_OUT' if 'SED_IN_MINUS_OUT' in sed_cols_all else sed_cols_all[0]

_dd_var = _w.Dropdown(options=sed_cols_all, value=_default_var, description='Var:')
_dd_freq = _w.Dropdown(options=['D','M','Y'], value='Y', description='Freq:')
_dd_method = _w.Dropdown(options=['sum','mean'], value='sum', description='Agg:')
_btn = _w.Button(description='Update', button_style='primary')
_out = _w.Output()

# Precompute daily frame with chosen columns only (copy to avoid SettingWithCopy)
_sed_base = df_sed_rch[sed_cols_all].copy()
if not isinstance(_sed_base.index, pd.DatetimeIndex):
    for cand in ['date','DATE','Date']:  # attempt to set date index
        if cand in df_sed_rch.columns:
            _sed_base = _sed_base.set_index(pd.to_datetime(df_sed_rch[cand]))
            break


def _agg(series: pd.Series, freq: str, how: str) -> pd.Series:
    if freq == 'D':
        return series
    rule = {'M':'M','Y':'Y'}[freq]
    if how == 'sum':
        return series.resample(rule).sum(min_count=1)
    elif how == 'mean':
        return series.resample(rule).mean()
    else:
        return series


def _update_view(*_):
    with _out:
        _out.clear_output()
        v = _dd_var.value
        f = _dd_freq.value
        h = _dd_method.value
        s = _sed_base[v].dropna()
        a = _agg(s, f, h)

        fig = _go.Figure()
        fig.add_trace(_go.Scatter(
            x=a.index, y=a.values, mode='lines+markers',
            marker=dict(size=5),
            line=dict(width=1.5),
            name=f'{v} ({h} {f})'
        ))
        fig.update_layout(
            height=330,
            width=640,              # narrower figure
            title='Cubillas sedimentation: Sed_in - Sed_out; yearly sum',
            margin=dict(l=60, r=10, t=50, b=40),
            yaxis_title='Sedimentation in tons',
            xaxis_title='Date'
        )
        fig.update_yaxes(automargin=True)
        display(fig)

# Narrow widget widths
for wdg, w in [(_dd_var, '170px'), (_dd_freq, '90px'), (_dd_method, '110px'), (_btn, '80px')]:
    wdg.layout.width = w

_btn.on_click(_update_view)

# Initial render
_update_view()

# Compact container (about half typical notebook width)
_w.VBox(
    [
        _w.HBox([_dd_var, _dd_freq, _dd_method, _btn],
                layout=_w.Layout(justify_content='flex-start', width='640px')),
        _out
    ],
    layout=_w.Layout(width='660px')  # overall box width
)

### Net sediment differences & aggregation
The previous cells added:
- `SED_IN_MINUS_OUT` = sediment entering minus sediment leaving the reach.
- `SED_OUT_MINUS_IN` = sediment leaving minus entering (mirror of the first).

We also produced monthly and annual aggregations (sum & mean). Use the interactive widget above to:
1. Pick a sediment (or derived) variable.
2. Choose frequency: Daily (`D`), Monthly (`M`), Yearly (`Y`).
3. Select aggregation: `sum` (load over period) or `mean` (average daily value in period).

Implementation notes:
- Supports both column name variants: `SED_INtons`/`SED_OUTtons` or `SED_IN`/`SED_OUT`.
- If required columns are missing, a message is printed and derived columns aren't added.
- Aggregations use pandas resample; sums require at least one non-NA (`min_count=1`).

You can now integrate this logic into other analysis cells or export the aggregated frames as needed.

In [None]:

panel1 = ChartPanel(df_sed_rch, df_sub_yld, df_flow=df_water_flow_m3_d_cubillas)
#panel2 = ChartPanel(df_sed_rch, df_sub_yld)
#panel3 = ChartPanel(df_sed_rch, df_sub_yld)
display(W.VBox([W.HTML("<h3>Δ Dynamics — Sediment / TP / TN</h3>"), panel1.widget(), W.HTML("<hr>")]))#, panel2.widget(), W.HTML("<hr>"), panel3.widget()]))

In [None]:
# Provide your two dataframes here:
# - df_sed_rch: SWAT .rch-like data (must include: date, RCH, SED_INtons, SED_OUTtons; optional: FLOW_OUT)
# - df_sub_yld: SWAT .sub-like data (must include: date, SUB, AREA (km2), SYLD (t/ha), WYLD (mm))




# Example wiring (adjust variable names to yours):
df_rch = df_sed_rch         # <-- replace with your reach dataframe variable
df_sub = df_sub_yld         # <-- replace with your subbasin dataframe variable

# If you don't have them yet, this creates a tiny synthetic example:
if "df_rch" not in locals() or "df_sub" not in locals():
    rng = pd.date_range("2018-01-01","2020-12-31",freq="D")
    reaches = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17]
    rec_r, rec_s = [], []
    for rch in reaches:
        IN = np.random.gamma(8, 15, len(rng))
        OUT = IN + np.sin(np.linspace(0, 18, len(rng))) * 30 + np.random.normal(0,15,len(rng))
        FLOW_OUT = np.random.uniform(3, 30, len(rng))
        rec_r.append(pd.DataFrame({"date": rng, "RCH": rch, "SED_INtons": IN.clip(0), "SED_OUTtons": OUT.clip(0), FLOW_OUTcms: FLOW_OUT}))
        area_km2 = 50 + 20*rch
        WYLD = np.random.uniform(0, 8, len(rng))
        SYLD_t_ha = np.random.gamma(3, 0.05, len(rng))
        rec_s.append(pd.DataFrame({"date": rng, "SUB": rch, "AREA": area_km2, "SYLD": SYLD_t_ha, "WYLD": WYLD}))
    df_rch = pd.concat(rec_r, ignore_index=True)
    df_sub = pd.concat(rec_s, ignore_index=True)

# Launch three independent panels
panel1 = ChartPanel(df_rch, df_sub)
panel2 = ChartPanel(df_rch, df_sub)
panel3 = ChartPanel(df_rch, df_sub)

display(W.VBox([
    W.HTML("<h3>Sediment Dynamics Dashboard — methods A/B/C with toggles</h3>"
           "<p>Δ_channel = (OUT − IN) − DR·SYLD_tons;  Retained = (IN + DR·SYLD_tons − OUT). "
           "C1: mg/L (uses FLOW_OUT). C2: kg/mm/ha (uses WYLD & area).</p>"),
    panel1.widget(),
    W.HTML("<hr>"),
    panel2.widget(),
    W.HTML("<hr>"),
    panel3.widget(),
]))




# 3 possible methods (one time series each)

## Method A — **Channel net balance corrected for local yield** (Δ\_channel)

**What it is:**
Net channel exchange (erosion/resuspension minus deposition) after removing the subbasin’s lateral supply.

$$
\boxed{
\Delta_\text{channel}(t) \;=\; \underbrace{\text{SED\_OUT}(t) - \text{SED\_IN}(t)}_{\text{net at reach}} \;-\; \underbrace{DR \cdot \text{SYLD\_tons}(t)}_{\text{local hillslope delivery}}
}
$$

* **Positive** ⇒ channel is a **source** (net erosion/resuspension that adds sediment).
* **Negative** ⇒ channel is a **sink** (net deposition/trapping).
* Units: **tons / time-step** (daily if daily outputs).
* $DR$ = delivery ratio (0–1) for how much of the subbasin yield actually makes it into the channel segment during the step. If you don’t estimate it, start with **DR=1** (upper bound).

---

## Method B — **Retained mass in reach** (deposition-focused)

**What it is:**
How much mass is kept in the reach after accounting for what came in (upstream + local) vs what left.

$$
\boxed{
\text{Retained}(t) \;=\; \underbrace{\big(\text{SED\_IN}(t) + DR \cdot \text{SYLD\_tons}(t)\big) - \text{SED\_OUT}(t)}_{\text{positive = deposition, negative = net export}}
}
$$

* **Positive** ⇒ **deposition/retention** that day;
* **Negative** ⇒ **net export** that day.
* Units: **tons / time-step**.
* This is simply the sign-flipped form of Method A: $\text{Retained} = -\Delta_\text{channel}$. It’s often easier to interpret when you care about **sedimentation**.

---

## Method C — **Hydrologically normalized channel signal**

**What it is:**
Normalize the channel net signal for hydrologic forcing so you can compare wet vs. dry periods and different-size basins. Two common normalizations; pick one depending on your question:

### (C1) Concentration-like normalization (per discharge)

$$
\boxed{
\Delta_\text{chan\_conc}(t)\;[\mathrm{mg/L}] \;\approx\; 
\frac{10^9\cdot\big(\text{SED\_OUT}-\text{SED\_IN}-DR\cdot \text{SYLD\_tons}\big)}
{ \text{FLOW\_OUT}(t)\,[\mathrm{m^3/s}] \cdot 86{,}400 \cdot 10^3 }
}
$$

(Practically: $\text{tons/day} \times 11.574074 / Q_\text{out,cms}$.)

* Tells you how much **excess channel signal per unit water** you have (mg/L).
* Use when comparing periods/places with different flows.

### (C2) Runoff normalization (per mm-runoff per ha)

$$
\boxed{
\Delta_\text{chan,\,kg/mm/ha}(t) \;=\;
\frac{\Delta_\text{channel}(t)\,[\text{tons}] \times 1000}{\text{WYLD}(t)\,[\text{mm}] \times \text{Area}_\text{sub}(t)\,[\text{ha}]}
}
$$

* Interpretable as **channel net mass per unit hydrologic production** from that subbasin.
* Great to remove the “it was a wet year” effect.

> You can compute **either C1 or C2** (or both). They each produce a single time series.

---


# Transparent explanations (what you’re actually measuring)

### Method A — Δ\_channel (tons/time)

* **Goal:** isolate the **channel** as a control volume.
* **Why subtract subbasin SYLD?** The raw `OUT−IN` includes hillslope + channel. Removing delivered local load aims to leave **net channel processes** (bed/bank erosion vs. deposition).
* **Delivery ratio $DR$:** If you don’t know it, **DR=1** is conservative (maximizes local delivery). If you later build a simple upstream accounting, estimate $DR$ empirically (see below).
* **Use when:** you want a direct, mass-based answer to “is this reach eroding or depositing over time?”

### Method B — Retained mass (tons/time)

* **Goal:** focus directly on **sedimentation/retention**.
* Positive values simply read as “this much sediment stayed in the reach this step.”
* **Pros:** very interpretable if you’re concerned with **sediment build-up** or **trap performance**;
  **Cons:** bounded only by inflow magnitude; can be volatile in events (use weekly/monthly smoothing if needed).

### Method C — Hydrologically normalized signal

* **Goal:** take out the effect of **varying water** so you can compare seasons/years or subbasins fairly.
* **C1 (mg/L):** treats the channel net term as a concentration-like metric per unit discharge (uses `FLOW_OUT`).

  * Good for diagnosing event-driven resuspension vs. baseline.
* **C2 (kg/mm/ha):** divides by subbasin **runoff production** (WYLD × area) — a *process-based* normalization.

  * Good for inter-year or inter-basin comparisons where hydrology differs.
* **Caveat:** normalizations move you away from strict mass balance; interpret as **indices** of channel behavior per unit water, not as total stored mass.

---

## Optional: estimating a better $DR$ later (if you get topology)

If you can map **upstream reaches** for each reach, estimate a **time-varying delivery ratio**:

1. Sum upstream `SED_OUT` into the target reach (per day).
2. The **observed lateral load** is $\max( \text{SED\_IN} - \sum \text{upstream SED\_OUT}, 0)$.
3. Then $ DR \approx \frac{\text{observed lateral load}}{\text{SYLD\_tons}}$ (clip to \[0,1], smooth by event/season).

You can pass that as `dr` (Series keyed by `(date, reach)`) to the same function above.

---

## What else to account for (quick checklist)

* **Storage features:** reservoirs/ponds/wetlands -> big effect on retention (consider separate treatment if present).
* **Timing/lag:** daily routing can shift peaks; exact day-by-day mass balance can be noisy (aggregate to weekly/monthly for interpretation).
* **Area scaling:** for cross-reach comparisons, use C2 or divide A/B by area (tons/km²/day).
* **Uncertainty:** show IQR/90% bands or confidence intervals over seasons/years.

If you share small slices of your `.rch`/`.sub` (columns + a few days), I can plug them into the functions and show you all three series overlayed for a sample reach.


In [None]:
#### old dashboard calls:


figs = build_sediment_dashboard(df_sed_rch, reach_numbers=[10,12,13,14, 15, 16, 17])    
figs["fig_avg_over_time"].show()
figs["fig_clim_aggregate"].show()
figs["fig_year_reach"].show()
#
# # per-reach quick views:
figs["per_reach_climatology"](reach=2).show()
figs["per_reach_full_series"](reach=2).show()

# Fan charts with optional spaghetti lines

In [None]:


def fan_compare_simulations_dashboard(
    sim_dfs: Dict[str, pd.DataFrame],
    variables: List[str],
    *,
    reach: Optional[int] = None,
    freq_options: Iterable[str] = ("D","W","M","A"),
    max_bin_size: int = 12,
    start: Optional[Union[str, datetime, date]] = None,
    end: Optional[Union[str, datetime, date]] = None,
    season_months: Optional[List[int]] = None,
    how_map_defaults: Optional[Dict[str, str]] = None,
    reach_col: str = "RCH",
    date_col: str = "date",
    flow_col: str = "FLOW_OUTcms",
    template: str = "plotly_white",
    figure_width: Optional[int] = 1200,
    figure_height: int = 650
):
    if how_map_defaults is None:
        how_map_defaults = {}

    # discover reaches
    all_reaches = set()
    number_of_simulations = 0
    for df in sim_dfs.values():
        if reach_col in df.columns:
            all_reaches.update(df[reach_col].dropna().unique().tolist())
        number_of_simulations += 1
    reach_choices = sorted(int(r) for r in all_reaches if pd.notna(r))
    if not reach_choices:
        raise ValueError("No reaches found.")
    if reach is None:
        reach = reach_choices[0]

    # widgets (no per-run checkboxes; too many runs)
    num_sim = widgets.HTML(value=f"Number of initializations: {number_of_simulations}")
    dd_var   = widgets.Dropdown(options=variables, value=variables[0], description="Variable:", layout=widgets.Layout(width="360px"))
    dd_reach = widgets.Dropdown(options=reach_choices, value=reach, description="Reach:", layout=widgets.Layout(width="180px"))
    dd_freq  = widgets.Dropdown(options=list(freq_options), value="D", description="Freq:", layout=widgets.Layout(width="140px"))
    sl_bin   = widgets.IntSlider(value=1, min=1, max=max_bin_size, step=1, description="Bin:", continuous_update=False, layout=widgets.Layout(width="300px"))
    dd_method = widgets.Dropdown(options=["sum","mean","flow_weighted_mean"], value="mean", description="Method:", layout=widgets.Layout(width="280px"))
    cb_autoscale_y_live = widgets.Checkbox(value=True, description="Auto-scale Y on zoom")
    cb_show_names_in_tooltip = widgets.Checkbox(value=False, description="Names in tooltip")
    

    def _default_method_for_var(v: str) -> str:
        if v in how_map_defaults:
            return how_map_defaults[v]
        if "Conc" in v or "mg/L" in v:
            return "flow_weighted_mean"
        if any(u in v.lower() for u in ["kg","tons","mg"]):
            return "sum"
        return "mean"
    dd_method.value = _default_method_for_var(dd_var.value)

    out = widgets.Output()
    _last = {"aligned_df": None, "y_fixed": None, "fig": None}

    TICK_STOPS = [
        dict(dtickrange=[None, 1000*60*60*24], value="%Y-%m-%d\n%H:%M"),
        dict(dtickrange=[1000*60*60*24, 1000*60*60*24*28], value="%Y-%m-%d"),
        dict(dtickrange=[1000*60*60*24*28, 1000*60*60*24*365], value="%Y-%m"),
        dict(dtickrange=[1000*60*60*24*365, None], value="%Y"),
    ]

    def _hovertemplate(show_name: bool) -> str:
        return ("%{fullData.name}: %{y:.4g}<extra></extra>" if show_name
                else "%{y:.4g}<extra></extra>")

    def _compute_and_plot():
        freq_str = _make_freq_string(dd_freq.value, sl_bin.value)
        var = dd_var.value
        method = dd_method.value

        # Extract a single resampled series per run for the selected reach/variable
        per_sim = {}
        for sim_name, df in sim_dfs.items():
            if var not in df.columns:
                continue
            sub = df[df[reach_col] == dd_reach.value][[date_col, var] + ([flow_col] if method == "flow_weighted_mean" else [])].copy()
            if sub.empty:
                continue
            sub = _ensure_dt_index(sub, date_col)
            if start or end:
                sub = _slice_time(sub, start, end)
            if season_months:
                sub = _filter_season(sub, season_months)
            if sub.empty:
                continue
            s = _resample_series(sub, var, freq=freq_str, how=method, flow_col=flow_col)
            if s.empty:
                continue
            s.name = sim_name
            per_sim[sim_name] = s

        if not per_sim:
            with out:
                clear_output(wait=True)
                print(f"No data for reach {dd_reach.value} and variable '{var}'.")
            return

        # Align series to a common time index (union) and build 2D matrix (T x N)
        aligned_df = pd.concat(per_sim.values(), axis=1).sort_index()
        aligned_df.index = pd.to_datetime(aligned_df.index, utc=False)
        arr = aligned_df.to_numpy(dtype=float)  # shape: (T, N)
        x_dt = aligned_df.index.to_pydatetime()
        _last["aligned_df"] = aligned_df

        # Compute quantiles across runs (ignore NaNs)
        percs = [5, 10, 25, 50, 75, 90, 95]
        if arr.shape[1] == 0:
            with out:
                clear_output(wait=True)
                print("No aligned data after resampling.")
            return
        qs = np.nanpercentile(arr, percs, axis=1)  # shape: (7, T)
        q = {p: qs[i, :] for i, p in enumerate(percs)}  # p -> array(T,)

        # Precompute y-range for fixed scaling
        finite_vals = arr[np.isfinite(arr)]
        if finite_vals.size:
            y_min = float(np.nanmin(finite_vals))
            y_max = float(np.nanmax(finite_vals))
            if y_min == y_max:
                y_max = y_min + 1.0
        else:
            y_min, y_max = 0.0, 1.0
        pad = (y_max - y_min) * 0.05
        _last["y_fixed"] = [y_min - pad, y_max + pad]

        # Build figure
        fig = go.FigureWidget(layout=dict(template=template))
        if figure_width is not None:
            fig.layout.width = int(figure_width)
        fig.layout.height = int(figure_height)

        # Fan chart: draw wider band first, then narrower, then median
        color = "#1f77b4"
        rgba = lambda a: f"rgba(31,119,180,{a})"

        # 90% band (p05..p95)
        fig.add_trace(go.Scatter(
            x=x_dt, y=q[95], mode="lines", line=dict(color=rgba(0.12), width=0.5),
            name="p95", showlegend=False, hoverinfo="skip"
        ))
        fig.add_trace(go.Scatter(
            x=x_dt, y=q[5], mode="lines", line=dict(color=rgba(0.12), width=0.5),
            fill="tonexty", fillcolor=rgba(0.12),
            name="p05–p95", showlegend=True, hoverinfo="skip"
        ))

        # 50% band (p25..p75)
        fig.add_trace(go.Scatter(
            x=x_dt, y=q[75], mode="lines", line=dict(color=rgba(0.28), width=0.5),
            name="p75", showlegend=False, hoverinfo="skip"
        ))
        fig.add_trace(go.Scatter(
            x=x_dt, y=q[25], mode="lines", line=dict(color=rgba(0.28), width=0.5),
            fill="tonexty", fillcolor=rgba(0.28),
            name="p25–p75", showlegend=True, hoverinfo="skip"
        ))

        # Median
        fig.add_trace(go.Scatter(
            x=x_dt, y=q[50], mode="lines", line=dict(color="black", width=2),
            name="median", hovertemplate="%{y:.4g}<extra></extra>"
        ))


        fig.update_layout(
            title=f"{var} — Reach {dd_reach.value} ({freq_str}, {method})",
            xaxis_title="Date", yaxis_title=var,
            hovermode="x unified",
            hoverlabel=dict(namelength=-1, align="left", font_size=12, bgcolor="white"),
            legend=dict(orientation="h", y=1.05, x=0),
            xaxis=dict(
                type="date",
                rangeslider=dict(visible=True),
                tickformatstops=TICK_STOPS
            ),
            margin=dict(l=60, r=20, t=110, b=50)
        )

        # fixed Y; optional live update on zoom
        fig.update_yaxes(autorange=False, range=_last["y_fixed"])
        _last["fig"] = fig

        def _on_xrange_change(layout, xrange):
            if _last["aligned_df"] is None:
                return
            if not cb_autoscale_y_live.value:
                fig.layout.yaxis.update(autorange=False, range=_last["y_fixed"])
                return
            try:
                x0 = pd.to_datetime(xrange[0]); x1 = pd.to_datetime(xrange[1])
            except Exception:
                return
            win = _last["aligned_df"].loc[( _last["aligned_df"].index >= x0) & (_last["aligned_df"].index <= x1)]
            if win.empty:
                return
            vals = win.to_numpy(dtype=float)
            vals = vals[np.isfinite(vals)]
            if vals.size == 0:
                return
            ymin = float(np.nanmin(vals)); ymax = float(np.nanmax(vals))
            if ymin == ymax: ymax = ymin + 1.0
            pad_local = (ymax - ymin) * 0.05
            fig.layout.yaxis.update(autorange=False, range=[ymin - pad_local, ymax + pad_local])

        fig.layout.xaxis.on_change(_on_xrange_change, 'range')

        with out:
            clear_output(wait=True)
            display(fig)

    # observers
    def _on_var_change(change):
        dd_method.value = _default_method_for_var(change["new"])
    dd_var.observe(_on_var_change, names="value")

    def _on_tooltip_toggle(change):
        if _last["fig"] is None:
            return
        ht = ("%{fullData.name}: %{y:.4g}<extra></extra>" if change["new"]
              else "%{y:.4g}<extra></extra>")


    cb_show_names_in_tooltip.observe(_on_tooltip_toggle, names="value")

    for w in [num_sim, dd_var, dd_reach, dd_freq, sl_bin, dd_method, cb_autoscale_y_live]:
        w.observe(lambda _: _compute_and_plot(), names="value")

    controls_left  = widgets.VBox([num_sim, dd_var, dd_method])
    controls_right = widgets.VBox([dd_reach, dd_freq, sl_bin, cb_autoscale_y_live, cb_show_names_in_tooltip])
    controls = widgets.HBox([controls_left, widgets.HBox([widgets.Label(""), controls_right])])
    display(controls, out)

    _compute_and_plot()


In [None]:
figs_fan1 = fan_compare_simulations_dashboard(
    dfs_mc_run_x,
    vars_to_compare,
    reach=15,
    start="1981-01-01",
    end="2020-12-30",
    freq_options=("D","W","M","A"),
    max_bin_size=30,
    how_map_defaults=how_map_defaults,
)



# compare measuremenst to swat

# Archive

# Data exploration

## Compare different simulations

### variance in certain columns

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
from typing import Dict, List, Optional, Literal, Union, Iterable
from datetime import datetime, date
import pandas as pd
import numpy as np
import plotly.graph_objects as go

def _ensure_dt_index(df: pd.DataFrame, date_col: str) -> pd.DataFrame:
    out = df.copy()
    out[date_col] = pd.to_datetime(out[date_col])
    return out.set_index(date_col).sort_index()

def _slice_time(df: pd.DataFrame, start: Optional[Union[str, datetime, date]], end: Optional[Union[str, datetime, date]]) -> pd.DataFrame:
    if start is not None:
        df = df.loc[pd.to_datetime(start):]
    if end is not None:
        df = df.loc[:pd.to_datetime(end)]
    return df

def _filter_season(df: pd.DataFrame, months: List[int]) -> pd.DataFrame:
    return df[df.index.month.isin(months)]

def _make_freq_string(freq: str, bin_size: int) -> str:
    base_map = {"D": "D", "DAILY": "D", "W": "W", "WEEKLY": "W",
                "M": "M", "MONTHLY": "M", "A": "A", "Y": "A", "YEARLY": "A"}
    base = base_map.get(freq.upper(), freq)
    return f"{bin_size}{base}" if bin_size and bin_size > 1 else base

def _resample_series(df: pd.DataFrame, var: str, freq: str,
                     how: Literal["sum","mean","flow_weighted_mean"],
                     flow_col: str = "FLOW_OUTcms") -> pd.Series:
    if how == "sum":
        s = df[var].resample(freq).sum(min_count=1)
    elif how == "mean":
        s = df[var].resample(freq).mean()
    elif how == "flow_weighted_mean":
        w = df[flow_col].clip(lower=0)
        num = (df[var] * w).resample(freq).sum(min_count=1)
        den = w.resample(freq).sum(min_count=1)
        s = num / den.replace(0, np.nan)
    else:
        raise ValueError("how must be 'sum','mean','flow_weighted_mean'")

    # Make 100% sure the index is a tz-naive DatetimeIndex (not PeriodIndex/object)
    if isinstance(s.index, pd.PeriodIndex):
        s.index = s.index.to_timestamp(how="end")
    s.index = pd.to_datetime(s.index, utc=False)  # tz-naive
    return s

def compare_simulations_dashboard(
    sim_dfs: Dict[str, pd.DataFrame],
    variables: List[str],
    *,
    reach: Optional[int] = None,
    freq_options: Iterable[str] = ("D","W","M","A"),
    max_bin_size: int = 12,
    start: Optional[Union[str, datetime, date]] = None,
    end: Optional[Union[str, datetime, date]] = None,
    season_months: Optional[List[int]] = None,
    how_map_defaults: Optional[Dict[str, str]] = None,
    reach_col: str = "RCH",
    date_col: str = "date",
    flow_col: str = "FLOW_OUTcms",
    template: str = "plotly_white",
    figure_width: Optional[int] = 1200,
    figure_height: int = 600
):
    if how_map_defaults is None:
        how_map_defaults = {}

    # discover reaches
    all_reaches = set()
    for df in sim_dfs.values():
        all_reaches.update(df[reach_col].dropna().unique().tolist())
    reach_choices = sorted(int(r) for r in all_reaches if pd.notna(r))
    if not reach_choices:
        raise ValueError("No reaches found.")
    if reach is None:
        reach = reach_choices[0]

    # widgets
    sim_checkboxes = [widgets.Checkbox(value=True, description=name, layout=widgets.Layout(width="360px"))
                      for name in sim_dfs.keys()]
    sim_box = widgets.VBox(sim_checkboxes, layout=widgets.Layout(width="400px"))

    dd_var   = widgets.Dropdown(options=variables, value=variables[0], description="Variable:", layout=widgets.Layout(width="360px"))
    dd_reach = widgets.Dropdown(options=reach_choices, value=reach, description="Reach:", layout=widgets.Layout(width="180px"))
    dd_freq  = widgets.Dropdown(options=list(freq_options), value="D", description="Freq:", layout=widgets.Layout(width="140px"))
    sl_bin   = widgets.IntSlider(value=1, min=1, max=max_bin_size, step=1, description="Bin:", continuous_update=False, layout=widgets.Layout(width="300px"))
    dd_method = widgets.Dropdown(options=["sum","mean","flow_weighted_mean"], value="mean", description="Method:", layout=widgets.Layout(width="280px"))
    cb_autoscale_y_live = widgets.Checkbox(value=True, description="Auto-scale Y when zooming")
    cb_show_names_in_tooltip = widgets.Checkbox(value=True, description="Show names in tooltip")

    def _default_method_for_var(v: str) -> str:
        if v in how_map_defaults:
            return how_map_defaults[v]
        if "Conc" in v or "mg/L" in v:
            return "flow_weighted_mean"
        if any(u in v.lower() for u in ["kg","tons","mg"]):
            return "sum"
        return "mean"
    dd_method.value = _default_method_for_var(dd_var.value)

    out = widgets.Output()
    _last = {"aligned_df": None, "y_fixed": None, "fig": None}

    # “Nice” dynamic date labels without breaking rendering
    TICK_STOPS = [
        dict(dtickrange=[None, 1000*60*60*24], value="%Y-%m-%d\n%H:%M"),
        dict(dtickrange=[1000*60*60*24, 1000*60*60*24*28], value="%Y-%m-%d"),
        dict(dtickrange=[1000*60*60*24*28, 1000*60*60*24*365], value="%Y-%m"),
        dict(dtickrange=[1000*60*60*24*365, None], value="%Y"),
    ]

    def _hovertemplate(show_name: bool) -> str:
        return ("%{fullData.name}: %{y:.4g}<extra></extra>" if show_name
                else "%{y:.4g}<extra></extra>")

    def _compute_and_plot():
        selected_sims = {cb.description: sim_dfs[cb.description] for cb in sim_checkboxes if cb.value}
        if not selected_sims:
            with out:
                clear_output(wait=True)
                print("No simulations selected.")
            return

        freq_str = _make_freq_string(dd_freq.value, sl_bin.value)
        var = dd_var.value
        method = dd_method.value

        # prepare per sim
        per_sim = {}
        for sim_name, df in selected_sims.items():
            sub = df[df[reach_col] == dd_reach.value].copy()
            if sub.empty:
                continue
            sub = _ensure_dt_index(sub, date_col)
            if start or end:
                sub = _slice_time(sub, start, end)
            if season_months:
                sub = _filter_season(sub, season_months)
            if not sub.empty:
                per_sim[sim_name] = sub

        if not per_sim:
            with out:
                clear_output(wait=True)
                print(f"No data for reach {dd_reach.value} in selection.")
            return

        # align series
        aligned = []
        for sim_name, sub in per_sim.items():
            s = _resample_series(sub, var, freq=freq_str, how=method, flow_col=flow_col)
            s.name = sim_name
            aligned.append(s)

        aligned_df = pd.concat(aligned, axis=1).sort_index()
        # Ensure real datetimes and build py-datetime array for Plotly (avoids “huge number” axes)
        aligned_df.index = pd.to_datetime(aligned_df.index, utc=False)
        x_dt = aligned_df.index.to_pydatetime()
        _last["aligned_df"] = aligned_df

        # y range with padding
        if aligned_df.size:
            y_min = float(np.nanmin(aligned_df.values))
            y_max = float(np.nanmax(aligned_df.values))
            if not np.isfinite(y_min) or not np.isfinite(y_max):
                y_min, y_max = 0.0, 1.0
            if y_min == y_max:
                y_max = y_min + 1.0
        else:
            y_min, y_max = 0.0, 1.0
        pad = (y_max - y_min) * 0.05
        _last["y_fixed"] = [y_min - pad, y_max + pad]

        # figure
        fig = go.FigureWidget(layout=dict(template=template))
        if figure_width is not None:
            fig.layout.width = int(figure_width)
        fig.layout.height = int(figure_height)

        # lines only (simplified)
        ht = _hovertemplate(cb_show_names_in_tooltip.value)
        for sim_name in aligned_df.columns:
            fig.add_trace(go.Scatter(
                x=x_dt, y=aligned_df[sim_name].values,
                mode="lines", name=sim_name, hovertemplate=ht,
                connectgaps=True
            ))

        fig.update_layout(
            title=f"{var} — Reach {dd_reach.value} ({freq_str}, {method})",
            xaxis_title="Date", yaxis_title=var,
            hovermode="x unified",
            hoverlabel=dict(namelength=-1, align="left", font_size=12, bgcolor="white"),
            legend=dict(orientation="h", y=1.05, x=0),
            xaxis=dict(
                type="date",
                rangeslider=dict(visible=True),
                tickformatstops=TICK_STOPS
            ),
            margin=dict(l=60, r=20, t=110, b=50)
        )

        # fixed Y; live update on zoom
        fig.update_yaxes(autorange=False, range=_last["y_fixed"])
        _last["fig"] = fig

        def _on_xrange_change(layout, xrange):
            if _last["aligned_df"] is None:
                return
            if not cb_autoscale_y_live.value:
                fig.layout.yaxis.update(autorange=False, range=_last["y_fixed"])
                return
            try:
                x0 = pd.to_datetime(xrange[0]); x1 = pd.to_datetime(xrange[1])
            except Exception:
                return
            win = _last["aligned_df"].loc[( _last["aligned_df"].index >= x0) & (_last["aligned_df"].index <= x1)]
            if win.empty:
                return
            ymin = float(np.nanmin(win.values)); ymax = float(np.nanmax(win.values))
            if not np.isfinite(ymin) or not np.isfinite(ymax):
                return
            if ymin == ymax:
                ymax = ymin + 1.0
            pad_local = (ymax - ymin) * 0.05
            fig.layout.yaxis.update(autorange=False, range=[ymin - pad_local, ymax + pad_local])

        fig.layout.xaxis.on_change(_on_xrange_change, 'range')

        with out:
            clear_output(wait=True)
            display(fig)

    # observers
    def _on_var_change(change):
        dd_method.value = _default_method_for_var(change["new"])
    dd_var.observe(_on_var_change, names="value")

    def _on_tooltip_toggle(change):
        if _last["fig"] is None:
            return
        ht = ("%{fullData.name}: %{y:.4g}<extra></extra>" if change["new"]
              else "%{y:.4g}<extra></extra>")
        for tr in _last["fig"].data:
            tr.update(hovertemplate=ht)

    cb_show_names_in_tooltip.observe(_on_tooltip_toggle, names="value")

    for w in [dd_var, dd_reach, dd_freq, sl_bin, dd_method, cb_autoscale_y_live] + sim_checkboxes:
        w.observe(lambda _: _compute_and_plot(), names="value")

    controls_top = widgets.HBox([sim_box, widgets.VBox([dd_var, dd_method])])
    controls_bottom = widgets.HBox([dd_reach, dd_freq, sl_bin, cb_autoscale_y_live, cb_show_names_in_tooltip])
    display(controls_top, controls_bottom, out)

    _compute_and_plot()


In [None]:





# Yearly, aggregated view (sum for loads, FWM for concentrations) for the same reach:
figs_annually = compare_simulations_dashboard(
    load_or_build_dfs_for_runs([81], force_rebuild=False),
    vars_to_compare,
    reach=15,
    start="1981-01-01",
    end="2020-12-30",
    freq_options=("D","W","M","A"),
    max_bin_size=30,
    how_map_defaults=how_map_defaults,
)

print("----------------------------------------------------------------------------------------------")

figs_annually_2 = compare_simulations_dashboard(
    runs,
    vars_to_compare,
    reach=15,
    start="1981-01-01",
    end="2020-12-30",
    freq_options=("D","W","M","A"),
    max_bin_size=30,
    how_map_defaults={
        "ORGN_OUTkg": "sum",
        "ORGP_OUTkg": "sum",
        "NO3_OUTkg": "sum",
        "SEDCONCmg/L": "flow_weighted_mean",
        "NO3ConcMg/l": "flow_weighted_mean"
    },
)



# Save a specific plot:
# figs_daily["NO3_OUTkg"].savefig("reach12_no3_daily_compare.png", dpi=180, bbox_inches="tight")


In [None]:


import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import plotly.graph_objects as go
import ipywidgets as W
from IPython.display import display

def overlay_tp_tn_toggle(
    df_rch: pd.DataFrame,
    df_sub: pd.DataFrame,
    reaches,
    agg: str = "sum",                 # "sum" or "mean" across selected reaches
    resample_rule: str | None = None, # e.g. "W", "M" (None = native timestep)
    resample_agg: str = "sum",        # for masses, "sum" is typical
    roll: int | None = None,          # rolling window (days); None/0/1 disables
    # column mappings (keep your updated names; change only if yours differ)
    rch_cols = dict(
        date="date", reach="RCH",
        ORGN_OUT="ORGN_OUTkg", ORGP_OUT="ORGP_OUTkg",
        NO3_OUT="NO3_OUTkg",  NH4_OUT="NH4_OUTkg", NO2_OUT="NO2_OUTkg",
        MINP_OUT="MINP_OUTkg",
        TOT_N="TOT_Nkg", TOT_P="TOT_Pkg",
    ),
    sub_cols = dict(
        date="date", sub="SUB", area_km2="AREA",
        ORGN="ORGN", NSURQ="NSURQ",        # N (kg/ha)
        ORGP="ORGP", SOLP="SOLP", SEDP="SEDP",  # P (kg/ha)
    ),
):
    # --- standardize columns locally
    r = df_rch.rename(columns={
        rch_cols["date"]:"date", rch_cols["reach"]:"RCH",
        rch_cols["ORGN_OUT"]:"ORGN_OUTkg", rch_cols["ORGP_OUT"]:"ORGP_OUTkg",
        rch_cols["NO3_OUT"]:"NO3_OUTkg",   rch_cols["NH4_OUT"]:"NH4_OUTkg", rch_cols["NO2_OUT"]:"NO2_OUTkg",
        rch_cols["MINP_OUT"]:"MINP_OUTkg",
        rch_cols["TOT_N"]:"TOT_Nkg", rch_cols["TOT_P"]:"TOT_Pkg",
    }).copy()
    s = df_sub.rename(columns={
        sub_cols["date"]:"date", sub_cols["sub"]:"SUB", sub_cols["area_km2"]:"AREA_KM2",
        sub_cols["ORGN"]:"ORGN", sub_cols["NSURQ"]:"NSURQ",
        sub_cols["ORGP"]:"ORGP", sub_cols["SOLP"]:"SOLP", sub_cols["SEDP"]:"SEDP",
    }).copy()

    r["date"] = pd.to_datetime(r["date"])
    s["date"] = pd.to_datetime(s["date"])

    # filter reaches
    reaches = list(reaches) if hasattr(reaches, "__iter__") and not isinstance(reaches, (str, bytes)) else [reaches]
    r = r[r["RCH"].isin(reaches)].copy()
    s = s[s["SUB"].isin(reaches)].copy()

    # RCH component sums
    r["TN_components"] = r[["ORGN_OUTkg","NO3_OUTkg","NH4_OUTkg","NO2_OUTkg"]].sum(axis=1)
    r["TP_components"] = r[["ORGP_OUTkg","MINP_OUTkg"]].sum(axis=1)

    # aggregate across reaches by date
    agg_key = "sum" if agg == "sum" else "mean"
    r_agg = (r.groupby("date")[["TN_components","TP_components","TOT_Nkg","TOT_Pkg"]]
               .agg(agg_key).sort_index())

    # SUB: kg/ha → kg (robust area handling: AREA in km² or ha)
    area_series = pd.to_numeric(s["AREA_KM2"], errors="coerce")
    AREA_IS_KM2 = np.nanmedian(area_series) < 1_000
    s["AREA_HA"] = area_series * (100.0 if AREA_IS_KM2 else 1.0)

    s["TN_yield_sub_kg"] = (s["ORGN"].fillna(0) + s["NSURQ"].fillna(0)) * s["AREA_HA"]
    s["TP_yield_sub_kg"] = (s["ORGP"].fillna(0) + s["SOLP"].fillna(0) + s["SEDP"].fillna(0)) * s["AREA_HA"]
    s_agg = (s.groupby("date")[["TN_yield_sub_kg","TP_yield_sub_kg"]]
               .agg(agg_key).sort_index())

    # merge series
    dfm = r_agg.merge(s_agg, left_index=True, right_index=True, how="inner")

    # optional resample (e.g., weekly/monthly totals)
    if resample_rule:
        dfm = getattr(dfm.resample(resample_rule), resample_agg)()

    # optional rolling smooth
    if roll and roll > 1:
        dfm = dfm.rolling(roll, min_periods=max(1, roll//2)).mean()

    # --- UI toggles
    show_total   = W.Checkbox(value=True,  description="Show RCH Total (TOT_N/TOT_P)")
    show_comp    = W.Checkbox(value=True,  description="Show RCH Components sum (TN: ORGN+NO3+NH4+NO2; TP: ORGP+MINP)")
    show_subyld  = W.Checkbox(value=True,  description="Show SUB Yield sums (TN: ORGN+NSURQ; TP: ORGP+SOLP+SEDP)")
    out_tn = W.Output()
    out_tp = W.Output()

    def draw():
        with out_tn:
            out_tn.clear_output(wait=True)
            fig = go.Figure()
            if show_total.value:
                fig.add_trace(go.Scatter(x=dfm.index, y=dfm["TOT_Nkg"], mode="lines",
                                         name="RCH Total N (TOT_Nkg)", line=dict(width=3)))
            if show_comp.value:
                fig.add_trace(go.Scatter(x=dfm.index, y=dfm["TN_components"], mode="lines",
                                         name="RCH N components sum (ORGN+NO3+NH4+NO2)"))
            if show_subyld.value:
                fig.add_trace(go.Scatter(x=dfm.index, y=dfm["TN_yield_sub_kg"], mode="lines",
                                         name="SUB N yield sum (ORGN+NSURQ) — kg (area-scaled)",
                                         line=dict(dash="dash")))
            fig.update_layout(title=f"TN — comparison (reaches {reaches}, agg={agg})",
                              xaxis_title="Date", yaxis_title="Nitrogen (kg per step)",
                              legend_title=None, margin=dict(l=40,r=20,t=50,b=40))
            fig.show()

        with out_tp:
            out_tp.clear_output(wait=True)
            fig = go.Figure()
            if show_total.value:
                fig.add_trace(go.Scatter(x=dfm.index, y=dfm["TOT_Pkg"], mode="lines",
                                         name="RCH Total P (TOT_Pkg)", line=dict(width=3)))
            if show_comp.value:
                fig.add_trace(go.Scatter(x=dfm.index, y=dfm["TP_components"], mode="lines",
                                         name="RCH P components sum (ORGP+MINP)"))
            if show_subyld.value:
                fig.add_trace(go.Scatter(x=dfm.index, y=dfm["TP_yield_sub_kg"], mode="lines",
                                         name="SUB P yield sum (ORGP+SOLP+SEDP) — kg (area-scaled)",
                                         line=dict(dash="dash")))
            fig.update_layout(title=f"TP — comparison (reaches {reaches}, agg={agg})",
                              xaxis_title="Date", yaxis_title="Phosphorus (kg per step)",
                              legend_title=None, margin=dict(l=40,r=20,t=50,b=40))
            fig.show()

    def on_change(_):
        draw()

    for w in (show_total, show_comp, show_subyld):
        w.observe(on_change, names="value")

    # initial draw & display
    draw()
    display(W.VBox([
        W.HTML("<b>Toggle which curves to show</b>"),
        W.HBox([show_total, show_comp, show_subyld]),
        out_tn,
        out_tp
    ]))


In [None]:
df_sed_rch_cmp = dict_109["rch_run000109_real000396_1"]
df_sub_yld_cmp = parse_swat_sub_to_df(r"C:\SWAT\RSWAT\cubillas\mc_results\run000109_real000396_1\output.sub")
df_sed_rch_cmp
df_sub_yld_cmp




# Minimal call — pick a single reach (or a list) and go:
overlay_tp_tn_toggle(df_sed_rch_cmp, df_sub_yld_cmp, reaches=[1,2,3,4,5,6,7,8,9,10,12,13,14,15,16,17], agg="mean", resample_rule="W", resample_agg="mean", roll=1)

