In [None]:
import sqlite3
import os
import glob
import pandas as pd
from dataretrieval import nwis
import geopandas as gpd
import itertools
from itertools import repeat
from tqdm import tqdm
import warnings


In [None]:
# Configs

huc_code = '14'
huc_feature = 'huc2'


ROW_CHUNK = 1000   
VAR_ID = 5       
warnings.filterwarnings('ignore')

In [None]:
# functions

# Function to insert data into POI_TYPE table
def insert_poi_type(poi_type_name, poi_type_source):
    try:
        cursor.execute('''
            SELECT POI_TYPE_ID FROM POI_TYPE WHERE POI_TYPE_NAME = ? AND POI_TYPE_SOURCE = ?
        ''', (poi_type_name, poi_type_source))
        if cursor.fetchone() is None:
            cursor.execute('''
                INSERT INTO POI_TYPE (POI_TYPE_NAME, POI_TYPE_SOURCE)
                VALUES (?, ?)
            ''', (poi_type_name, poi_type_source))
            print(f"Inserted {poi_type_name} into POI_TYPE")
        else:
            print(f"{poi_type_name} already exists in POI_TYPE")
    except sqlite3.Error as e:
        print(f"An error occurred: {e}")

# Function to insert data into VARIABLES table
def insert_variable(variable_name, unit):
    try:
        cursor.execute('''
            SELECT VARIABLE_ID FROM VARIABLES WHERE VARIABLE_NAME = ? AND UNIT = ?
        ''', (variable_name, unit))
        if cursor.fetchone() is None:
            cursor.execute('''
                INSERT INTO VARIABLES (VARIABLE_NAME, UNIT)
                VALUES (?, ?)
            ''', (variable_name, unit))
            print(f"Inserted {variable_name} into VARIABLES")
        else:
            print(f"{variable_name} already exists in VARIABLES")
    except sqlite3.Error as e:
        print(f"An error occurred: {e}")
        
        
# ---------- FAST LOOKUP + BULK INSERT HELPERS ----------
def standardize_datetime(dt):
    """
    Return a standardized ISO-8601 string for a date/datetime-like input.
    - If only a date is present -> 'YYYY-MM-DD'
    - If time is present       -> 'YYYY-MM-DDTHH:MM:SS'
    """
    if dt is None or (isinstance(dt, float) and pd.isna(dt)):
        return None

    # Fast path for already-clean pandas Timestamps/DatetimeIndex
    try:
        ts = pd.to_datetime(dt, errors='coerce', infer_datetime_format=True, utc=False)
    except Exception:
        ts = pd.NaT

    if pd.isna(ts):
        return None

    # If tz-aware, drop tz (store as local naive)
    try:
        if getattr(ts, 'tzinfo', None) is not None:
            ts = ts.tz_convert(None)
    except Exception:
        # Some types use tz_localize(None)
        try:
            ts = ts.tz_localize(None)
        except Exception:
            pass

    # Decide whether to emit date-only or date-time
    src = str(dt)
    has_time_in_src = ('T' in src) or (':' in src) or (len(src.strip()) > 10)

    if not has_time_in_src and ts.hour == 0 and ts.minute == 0 and ts.second == 0:
        return ts.date().isoformat()  # 'YYYY-MM-DD'
    else:
        return ts.strftime('%Y-%m-%dT%H:%M:%S')

def rows_for_site_id(wdid: str):
    """
    Return a DataFrame (possibly empty) of WaDE rights for a given SiteNativeId.
    Uses the index for O(1)/hashed access and avoids per-iteration astype/scan.
    """
    try:
        hits = wade_df.loc[wdid]  # could be Series (single) or DataFrame (multiple)
    except KeyError:
        return pd.DataFrame(columns=wade_df.columns)

    if isinstance(hits, pd.Series):
        # single match -> make it a one-row DataFrame
        return hits.to_frame().T
    return hits

def bulk_insert_pod_waterrights(rows, chunk=50000):
    """
    rows: iterable of tuples (POIID, SiteName, WaterRightID, Allocation_CFS,
                              Allocation_Date, Use_Type, Water_Source, Source_ID)
    """
    if not rows:
        return
    sql = '''
        INSERT INTO POD_WATER_RIGHTS
            (POI_ID, SITE_NAME, WATER_RIGHT_ID, ALLOCATION_CFS, ALLOCATION_DATE,
             USE_TYPE, WATER_SOURCE, SOURCE_ID)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?)
    '''
    # Standardize the Allocation_Date column before insert
    rows_std = []
    for (pid, sname, wid, cfs, adate, utype, wsrc, srcid) in rows:
        rows_std.append((pid, sname, wid, cfs, standardize_datetime(adate), utype, wsrc, srcid))

    n = len(rows_std)
    start = 0
    while start < n:
        end = min(start + chunk, n)
        cursor.executemany(sql, rows_std[start:end])
        start = end

def bulk_insert_poi_values(rows, chunk=2000000):
    """
    rows: iterable of (DataValue, LocalDateTime, POIID, VariableID)
    """
    if not rows:
        return
    sql = '''
        INSERT INTO POI_VALUES (DATA_VALUE, LOCAL_DATE_TIME, POI_ID, VARIABLE_ID)
        VALUES (?, ?, ?, ?)
    '''
    # Map timestamps on the way in
    rows_std = [
        (dv, standardize_datetime(ldt), pid, vid)
        for (dv, ldt, pid, vid) in rows
    ]
    start = 0
    n = len(rows_std)
    while start < n:
        end = min(start + chunk, n)
        cursor.executemany(sql, rows_std[start:end])
        start = end

def load_all_wade(root_dir: str) -> pd.DataFrame:
    """
    Find */Sites.csv and */WaterAllocations.csv under root_dir, normalize column names,
    merge on SiteUuid, explode multi-site allocations, and return a single DataFrame
    indexed by SiteNativeId (string).
    """
    import os, glob
    import pandas as pd

    site_files = glob.glob(os.path.join(root_dir, "**", "Sites.csv"), recursive=True)
    parts = []

    def pick(colmap, candidates):
        """Finds first matching column in candidates (case-insensitive)."""
        clower = {c.lower(): c for c in candidates}
        for name in colmap:
            if name.lower() in clower:
                return clower[name.lower()]
        return None

    for sites_path in site_files:
        alloc_path = os.path.join(os.path.dirname(sites_path), "WaterAllocations.csv")
        if not os.path.exists(alloc_path):
            continue

        sites_df = pd.read_csv(sites_path)
        alloc_df = pd.read_csv(alloc_path)

        # ----- resolve keys -----
        sites_siteuuid = pick(["SiteUuid", "SiteUUID"], sites_df.columns)
        alloc_siteuuid = pick(["SiteUuid", "SiteUUID"], alloc_df.columns)

        if not sites_siteuuid or not alloc_siteuuid:
            continue

        # ----- pick important columns -----
        sites_sitenativeid = pick(["SiteNativeId", "SiteNativeID"], sites_df.columns)
        sites_sitename     = pick(["SiteName"], sites_df.columns)
        sites_lat          = pick(["Latitude", "Latitude_DD"], sites_df.columns)
        sites_lon          = pick(["Longitude", "Longitude_DD"], sites_df.columns)

        alloc_nativeid     = pick(["AllocationNativeID", "AllocationNativeId"], alloc_df.columns)
        alloc_flow_cfs     = pick(["AllocationFlow_CFS", "AllocationFlow_Cfs"], alloc_df.columns)
        alloc_priority     = pick(["AllocationPriorityDate", "PriorityDate"], alloc_df.columns)
        alloc_beneficial   = pick(["BeneficialUseCategory", "BeneficialUseCategoryCV"], alloc_df.columns)

        # ----- explode multi-site UUIDs -----
        # If SiteUUID field contains comma-separated values
        if alloc_siteuuid in alloc_df.columns:
            alloc_df[alloc_siteuuid] = alloc_df[alloc_siteuuid].astype(str).str.split(",")
            alloc_df = (
                alloc_df
                .explode(alloc_siteuuid)
                .assign(**{alloc_siteuuid: lambda d: d[alloc_siteuuid].str.strip()})
            )

        # ----- normalize and rename -----
        sites_keep = {
            sites_siteuuid: "SiteUuid",
            sites_sitenativeid: "SiteNativeId",
            sites_sitename if sites_sitename else sites_siteuuid: "SiteName",
            sites_lat: "Latitude",
            sites_lon: "Longitude",
        }
        sites_norm = sites_df[list(sites_keep.keys())].rename(columns=sites_keep)

        alloc_keep = {
            alloc_siteuuid: "SiteUuid",
            alloc_nativeid: "AllocationNativeID",
            alloc_priority: "AllocationPriorityDate",
        }
        if alloc_flow_cfs:
            alloc_keep[alloc_flow_cfs] = "AllocationFlow_CFS"
        else:
            alloc_df["__AllocationFlow_CFS__"] = pd.NA
            alloc_keep["__AllocationFlow_CFS__"] = "AllocationFlow_CFS"
        if alloc_beneficial:
            alloc_keep[alloc_beneficial] = "BeneficialUseCategory"
        else:
            alloc_df["__BeneficialUseCategory__"] = pd.NA
            alloc_keep["__BeneficialUseCategory__"] = "BeneficialUseCategory"

        alloc_norm = alloc_df[list(alloc_keep.keys())].rename(columns=alloc_keep)

        # ----- merge -----
        merged = alloc_norm.merge(sites_norm, on="SiteUuid", how="inner")
        parts.append(merged)

    if not parts:
        cols = [
            "SiteUuid", "SiteNativeId", "SiteName",
            "AllocationNativeID", "AllocationFlow_CFS",
            "AllocationPriorityDate", "BeneficialUseCategory",
            "Latitude", "Longitude",
        ]
        return pd.DataFrame(columns=cols)

    wade_all = pd.concat(parts, ignore_index=True)
    wade_all["SiteNativeId"] = wade_all["SiteNativeId"].astype(str)
    wade_all.set_index("SiteNativeId", inplace=True, drop=False)
    return wade_all


def insert_poi(poiid, poi_type_id, poi_lat, poi_lon, poi_native_id, poi_flow_com_id):
    try:
        cursor.execute(
            '''
            INSERT INTO POI (POI_ID, POI_TYPE_ID, POI_LAT, POI_LON, POI_NATIVE_ID, POI_FLOW_COMID)
            VALUES (?, ?, ?, ?, ?, ?)
            ''',
            (poiid, poi_type_id, poi_lat, poi_lon, poi_native_id, poi_flow_com_id),
        )
    except sqlite3.Error as e:
        print(f"An error occurred: {e}")

def insert_poi_values(dataval, localtime, poiid, variableid):
    try:
        cursor.execute(
            '''
            INSERT INTO POI_VALUES (DATA_VALUE, LOCAL_DATE_TIME, POI_ID, VARIABLE_ID)
            VALUES (?, ?, ?, ?)
            ''',
            (dataval, standardize_datetime(localtime), poiid, variableid),
        )
    except sqlite3.Error as e:
        print(f"An error occurred: {e}")

def poiid_exists(poiid):
    cursor.execute('SELECT 1 FROM POI WHERE POI_ID = ?', (poiid,))
    return cursor.fetchone() is not None

def get_site_coordinates(site_number):
    try:
        site_info = nwis.get_info(sites=site_number)
        latitude = site_info[0]['dec_lat_va'].iloc[0]
        longitude = site_info[0]['dec_long_va'].iloc[0]
        return latitude, longitude
    except Exception as e:
        print(f"Error retrieving coordinates for site {site_number}: {e}")
        return None, None

def insert_pod_waterrights(poiid, site_name, right_id, allocation_cfs, allocation_date,
                           use_type, water_source, source_id):
    try:
        cursor.execute(
            '''
            INSERT INTO POD_WATER_RIGHTS
                (POI_ID, SITE_NAME, WATER_RIGHT_ID, ALLOCATION_CFS, ALLOCATION_DATE,
                 USE_TYPE, WATER_SOURCE, SOURCE_ID)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?)
            ''',
            (poiid, site_name, right_id, allocation_cfs, standardize_datetime(allocation_date),
             use_type, water_source, source_id),
        )
    except sqlite3.Error as e:
        print(f"An error occurred: {e}")
        
def yield_payload():
    import numpy as np
    for start in range(0, len(demands_df), ROW_CHUNK):
        end = min(start + ROW_CHUNK, len(demands_df))
        sub_vals = demands_df.iloc[start:end][valid_wdid_cols]
        sub_date = date_std.iloc[start:end]

        # wide -> long just for this slice
        sub_long = (
            sub_vals.assign(__Date=sub_date.values)
                    .melt(id_vars='__Date', var_name='WDID', value_name='DataValue')
                    .dropna(subset=['DataValue'])
        )
        if sub_long.empty:
            continue

        # map to POI ids
        sub_long['POI_ID'] = sub_long['WDID'].map(wdid_to_poi)

        # build arrays, not Python lists (smaller overhead), then zip
        dv = sub_long['DataValue'].to_numpy()
        ldt = sub_long['__Date'].to_numpy()      # already standardized strings
        pid = sub_long['POI_ID'].to_numpy()

        # yield tuples to our bulk insert with "already standardized time"
        for t in zip(dv, ldt, pid, repeat(5, len(dv))):
            yield t

def bulk_insert_poi_values_fast(rows_iter, chunk=200000):
    sql = '''
        INSERT INTO POI_VALUES (DATA_VALUE, LOCAL_DATE_TIME, POI_ID, VARIABLE_ID)
        VALUES (?, ?, ?, ?)
    '''
    batch = []
    for r in rows_iter:
        batch.append(r)
        if len(batch) >= chunk:
            cursor.executemany(sql, batch)
            batch.clear()
    if batch:
        cursor.executemany(sql, batch)

In [None]:
# Directories

db_path= f'data/output/optional_db_{huc_code}.db'

# extended hydrofabric geopackage (can be HUC4-specific)
extended_hydrofabric = f'data/output/enhanced_reference_14.gpkg'

# ResOpsUS reservoir attributes + folder with timeseries
csv_file_path = 'data/reservoirs/updated_reservoir_attributes.csv'
res_folder_path = 'data/reservoirs/time_series_all/'

# Demand time series for diversions
demands_df = pd.read_csv('data/USGS (Lopez) Demands/diversion_records_wadeID.csv')
demands_df.columns = demands_df.columns.map(str)   # <-- force all column names to str

agg_csv = 'data/*' # Optional, If you have aggregated diversions place them in the folder with the correct structure (see example)
if os.path.exists(agg_csv):
    aggregated_diversions_df = pd.read_csv(agg_csv)
    aggregated_diversions_df['Aggregation ID'] = aggregated_diversions_df['Aggregation ID'].astype(str).str.strip()
else:
    print(" aggregated_table.csv not found. Continuing without aggregated diversion data.")

# Root folder containing per-state WaDE folders; each folder has Sites.csv + WaterAllocations.csv
# Example layout:
# data/WaDE_all_states/
#   Colorado_WaDE/Sites.csv
#   Colorado_WaDE/WaterAllocations.csv
#   Utah_WaDE/Sites.csv
#   Utah_WaDE/WaterAllocations.csv
WADE_ROOT = "data/wade data"

df = pd.read_csv(csv_file_path)

pod_layer = gpd.read_file(extended_hydrofabric, layer='DIVERSION_POINTS')
gage_layer = gpd.read_file(extended_hydrofabric, layer='event')
res_layer  = gpd.read_file(extended_hydrofabric, layer='RESERVOIR_POINTS')

# Boundary clipping for gages
boundary = gpd.read_file('data/wbd/WBDHU2.shp')
huc_wbd = boundary[boundary[huc_feature] == huc_code].to_crs(epsg=4326)

gage_layer = gage_layer.to_crs(huc_wbd.crs)
events_selected = gage_layer[gage_layer.geometry.within(huc_wbd.unary_union)]
gage_filtered = events_selected.loc[events_selected['hl_reference'] == 'type_gages']
USGS_gage_ids = gage_filtered['hl_link'].tolist()

# IDs for reservoirs & diversions
dam_name_ids = res_layer['NID_ID'].tolist()

# Filter DataFrame for existing dams in attributes table
filtered_df = df[df['NID_ID'].isin(dam_name_ids)]

In [None]:
# Blank database schema creation

conn = sqlite3.connect(db_path)
cursor = conn.cursor()

# Create table for POI_TYPE
cursor.execute('''
CREATE TABLE IF NOT EXISTS POI_TYPE (
    POI_TYPE_ID INTEGER PRIMARY KEY,
    POI_TYPE_NAME TEXT,
    POI_TYPE_SOURCE TEXT
)
''')

# Create table for POI
cursor.execute('''
CREATE TABLE IF NOT EXISTS POI (
    POI_ID TEXT PRIMARY KEY,
    POI_TYPE_ID INTEGER,
    POI_LAT REAL,
    POI_LON REAL,
    POI_NATIVE_ID TEXT,
    POI_FLOW_COMID INTEGER,
    FOREIGN KEY (POI_TYPE_ID) REFERENCES POI_TYPE (POI_TYPE_ID)
)
''')

# Create table for VARIABLES
cursor.execute('''
CREATE TABLE IF NOT EXISTS VARIABLES (
    VARIABLE_ID INTEGER PRIMARY KEY,
    VARIABLE_NAME TEXT,
    UNIT TEXT
)
''')

# Create table for POI_VALUES
cursor.execute('''
CREATE TABLE IF NOT EXISTS POI_VALUES (
    VALUE_ID INTEGER PRIMARY KEY AUTOINCREMENT,
    DATA_VALUE REAL,
    LOCAL_DATE_TIME TEXT,
    POI_ID TEXT,
    VARIABLE_ID INTEGER,
    FOREIGN KEY (POI_ID) REFERENCES POI (POI_ID),
    FOREIGN KEY (VARIABLE_ID) REFERENCES VARIABLES (VARIABLE_ID)
)
''')

# Create table for ET_PRECIP
cursor.execute('''
CREATE TABLE IF NOT EXISTS ET_PRECIP (
    ET_PRECIP_ID INTEGER PRIMARY KEY,
    ET_DA_VALUE REAL,
    PRECIP_VALUE REAL,
    LOCAL_DATE_TIME TEXT,
    POI_ID TEXT,
    VARIABLE_ID INTEGER,
    FOREIGN KEY (POI_ID) REFERENCES POI (POI_ID),
    FOREIGN KEY (VARIABLE_ID) REFERENCES VARIABLES (VARIABLE_ID)
)
''')

# Create table for RULE_CURVES
cursor.execute('''
CREATE TABLE IF NOT EXISTS RULE_CURVES (
    RULE_CURVE_ID INTEGER PRIMARY KEY,
    MIN_RELEASE_VALUE REAL,
    TARGET_STORAGE_VALUE REAL,
    LOCAL_DATE_TIME TEXT,
    POI_ID TEXT,
    VARIABLE_ID INTEGER,
    FOREIGN KEY (POI_ID) REFERENCES POI (POI_ID),
    FOREIGN KEY (VARIABLE_ID) REFERENCES VARIABLES (VARIABLE_ID)
)
''')

# Create table for POD_WATER_RIGHTS
cursor.execute('''
CREATE TABLE IF NOT EXISTS POD_WATER_RIGHTS (
    POD_WATER_RIGHTS_ID INTEGER PRIMARY KEY,
    POI_ID TEXT,
    SITE_NAME TEXT,
    WATER_RIGHT_ID TEXT,
    ALLOCATION_CFS REAL,
    ALLOCATION_DATE TEXT,
    USE_TYPE TEXT,
    WATER_SOURCE TEXT,
    SOURCE_ID REAL,
    FOREIGN KEY (POI_ID) REFERENCES POI (POI_ID)
)
''')

# Commit the changes and close the connection
conn.commit()
conn.close()

In [None]:
# fixed variable inserts
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

# Data inserted into POI_TYPE
poi_types = [
    ('USGS_GAGE', 'USGS'),
    # ('STATE_GAGE', 'CDSS'),
    ('POD', 'WaDE'),
    ('RESERVOIR', 'ResOpsUS')
]

for poi_type in poi_types:
    insert_poi_type(poi_type[0], poi_type[1])

# Data inserted into VARIABLES
variables = [
    ('INFLOW', 'CMS'),
    ('OUTFLOW', 'CMS'),
    ('STORAGE', 'MCM'),
    ('GAGE_FLOW', 'CFS'),
    ('DEMAND', 'CM')
]

for variable in variables:
    insert_variable(variable[0], variable[1])

conn.commit()
conn.close()

In [None]:

# --------------------------
# LOAD *ALL STATES* WADE (STATE-AGNOSTIC)
# --------------------------
wade_df = load_all_wade(WADE_ROOT)  # wade_df is indexed by SiteNativeId (string)


print("Wade loaded")

In [None]:
# PROCESS RESERVOIRS (ResOpsUS)

# SQLite DB target
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

cursor.execute("PRAGMA journal_mode = WAL;")      
cursor.execute("PRAGMA synchronous = NORMAL;")  
cursor.execute("PRAGMA temp_store = MEMORY;")     
cursor.execute("PRAGMA cache_size = -200000;") 

for _, row in tqdm(filtered_df.iterrows(), total=len(filtered_df), desc="Processing Reservoir Data"):
    poiid = f"{row['DAM_ID']}_{row['NID_ID']}"

    if poiid_exists(poiid):
        continue

    # Find the Source_comid for this reservoir
    comid_vals = res_layer.loc[res_layer['NID_ID'] == row['NID_ID'], 'SOURCE_COMID'].values
    poi_flow_com_id = comid_vals[0] if len(comid_vals) > 0 else None

    insert_poi(
        poiid=poiid,
        poi_type_id=3,
        poi_lat=row['LATITUDE'],
        poi_lon=row['LONGITUDE'],
        poi_native_id=row['NID_ID'],
        poi_flow_com_id=poi_flow_com_id
    )

    file_name = f"ResOpsUS_{row['DAM_ID']}.csv"
    file_path = os.path.join(res_folder_path, file_name)

    if os.path.exists(file_path):
        dam_df = pd.read_csv(file_path)
        variable_names = ['inflow', 'outflow', 'storage']
        var_id_start = 1

        for var_id, var_name in enumerate(variable_names, start=var_id_start):
            if var_name in dam_df.columns:
                for _, dam_row in dam_df.iterrows():
                    insert_poi_values(
                        dataval=dam_row[var_name],
                        localtime=dam_row['date'],
                        poiid=poiid,
                        variableid=var_id,
                    )
                    
conn.commit()
cursor.close()
conn.close()

In [None]:
# --------------------------
# PROCESS USGS GAGES
# --------------------------

# SQLite DB target
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

cursor.execute("PRAGMA journal_mode = WAL;")      
cursor.execute("PRAGMA synchronous = NORMAL;")    
cursor.execute("PRAGMA temp_store = MEMORY;")     
cursor.execute("PRAGMA cache_size = -200000;")   

for gage_id in tqdm (USGS_gage_ids, total = len(USGS_gage_ids), desc= "Processing USGS Dage Data"):
    poiid = f"USGS_{gage_id}"

    if poiid_exists(poiid):
        continue

    try:
        nwis_data_raw = nwis.get_dv(sites=gage_id, parameterCd='00060', statCd='00003', startDT='1880-01-01')

        nwis_data = nwis_data_raw[0]
        latitude, longitude = get_site_coordinates(gage_id)
    except Exception as e:
        continue


    def pick_00060_mean_column(df):
        # case-insensitive match for "mean" to handle "Mean" vs "mean"
        candidates = [c for c in df.columns if c.startswith('00060') and 'mean' in c.lower()]
        if not candidates:
            raise KeyError("No 00060 mean column found.")
        # Prefer the unsuffixed version (e.g., '00060_Mean') if it exists
        unsuffixed = [c for c in candidates if c.lower() in ('00060_mean', '00060_mean')]
        if unsuffixed:
            return unsuffixed[0]
        # Otherwise just pick the first candidate (e.g., '00060_2_mean')
        return sorted(candidates)[0]

    comid_vals = gage_layer.loc[gage_layer['hl_link'] == gage_id, 'hy_id'].values
    poi_flow_com_id = comid_vals[0] if len(comid_vals) > 0 else None

    if not nwis_data.empty and latitude is not None and longitude is not None:
        insert_poi(
            poiid=poiid,
            poi_type_id=1,
            poi_lat=latitude,
            poi_lon=longitude,
            poi_native_id=gage_id,
            poi_flow_com_id=poi_flow_com_id
        )

        nwis_data = nwis_data_raw[0]
        flow_col = pick_00060_mean_column(nwis_data)

        for index, entry in nwis_data.iterrows():
            insert_poi_values(
                dataval=entry[flow_col],
                localtime=index.isoformat(),
                poiid=poiid,
                variableid=4,
            )
    else:
        print(f"No data found or incomplete data for gage ID: {gage_id}")
        
        
conn.commit()
cursor.close()
conn.close()

In [None]:
# --------------------------
# PROCESS PODs and Water Rights
# --------------------------

# SQLite DB target
conn = sqlite3.connect(db_path)
cursor = conn.cursor()

cursor.execute("PRAGMA journal_mode = WAL;")      
cursor.execute("PRAGMA synchronous = NORMAL;")    
cursor.execute("PRAGMA temp_store = MEMORY;")     
cursor.execute("PRAGMA cache_size = -200000;")  

wdid_to_poi = {str(row['WDID']): f"DIV_{str(row['WDID'])}" for _, row in pod_layer.iterrows()}

for _, diversion_row in tqdm(pod_layer.iterrows(), total = len(pod_layer), desc = "Processing Diversion Points and Water Rights"):
    wdid = str(diversion_row['WDID'])
    poi_id = f"DIV_{wdid}"
    latitude = diversion_row['LATITUDE']
    longitude = diversion_row['LONGITUDE']
    flowline_comid = diversion_row['SOURCE_COMID']
    diversion_type = str(diversion_row['TYPE']).strip()
    water_source = diversion_row['WATER_SOURCE']
    source_gnis_id = diversion_row['SOURCE_GNIS_ID']
    native_id = diversion_row['POI_NATIVE_ID']

    # Create POI (idempotent is fine; rely on unique PK to avoid dupes if needed)
    insert_poi(
        poiid=poi_id,
        poi_type_id=2,
        poi_lat=latitude,
        poi_lon=longitude,
        poi_native_id=native_id,
        poi_flow_com_id=flowline_comid
    )


    payload = []  # collect rights rows for this POI, then insert in bulk

    if diversion_type == 'Physical':
        matches = rows_for_site_id(wdid)
        if not matches.empty:
            for _, r in matches.iterrows():
                payload.append((
                    poi_id,
                    r['SiteName'],
                    r['AllocationNativeID'],
                    r['AllocationFlow_CFS'],
                    r['AllocationPriorityDate'],
                    r['BeneficialUseCategory'],
                    water_source,
                    source_gnis_id
                ))

    elif os.path.exists(agg_csv) and diversion_type == 'Aggregated Diversion':
        aggregation_matches = aggregated_df[
            aggregated_df['Aggregation ID'].astype(str) == wdid
        ]
        for _, a in aggregation_matches.iterrows():
            agg_wdid = str(a['WDID'])
            matches = rows_for_site_id(agg_wdid)
            if matches.empty:
                continue
            for _, r in matches.iterrows():
                payload.append((
                    poi_id,
                    r['SiteName'],
                    r['AllocationNativeID'],
                    r['AllocationFlow_CFS'],
                    r['AllocationPriorityDate'],
                    r['BeneficialUseCategory'],
                    water_source,
                    source_gnis_id
                ))
    else:
        # Fallback direct match
        matches = rows_for_site_id(wdid)
        if not matches.empty:
            for _, r in matches.iterrows():
                payload.append((
                    poi_id,
                    r['SiteName'],
                    r['AllocationNativeID'],
                    r['AllocationFlow_CFS'],
                    r['AllocationPriorityDate'],
                    r['BeneficialUseCategory'],
                    water_source,
                    source_gnis_id
                ))

    # single fast insert for this POD
    bulk_insert_pod_waterrights(payload)
    
conn.commit()
cursor.close()
conn.close()

In [None]:
# --------------------------
# Processing Demands
# --------------------------

conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("PRAGMA foreign_keys = OFF;")
cursor.execute("PRAGMA synchronous = OFF;")
cursor.execute("PRAGMA journal_mode = MEMORY;")
cursor.execute("PRAGMA temp_store = MEMORY;")
cursor.execute("PRAGMA locking_mode = EXCLUSIVE;")
cursor.execute("BEGIN IMMEDIATE;")  

valid_wdid_cols = [c for c in demands_df.columns if c != 'Date' and c in wdid_to_poi]
if valid_wdid_cols:

    dt = pd.to_datetime(demands_df['Date'], errors='coerce', utc=False)
    date_std = (
        pd.Series(
            dt.dt.strftime('%Y-%m-%dT%H:%M:%S').where(
                # if it looks like date-only (midnight) and original didn't have ':'
                (dt.dt.hour.ne(0)) | (dt.dt.minute.ne(0)) | (dt.dt.second.ne(0)) | demands_df['Date'].astype(str).str.contains(':|T'),
                dt.dt.date.astype(str)
            ),
            index=demands_df.index
        )
    )
    

    bulk_insert_poi_values_fast(yield_payload(), chunk=200000)

    print("Finished bulk demand insert.")

# Restore safer settings and end transaction
cursor.execute("COMMIT;")
cursor.execute("PRAGMA foreign_keys = ON;")
cursor.execute("PRAGMA synchronous = NORMAL;")
cursor.execute("PRAGMA journal_mode = WAL;")
cursor.execute("PRAGMA locking_mode = NORMAL;")
# --------------------------
# COMMIT & CLOSE
# --------------------------
conn.commit()
cursor.close()
conn.close()