In [1]:
import pytest
import warnings
import time
from pathlib import Path
import numpy as np
import xarray as xr
import pandas as pd
from importlib.resources import files

from TRITON_SWMM_toolkit.swmm_output_parser import (
    retrieve_SWMM_outputs_as_datasets,
    convert_swmm_tdeltas_to_minutes,
    return_swmm_outputs,
    return_node_time_series_results_from_rpt,
    format_rpt_section_into_dataframe,
    return_data_from_rpt,
)
from TRITON_SWMM_toolkit.utils import write_zarr
from TRITON_SWMM_toolkit.constants import (
    APP_NAME,
)
import tempfile

REFERENCE_DATA_DIR = (
    files(APP_NAME).parents[1] / "test_data" / "swmm_refactoring_reference"  # type: ignore
)

# Reference files
REF_INP = REFERENCE_DATA_DIR / "hydraulics.inp"
REF_HYDRAULICS_RPT = REFERENCE_DATA_DIR / "hydraulics.rpt"

# Reference zarr outputs
REF_LINK_SUMMARY_ZARR = REFERENCE_DATA_DIR / "SWMM_link_summary.zarr"
REF_LINK_TSERIES_ZARR = REFERENCE_DATA_DIR / "SWMM_link_tseries.zarr"
REF_NODE_SUMMARY_ZARR = REFERENCE_DATA_DIR / "SWMM_node_summary.zarr"
REF_NODE_TSERIES_ZARR = REFERENCE_DATA_DIR / "SWMM_node_tseries.zarr"

# Tolerance for numeric comparisons
RTOL = 1e-5  # Relative tolerance
ATOL = 1e-8  # Absolute tolerance


In [3]:
def reference_link_tseries():
    return xr.open_dataset(REF_LINK_TSERIES_ZARR, engine="zarr", consolidated=False)

def reference_node_tseries():
    return xr.open_dataset(REF_NODE_TSERIES_ZARR, engine="zarr", consolidated=False)

def parsed_outputs():
    ds_nodes, ds_links = retrieve_SWMM_outputs_as_datasets(REF_INP, REF_HYDRAULICS_RPT)
    with tempfile.TemporaryDirectory() as tmpdir:
        tmpdir = Path(tmpdir)

        nodes_path = tmpdir / "nodes.zarr"
        links_path = tmpdir / "links.zarr"

        write_zarr(ds_nodes, nodes_path, compression_level=5)
        write_zarr(ds_links, links_path, compression_level=5)

        ds_nodes = xr.open_dataset(
            nodes_path, engine="zarr", consolidated=False
        ).load()
        ds_links = xr.open_dataset(
            links_path, engine="zarr", consolidated=False
        ).load()
    return ds_nodes, ds_links

ds_nodes_old, ds_links_old = reference_node_tseries(), reference_link_tseries()

ds_nodes_new, ds_links_new = parsed_outputs()



##################################
Found problem. Orifice conduits do not return max velocity or max over full flow. Filling with empty string
Normal row vs. problem row:
  0                    CONDUIT     0.000     0  00:01      0.00    0.00    0.00

  756                  ORIFICE     0.000     0  00:00                      0.00

Properly parsed values:
['756', 'ORIFICE', '0.000', '0', '00:00', '', '', '0.00\n']
Converted variable to datatype = type, <class 'str'>
Converted variable to datatype = OutfallType, <class 'str'>
Converted variable to datatype = StageOrTimeseries, <class 'str'>
Converted variable to datatype = StorageCurve, <class 'str'>
Converted variable to datatype = Coefficient, <class 'str'>
Converted variable to datatype = type, <class 'str'>
Converted variable to datatype = InletNode, <class 'str'>
Converted variable to datatype = OutletNode, <class 'str'>
Converted variable to datatype = OrificeType, <class 'str'>
Converted variable to datatype = FlapGate, <class 'st

  ds_nodes = xr.merge(
  ds_links = xr.merge(


In [11]:
def compare_zarr_datasets(
    ds_new: xr.Dataset, ds_ref: xr.Dataset, rtol: float = RTOL, atol: float = ATOL
) -> tuple[bool, dict]:
    """
    Compare two xarray Datasets for equivalence.

    Parameters
    ----------
    ds_new : xr.Dataset
        The newly generated dataset
    ds_ref : xr.Dataset
        The reference dataset
    rtol : float
        Relative tolerance for numeric comparisons
    atol : float
        Absolute tolerance for numeric comparisons

    Returns
    -------
    tuple
        (is_equivalent: bool, differences: dict)
    """
    differences = {}

    # Check dimensions match
    if ds_new.sizes != ds_ref.sizes:
        differences["dims"] = {
            "new": set(ds_new.dims),
            "ref": set(ds_ref.dims),
            "missing_in_new": set(ds_ref.dims) - set(ds_new.dims),
            "extra_in_new": set(ds_new.dims) - set(ds_ref.dims),
        }

    # Check coordinates match
    for coord in ds_ref.coords:
        if coord not in ds_new.coords:
            differences[f"missing_coord_{coord}"] = True
            continue

        new_vals = ds_new[coord].values
        ref_vals = ds_ref[coord].values

        # Handle different dtypes
        if new_vals.dtype != ref_vals.dtype:
            # Try to compare as strings if dtypes differ
            try:
                new_str = np.array(new_vals, dtype=str)
                ref_str = np.array(ref_vals, dtype=str)
                if not np.array_equal(new_str, ref_str):
                    differences[f"coord_{coord}"] = {
                        "reason": "values differ (compared as strings)",
                        "new_dtype": str(new_vals.dtype),
                        "ref_dtype": str(ref_vals.dtype),
                    }
            except Exception as e:
                differences[f"coord_{coord}"] = {
                    "reason": f"dtype mismatch and comparison failed: {e}",
                    "new_dtype": str(new_vals.dtype),
                    "ref_dtype": str(ref_vals.dtype),
                }
        elif not np.array_equal(new_vals, ref_vals):
            differences[f"coord_{coord}"] = "values differ"

    # Check data variables
    for var in ds_ref.data_vars:
        if var not in ds_new.data_vars:
            differences[f"missing_var_{var}"] = True
            continue

        new_vals = ds_new[var].values
        ref_vals = ds_ref[var].values

        # Handle numeric vs string comparison
        if np.issubdtype(ref_vals.dtype, np.number):
            # Numeric comparison with tolerance
            # Create mask for valid (non-NaN) values in both arrays
            new_nan = (
                np.isnan(new_vals)
                if np.issubdtype(new_vals.dtype, np.floating)
                else np.zeros_like(new_vals, dtype=bool)
            )
            ref_nan = (
                np.isnan(ref_vals)
                if np.issubdtype(ref_vals.dtype, np.floating)
                else np.zeros_like(ref_vals, dtype=bool)
            )

            # Check NaN positions match
            if not np.array_equal(new_nan, ref_nan):
                differences[f"var_{var}"] = "NaN positions differ"
                continue

            # Compare non-NaN values
            mask = ~ref_nan
            if mask.any():
                if not np.allclose(
                    new_vals[mask], ref_vals[mask], rtol=rtol, atol=atol
                ):
                    max_diff = np.max(np.abs(new_vals[mask] - ref_vals[mask]))
                    differences[f"var_{var}"] = (
                        f"numeric values differ (max diff: {max_diff})"
                    )
        else:
            # String/object comparison
            try:
                new_str = np.array(new_vals, dtype=str)
                ref_str = np.array(ref_vals, dtype=str)
                if not np.array_equal(new_str, ref_str):
                    # Find first difference for debugging
                    diff_mask = new_str != ref_str
                    if diff_mask.any():
                        idx = np.argwhere(diff_mask)[0]
                        differences[f"var_{var}"] = {
                            "reason": "string values differ",
                            "first_diff_idx": idx.tolist(),
                            "new_val": str(new_str[tuple(idx)]),
                            "ref_val": str(ref_str[tuple(idx)]),
                        }
            except Exception as e:
                differences[f"var_{var}"] = f"comparison failed: {e}"

    # Check for extra variables in new dataset
    for var in ds_new.data_vars:
        if var not in ds_ref.data_vars:
            differences[f"extra_var_{var}"] = True

    return len(differences) == 0, differences

In [12]:
node_max_flow = ds_nodes_new['max_flow_cms'].idxmax().to_dict()['data']


In [13]:
match, diffs = compare_zarr_datasets(ds_nodes_new, ds_nodes_old)
match

True

In [15]:
match, diffs = compare_zarr_datasets(ds_links_new, ds_links_old)
match

True