In [1]:
#%% Setting Up
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import networkx as nx
import rioxarray as rxr

import geopandas as gpd
import seaborn as sns
import matplotlib.pyplot as plt

from shapely.geometry import Point
from shapely.geometry import Polygon

import glob
import os
import itertools
import tqdm
import gc
import time
import pickle

from joblib import Parallel, delayed

import rioxarray as rxr

import configparser
cfg = configparser.ConfigParser()
cfg.optionxform = str
cfg.read('/home/sarth/rootdir/datadir/assets/defaults.ini')
cfg = {s: dict(cfg.items(s)) for s in cfg.sections()}
PATHS = cfg['PATHS']

print("Setting up...")

Setting up...


In [2]:
#%% Region-Specific: CAMELS-IND
DIRNAME = '03min_GloFAS_CAMELS-IND'
SAVE_PATH = os.path.join(PATHS['devp_datasets'], DIRNAME)
resolution = 0.05
lon_360_180 = lambda x: (x + 180) % 360 - 180 # convert 0-360 to -180-180
lon_180_360 = lambda x: x % 360 # convert -180-180 to 0-360
region_bounds = {
    'minx': 66,
    'miny': 5,
    'maxx': 100,
    'maxy': 30
}

camels_graph = pd.read_csv(os.path.join(SAVE_PATH, 'nested_gauges', 'graph_attributes_with_nesting.csv'), index_col=0)
camels_graph.index = camels_graph.index.map(lambda x: str(x).zfill(5))
camels_graph['huc_02'] = camels_graph['huc_02'].map(lambda x: str(x).zfill(2))
# camels_graph = camels_graph[camels_graph['nesting'].isin(['not_nested', 'nested_downstream'])]
camels_graph = camels_graph.reset_index()
print(f"Number of catmt's with nesting: {len(camels_graph)}")
camels_graph

Number of catmt's with nesting: 191


Unnamed: 0,gauge_id,huc_02,gauge_lon,gauge_lat,ghi_area,cwc_lon,cwc_lat,cwc_area,cwc_site_name,ghi_stn_id,...,cwc_river,flow_availability,snapped_lon,snapped_lat,snapped_uparea,snapped_iou,area_percent_difference,num_nodes,num_edges,nesting
0,14015,14,73.11033,18.73540,125.7,73.1108,18.7367,125.0,Pen,wfrn_penxx,...,Bhogeswari,31.36,73.125,18.725,116.744995,0.650870,7.124106,4.0,3.0,not_nested
1,15006,15,74.88124,13.51876,299.6,74.8800,13.5214,253.0,Avershe,wfrs_avers,...,Seetha,39.00,74.925,13.475,329.444850,0.607418,9.961565,11.0,10.0,not_nested
2,05025,05,78.05617,11.93959,356.0,78.0572,11.9383,362.0,Thoppur,cauv_thopp,...,Cauvery/Thoppaiyar,31.52,78.125,11.975,331.356480,0.634196,6.922338,11.0,10.0,not_nested
3,15032,15,74.98123,13.29791,356.9,74.9806,13.2942,327.0,Yennehole,wfrs_yenne,...,Swarna,70.67,74.975,13.275,329.717100,0.755086,7.616389,11.0,10.0,not_nested
4,15007,15,76.84792,8.71458,555.8,76.8500,8.7150,540.0,Ayilam,wfrs_ayila,...,Vamanapuram,93.18,76.875,8.725,578.080800,0.773097,4.008784,19.0,18.0,not_nested
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
186,04038,04,75.13955,17.97291,22326.1,75.1392,17.9722,22856.0,Narsingpur,kris_narsi,...,Krishna/Bhima,95.30,75.125,17.975,22149.584000,0.953356,0.790624,758.0,757.0,nested_downstream
187,04057,04,76.96458,15.66043,23527.1,76.9628,15.6611,23500.0,T. Ramapuram (Seasonal),kris_trama,...,Krishna/Tungabhadra/Hagari,84.89,76.975,15.625,23463.662000,0.941489,0.269636,786.0,785.0,nested_downstream
188,03057,03,80.68958,19.29779,24041.7,80.6892,19.2989,24210.0,Mirdapalli,goda_mirda,...,Godavari/Pranhita/Indravati,56.17,80.675,19.275,23877.793000,0.940053,0.681758,821.0,820.0,nested_downstream
189,12004,12,79.01458,23.03126,26297.6,79.0156,23.0308,26453.0,Barmanghat,narm_barma,...,Narmada,91.85,79.025,23.025,26175.932000,0.943601,0.462658,921.0,920.0,nested_downstream


In [3]:
def idx_to_map(ds, var_name):
    lats = ds.lat.values
    lons = ds.lon.values
    catmt_var_map = xr.DataArray(
        np.zeros((len(lats), len(lons)), dtype = np.float32)*np.nan,
        dims = ['lat', 'lon'],
        coords = {'lat': lats, 'lon': lons}
    )
    for idx in ds.idx.values:
        lat, lon = ds['idx2lat'].sel(idx = idx).values, ds['idx2lon'].sel(idx = idx).values
        catmt_var_map.loc[dict(lat = lat, lon = lon)] = ds[var_name].sel(idx = idx).values
    return catmt_var_map

START_DATE = pd.Timestamp('1998-01-01')
END_DATE = pd.Timestamp('2022-12-31')

In [4]:
def process_catmt_final(huc, gauge_id):
    warnings.filterwarnings('ignore')
    
    catmt = xr.open_zarr(os.path.join(SAVE_PATH, 'inventory', huc, f'{gauge_id}.zarr'))
    catmt = catmt[sorted(catmt.data_vars)]

    LUMPED_PATH = os.path.join(SAVE_PATH, 'lumped_inventory')

    var_names = [var for var in catmt.data_vars if (var.startswith('dynamic_') and (not 'IndiaWRIS' in var) and (not 'discharge' in var))]
    var_names.extend(['encoding_solar_insolation'])
    catmt_area_weighted = (catmt[var_names] * catmt['static_GloFAS_cellarea_km2']).sum(dim='idx') / catmt['static_GloFAS_cellarea_km2'].sum(dim='idx')
    catmt_area_weighted = catmt_area_weighted[sorted(catmt_area_weighted.data_vars)]

    var_names = [var for var in catmt.data_vars if (var.startswith('dynamic_') and (not 'IndiaWRIS' in var) and (not 'discharge' in var))]
    var_names.extend(['encoding_solar_insolation', 'idx2lat', 'idx2lon', 'mask'])
    var_names = [var for var in catmt.data_vars if not var in var_names]

    # For .sel(idx=0)
    catmt_lumped_sel = catmt[var_names].copy()
    catmt_lumped_sel = catmt_lumped_sel[['dynamic_GloFAS_discharge_mm', 'static_uparea', 'encoding_transformed_longitude', 'encoding_transformed_latitude']]
    catmt_lumped_sel = catmt_lumped_sel.sel(idx=0)

    # For .area_weighted
    # Drop ['dynamic_GloFAS_discharge_mm', 'static_uparea'] from var_names
    var_names = [var for var in var_names if not (var in ['dynamic_GloFAS_discharge_mm', 'static_uparea', 'encoding_transformed_longitude', 'encoding_transformed_latitude'])]
    var_names = [var for var in var_names if not var.startswith('static_ERA5_type') or var.startswith('static_ERA5_soil_type')]
    catmt_area_weighted_2 = (catmt[var_names] * catmt['static_GloFAS_cellarea_km2']).sum(dim='idx') / catmt['static_GloFAS_cellarea_km2'].sum(dim='idx')
    catmt_area_weighted_2 = catmt_area_weighted_2[sorted(catmt_area_weighted_2.data_vars)]

    catmt_lumped = xr.merge([
        catmt_area_weighted,
        catmt_lumped_sel,
        catmt_area_weighted_2
    ])

    # Save the area-weighted variables
    os.makedirs(os.path.join(LUMPED_PATH, huc), exist_ok=True)
    catmt_lumped.to_zarr(os.path.join(LUMPED_PATH, huc, f'{gauge_id}.zarr'), mode='w')

    # Clean up
    del catmt, catmt_area_weighted
    gc.collect()

    return None

In [5]:
with Parallel(n_jobs=16, verbose=10) as parallel:
    parallel(
        delayed(process_catmt_final)(row['huc_02'], row['gauge_id']) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph))
    )

  0%|          | 0/191 [00:00<?, ?it/s][Parallel(n_jobs=16)]: Using backend LokyBackend with 16 concurrent workers.


 17%|█▋        | 32/191 [00:10<00:58,  2.73it/s][Parallel(n_jobs=16)]: Done   9 tasks      | elapsed:   12.4s
 25%|██▌       | 48/191 [00:27<01:38,  1.45it/s][Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:   31.0s
[Parallel(n_jobs=16)]: Done  29 tasks      | elapsed:   34.8s
 34%|███▎      | 64/191 [00:41<01:37,  1.30it/s][Parallel(n_jobs=16)]: Done  40 tasks      | elapsed:  1.0min
 42%|████▏     | 80/191 [01:05<01:54,  1.03s/it][Parallel(n_jobs=16)]: Done  53 tasks      | elapsed:  1.3min
 50%|█████     | 96/191 [01:34<02:02,  1.29s/it][Parallel(n_jobs=16)]: Done  66 tasks      | elapsed:  1.6min
[Parallel(n_jobs=16)]: Done  81 tasks      | elapsed:  2.2min
 59%|█████▊    | 112/191 [02:14<02:13,  1.69s/it][Parallel(n_jobs=16)]: Done  96 tasks      | elapsed:  3.0min
 67%|██████▋   | 128/191 [03:07<02:19,  2.21s/it][Parallel(n_jobs=16)]: Done 113 tasks      | elapsed:  4.2min
 84%|████████▍ | 160/191 [05:03<01:30,  2.93s/it][Parallel(n_jobs=16)]: Done 130 tasks      | elapsed:  

In [None]:
# def process_catmt(huc, gauge_id):
#     warnings.filterwarnings('ignore')
    
#     catmt = xr.open_zarr(os.path.join(SAVE_PATH, 'inventory', huc, f'{gauge_id}.zarr'))
#     catmt = catmt[sorted(catmt.data_vars)]

#     LUMPED_PATH = os.path.join(SAVE_PATH, 'lumped_inventory')

#     var_names = [var for var in catmt.data_vars if (var.startswith('dynamic_') and (not 'USGS' in var) and (not 'discharge' in var))]
#     var_names.extend(['encoding_solar_insolation'])
#     catmt_area_weighted = (catmt[var_names] * catmt['static_GloFAS_cellarea_km2']).sum(dim='idx') / catmt['static_GloFAS_cellarea_km2'].sum(dim='idx')
#     catmt_area_weighted = catmt_area_weighted[sorted(catmt_area_weighted.data_vars)]

#     # Save the area-weighted variables
#     os.makedirs(os.path.join(LUMPED_PATH, huc), exist_ok=True)
#     catmt_area_weighted.to_zarr(os.path.join(LUMPED_PATH, huc, f'{gauge_id}.zarr'), mode='w')

#     # Clean up
#     del catmt, catmt_area_weighted
#     gc.collect()

#     return None

In [None]:
# with Parallel(n_jobs=8, verbose=10) as parallel:
#     parallel(
#         delayed(process_catmt)(row['huc_02'], row['gauge_id']) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph))
#     )

In [None]:
# def process_catmt(huc, gauge_id):
#     warnings.filterwarnings('ignore')
    
#     LUMPED_PATH = os.path.join(SAVE_PATH, 'lumped_inventory')

#     catmt = xr.open_zarr(os.path.join(SAVE_PATH, 'inventory', huc, f'{gauge_id}.zarr'))
#     catmt = catmt[sorted(catmt.data_vars)]

#     catmt_lumped = xr.open_zarr(os.path.join(LUMPED_PATH, huc, f'{gauge_id}.zarr'))

#     var_names = [var for var in catmt.data_vars if (var.startswith('dynamic_') and (not 'USGS' in var) and (not 'discharge' in var))]
#     var_names.extend(['encoding_solar_insolation', 'idx2lat', 'idx2lon', 'mask'])
#     var_names = [var for var in catmt.data_vars if not var in var_names]

#     # For .sel(idx=0)
#     catmt_lumped_sel = catmt[var_names].copy()
#     catmt_lumped_sel = catmt_lumped_sel[['dynamic_GloFAS_discharge_mm', 'static_uparea']]
#     catmt_lumped_sel = catmt_lumped_sel.sel(idx=0)

#     # For .area_weighted
#     # Drop ['dynamic_GloFAS_discharge_mm', 'static_uparea'] from var_names
#     var_names = [var for var in var_names if not (var in ['dynamic_GloFAS_discharge_mm', 'static_uparea'])]
#     var_names = [var for var in var_names if not var.startswith('static_ERA5_type') or var.startswith('static_ERA5_soil_type')]
#     catmt_area_weighted = (catmt[var_names] * catmt['static_GloFAS_cellarea_km2']).sum(dim='idx') / catmt['static_GloFAS_cellarea_km2'].sum(dim='idx')
#     catmt_area_weighted = catmt_area_weighted[sorted(catmt_area_weighted.data_vars)]

#     catmt_lumped = xr.open_zarr(os.path.join(SAVE_PATH, 'lumped_inventory', huc, f'{gauge_id}.zarr'))
#     catmt_lumped = xr.merge([catmt_lumped, catmt_lumped_sel, catmt_area_weighted])
#     catmt_lumped[sorted(catmt_lumped.data_vars)]

#     # Replace the value for 'encoding_transformed_longitude', 'encoding_transformed_latitude'
#     catmt_lumped['encoding_transformed_longitude'] = catmt['encoding_transformed_longitude'].sel(idx=0)
#     catmt_lumped['encoding_transformed_latitude'] = catmt['encoding_transformed_latitude'].sel(idx=0)

#     # Save the area-weighted variables
#     os.makedirs(os.path.join(LUMPED_PATH, huc), exist_ok=True)
#     catmt_area_weighted.to_zarr(os.path.join(LUMPED_PATH, huc, f'{gauge_id}.zarr'), mode='a')

#     # Clean up
#     del catmt, catmt_area_weighted
#     gc.collect()

#     return None

In [None]:
# with Parallel(n_jobs=8, verbose=10) as parallel:
#     parallel(
#         delayed(process_catmt)(row['huc_02'], row['gauge_id']) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph))
#     )

In [4]:
row = camels_graph.iloc[0]
huc, gauge_id = row['huc_02'], row['gauge_id']
catmt = xr.open_zarr(os.path.join(SAVE_PATH, 'lumped_inventory', huc, f'{gauge_id}.zarr'))
data_vars = list(sorted(catmt.data_vars))
# print(f"Length of idx vars: {len(catmt.idx.values)}")
catmt.close()
print(f"Length of data_vars: {len(data_vars)}")
data_vars

Length of data_vars: 140


['dynamic_ERA5-Land_dewpoint_temperature_2m_max',
 'dynamic_ERA5-Land_dewpoint_temperature_2m_min',
 'dynamic_ERA5-Land_leaf_area_index_high_vegetation',
 'dynamic_ERA5-Land_leaf_area_index_low_vegetation',
 'dynamic_ERA5-Land_potential_evaporation_sum',
 'dynamic_ERA5-Land_runoff_sum',
 'dynamic_ERA5-Land_snow_cover',
 'dynamic_ERA5-Land_snow_depth',
 'dynamic_ERA5-Land_snowfall_sum',
 'dynamic_ERA5-Land_snowmelt_sum',
 'dynamic_ERA5-Land_sub_surface_runoff_sum',
 'dynamic_ERA5-Land_surface_net_solar_radiation_sum',
 'dynamic_ERA5-Land_surface_net_thermal_radiation_sum',
 'dynamic_ERA5-Land_surface_pressure',
 'dynamic_ERA5-Land_surface_runoff_sum',
 'dynamic_ERA5-Land_temperature_2m_max',
 'dynamic_ERA5-Land_temperature_2m_min',
 'dynamic_ERA5-Land_total_evaporation_sum',
 'dynamic_ERA5-Land_total_precipitation_sum',
 'dynamic_ERA5-Land_u_component_of_wind_10m',
 'dynamic_ERA5-Land_v_component_of_wind_10m',
 'dynamic_ERA5-Land_volumetric_soil_water_layer_1',
 'dynamic_ERA5-Land_volum