In [None]:
! /home/balaji24/.pixi/bin/pixi add s3fs regionmask geopandas rioxarray pyproj intake

In [None]:
! /home/balaji24/.pixi/bin/pixi add s3fs

In [None]:
import os
os.environ["PROJ_LIB"] = "/home/balaji24/skagit-met/.pixi/envs/analysis/share/proj"
os.environ["GDAL_DATA"] = "/home/balaji24/skagit-met/.pixi/envs/analysis/share/gdal"


In [None]:
import json
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import regionmask
from shapely.geometry import shape, Polygon, LinearRing
import matplotlib.pyplot as plt
from dask.diagnostics import ProgressBar

In [None]:
# CONFIG
DATASET_KIND = "daily"   # or "hourly"
START = "2014-10-01"
END   = "2023-12-31"

VARS = [
    "T2", "Q2", "U10", "V10", "PSFC",
    "SWDNB", "LWDNB",
    "RAINC", "RAINNC", "SNOWNC",
    "SMOIS", "TSLB", "HFX", "LH", "HGT",
]

BOUNDARY_JSON = "../data/GIS/SkagitBoundary.json" 

OUT_ZARR = Path("../../../../data0/balaji24/data/CONUS/conus404_skagit_daily.zarr") if DATASET_KIND=="daily" \
           else Path("../../../../data0/balaji24/data/CONUS/conus404_skagit_hourly.zarr")

# Check file existence
assert Path(BOUNDARY_JSON).exists(), f"Boundary file not found: {BOUNDARY_JSON}"

In [None]:
# Open CONUS404 from OSN

endpoint = "https://usgs.osn.mghpcc.org"
store_url = {
    "daily":  "s3://hytest/conus404/conus404_daily.zarr",
    "hourly": "s3://hytest/conus404/conus404_hourly.zarr",
}[DATASET_KIND]

ds = xr.open_zarr(
    store=store_url,
    storage_options={"anon": True, "client_kwargs": {"endpoint_url": endpoint}},
    consolidated=True,
)

In [None]:
# Subset time & select variables

ds = ds.sel(time=slice(START, END))
available = set(ds.data_vars)
keep = [v for v in VARS if v in available]
if not keep:
    raise ValueError(f"None of {VARS} found in dataset. Available: {list(ds.data_vars)[:25]}")
ds = ds[keep]

In [None]:
# Identify lat/lon grid

if "lat" in ds and "lon" in ds:
    da_lat, da_lon = ds["lat"], ds["lon"]
else:
    candidates = [("XLAT","XLONG"), ("latitude","longitude")]
    for a,b in candidates:
        if a in ds and b in ds:
            da_lat, da_lon = ds[a], ds[b]
            break
    else:
        raise KeyError("Could not find lat/lon variables in dataset.")

spatial_dims = da_lat.dims  # ('y','x')

In [None]:
# Load Skagit boundary JSON

def load_boundary_to_gdf(json_path: str) -> gpd.GeoDataFrame:
    with open(json_path) as f:
        obj = json.load(f)

    basin_geom = None
    if isinstance(obj, dict) and "type" in obj:
        try:
            if obj["type"] == "FeatureCollection":
                geom = shape(obj["features"][0]["geometry"])
            elif obj["type"] == "Feature":
                geom = shape(obj["geometry"])
            else:
                geom = shape(obj)
            basin_geom = geom
        except Exception:
            basin_geom = None

    if basin_geom is None and isinstance(obj, dict) and "lon" in obj and "lat" in obj:
        coords = list(zip(obj["lon"], obj["lat"]))
        if coords[0] != coords[-1]:
            coords.append(coords[0])
        basin_geom = Polygon(LinearRing(coords))

    if basin_geom is None:
        raise ValueError("Invalid boundary file format. Must be GeoJSON or {'lon','lat'} arrays.")

    return gpd.GeoDataFrame({"name":["SkagitBoundary"]}, geometry=[basin_geom], crs="EPSG:4326")

gdf = load_boundary_to_gdf(BOUNDARY_JSON).explode(ignore_index=True)

In [None]:
# Apply spatial mask for Skagit

if hasattr(regionmask.Regions, "from_geopandas"):
    regions = regionmask.Regions.from_geopandas(gdf, names="name")
else:
    outlines = [geom for geom in gdf.geometry if geom is not None]
    regions  = regionmask.Regions(outlines=outlines, names=["SkagitBoundary"]*len(outlines),
                                  numbers=list(range(len(outlines))), name="SkagitBoundary")

mask = regions.mask(ds["lon"], ds["lat"])
ds = ds.where(mask.notnull())

In [None]:
# Derived diagnostics

if {"U10","V10"} <= set(ds.data_vars):
    ds["WS10"] = (ds["U10"]**2 + ds["V10"]**2)**0.5
    ds["WS10"].attrs.update(units="m s-1", long_name="10 m wind speed")

has_rain = {"RAINC","RAINNC"} <= set(ds.data_vars)
has_snow = "SNOWNC" in ds.data_vars
if (has_rain or has_snow):
    pieces = []
    if has_rain: pieces.append(ds["RAINC"] + ds["RAINNC"])
    if has_snow: pieces.append(ds["SNOWNC"])
    ds["PRECIP_TOT"] = sum(pieces)
    ds["PRECIP_TOT"].attrs.update(units="mm", long_name="Total precipitation (rain+snow)")

if {"T2","Q2","PSFC"} <= set(ds.data_vars):
    T_c = ds["T2"] - 273.15
    p_kpa = ds["PSFC"]/1000.0
    q = ds["Q2"]
    es_kpa = 0.6108 * np.exp((17.27*T_c)/(T_c+237.3))
    e_kpa = (q*p_kpa)/(0.622 + 0.378*q)
    ds["RH"] = (e_kpa/es_kpa).clip(0,1)*100.0
    ds["VPD"] = (es_kpa - e_kpa).clip(min=0)
    ds["RH"].attrs.update(units="%", long_name="Relative Humidity")
    ds["VPD"].attrs.update(units="kPa", long_name="Vapor Pressure Deficit")

In [None]:
# Chunking + encoding cleanup

CHUNKS = {"time": 30, spatial_dims[0]: 300, spatial_dims[1]: 300}
if DATASET_KIND == "hourly":
    CHUNKS["time"] = 24*7
ds = ds.chunk(CHUNKS)

# remove bad encodings (prevents overlapping chunk error)
for v in ds.variables:
    ds[v].encoding.pop("chunks", None)

if "lat" in ds:
    ds["lat"] = ds["lat"].chunk({spatial_dims[0]:300, spatial_dims[1]:300})
if "lon" in ds:
    ds["lon"] = ds["lon"].chunk({spatial_dims[0]:300, spatial_dims[1]:300})

In [None]:
# Save to Zarr (v2, with progress bar)

OUT_ZARR.parent.mkdir(parents=True, exist_ok=True)
with ProgressBar():
    ds.to_zarr(OUT_ZARR, mode="w", consolidated=True, zarr_version=2)

In [None]:
# QC Visualization

saved = xr.open_zarr(OUT_ZARR)
print("Saved variables:", list(saved.data_vars))
print("Dataset sizes:", saved.sizes)

plot_var = next((v for v in ["T2","PRECIP_TOT","SWDNB","WS10"] if v in saved.data_vars), list(saved.data_vars)[0])
ts_var = "T2" if "T2" in saved.data_vars else plot_var
ts = saved[ts_var].mean(dim=[spatial_dims[0], spatial_dims[1]])

fig = plt.figure(figsize=(12,4.5))
ax1 = fig.add_subplot(1,2,1)
saved[plot_var].isel(time=0).plot(ax=ax1)
gdf.boundary.plot(ax=ax1, color="k")
ax1.set_title(f"{plot_var} (first timestep)")

ax2 = fig.add_subplot(1,2,2)
ts.plot(ax=ax2)
ax2.set_title(f"Basin-mean {ts_var} over time")
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
print(sorted(list(ds.data_vars))[:50])
missing = [v for v in ["RAINC","RAINNC","SNOWNC"] if v not in ds]
print("Missing precip vars:", missing)

In [None]:
import s3fs, itertools, xarray as xr

endpoint = "https://usgs.osn.mghpcc.org"
fs = s3fs.S3FileSystem(anon=True, client_kwargs={"endpoint_url": endpoint})

# 1) Explore likely roots
roots = [
    "hytest/conus404",
    "hytest/CONUS404",
    "hytest/CONUS_404",
    "hytest/conus404/wrfout",
    "hytest/CONUS404/wrfout",
    "hytest/CONUS_404/wrfout",
]
existing = []
for r in roots:
    try:
        fs.ls(r)
        existing.append(r)
    except Exception:
        pass

print("Existing roots:", existing or "(none)")

# 2) Try to *find* any wrfout NetCDFs by scanning a few depths
candidates = []
for r in existing or ["hytest"]:
    try:
        # limit depth to keep it fast; increase maxdepth if needed
        for path in fs.find(r, maxdepth=6):
            if path.lower().endswith(".nc") and "wrfout" in path.lower():
                candidates.append(path)
    except Exception:
        pass

print(f"Found {len(candidates)} wrfout .nc files (showing 10):")
for p in candidates[:10]:
    print(" -", p)

# 3) If none yet, try a broader search but with a safety cap
if not candidates:
    hits = []
    for path in fs.find("hytest", maxdepth=7):
        if path.lower().endswith(".nc") and "wrfout" in path.lower():
            hits.append(path)
            if len(hits) >= 20:
                break
    candidates = hits
    print(f"\nBroader scan found {len(candidates)} candidates (showing 10):")
    for p in candidates[:10]:
        print(" -", p)

# 4) Open one file and list precip-like variables
if not candidates:
    raise SystemExit("No wrfout NetCDFs found under hytest (at searched depths). Try increasing maxdepth or different roots.")

sample = candidates[0]
print("\nOpening sample:", sample)

url = f"simplecache::s3://{sample}"
ds = xr.open_dataset(
    url,
    engine="h5netcdf",
    backend_kwargs={"phony_dims": "access"},
    storage_options={"anon": True, "client_kwargs": {"endpoint_url": endpoint}},
)

prec_vars = [v for v in ds.data_vars if any(k in v.upper() for k in ["RAIN","SNOW","APCP","PREC"])]
print("Precip-like vars in sample:", prec_vars)

# If you see RAINNC/RAINC/SNOWNC, youâ€™re good.