In [1]:
#%% Setting Up
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import networkx as nx
import rioxarray as rxr

import geopandas as gpd
import seaborn as sns
import matplotlib.pyplot as plt

from shapely.geometry import Point
from shapely.geometry import Polygon

import glob
import os
import itertools
import tqdm
import gc
import time
import pickle

from joblib import Parallel, delayed

import rioxarray as rxr

import configparser
cfg = configparser.ConfigParser()
cfg.optionxform = str
cfg.read('/home/sarth/rootdir/datadir/assets/defaults.ini')
cfg = {s: dict(cfg.items(s)) for s in cfg.sections()}
PATHS = cfg['PATHS']

print("Setting up...")

Setting up...


In [2]:
#%% Region-Specific: CAMELS-IND
DIRNAME = '03min_GloFAS_CAMELS-IND'
SAVE_PATH = os.path.join(PATHS['devp_datasets'], DIRNAME)
resolution = 0.05
lon_360_180 = lambda x: (x + 180) % 360 - 180 # convert 0-360 to -180-180
lon_180_360 = lambda x: x % 360 # convert -180-180 to 0-360
region_bounds = {
    'minx': 66,
    'miny': 5,
    'maxx': 100,
    'maxy': 30
}

camels_graph = pd.read_csv(os.path.join(SAVE_PATH, 'nested_gauges', 'graph_attributes_with_nesting.csv'), index_col=0)
camels_graph.index = camels_graph.index.map(lambda x: str(x).zfill(5))
camels_graph['huc_02'] = camels_graph['huc_02'].map(lambda x: str(x).zfill(2))
# camels_graph = camels_graph[camels_graph['nesting'].isin(['not_nested', 'nested_downstream'])]
camels_graph = camels_graph.reset_index()
print(f"Number of catmt's with nesting: {len(camels_graph)}")
camels_graph

Number of catmt's with nesting: 191


Unnamed: 0,gauge_id,huc_02,gauge_lon,gauge_lat,ghi_area,cwc_lon,cwc_lat,cwc_area,cwc_site_name,ghi_stn_id,...,cwc_river,flow_availability,snapped_lon,snapped_lat,snapped_uparea,snapped_iou,area_percent_difference,num_nodes,num_edges,nesting
0,14015,14,73.11033,18.73540,125.7,73.1108,18.7367,125.0,Pen,wfrn_penxx,...,Bhogeswari,31.36,73.125,18.725,116.744995,0.650870,7.124106,4.0,3.0,not_nested
1,15006,15,74.88124,13.51876,299.6,74.8800,13.5214,253.0,Avershe,wfrs_avers,...,Seetha,39.00,74.925,13.475,329.444850,0.607418,9.961565,11.0,10.0,not_nested
2,05025,05,78.05617,11.93959,356.0,78.0572,11.9383,362.0,Thoppur,cauv_thopp,...,Cauvery/Thoppaiyar,31.52,78.125,11.975,331.356480,0.634196,6.922338,11.0,10.0,not_nested
3,15032,15,74.98123,13.29791,356.9,74.9806,13.2942,327.0,Yennehole,wfrs_yenne,...,Swarna,70.67,74.975,13.275,329.717100,0.755086,7.616389,11.0,10.0,not_nested
4,15007,15,76.84792,8.71458,555.8,76.8500,8.7150,540.0,Ayilam,wfrs_ayila,...,Vamanapuram,93.18,76.875,8.725,578.080800,0.773097,4.008784,19.0,18.0,not_nested
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
186,04038,04,75.13955,17.97291,22326.1,75.1392,17.9722,22856.0,Narsingpur,kris_narsi,...,Krishna/Bhima,95.30,75.125,17.975,22149.584000,0.953356,0.790624,758.0,757.0,nested_downstream
187,04057,04,76.96458,15.66043,23527.1,76.9628,15.6611,23500.0,T. Ramapuram (Seasonal),kris_trama,...,Krishna/Tungabhadra/Hagari,84.89,76.975,15.625,23463.662000,0.941489,0.269636,786.0,785.0,nested_downstream
188,03057,03,80.68958,19.29779,24041.7,80.6892,19.2989,24210.0,Mirdapalli,goda_mirda,...,Godavari/Pranhita/Indravati,56.17,80.675,19.275,23877.793000,0.940053,0.681758,821.0,820.0,nested_downstream
189,12004,12,79.01458,23.03126,26297.6,79.0156,23.0308,26453.0,Barmanghat,narm_barma,...,Narmada,91.85,79.025,23.025,26175.932000,0.943601,0.462658,921.0,920.0,nested_downstream


In [3]:
def idx_to_map(ds, var_name):
    lats = ds.lat.values
    lons = ds.lon.values
    catmt_var_map = xr.DataArray(
        np.zeros((len(lats), len(lons)), dtype = np.float32)*np.nan,
        dims = ['lat', 'lon'],
        coords = {'lat': lats, 'lon': lons}
    )
    for idx in ds.idx.values:
        lat, lon = ds['idx2lat'].sel(idx = idx).values, ds['idx2lon'].sel(idx = idx).values
        catmt_var_map.loc[dict(lat = lat, lon = lon)] = ds[var_name].sel(idx = idx).values
    return catmt_var_map

START_DATE = pd.Timestamp('1998-01-01')
END_DATE = pd.Timestamp('2022-12-31')

# Helper Function

In [5]:
def process_catmt(huc, gauge_id):
    catmt = xr.open_zarr(os.path.join(SAVE_PATH, 'inventory', huc, f'{gauge_id}.zarr'))
    catmt = catmt[sorted(catmt.data_vars)]

    MONTHLY_PATH = os.path.join(SAVE_PATH, 'monthly_inventory')

    #########################
    sum_var_names = []
    # ERA5
    var_names = [
        'total_precipitation',
        'surface_net_solar_radiation',
        'surface_net_thermal_radiation',
        'evaporation',
        'potential_evaporation',
        'runoff',
        'surface_runoff',
        'sub_surface_runoff'
    ]
    var_names = [f"dynamic_ERA5_{var_name}" for var_name in var_names]
    sum_var_names.extend(var_names)

    # # Daymet
    # var_names = [
    #     'prcp',
    # ]
    # var_names = [f"dynamic_Daymet_{var_name}" for var_name in var_names]
    # sum_var_names.extend(var_names)

    # ERA5-Land
    var_names = [
        'total_precipitation_sum', # SUM
        'total_evaporation_sum', # SUM
        'potential_evaporation_sum', # SUM
        'surface_net_solar_radiation_sum', # SUM
        'surface_net_thermal_radiation_sum', # SUM
        'snowfall_sum', # SUM
        'snowmelt_sum', # SUM
        'runoff_sum', # SUM
        'surface_runoff_sum', # SUM
        'sub_surface_runoff_sum', # SUM
    ]
    var_names = [f"dynamic_ERA5-Land_{var_name}" for var_name in var_names]
    sum_var_names.extend(var_names)

    # GLEAM4
    var_names = [
        'Ep', # SUM
    ]
    var_names = [f"dynamic_GLEAM4_{var_name}" for var_name in var_names]
    sum_var_names.extend(var_names)

    # GPM
    var_names = [
        'Early_Run', # SUM
        'Late_Run', # SUM
        'Final_Run', # SUM
    ]
    var_names = [f"dynamic_GPM_{var_name}" for var_name in var_names]
    sum_var_names.extend(var_names)

    # GloFAS
    var_names = [
        'discharge_mm', # SUM
        'runoff_water_equivalent', # SUM
    ]
    var_names = [f"dynamic_GloFAS_{var_name}" for var_name in var_names]
    sum_var_names.extend(var_names)

    # IndiaWRIS
    var_names = [
        'outlet_IndiaWRIS_Q_mm'
    ]
    # var_names = [f"dynamic_USGS_{var_name}" for var_name in var_names]
    sum_var_names.extend(var_names)

    catmt_sum = catmt[sum_var_names].copy()

    # Resample to monthly sum at month-start
    catmt_sum = catmt_sum.resample(time='1MS').sum(dim='time')
    #########################

    #########################
    mean_var_names = []

    # ERA5
    var_names = [
        '2m_temperature',
        'surface_pressure',
        '2m_dewpoint_temperature',
        '10m_u_component_of_wind',
        '10m_v_component_of_wind',
        'snowfall',
        'snow_depth',
        'snowmelt',
        'volumetric_soil_water_layer_1',
        'volumetric_soil_water_layer_2',
        'volumetric_soil_water_layer_3',
        'volumetric_soil_water_layer_4'
    ]
    var_names = [f"dynamic_ERA5_{var_name}" for var_name in var_names]
    mean_var_names.extend(var_names)

    # # Daymet
    # var_names = [
    #     'srad', # MEAN
    #     'swe', # MEAN
    #     'tmax', # MAX
    #     'tmin', # MIN
    #     'vp', # MEAN
    #     'dayl', # MEAN
    # ]
    # var_names = [f"dynamic_Daymet_{var_name}" for var_name in var_names]
    # mean_var_names.extend(var_names)

    # ERA5-Land
    var_names = [
        'temperature_2m_min', # MIN
        'temperature_2m_max', # MAX
        'surface_pressure', # MEAN
        'u_component_of_wind_10m', # MEAN
        'v_component_of_wind_10m', # MEAN
        'snow_depth', # MEAN
        'snow_cover', # MEAN
        'dewpoint_temperature_2m_min', # MIN
        'dewpoint_temperature_2m_max', # MAX
        'leaf_area_index_high_vegetation', # MEAN
        'leaf_area_index_low_vegetation', # MEAN
        'volumetric_soil_water_layer_1', # MEAN
        'volumetric_soil_water_layer_2', # MEAN
        'volumetric_soil_water_layer_3', # MEAN
        'volumetric_soil_water_layer_4', # MEAN
    ]
    var_names = [f"dynamic_ERA5-Land_{var_name}" for var_name in var_names]
    mean_var_names.extend(var_names)

    # GLEAM4
    var_names = [
        'SMs', # MEAN
        'SMrz' # MEAN
    ]
    var_names = [f"dynamic_GLEAM4_{var_name}" for var_name in var_names]
    mean_var_names.extend(var_names)

    # GPM

    # GloFAS
    var_names = [
        'snow_depth_water_equivalent', # MEAN
        'soil_wetness_index', # MEAN
    ]
    var_names = [f"dynamic_GloFAS_{var_name}" for var_name in var_names]
    mean_var_names.extend(var_names)

    # Encodings
    var_names = [
        'encoding_solar_insolation', # MEAN
        'encoding_sine_dayofyear',
        'encoding_sine_weekofyear',
        'encoding_sine_month'
        
        
    ]
    mean_var_names.extend(var_names)

    catmt_mean = catmt[mean_var_names].copy()

    # Resample to monthly mean at month-start
    catmt_mean = catmt_mean.resample(time='1MS').mean(dim='time')
    #########################

    #########################
    catmt_non_dynamic = catmt[[var for var in catmt.data_vars if not (var in sum_var_names or var in mean_var_names or var.startswith('dynamic_HRES'))]]
    catmt_non_dynamic = catmt_non_dynamic[sorted(catmt_non_dynamic.data_vars)]
    #########################

    catmt_monthly = xr.merge([catmt_sum, catmt_mean, catmt_non_dynamic])
    catmt_monthly = catmt_monthly[sorted(catmt_monthly.data_vars)]

    # Save to zarr
    os.makedirs(os.path.join(MONTHLY_PATH, huc), exist_ok=True)
    catmt_monthly.to_zarr(os.path.join(MONTHLY_PATH, huc, f'{gauge_id}.zarr'), mode='w', consolidated=True)

    # Clean up
    catmt.close()
    del catmt, catmt_sum, catmt_mean, catmt_non_dynamic, catmt_monthly
    gc.collect()

In [None]:
# camels_graph = camels_graph.sort_values(by=['area_geospa_fabric'], ascending=True)

In [6]:
for idx, row in itertools.islice(tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)), 0, 1, 1):
    huc, gauge_id = row['huc_02'], row['gauge_id']
    process_catmt(huc, gauge_id)

  0%|          | 0/191 [03:11<?, ?it/s]


In [7]:
with Parallel(n_jobs=16, verbose=10) as parallel:
    parallel(
        delayed(process_catmt)(row['huc_02'], row['gauge_id']) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph))
    )

[Parallel(n_jobs=16)]: Done 395 out of 395 | elapsed: 128.2min finished


In [None]:
# huc = '05'
# for gauge_id in ['03140000', '03384450', '03280700', '03078000', '03187500', '03021350', '03010655', '03186500']:
#     print(f"Processing HUC: {huc}, Gauge ID: {gauge_id}")
#     process_catmt(huc, gauge_id)

Processing HUC: 05, Gauge ID: 03140000
Processing HUC: 05, Gauge ID: 03384450
Processing HUC: 05, Gauge ID: 03280700
Processing HUC: 05, Gauge ID: 03078000
Processing HUC: 05, Gauge ID: 03187500
Processing HUC: 05, Gauge ID: 03021350
Processing HUC: 05, Gauge ID: 03010655
Processing HUC: 05, Gauge ID: 03186500


# Development