In [None]:
import xarray as xr
import pandas as pd
import ibis
from psycopg import sql

from swed_17.zone_db import CBRFCZone
from swed_17.nb_helpers import start_cluster

xr.set_options(use_new_combine_kwarg_defaults=True)

In [None]:
BASE_DIR = "/perc10/data/cbrfcSnowModel/"
BASIN_DIR = BASE_DIR + "basinSetup"
SNOBAL_ARCHIVE = BASE_DIR + "isnobal/zarr_archive/"
UA_ARCHIVE = "/nvm9/data/swann/zarr_archive/"
SNODAS_ARCHIVE = "/nvm9/data/snodas/zarr_archive/"

In [None]:
DB_CONNECTION = "service=swe_db"
SWE_DB = ibis.connect("postgres://?" + DB_CONNECTION)
zone_db = CBRFCZone(DB_CONNECTION)

In [None]:
ISNOBAL_DOMAIN = 'colkrem'

## iSnobal

In [None]:
# ZARR_ARCHIVE = SNOBAL_ARCHIVE + f"latest/wy20*{ISNOBAL_DOMAIN}.zarr"
# VARIABLE = ["specific_mass"]
# DB_TABLE = "isnobal_zonal_swe"

## UArizona

In [None]:
# ZARR_ARCHIVE = UA_ARCHIVE + "*.zarr"
# VARIABLE = ["SWE", "cbrfc_zone_gid"]
# DB_TABLE = "ua_zonal_swe"

## SNODAS

In [None]:
# ZARR_ARCHIVE = SNODAS_ARCHIVE + "*.zarr"
# VARIABLE = ["SWE", "cbrfc_zone_gid"]
# DB_TABLE = "snodas_zonal_swe"

## CU Boulder

In [None]:
# DB_TABLE = "cu_boulder_zonal_swe"
# tables = [
#     "cu_boulder_wy2021",
#     "cu_boulder",
#     "cu_boulder_wy2025",
# ]

## ASO

In [None]:
DB_TABLE = "aso_zonal_swe"
tables = [
    "aso_swe_13n",
]

In [None]:
SWE_FUNCTION = """
SELECT swe_date, swe FROM public.swe_from_product_for_zone({}, {});
"""

In [None]:
swe_table = SWE_DB.table(DB_TABLE)
zones_in_isnobal = SWE_DB.table("cbrfc_zones_in_isnobal")

## Get zones covered in the model domain

In [None]:
zone_info = zones_in_isnobal.filter(zones_in_isnobal.basin_name == ISNOBAL_DOMAIN).execute().set_index("gid")

In [None]:
zone_info.head()

In [None]:
len(zone_info)

## Extract statistic 

### From products in Zarr archives

In [None]:
cluster = start_cluster(n_workers=36, memory_limit="8GB", local=False)

In [None]:
ZARR_ARCHIVE

In [None]:
ds_swe = xr.open_mfdataset(
    ZARR_ARCHIVE,
    preprocess=lambda ds: ds[VARIABLE],
    parallel=True,
    engine="zarr",
)

In [None]:
# iSnobal
# ds_swe.coords["cbrfc_zone_gid"] = (('y', 'x'), erw_topo['cbrfc_zone'].values)

In [None]:
mask = ds_swe.cbrfc_zone_gid.isin(zone_info.index.values).compute()

swe_data = ds_swe.where(mask, drop=True).compute()
swe_data = swe_data.groupby(swe_data["cbrfc_zone_gid"]).mean(["lat", "lon"])

In [None]:
cluster.shutdown()

In [None]:
swe_data

#### Insert into SWE DB

In [None]:
for cid in zone_info.index.values:
    current_rows = swe_table.filter(swe_table.cbrfc_zone_id == cid).count().execute()
    if current_rows > 0:
        continue

    # Below block is for Zarr archives
    # df = swe_data.sel(cbrfc_zone_gid=cid).to_dataframe().reset_index()
    # iSnobal
    # df["isnobal_version_id"] = 2
    # df['cbrfc_zone_id'] = cid
    # iSnobal, UArizona and SNODAS
    df.rename(columns={'cbrfc_zone_gid': 'cbrfc_zone_id'}, inplace=True)
    df.rename(columns={VARIABLE[0]: 'value'}, inplace=True)
    df.rename(columns={'time': 'datetime'}, inplace=True)
    df["datetime"] = df["datetime"].dt.tz_localize("utc")
    df['metric_type_id'] = 1

    zone_db.write(df, DB_TABLE)

### From database stored products

In [None]:
def entry_exists(cbrfc_id):
    current_rows = swe_table.filter(
        swe_table.cbrfc_zone_id == cbrfc_id
    ).count().execute()
    
    if current_rows > 0:
        print(f"Skipping: {cbrfc_id}")
        return True

    return False

In [None]:
for cbrfc_id in zone_info.index.values:
    if entry_exists(cbrfc_id):
        continue

    print(f"{cbrfc_id}: ", end='')
    row = zone_info.loc[cbrfc_id]
    for table in tables:
        zone_query = sql.SQL(SWE_FUNCTION).format(
            sql.Literal(table), sql.Literal(row.zone)
        )

        with zone_db.query(zone_query) as query_data:
            df = pd.DataFrame(
                query_data.fetchall(),
                columns=["datetime", "value"],
            )

            if len(df) == 0:
                print(f"No data for ID: {cbrfc_id}", end="")
                continue

            df["datetime"] = pd.to_datetime(df["datetime"]).dt.tz_localize("UTC")
            df["cbrfc_zone_id"] = cbrfc_id
            df['metric_type_id'] = 1
            # CU Boulder and ASO are in meters
            df["value"] *= 1000
            print(f"{table} ", end="")
            zone_db.write(df, DB_TABLE)
    print("")