# Setting Up

In [1]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import networkx as nx

import geopandas as gpd
import seaborn as sns
import matplotlib.pyplot as plt

from shapely.geometry import Point
from shapely.geometry import Polygon

import glob
import os
import itertools
import tqdm
import gc
import time
import pickle

from joblib import Parallel, delayed

In [2]:
import configparser
cfg = configparser.ConfigParser()
cfg.optionxform = str
cfg.read('/home/sarth/rootdir/assets/global.ini')
cfg = {s: dict(cfg.items(s)) for s in cfg.sections()}
PATHS = cfg['PATHS']

In [3]:
DIRNAME = '30min_CWatM_India'
basin_name = 'mahanadi'
SAVE_PATH = os.path.join(PATHS['devp_datasets'], DIRNAME)
resolution = 0.50
lon_360_180 = lambda x: (x + 180) % 360 - 180 # convert 0-360 to -180-180
lon_180_360 = lambda x: x % 360 # convert -180-180 to 0-360

In [4]:
basins = gpd.read_file(os.path.join(PATHS['gis_shapefiles'], 'asia', basin_name, 'shapefile.shp'), crs = 'epsg:4326')
minx, miny, maxx, maxy = basins.total_bounds
buffer = 0.5
region_bounds = {
    'minx': minx - buffer,
    'miny': miny - buffer,
    'maxx': maxx + buffer,
    'maxy': maxy + buffer
}

# Load Watershed Attributes

In [5]:
indiawris_attributes_graph = pd.read_csv(os.path.join(SAVE_PATH, f'graph_attributes_{basin_name}.csv'), index_col = 0)
indiawris_attributes_graph

Unnamed: 0,River Point Name,Latitude,Longitude,NaNs,Days with Observations,Zero of Gauge,MERIT_Longitude,MERIT_Latitude,Gauge_MERIT_Distance,uparea (in m2),uparea (in km2),snapped_lon,snapped_lat,snapped_uparea,snapped_iou,area_percent_difference,num_nodes,num_edges
0,andhiyar_khore,21.833889,81.5975,5.3,14184,252.0,81.5975,21.834167,30.887,2166368000.0,2166.368402,81.25,22.25,2853.6025,0.512368,31.722864,1.0,0.0
1,bamnidhi,21.908486,82.71364,4.45,14311,223.0,82.71,21.91,411.522,9811223000.0,9811.223484,82.75,22.25,11373.951,0.63943,15.927957,4.0,3.0
2,baronda,20.91,81.888056,6.64,13983,283.0,81.888333,20.91,28.853,3193768000.0,3193.768107,82.25,20.25,2891.6934,0.362836,9.458254,1.0,0.0
3,basantpur,21.7385,82.785942,4.64,14282,206.0,82.778333,21.7225,1944.963,58984890000.0,58984.885999,82.75,21.75,65959.01,0.798565,11.823573,23.0,22.0
4,ghatora,22.048592,82.221956,5.35,14176,246.0,82.2225,22.0475,133.739,3037212000.0,3037.21172,82.25,22.25,2853.6025,0.377117,6.045319,1.0,0.0
5,jondhra,21.712492,82.333106,5.66,14130,219.0,82.3325,21.715,285.855,29799010000.0,29799.010658,81.75,21.75,22964.73,0.615867,22.934586,8.0,7.0
6,kantamal,20.658333,83.732069,5.28,14186,118.0,83.73,20.660833,351.621,20265830000.0,20265.829359,83.25,20.25,14485.411,0.501618,28.52298,5.0,4.0
7,kesinga,20.285831,83.221333,5.58,14142,166.0,83.249167,20.2625,3893.397,11977870000.0,11977.8703,83.25,20.25,14485.411,0.570078,20.934782,5.0,4.0
8,kotni,21.24,81.250278,4.57,14293,268.0,81.250833,21.236667,375.095,7036693000.0,7036.69319,81.25,21.25,11511.149,0.550663,63.587482,4.0,3.0
9,kurubhata,21.981278,83.2106,12.18,13153,215.0,83.211667,21.980833,120.579,4752530000.0,4752.530476,83.25,22.25,5697.1436,0.5993,19.876007,2.0,1.0


In [6]:
indiawris_graph = indiawris_attributes_graph.copy()
indiawris_graph = indiawris_graph[indiawris_graph['area_percent_difference'] < 15]
print(indiawris_graph.shape)
indiawris_graph = indiawris_graph[indiawris_graph['num_nodes'] > 5]
print(indiawris_graph.shape)

(6, 18)
(2, 18)


In [7]:
indiawris_graph['uparea (in km2)'].describe()

count        2.000000
mean     53456.624942
std       7818.141764
min      47928.363884
25%      50692.494413
50%      53456.624942
75%      56220.755470
max      58984.885999
Name: uparea (in km2), dtype: float64

In [8]:
indiawris_graph = indiawris_graph.reset_index(drop = True)
indiawris_graph

Unnamed: 0,River Point Name,Latitude,Longitude,NaNs,Days with Observations,Zero of Gauge,MERIT_Longitude,MERIT_Latitude,Gauge_MERIT_Distance,uparea (in m2),uparea (in km2),snapped_lon,snapped_lat,snapped_uparea,snapped_iou,area_percent_difference,num_nodes,num_edges
0,basantpur,21.7385,82.785942,4.64,14282,206.0,82.778333,21.7225,1944.963,58984890000.0,58984.885999,82.75,21.75,65959.01,0.798565,11.823573,23.0,22.0
1,seorinarayan,21.718535,82.597543,18.8,12161,209.5,82.6,21.715833,393.276,47928360000.0,47928.363884,82.25,21.75,48848.535,0.811889,1.91989,17.0,16.0


In [9]:
del indiawris_attributes_graph

# Create Node Features as csv

In [10]:
os.makedirs(os.path.join(SAVE_PATH, "graph_features"), exist_ok = True)

In [11]:
ldd = xr.open_dataset(os.path.join(PATHS['gis_ldd'], 'CWatM_30min', 'ldd.nc'))
ldd = ldd['ldd']
ldd = ldd.sel(
    lat = slice(region_bounds['maxy'], region_bounds['miny']), 
    lon = slice(region_bounds['minx'], region_bounds['maxx'])
)

lons = ldd['lon'].values
lats = ldd['lat'].values

ds_grid = xr.Dataset({
    'lat': (['lat'], lats),
    'lon': (['lon'], lons),
})

# Round the lat lon values to 3 decimal places in ds_grid
ds_grid['lat'] = ds_grid['lat'].round(3)
ds_grid['lon'] = ds_grid['lon'].round(3)

In [12]:
regridder_files = {
    'ERA5': f'regridder_era5_to_cwatm_30min_India_{basin_name}.nc',
    'GLEAM': f'regridder_gleam_to_cwatm_30min_India_{basin_name}.nc',
}

## ERA5

### Dynamic

In [17]:
var_names = [
    # 'sub_surface_runoff',
    # 'surface_runoff',
    # 'total_precipitation',
    # '2m_temperature',
    # 'volumetric_soil_water_layer_1',
    # 'volumetric_soil_water_layer_2',
    # 'volumetric_soil_water_layer_3',
    # 'volumetric_soil_water_layer_4',
    # 'runoff',
    'evaporation', 
    'snowfall', 
    'surface_net_solar_radiation', 
    'surface_net_thermal_radiation', 
    'surface_pressure', 
    '2m_dewpoint_temperature',
    '10m_u_component_of_wind',
    '10m_v_component_of_wind',
    'forecast_albedo',
    'potential_evaporation',
    'snow_albedo',
    'snow_depth',
    'snowmelt',
    'total_column_water',
]

dates = pd.date_range('1980-01-01', '2020-12-31', freq='D')
dates = dates[~((dates.month == 2) & (dates.day == 29))]
print(f"Number of dates: {len(dates)}")

def process(idx, row, var_name):
    huc, gauge_id = basin_name, row['River Point Name']
    nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    data = pd.DataFrame(index = dates, columns = nodes_coords.index)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic'), exist_ok = True)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'ERA5'), exist_ok = True)
    data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'ERA5', f"{var_name}.csv"))

for var_name in var_names:
    print(var_name)
    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row, var_name) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

Number of dates: 14965
evaporation


100%|██████████| 2/2 [00:00<00:00, 1796.28it/s]


snowfall


100%|██████████| 2/2 [00:00<00:00, 2654.62it/s]


surface_net_solar_radiation


100%|██████████| 2/2 [00:00<00:00, 1832.77it/s]


surface_net_thermal_radiation


100%|██████████| 2/2 [00:00<00:00, 1527.42it/s]


surface_pressure


100%|██████████| 2/2 [00:00<00:00, 1775.37it/s]


2m_dewpoint_temperature


100%|██████████| 2/2 [00:00<00:00, 1833.58it/s]


10m_u_component_of_wind


100%|██████████| 2/2 [00:00<00:00, 1709.87it/s]


10m_v_component_of_wind


100%|██████████| 2/2 [00:00<00:00, 1792.05it/s]


forecast_albedo


100%|██████████| 2/2 [00:00<00:00, 1859.59it/s]


potential_evaporation


100%|██████████| 2/2 [00:00<00:00, 1590.56it/s]


snow_albedo


100%|██████████| 2/2 [00:00<00:00, 1510.37it/s]


snow_depth


100%|██████████| 2/2 [00:00<00:00, 1557.77it/s]


snowmelt


100%|██████████| 2/2 [00:00<00:00, 1612.57it/s]


total_column_water


100%|██████████| 2/2 [00:00<00:00, 1844.06it/s]


In [18]:
for var_name in itertools.islice(var_names, 0, None, 1):
    print(var_name)
    ds = xr.open_mfdataset(os.path.join(PATHS['RawData'], 'ERA5', var_name, f"*.nc"), combine='by_coords')
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.rename({'longitude': 'lon', 'latitude': 'lat'})
    ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
    ds['lon'] = [lon_360_180(lon) for lon in ds['lon'].values]
    ds = ds.sortby('lon')
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    _, index = np.unique(ds['time'], return_index = True)
    ds = ds.isel(time = index)

    if os.path.exists(os.path.join(PATHS['Assets'], regridder_files['ERA5'])):
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=True, 
            filename = os.path.join(PATHS['Assets'], regridder_files['ERA5'])
        )
    else:
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=False
        )
        regridder.to_netcdf(os.path.join(PATHS['Assets'], regridder_files['ERA5']))
    
    ds_regrided = regridder(ds)
    ds.close()
    start_time = time.time()
    ds_regrided.load()
    end_time = time.time()
    print(f'Time: {((end_time - start_time) / 60):.4f} mins')
    
    def process(idx, row):
        huc, gauge_id = basin_name, row['River Point Name']
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'ERA5', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds_regrided.sel(lat = lat, lon = lon, method = 'nearest')
            data.loc[:, str(node_idx)] = ds_window_loc.values
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'ERA5', f"{var_name}.csv"))

    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

    ds.close()
    del ds
    gc.collect()

evaporation
Time: 6.5852 mins


100%|██████████| 2/2 [00:00<00:00, 1091.70it/s]


snowfall
Time: 5.0478 mins


100%|██████████| 2/2 [00:00<00:00, 1093.12it/s]


surface_net_solar_radiation
Time: 5.3348 mins


100%|██████████| 2/2 [00:00<00:00, 1028.52it/s]


surface_net_thermal_radiation
Time: 4.8022 mins


100%|██████████| 2/2 [00:00<00:00, 839.87it/s]


surface_pressure
Time: 6.4313 mins


100%|██████████| 2/2 [00:00<00:00, 1061.18it/s]


2m_dewpoint_temperature
Time: 4.6667 mins


100%|██████████| 2/2 [00:00<00:00, 956.40it/s]


10m_u_component_of_wind
Time: 4.7914 mins


100%|██████████| 2/2 [00:00<00:00, 760.46it/s]


10m_v_component_of_wind
Time: 4.5801 mins


100%|██████████| 2/2 [00:00<00:00, 666.61it/s]


forecast_albedo
Time: 4.6053 mins


100%|██████████| 2/2 [00:00<00:00, 920.31it/s]


potential_evaporation
Time: 4.5425 mins


100%|██████████| 2/2 [00:00<00:00, 830.72it/s]


snow_albedo
Time: 4.6815 mins


100%|██████████| 2/2 [00:00<00:00, 3003.44it/s]


snow_depth
Time: 4.5222 mins


100%|██████████| 2/2 [00:00<00:00, 1407.25it/s]


snowmelt
Time: 4.6015 mins


100%|██████████| 2/2 [00:00<00:00, 1423.73it/s]


total_column_water
Time: 4.5999 mins


100%|██████████| 2/2 [00:00<00:00, 591.16it/s]


### Static

In [42]:
var_names = [
    'static_soil_type', 
    'static_high_vegetation_cover', 
    'static_low_vegetation_cover', 
    'static_type_of_high_vegetation', 
    'static_type_of_low_vegetation'
    ]
ds_filenames = [
    'soil_type_static.nc',
    'high_vegetation_cover_static.nc',
    'low_vegetation_cover_static.nc',
    'type_of_high_vegetation_static.nc',
    'type_of_low_vegetation_static.nc'
]

for var_name, ds_filename in zip(var_names, ds_filenames):
    print(var_name)
    ds = xr.open_dataset(os.path.join(PATHS['RawData'], 'ERA5', ds_filename))
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.isel(time = 0)
    ds = ds.drop('time')
    ds = ds.rename({'longitude': 'lon', 'latitude': 'lat'})
    ds['lon'] = [lon_360_180(lon) for lon in ds['lon'].values]
    ds = ds.sortby('lon')
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    def process(idx, row):
        huc, gauge_id = basin_name, row['River Point Name']
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = [0])
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds.sel(lat = lat, lon = lon, method = 'nearest')
            data.loc[0, node_idx] = int(ds_window_loc.values)
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static'), exist_ok = True)
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'ERA5'), exist_ok = True)
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'ERA5', f"{var_name}.csv"))

    for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)):
        process(idx, row)

static_soil_type


100%|██████████| 2/2 [00:00<00:00, 12.62it/s]


static_high_vegetation_cover


100%|██████████| 2/2 [00:00<00:00, 10.08it/s]


static_low_vegetation_cover


100%|██████████| 2/2 [00:00<00:00, 20.54it/s]


static_type_of_high_vegetation


100%|██████████| 2/2 [00:00<00:00, 24.15it/s]


static_type_of_low_vegetation


100%|██████████| 2/2 [00:00<00:00, 35.82it/s]


## HWSD

In [43]:
var_names = ['S_CLAY', 'S_GRAVEL', 'S_SAND', 'S_SILT', 'T_CLAY', 'T_GRAVEL', 'T_SAND', 'T_SILT']

for var_name in var_names:
    print(var_name)
    ds = xr.open_dataset(os.path.join(PATHS['HWSD'], f'{var_name}.nc4'))
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(
        lat = slice(region_bounds['miny'], region_bounds['maxy']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    ds = ds / 100
    ds.load()
    def process(idx, row):
        huc, gauge_id = basin_name, row['River Point Name']
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = [0])
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds.sel(
                lat = slice(lat-resolution/2, lat+resolution/2),
                lon = slice(lon-resolution/2, lon+resolution/2)
            ).values.mean()
            data.loc[0, node_idx] = ds_window_loc
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static'), exist_ok = True)
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'HWSD'), exist_ok = True)
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'HWSD', f"{var_name}.csv"))

    for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)):
        process(idx, row)
    
    ds.close()
    del ds
    gc.collect()

S_CLAY


100%|██████████| 2/2 [00:00<00:00, 15.08it/s]

S_GRAVEL



100%|██████████| 2/2 [00:00<00:00, 90.44it/s]


S_SAND


100%|██████████| 2/2 [00:00<00:00, 110.76it/s]


S_SILT


100%|██████████| 2/2 [00:00<00:00, 106.09it/s]


T_CLAY


100%|██████████| 2/2 [00:00<00:00, 107.86it/s]


T_GRAVEL


100%|██████████| 2/2 [00:00<00:00, 113.70it/s]


T_SAND


100%|██████████| 2/2 [00:00<00:00, 108.93it/s]


T_SILT


100%|██████████| 2/2 [00:00<00:00, 104.10it/s]


## GLEAM

In [44]:
var_names = ['Ep', 'SMroot', 'SMsurf']

dates = pd.date_range('1980-01-01', '2020-12-31', freq='D')
dates = dates[~((dates.month == 2) & (dates.day == 29))]
print(f"Number of dates: {len(dates)}")

def process(idx, row, var_name):
    huc, gauge_id = basin_name, row['River Point Name']
    nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    data = pd.DataFrame(index = dates, columns = nodes_coords.index)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic'), exist_ok = True)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM'), exist_ok = True)
    data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"))

for var_name in var_names:
    print(var_name)
    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row, var_name) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

Number of dates: 14965
Ep


100%|██████████| 2/2 [00:00<00:00, 4826.59it/s]


SMroot


100%|██████████| 2/2 [00:00<00:00, 2142.68it/s]


SMsurf


100%|██████████| 2/2 [00:00<00:00, 2494.38it/s]


In [45]:
for var_name in itertools.islice(var_names, 0, None, 1):
    print(var_name)
    ds = xr.open_mfdataset(os.path.join(PATHS['GLEAM'], var_name, f"*.nc"), combine='by_coords')
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )

    if os.path.exists(os.path.join(PATHS['Assets'], regridder_files['GLEAM'])):
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=True, 
            filename = os.path.join(PATHS['Assets'], regridder_files['GLEAM'])
        )
    else:
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=False
        )
        regridder.to_netcdf(os.path.join(PATHS['Assets'], regridder_files['GLEAM']))
    
    ds_regrided = regridder(ds)
    ds.close()
    start_time = time.time()
    ds_regrided.load()
    end_time = time.time()
    print(f'Time: {((end_time - start_time) / 60):.4f} mins')
    
    def process(idx, row):
        huc, gauge_id = basin_name, row['River Point Name']
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds_regrided.sel(lat = lat, lon = lon, method = 'nearest')
            data.loc[:, str(node_idx)] = ds_window_loc.values
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"))

    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

    ds.close()
    del ds
    gc.collect()

Ep
Time: 0.3355 mins


100%|██████████| 2/2 [00:00<00:00, 1780.26it/s]


SMroot
Time: 0.3308 mins


100%|██████████| 2/2 [00:00<00:00, 1084.92it/s]


SMsurf
Time: 0.3247 mins


100%|██████████| 2/2 [00:00<00:00, 1236.35it/s]


### Fix NaNs

In [46]:
var_names = ['Ep', 'SMroot', 'SMsurf']
for var_name in var_names:
    ds = xr.open_mfdataset(os.path.join(PATHS['GLEAM'], var_name, f"*.nc"), combine='by_coords')
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    if os.path.exists(os.path.join(PATHS['Assets'], regridder_files['GLEAM'])):
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=True, 
            filename = os.path.join(PATHS['Assets'], regridder_files['GLEAM'])
        )
    else:
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=False
        )
        regridder.to_netcdf(os.path.join(PATHS['Assets'], regridder_files['GLEAM']))
    ds_regrided = regridder(ds)
    ds.close()
    start_time = time.time()
    ds_regrided.load()
    end_time = time.time()
    print(f'{var_name} (Time: {((end_time - start_time) / 60):.4f} mins)')

    # Loop over catchments and find ones with issues
    issues = []
    for idx, row in tqdm.tqdm(indiawris_graph.iterrows()):
        huc, gauge_id = basin_name, row['River Point Name']
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        if data.isnull().values.any():
            issues.append([huc, gauge_id])
    issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
    print(f"Number of catchments with issues: {issues.shape[0]}")
    print("------")

    # Fix the catchments with issues
    for issue_idx, (huc, gauge_id) in enumerate(issues.values):
        print(issue_idx, huc, gauge_id)
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        nodes_coords['isNaN'] = False
        nodes_coords['nonNaNneighbours'] = 0
        # Loop over nodes and find the nodes with issues
        for node_idx in nodes_coords.index:
            if data[str(node_idx)].isnull().values.any():
                nodes_coords.loc[node_idx, 'isNaN'] = True
                node_lat = float(round(nodes_coords.loc[node_idx, 'lat'], 3))
                node_lon = float(round(nodes_coords.loc[node_idx, 'lon'], 3))
                multiplier = 1.5
                ds_slice = ds_regrided.sel(
                    lat = slice(node_lat+multiplier*resolution, node_lat-multiplier*resolution), 
                    lon = slice(node_lon-multiplier*resolution, node_lon+multiplier*resolution)
                    )
                slice_df = ds_slice.to_dataframe(name = var_name).reset_index()
                slice_df['lat'] = slice_df['lat'].round(3)
                slice_df['lon'] = slice_df['lon'].round(3)
                slice_df['location'] = list(zip(slice_df['lat'], slice_df['lon']))
                slice_df = slice_df.pivot(index='time', columns='location', values=var_name)
                num_nan_nodes = slice_df.isnull().any(axis=0).sum()
                num_nonnan_nodes = len(slice_df.columns) - num_nan_nodes
                nodes_coords.loc[node_idx, 'nonNaNneighbours'] = num_nonnan_nodes
        nodes_coords_sorted = nodes_coords.sort_values(by = 'nonNaNneighbours', ascending = False)
        nodes_coords_sorted = nodes_coords_sorted[nodes_coords_sorted['isNaN']]
        print(f"Number of nodes with NaN values: {nodes_coords_sorted.shape[0]}")
        
        for node_idx in tqdm.tqdm(nodes_coords_sorted.index):
            node_lat, node_lon = float(round(nodes_coords.loc[node_idx, 'lat'], 3)), float(round(nodes_coords.loc[node_idx, 'lon'], 3))
            multiplier = 1.5
            ds_slice = ds_regrided.sel(
                lat = slice(node_lat+multiplier*resolution, node_lat-multiplier*resolution), 
                lon = slice(node_lon-multiplier*resolution, node_lon+multiplier*resolution)
                )
            slice_df = ds_slice.to_dataframe(name = var_name).reset_index()
            slice_df['lat'] = slice_df['lat'].round(3)
            slice_df['lon'] = slice_df['lon'].round(3)
            slice_df['location'] = list(zip(slice_df['lat'], slice_df['lon']))
            slice_df = slice_df.pivot(index='time', columns='location', values=var_name)
            slice_df.columns = list(map(str, slice_df.columns))
            num_nonnan_nodes = len(slice_df.columns) - slice_df.isnull().any(axis=0).sum()
            # print(node_idx, (node_lat, node_lon), num_nonnan_nodes)
            if num_nonnan_nodes == 9:
                replacement_values = slice_df.loc[:, f"({node_lat}, {node_lon})"]
                data.loc[:, str(node_idx)] = replacement_values
                nodes_coords_sorted.loc[node_idx, 'isNaN'] = False
            elif num_nonnan_nodes > 0:
                replacement_values = np.nanmean(slice_df, axis = 1)
                data.loc[:, str(node_idx)] = replacement_values
                ds_regrided.loc[dict(lat = node_lat, lon = node_lon)] = replacement_values
                nodes_coords_sorted.loc[node_idx, 'isNaN'] = False
        print(f"Number of nodes with NaN values: {nodes_coords_sorted['isNaN'].sum()}")
        print(issue_idx, huc, gauge_id, data.isnull().values.any())
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"))
        print("------")

Ep (Time: 0.3222 mins)


2it [00:00, 13.32it/s]

Number of catchments with issues: 0
------





SMroot (Time: 0.3133 mins)


2it [00:00, 13.23it/s]

Number of catchments with issues: 0
------





SMsurf (Time: 0.3121 mins)


2it [00:00, 12.18it/s]

Number of catchments with issues: 0
------





In [None]:
# var_names = ['Ep', 'SMroot', 'SMsurf']
# for var_name in var_names:
#     # Loop over catchments and find ones with issues
#     issues = []
#     for idx, row in tqdm.tqdm(indiawris_graph.iterrows()):
#         huc, gauge_id = basin_name, row['River Point Name']
#         nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
#         data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
#         if data.isnull().values.any():
#             issues.append([huc, gauge_id])
#     issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
#     print(f"Number of catchments with issues: {issues.shape[0]}")
#     print("------")

#     # Fix the catchments with issues
#     for issue_idx, (huc, gauge_id) in enumerate(issues.values):
#         print(issue_idx, huc, gauge_id)
#         nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
#         data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
#         nodes_coords['isNaN'] = False
#         nodes_coords['nonNaNneighbours'] = 0
#         # Loop over nodes and find the nodes with issues
#         for node_idx in nodes_coords.index:
#             if data[str(node_idx)].isnull().values.any():
#                 nodes_coords.loc[node_idx, 'isNaN'] = True
#         print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")
#         print("------")

In [None]:
# var_names = ['Ep', 'SMroot', 'SMsurf']
# for var_name in var_names:
#     # Loop over catchments and find ones with issues
#     issues = []
#     for idx, row in tqdm.tqdm(indiawris_graph.iterrows()):
#         huc, gauge_id = basin_name, row['River Point Name']
#         nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
#         data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
#         if data.isnull().values.any():
#             issues.append([huc, gauge_id])
#     issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
#     print(f"Number of catchments with issues: {issues.shape[0]}")
#     print("------")

#     # Fix the catchments with issues
#     for issue_idx, (huc, gauge_id) in enumerate(issues.values):
#         print(issue_idx, huc, gauge_id)
#         nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
#         data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
#         nodes_coords['isNaN'] = False
#         # Loop over nodes and find the nodes with issues
#         for node_idx in nodes_coords.index:
#             if data[str(node_idx)].isnull().values.any():
#                 nodes_coords.loc[node_idx, 'isNaN'] = True
#         print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")

        
#         for node_idx in tqdm.tqdm(nodes_coords[nodes_coords['isNaN']].index):
#             nodes_coords['distances'] = None
#             node_lat, node_lon = float(round(nodes_coords.loc[node_idx, 'lat'], 3)), float(round(nodes_coords.loc[node_idx, 'lon'], 3))
#             for node_idx2 in nodes_coords[nodes_coords['isNaN'] == False].index:
#                 if node_idx != node_idx2:
#                     node_lat2, node_lon2 = float(round(nodes_coords.loc[node_idx2, 'lat'], 3)), float(round(nodes_coords.loc[node_idx2, 'lon'], 3))
#                     distance = np.sqrt((node_lat - node_lat2)**2 + (node_lon - node_lon2)**2)
#                     nodes_coords.loc[node_idx2, 'distances'] = distance
#             min_distance = nodes_coords.loc[nodes_coords['distances'].idxmin(), 'distances']
#             # Replace with mean of nodes having distance equal to min_distance
#             replacement_nodes = nodes_coords[nodes_coords['distances'] == min_distance].index
#             replacement_nodes = list(map(str, replacement_nodes))
#             replacement_values = data.loc[:, replacement_nodes].mean(axis = 1)
#             data.loc[:, str(node_idx)] = replacement_values
#             nodes_coords.loc[node_idx, 'isNaN'] = False
#         print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")
#         print(issue_idx, huc, gauge_id, data.isnull().values.any())
#         data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"))
#         print("------")

## Solar Insolation

In [47]:
def solar_insolation(lat, lon, start_date, end_date):
    # Constants
    Sc = 1361  # Solar constant (W/m^2)
    
    # Convert dates to datetime objects
    start_date = pd.to_datetime(start_date)
    end_date = pd.to_datetime(end_date)
    
    # Generate date range
    dates = pd.date_range(start=start_date, end=end_date, freq='D')
    dates = dates[~((dates.month == 2) & (dates.day == 29))]
    
    # Function to calculate solar declination
    def solar_declination(n):
        return 23.45 * np.sin(np.radians((360 / 365) * (n - 81)))

    # Function to calculate cos(theta_z) for solar zenith angle
    def cos_theta_z(lat, decl, hour_angle):
        lat_rad = np.radians(lat)
        decl_rad = np.radians(decl)
        return (np.sin(lat_rad) * np.sin(decl_rad) + 
                np.cos(lat_rad) * np.cos(decl_rad) * np.cos(np.radians(hour_angle)))
    
    # Function to calculate the hour angle
    def hour_angle(lon, date):
        # Assuming solar noon (local solar time = 12 hours)
        return 0  # hour angle at solar noon
    
    # Calculate solar insolation for each day
    insolation_values = []
    for date in dates:
        day_of_year = date.day_of_year
        declination = solar_declination(day_of_year)
        h = hour_angle(lon, date)
        cos_zenith_angle = cos_theta_z(lat, declination, h)
        
        # Insolation formula
        insolation = Sc * (1 + 0.033 * np.cos(np.radians(360 * day_of_year / 365))) * cos_zenith_angle
        
        # Make sure insolation is non-negative
        insolation = max(insolation, 0)
        insolation_values.append(insolation)
    
    # Create pandas Series
    insolation_series = pd.Series(insolation_values, index=dates, name='Solar Insolation (kW/m²)')
    insolation_series = insolation_series / 1000  # Convert to kW/m²
    
    return insolation_series

In [48]:
dates = pd.date_range('1980-01-01', '2020-12-31', freq='D')
dates = dates[~((dates.month == 2) & (dates.day == 29))]

def process(idx, row):
    huc, gauge_id = basin_name, row['River Point Name']
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)

    data = pd.DataFrame(columns = nodes_coords.index, index = dates)
    for node_idx, node_row in nodes_coords.iterrows():
        lat, lon = node_row['lat'], node_row['lon']
        ds_window_loc = solar_insolation(lat, lon, '1980-01-01', '2020-12-31')
        data.loc[:, node_idx] = ds_window_loc.values

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'solar_insolation.csv'))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

# for idx, row in indiawris_graph.iterrows():
#     process(idx, row)

100%|██████████| 2/2 [00:00<00:00, 2680.07it/s]


## Time Encodings

In [49]:
def sine_time_encoding(start_date, end_date):
    # (a) Create a date_range and remove leap days
    dates = pd.date_range(start=start_date, end=end_date, freq='D')
    dates = dates[~((dates.month == 2) & (dates.day == 29))]  # Remove February 29 (leap days)
    
    # (b) Create a dataframe with 'month', 'weekofyear', 'dayofyear' columns
    df = pd.DataFrame(index=dates)
    df['month'] = df.index.month
    df['weekofyear'] = df.index.isocalendar().week
    df['dayofyear'] = df.index.dayofyear
    
    # (c) Define lambda transformations for sine encoding
    # For day of year (range 1-365), week of year (range 1-52), and month (range 1-12)
    sine_transform = lambda x, max_val: np.sin(2 * np.pi * x / max_val)
    
    # (d) Apply sine transformation and add transformed columns
    df['sine_month'] = df['month'].apply(sine_transform, max_val=12)
    df['sine_weekofyear'] = df['weekofyear'].apply(sine_transform, max_val=52)
    df['sine_dayofyear'] = df['dayofyear'].apply(sine_transform, max_val=365)
    
    # return df[['sine_month', 'sine_weekofyear', 'sine_dayofyear']]
    return df

In [50]:
df_encoded = sine_time_encoding('1980-01-01', '2020-12-31')

def process(idx, row):
    huc, gauge_id = basin_name, row['River Point Name']

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    df_encoded.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'time_encodings.csv'))

# with Parallel(n_jobs = 8, verbose = 0) as parallel:
    # _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)):
    process(idx, row)

100%|██████████| 2/2 [00:00<00:00,  8.08it/s]


## Terrain Attributes

In [51]:
from shapely.geometry import Polygon
import rioxarray

def coords_to_polygon(lon, lat, resolution):
    half_res = resolution / 2
    return Polygon([
        (round(lon - half_res,3), round(lat - half_res,3)),
        (round(lon - half_res,3), round(lat + half_res,3)),
        (round(lon + half_res,3), round(lat + half_res,3)),
        (round(lon + half_res,3), round(lat - half_res,3))
    ])
def tile_filename_to_coords(filename):
    # format: n/s{dd}e/w{ddd}_elv.tif
    # n/e: positive, s/w: negative
    n_s, lat, e_w, lon = filename[0], int(filename[1:3]), filename[3], int(filename[4:7])
    lat = lat if n_s == 'n' else -lat
    lon = lon if e_w == 'e' else -lon
    return (lon, lat)

In [52]:
import itertools
var_names = ['elv', 'slope_percentage', 'slope_riserun', 'slope_degrees', 'slope_radians', 'aspect', 'curvature', 'planform_curvature', 'profile_curvature', 'upa', 'wth']
# valid_tiles = ['n30w150', 'n30w120', 'n30w090']

issues = []
for var_name in itertools.islice(var_names,0,None,1):
    print(var_name)
    tiles_paths = sorted(glob.glob(os.path.join(PATHS['MERIT-Hydro'], var_name, '**', '*.tif'), recursive=True))
    # tiles_paths = [tile for tile in tiles_paths if os.path.basename(os.path.dirname(tile)).split('_')[-1] in valid_tiles]
    tiles_filenames = [os.path.basename(tile) for tile in tiles_paths]
    tiles_names = [tile.split('_')[0] for tile in tiles_filenames]
    tiles_lower_left_corner = [tile_filename_to_coords(tile) for tile in tiles_filenames]
    tiles_polygons = [Polygon([(lon, lat), (lon + 5, lat), (lon + 5, lat + 5), (lon, lat + 5)]) for lon, lat in tiles_lower_left_corner]

    def process(idx, row):
        huc, gauge_id = basin_name, row['River Point Name']
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = ['mean', 'std', '25%', '50%', '75%'])
        cell_polygons = [coords_to_polygon(row['lon'], row['lat'], resolution) for _, row in nodes_coords.iterrows()]
        catmt_polygon = cell_polygons[0]
        for polygon in cell_polygons[1:]:
            catmt_polygon = catmt_polygon.union(polygon)
        intersected_tiles = []
        for tile_polygon, tile_path in zip(tiles_polygons, tiles_paths):
            if tile_polygon.intersects(catmt_polygon):
                intersected_tiles.append(tile_path)
        ds = rioxarray.open_rasterio(intersected_tiles[0])
        for tile in intersected_tiles[1:]:
            ds = ds.combine_first(rioxarray.open_rasterio(tile))
        ds = ds.sel(band=1)
        # Sort the x and y coordinates to be ascending
        ds = ds.sortby('x', ascending=True)
        ds = ds.sortby('y', ascending=True)
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            # ds_node = ds.rio.clip_box(lon - resolution/2, lat - resolution/2, lon + resolution/2, lat + resolution/2)
            ds_node = ds.sel(x = slice(lon - resolution/2, lon + resolution/2), y = slice(lat - resolution/2, lat + resolution/2))
            ds_node = ds_node.where(ds_node != ds.rio.nodata)
            ds_node_values = ds_node.values.flatten()
            mean = np.nanmean(ds_node_values)
            std = np.nanstd(ds_node_values)
            q25 = np.nanquantile(ds_node_values, 0.25)
            q50 = np.nanquantile(ds_node_values, 0.50)
            q75 = np.nanquantile(ds_node_values, 0.75)
            data.loc['mean', node_idx] = mean
            data.loc['std', node_idx] = std
            data.loc['25%', node_idx] = q25
            data.loc['50%', node_idx] = q50
            data.loc['75%', node_idx] = q75
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static'), exist_ok = True)
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'MERIT-Hydro'), exist_ok = True)
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'MERIT-Hydro', f"{var_name}.csv"))

        ds.close()
        del ds
        gc.collect()

    for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)):
        try:
            process(idx, row)
        except Exception as e:
            issues.append(f"{var_name}-{row['huc_02']}-{row.name}")
            print(f"Error: {var_name}-{row['huc_02']}-{row.name}. {e}")

elv


100%|██████████| 2/2 [00:08<00:00,  4.28s/it]


slope_percentage


100%|██████████| 2/2 [00:05<00:00,  2.73s/it]


slope_riserun


100%|██████████| 2/2 [00:05<00:00,  2.89s/it]


slope_degrees


100%|██████████| 2/2 [00:04<00:00,  2.39s/it]


slope_radians


100%|██████████| 2/2 [00:04<00:00,  2.24s/it]


aspect


100%|██████████| 2/2 [00:05<00:00,  2.52s/it]


curvature


100%|██████████| 2/2 [00:05<00:00,  2.63s/it]


planform_curvature


100%|██████████| 2/2 [00:04<00:00,  2.11s/it]


profile_curvature


100%|██████████| 2/2 [00:04<00:00,  2.39s/it]


upa


100%|██████████| 2/2 [00:05<00:00,  2.98s/it]


wth


100%|██████████| 2/2 [00:04<00:00,  2.37s/it]


In [53]:
len(issues)

0

In [54]:
issues_df = [entry.split('-') for entry in issues]
issues_df = pd.DataFrame(issues_df, columns = ['var_name', 'huc_02', 'gauge_id'])
issues_df

Unnamed: 0,var_name,huc_02,gauge_id


In [55]:
issues_df[issues_df['var_name'] == 'elv']

Unnamed: 0,var_name,huc_02,gauge_id


## Spatial Encodings

In [19]:
def process(idx, row):
    # lon: -180 to 180; lat: -60 to 90
    lon_transform = lambda x: np.sin(2 * np.pi * (x+180) / 360)
    lat_transform = lambda x: (x - (-60))/(90 - (-60))

    huc, gauge_id = basin_name, row['River Point Name']
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)

    data = pd.DataFrame(columns = nodes_coords.index, index = ['lon_transformed', 'lat_transformed'])
    for node_idx, node_row in nodes_coords.iterrows():
        lat, lon = node_row['lat'], node_row['lon']
        data.loc['lon_transformed', node_idx] = lon_transform(lon)
        data.loc['lat_transformed', node_idx] = lat_transform(lat)

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'spatial_encodings.csv'))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

100%|██████████| 2/2 [00:00<00:00, 4017.53it/s]


## uparea

In [20]:
uparea = xr.open_dataset(os.path.join(PATHS['gis_ldd'], 'CWatM_30min/upstream_area_km2.nc'))
ds_varname = list(uparea.data_vars)[0]
uparea = uparea[ds_varname]
uparea = uparea.sel(
    lat = slice(region_bounds['maxy'], region_bounds['miny']), 
    lon = slice(region_bounds['minx'], region_bounds['maxx'])
)
uparea.load()

def process(idx, row):
    huc, gauge_id = basin_name, row['River Point Name']
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)

    data = pd.DataFrame(columns = nodes_coords.index, index = [0])
    for node_idx, node_row in nodes_coords.iterrows():
        lat, lon = node_row['lat'], node_row['lon']
        data.loc[0, node_idx] = uparea.sel(lat = lat, lon = lon, method = 'nearest').values.item()

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'uparea.csv'))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

uparea.close()
del uparea
gc.collect()

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:00<00:00, 2201.73it/s]


60

## IndiaWRIS

In [21]:
indiawris_flow = pd.read_csv(os.path.join(PATHS['IndiaWRIS'], 'indiawris_flow.csv'), index_col=0, parse_dates=True)
indiawris_flow = indiawris_flow[indiawris_flow.columns.intersection(indiawris_graph['River Point Name'])]

def process(idx, row):
    huc, gauge_id = basin_name, row['River Point Name']
    uparea = row['snapped_uparea'] * 1e6

    indiawris_gauge = indiawris_flow[[gauge_id]].copy()
    indiawris_gauge.columns = ['Q_m3s']
    indiawris_gauge['Q_mm'] = (indiawris_gauge['Q_m3s']*3600*24*1000) / uparea
    
    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    indiawris_gauge.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'IndiaWRIS.csv'))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

del indiawris_flow

100%|██████████| 2/2 [00:00<00:00, 4758.14it/s]


## GloFAS Parameter Maps

In [22]:
# "Catchment_morphology_and_river_network" (14 surface fields)
# - chanbnkf_Global_03min.nc (channel bankfull depth, m);
# - chanflpn_Global_03min.nc (width of the floodplain, m);
# - changrad_Global_03min.nc (channel longitudinal gradient, m/m);
# - chanlength_Global_03min.nc (channel length within a pixel, m);
# - chanman_Global_03min.nc (channel Manning's roughness coefficient, m^(1/3)s^(-1));
# - chans_Global_03min.nc (channel side slope, m/m);
# - chanbw_Global_03min.nc (channel bottom width, m):

# "Land_use" (7 surface fields)
# - fracforest_Global_03min.nc (fraction of forest for each grid-cell, -);
# - fracirrigated_Global_03min.nc (fraction of irrigated crops [except rice] for each grid-cell, -);
# - fracrice_Global_03min.nc (fraction of rice crops for each grid-cell, -);
# - fracsealed_Global_03min.nc (fraction of urban area for each grid-cell, -);
# - fracwater_Global_03min.nc (fraction of inland water for each grid-cell, -);
# - fracother_Global_03min.nc (fraction of other land cover for each grid-cell, -);
Parameter_Maps = os.path.join(PATHS['GloFAS'], 'LISFLOOD_Parameter_Maps')

var_names = ['chanbnkf', 'chanflpn', 'changrad', 'chanlength', 'chanman', 'chans', 'chanbw']
for var_name in var_names:
    print(var_name)
    ds = xr.open_dataset(os.path.join(Parameter_Maps, 'Catchments_morphology_and_river_network', f"{var_name}_Global_03min.nc"))['Band1']
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    ds.load()

    def process(idx, row):
        huc, gauge_id = basin_name, row['River Point Name']
        nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = [0])
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            # ds_window_loc = ds.sel(lat = lat, lon = lon, method = 'nearest')
            ds_window_loc = ds.sel(
                lat = slice(lat + 0.5*resolution, lat - 0.5*resolution),
                lon = slice(lon - 0.5*resolution, lon + 0.5*resolution)
            ).mean()
            data.loc[0, node_idx] = ds_window_loc.values.item()
        os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS'), exist_ok = True)
        data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS', f"{var_name}.csv"))

    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

    ds.close()
    del ds
    gc.collect()

var_names = ['fracforest', 'fracirrigated', 'fracrice', 'fracsealed', 'fracwater', 'fracother']
for var_name in var_names:
    print(var_name)
    ds = xr.open_dataset(os.path.join(Parameter_Maps, 'Land_use', f"{var_name}_Global_03min.nc"))['Band1']
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    ds.load()

    def process(idx, row):
        huc, gauge_id = basin_name, row['River Point Name']
        nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = [0])
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            # ds_window_loc = ds.sel(lat = lat, lon = lon, method = 'nearest')
            ds_window_loc = ds.sel(
                lat = slice(lat + 0.5*resolution, lat - 0.5*resolution),
                lon = slice(lon - 0.5*resolution, lon + 0.5*resolution)
            ).mean()
            data.loc[0, node_idx] = ds_window_loc.values.item()
        os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS'), exist_ok = True)
        data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS', f"{var_name}.csv"))

    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(indiawris_graph.iterrows(), total=len(indiawris_graph)))

chanbnkf


100%|██████████| 2/2 [00:00<00:00, 3318.28it/s]


chanflpn


100%|██████████| 2/2 [00:00<00:00, 2024.28it/s]


changrad


100%|██████████| 2/2 [00:00<00:00, 4371.34it/s]


chanlength


100%|██████████| 2/2 [00:00<00:00, 4275.54it/s]


chanman


100%|██████████| 2/2 [00:00<00:00, 3524.63it/s]


chans


100%|██████████| 2/2 [00:00<00:00, 4857.33it/s]


chanbw


100%|██████████| 2/2 [00:00<00:00, 4571.45it/s]


fracforest


100%|██████████| 2/2 [00:00<00:00, 4583.94it/s]


fracirrigated


100%|██████████| 2/2 [00:00<00:00, 5155.87it/s]


fracrice


100%|██████████| 2/2 [00:00<00:00, 5096.36it/s]


fracsealed


100%|██████████| 2/2 [00:00<00:00, 5084.00it/s]


fracwater


100%|██████████| 2/2 [00:00<00:00, 4038.81it/s]


fracother


100%|██████████| 2/2 [00:00<00:00, 4752.75it/s]
