In [1]:
import os
from dask.distributed import Client

n_workers = max(os.cpu_count() - 1, 1)
client = Client(n_workers=n_workers)

In [2]:
from typing import Callable, Optional, Iterable, Literal
import pandas as pd
import geopandas as gpd

class NLDIClient:
    """Client tool to retrieve data from the Network Linked Data Index API."""
    base_url: str = "https://labs.waterdata.usgs.gov/api/nldi"

    def get_data(
        self,
        endpoint: str,
        data_handler: Callable[[str], pd.DataFrame],
        parameters: Optional[dict[str, str]] = None
    ) -> pd.DataFrame:
        """Generic data retrieval method."""
        url = self.base_url + endpoint
        if parameters:
            url += "?" + "&".join([f"{k}={v}" for k, v in parameters.items()])
        return data_handler(url)

    def get_data_sources(self) -> pd.DataFrame:
        """Get list of data sources."""
        return self.get_data("/linked-data", pd.read_json)

    def get_registered_feature(self, feature_source: str, feature_id: str) -> gpd.GeoDataFrame:
        """Get site information."""
        return self.get_data(f"/linked-data/{feature_source}/{feature_id}", gpd.read_file)

    def get_basin(self, feature_source: str, feature_id: str, simplified: bool, split_catchment: bool) -> gpd.GeoDataFrame:
        """Get upstream catchment boundary."""
        return self.get_data(
            f"/linked-data/{feature_source}/{feature_id}/basin",
            gpd.read_file,
            parameters = {
                "simplified": str(simplified).lower(),
                "splitCatchment": str(split_catchment).lower()
            }
        )

In [3]:
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

@dataclass
class RetroConfiguration:
    NWM_VERSION: str
    VARIABLE_NAME: str
    START_DATE: datetime
    END_DATE: datetime
    LOCATION_IDS: Iterable[int]
    OUTPUT_DIR: Path
    CHUNK_BY: Literal["week", "month", "year"]

@dataclass
class USGSConfiguration:
    START_DATE: datetime
    END_DATE: datetime
    SITES: list[str]
    OUTPUT_DIR: Path
    CHUNK_BY: Literal["week", "month", "year"]

In [4]:
import teehr.loading.nwm.retrospective_points as nwm_retro
from teehr.loading.usgs import usgs
import dask.dataframe as dd

def load_retrospective_points(
    config: RetroConfiguration,
) -> pd.DataFrame:
    if config.OUTPUT_DIR.exists():
        return dd.read_parquet(config.OUTPUT_DIR).compute()
        
    nwm_retro.nwm_retro_to_parquet(
        nwm_version=config.NWM_VERSION,
        variable_name=config.VARIABLE_NAME,
        start_date=config.START_DATE,
        end_date=config.END_DATE,
        location_ids=config.LOCATION_IDS,
        output_parquet_dir=config.OUTPUT_DIR,
        chunk_by=config.CHUNK_BY
    )
    return dd.read_parquet(config.OUTPUT_DIR).compute()

def load_usgs_points(
    config: USGSConfiguration,
) -> pd.DataFrame:
    if config.OUTPUT_DIR.exists():
        return dd.read_parquet(config.OUTPUT_DIR).compute()
        
    usgs.usgs_to_parquet(
        sites = config.SITES,
        start_date=config.START_DATE,
        end_date=config.END_DATE,
        output_parquet_dir = config.OUTPUT_DIR,
        chunk_by = config.CHUNK_BY
    )
    return dd.read_parquet(config.OUTPUT_DIR).compute()

In [5]:
usgs_site_code = "02146470"

nldi_client = NLDIClient()
site_info = nldi_client.get_registered_feature("nwissite", f"USGS-{usgs_site_code}")

In [6]:
retro_config = RetroConfiguration(
    NWM_VERSION = "nwm30",
    VARIABLE_NAME = "streamflow",
    START_DATE = nwm_retro.NWM30_MIN_DATE,
    END_DATE = nwm_retro.NWM30_MAX_DATE,
    LOCATION_IDS = site_info.comid.astype(int).values,
    OUTPUT_DIR = Path().home() / f"temp/USGS-{usgs_site_code}-NWM-V30",
    CHUNK_BY = "year"
)

sim = load_retrospective_points(retro_config)

In [7]:
usgs_config = USGSConfiguration(
    START_DATE = nwm_retro.NWM30_MIN_DATE,
    END_DATE = nwm_retro.NWM30_MAX_DATE,
    SITES = [usgs_site_code],
    OUTPUT_DIR = Path().home() / f"temp/USGS-{usgs_site_code}-OBS",
    CHUNK_BY = "year"
)

obs = load_usgs_points(usgs_config)

In [8]:
from hydrotools.events.event_detection import decomposition as ev

In [9]:
obs_ts = obs[["value_time", "value"]].drop_duplicates().set_index("value_time")

In [10]:
from hydrotools.events.event_detection._version import __version__

In [12]:
# Detect events
events = ev.list_events(
    obs_ts,
    halflife='6H', 
    window='7D',
    minimum_event_duration='6H',
    start_radius='6H'
)

ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().