# Setting Up

In [1]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import networkx as nx

import geopandas as gpd
import seaborn as sns
import matplotlib.pyplot as plt

from shapely.geometry import Point
from shapely.geometry import Polygon

import glob
import os
import itertools
import tqdm
import gc
import time
import pickle

from joblib import Parallel, delayed

In [2]:
import configparser
cfg = configparser.ConfigParser()
cfg.optionxform = str
cfg.read('/data/sarth/rootdir/assets/global.ini')
cfg = {s: dict(cfg.items(s)) for s in cfg.sections()}
PATHS = cfg['PATHS']

In [3]:
DIRNAME = '03min_GloFAS_CAMELS-IND'
SAVE_PATH = os.path.join(PATHS['devp_datasets'], DIRNAME)
resolution = 0.05
lon_360_180 = lambda x: (x + 180) % 360 - 180 # convert 0-360 to -180-180
lon_180_360 = lambda x: x % 360 # convert -180-180 to 0-360
region_bounds = {
    'minx': 66,
    'miny': 5,
    'maxx': 100,
    'maxy': 30
}
# minx, miny, maxx, maxy = 66, 5, 100, 30

# Load Watershed Attributes

In [4]:
camels_attributes_graph = pd.read_csv(os.path.join(SAVE_PATH, 'graph_attributes.csv'), index_col=0)
camels_attributes_graph.index = camels_attributes_graph.index.map(lambda x: str(x).zfill(5))
camels_attributes_graph['huc_02'] = camels_attributes_graph['huc_02'].map(lambda x: str(x).zfill(2))
camels_attributes_graph

Unnamed: 0_level_0,huc_02,ghi_lon,ghi_lat,ghi_area,cwc_lon,cwc_lat,cwc_area,cwc_site_name,ghi_stn_id,river_basin,cwc_river,flow_availability,snapped_lon,snapped_lat,snapped_uparea,snapped_iou,area_percent_difference,num_nodes,num_edges
gauge_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
14015,14,73.11033,18.73540,125.7,73.1108,18.7367,125.0,Pen,wfrn_penxx,WFRN,Bhogeswari,31.36,73.125,18.725,116.744995,0.650870,7.124106,4.0,3.0
14009,14,73.28122,18.23122,292.3,73.2833,18.2317,259.0,Mangaon (Seasonal),wfrn_manga,WFRN,Savitri/Kal,67.40,73.325,18.225,380.302100,0.520864,30.106777,13.0,12.0
15006,15,74.88124,13.51876,299.6,74.8800,13.5214,253.0,Avershe,wfrs_avers,WFRS,Seetha,39.00,74.925,13.475,329.444850,0.607418,9.961565,11.0,10.0
05025,05,78.05617,11.93959,356.0,78.0572,11.9383,362.0,Thoppur,cauv_thopp,Cauvery,Cauvery/Thoppaiyar,31.52,78.125,11.975,331.356480,0.634196,6.922338,11.0,10.0
15032,15,74.98123,13.29791,356.9,74.9806,13.2942,327.0,Yennehole,wfrs_yenne,WFRS,Swarna,70.67,74.975,13.275,329.717100,0.755086,7.616389,11.0,10.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
04062,04,80.06876,16.79375,240055.4,80.0692,16.7942,235544.0,Wadenapally,kris_waden,Krishna,Krishna,95.43,80.075,16.775,240186.480000,0.981840,0.054603,8131.0,8130.0
04060,04,80.61874,16.49796,257260.0,80.6250,16.5011,251360.0,Vijayawada,kris_vijay,Krishna,Krishna,84.90,80.625,16.475,256368.720000,0.978610,0.346452,8681.0,8680.0
03073,03,80.39375,18.58542,267340.4,80.3958,18.5872,268200.0,Perur,goda_perur,Godavari,Godavari,94.32,80.375,18.625,267313.380000,0.981013,0.010111,9218.0,9217.0
03043,03,81.38540,17.48542,304628.7,81.3875,17.4875,305460.0,Koida,goda_koida,Godavari,Godavari,64.48,81.375,17.475,304718.620000,0.983208,0.029524,10496.0,10495.0


In [5]:
camels_graph = camels_attributes_graph.copy()
print(camels_graph.shape)
camels_graph = camels_graph[camels_graph['ghi_area'] <= 30000]
print(camels_graph.shape)
camels_graph = camels_graph[camels_graph['area_percent_difference'] < 10]
print(camels_graph.shape)
camels_graph = camels_graph[camels_graph['num_nodes'] > 1]
print(camels_graph.shape)
# Print the number of graphs per 'huc_02' (sorted in values of huc_02)
camels_graph.sort_values(ascending=True, by = 'huc_02').groupby('huc_02').size()
# camels_graph['huc_02'].value_counts(sort=True)

(242, 19)
(200, 19)
(191, 19)
(191, 19)


huc_02
03    37
04    31
05    16
06     4
07     7
08    16
09     5
10     5
11     5
12    13
13     6
14     8
15    21
16     5
17    12
dtype: int64

In [6]:
camels_graph['ghi_area'].describe()

count      191.000000
mean      5765.108901
std       5827.288252
min        125.700000
25%       1757.100000
50%       3442.600000
75%       8072.400000
max      29822.500000
Name: ghi_area, dtype: float64

In [7]:
camels_graph['num_nodes'].describe()

count     191.000000
mean      197.267016
std       200.706979
min         4.000000
25%        60.000000
50%       117.000000
75%       270.500000
max      1034.000000
Name: num_nodes, dtype: float64

In [8]:
del camels_attributes_graph

# Create Node Features as csv

In [9]:
os.makedirs(os.path.join(SAVE_PATH, "graph_features"), exist_ok = True)

In [10]:
ldd = xr.open_dataset(os.path.join(PATHS['gis_ldd'], 'GloFAS_03min', 'ldd.nc'))
ldd = ldd['ldd']
ldd = ldd.sel(
    lat = slice(region_bounds['maxy'], region_bounds['miny']), 
    lon = slice(region_bounds['minx'], region_bounds['maxx'])
)

lons = ldd['lon'].values
lats = ldd['lat'].values

ds_grid = xr.Dataset({
    'lat': (['lat'], lats),
    'lon': (['lon'], lons),
})

# Round the lat lon values to 3 decimal places in ds_grid
ds_grid['lat'] = ds_grid['lat'].round(3)
ds_grid['lon'] = ds_grid['lon'].round(3)

## ERA5

### Dynamic

In [None]:
# var_names = [
#     '2m_temperature', 
#     'evaporation', 
#     'snowfall', 
#     'surface_net_solar_radiation', 
#     'surface_net_thermal_radiation', 
#     'surface_pressure', 
#     'total_precipitation',
#     '2m_dewpoint_temperature',
#     '10m_u_component_of_wind',
#     '10m_v_component_of_wind',
#     'forecast_albedo',
#     'potential_evaporation',
#     'runoff',
#     'snow_albedo',
#     'snow_depth',
#     'snowmelt',
#     'sub_surface_runoff',
#     'surface_runoff',
#     'total_column_water',
#     'volumetric_soil_water_layer_1',
#     'volumetric_soil_water_layer_2',
#     'volumetric_soil_water_layer_3',
#     'volumetric_soil_water_layer_4'
# ]

# dates = pd.date_range('1980-01-01', '2020-12-31', freq='D')
# dates = dates[~((dates.month == 2) & (dates.day == 29))]
# print(f"Number of dates: {len(dates)}")

# def process(idx, row, var_name):
#     huc, gauge_id = row['huc_02'], row.name
#     nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
#     data = pd.DataFrame(index = dates, columns = nodes_coords.index)
#     os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic'), exist_ok = True)
#     os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'ERA5'), exist_ok = True)
#     data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'ERA5', f"{var_name}.csv"))

# for var_name in var_names:
#     print(var_name)
#     with Parallel(n_jobs = 8, verbose = 0) as parallel:
#         _ = parallel(delayed(process)(idx, row, var_name) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

In [None]:
# for var_name in itertools.islice(var_names, 0, None, 1):
#     print(var_name)
#     ds = xr.open_mfdataset(os.path.join(PATHS['RawData'], 'ERA5', var_name, f"*.nc"), combine='by_coords')
#     ds_var_name = list(ds.data_vars)[0]
#     ds = ds[ds_var_name]
#     ds = ds.rename({'longitude': 'lon', 'latitude': 'lat'})
#     ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
#     ds['lon'] = [lon_360_180(lon) for lon in ds['lon'].values]
#     ds = ds.sortby('lon')
#     ds = ds.sel(
#         lat = slice(region_bounds['maxy'], region_bounds['miny']), 
#         lon = slice(region_bounds['minx'], region_bounds['maxx'])
#     )
#     _, index = np.unique(ds['time'], return_index = True)
#     ds = ds.isel(time = index)

#     if os.path.exists(os.path.join(PATHS['Assets'], 'regridder_era5_to_glofas_03min_IND.nc')):
#         regridder = xe.Regridder(
#             ds, 
#             ds_grid, 
#             'bilinear', 
#             reuse_weights=True, 
#             filename = os.path.join(PATHS['Assets'], 'regridder_era5_to_glofas_03min_IND.nc')
#         )
#     else:
#         regridder = xe.Regridder(
#             ds, 
#             ds_grid, 
#             'bilinear', 
#             reuse_weights=False
#         )
#         regridder.to_netcdf(os.path.join(PATHS['Assets'], 'regridder_era5_to_glofas_03min_IND.nc'))
    
#     ds_regrided = regridder(ds)
#     ds.close()
#     start_time = time.time()
#     ds_regrided.load()
#     end_time = time.time()
#     print(f'Time: {((end_time - start_time) / 60):.4f} mins')
    
#     def process(idx, row):
#         huc, gauge_id = row['huc_02'], row.name
#         nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
#         data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'ERA5', f"{var_name}.csv"), index_col = 0, parse_dates = True)
#         for node_idx, node_row in nodes_coords.iterrows():
#             lat, lon = node_row['lat'], node_row['lon']
#             ds_window_loc = ds_regrided.sel(lat = lat, lon = lon, method = 'nearest')
#             data.loc[:, str(node_idx)] = ds_window_loc.values
#         data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'ERA5', f"{var_name}.csv"))

#     with Parallel(n_jobs = 8, verbose = 0) as parallel:
#         _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

#     ds.close()
#     del ds
#     gc.collect()

### Static

In [11]:
var_names = [
    'static_soil_type', 
    'static_high_vegetation_cover', 
    'static_low_vegetation_cover', 
    'static_type_of_high_vegetation', 
    'static_type_of_low_vegetation'
    ]
ds_filenames = [
    'soil_type_static.nc',
    'high_vegetation_cover_static.nc',
    'low_vegetation_cover_static.nc',
    'type_of_high_vegetation_static.nc',
    'type_of_low_vegetation_static.nc'
]

for var_name, ds_filename in zip(var_names, ds_filenames):
    print(var_name)
    ds = xr.open_dataset(os.path.join(PATHS['RawData'], 'ERA5', ds_filename))
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.isel(time = 0)
    ds = ds.drop('time')
    ds = ds.rename({'longitude': 'lon', 'latitude': 'lat'})
    ds['lon'] = [lon_360_180(lon) for lon in ds['lon'].values]
    ds = ds.sortby('lon')
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    def process(idx, row):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = [0])
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds.sel(lat = lat, lon = lon, method = 'nearest')
            data.loc[0, node_idx] = int(ds_window_loc.values)
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static'), exist_ok = True)
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'ERA5'), exist_ok = True)
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'ERA5', f"{var_name}.csv"))

    for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)):
        process(idx, row)

static_soil_type


100%|██████████| 191/191 [00:32<00:00,  5.86it/s]


static_high_vegetation_cover


100%|██████████| 191/191 [00:30<00:00,  6.35it/s]


static_low_vegetation_cover


100%|██████████| 191/191 [00:30<00:00,  6.35it/s]


static_type_of_high_vegetation


100%|██████████| 191/191 [00:30<00:00,  6.27it/s]


static_type_of_low_vegetation


100%|██████████| 191/191 [00:30<00:00,  6.28it/s]


## HWSD

In [12]:
var_names = ['S_CLAY', 'S_GRAVEL', 'S_SAND', 'S_SILT', 'T_CLAY', 'T_GRAVEL', 'T_SAND', 'T_SILT']

for var_name in var_names:
    print(var_name)
    ds = xr.open_dataset(os.path.join(PATHS['HWSD'], f'{var_name}.nc4'))
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(
        lat = slice(region_bounds['miny'], region_bounds['maxy']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    ds = ds / 100
    ds.load()
    def process(idx, row):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = [0])
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds.sel(
                lat = slice(lat-resolution/2, lat+resolution/2),
                lon = slice(lon-resolution/2, lon+resolution/2)
            ).values.mean()
            data.loc[0, node_idx] = ds_window_loc
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static'), exist_ok = True)
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'HWSD'), exist_ok = True)
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'HWSD', f"{var_name}.csv"))

    for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)):
        process(idx, row)
    
    ds.close()
    del ds
    gc.collect()

S_CLAY


100%|██████████| 191/191 [00:10<00:00, 18.83it/s]


S_GRAVEL


100%|██████████| 191/191 [00:09<00:00, 19.65it/s]


S_SAND


100%|██████████| 191/191 [00:09<00:00, 19.48it/s]


S_SILT


100%|██████████| 191/191 [00:09<00:00, 19.44it/s]


T_CLAY


100%|██████████| 191/191 [00:09<00:00, 19.36it/s]


T_GRAVEL


100%|██████████| 191/191 [00:10<00:00, 19.09it/s]


T_SAND


100%|██████████| 191/191 [00:09<00:00, 19.57it/s]


T_SILT


100%|██████████| 191/191 [00:09<00:00, 19.53it/s]


## GLEAM

In [38]:
var_names = ['Ep', 'SMroot', 'SMsurf']

dates = pd.date_range('1980-01-01', '2020-12-31', freq='D')
dates = dates[~((dates.month == 2) & (dates.day == 29))]
print(f"Number of dates: {len(dates)}")

def process(idx, row, var_name):
    huc, gauge_id = row['huc_02'], row.name
    nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    data = pd.DataFrame(index = dates, columns = nodes_coords.index)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic'), exist_ok = True)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM'), exist_ok = True)
    data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"))

for var_name in var_names:
    print(var_name)
    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row, var_name) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

Number of dates: 14965
Ep


100%|██████████| 191/191 [00:02<00:00, 85.64it/s]


SMroot


100%|██████████| 191/191 [00:02<00:00, 86.09it/s]


SMsurf


100%|██████████| 191/191 [00:02<00:00, 90.00it/s]


In [39]:
for var_name in itertools.islice(var_names, 0, None, 1):
    print(var_name)
    ds = xr.open_mfdataset(os.path.join(PATHS['GLEAM'], var_name, f"*.nc"), combine='by_coords')
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )

    if os.path.exists(os.path.join(PATHS['Assets'], 'regridder_gleam_to_glofas_03min_IND.nc')):
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=True, 
            filename = os.path.join(PATHS['Assets'], 'regridder_gleam_to_glofas_03min_IND.nc')
        )
    else:
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=False
        )
        regridder.to_netcdf(os.path.join(PATHS['Assets'], 'regridder_gleam_to_glofas_03min_IND.nc'))
    
    ds_regrided = regridder(ds)
    ds.close()
    start_time = time.time()
    ds_regrided.load()
    end_time = time.time()
    print(f'Time: {((end_time - start_time) / 60):.4f} mins')
    
    def process(idx, row):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds_regrided.sel(lat = lat, lon = lon, method = 'nearest')
            data.loc[:, str(node_idx)] = ds_window_loc.values
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"))

    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

    ds.close()
    del ds
    gc.collect()

Ep
Time: 1.6640 mins


100%|██████████| 191/191 [02:03<00:00,  1.55it/s]


SMroot
Time: 1.2847 mins


100%|██████████| 191/191 [02:09<00:00,  1.48it/s]


SMsurf
Time: 1.3772 mins


100%|██████████| 191/191 [02:16<00:00,  1.40it/s]


### Fix NaNs

In [40]:
var_names = ['Ep', 'SMroot', 'SMsurf']
for var_name in var_names:
    ds = xr.open_mfdataset(os.path.join(PATHS['GLEAM'], var_name, f"*.nc"), combine='by_coords')
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    if os.path.exists(os.path.join(PATHS['Assets'], 'regridder_gleam_to_glofas_03min_IND.nc')):
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=True, 
            filename = os.path.join(PATHS['Assets'], 'regridder_gleam_to_glofas_03min_IND.nc')
        )
    else:
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=False
        )
        regridder.to_netcdf(os.path.join(PATHS['Assets'], 'regridder_gleam_to_glofas_03min_IND.nc'))
    ds_regrided = regridder(ds)
    ds.close()
    start_time = time.time()
    ds_regrided.load()
    end_time = time.time()
    print(f'{var_name} (Time: {((end_time - start_time) / 60):.4f} mins)')

    # Loop over catchments and find ones with issues
    issues = []
    for idx, row in tqdm.tqdm(camels_graph.iterrows()):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        if data.isnull().values.any():
            issues.append([huc, gauge_id])
    issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
    print(f"Number of catchments with issues: {issues.shape[0]}")
    print("------")

    # Fix the catchments with issues
    for issue_idx, (huc, gauge_id) in enumerate(issues.values):
        print(issue_idx, huc, gauge_id)
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        nodes_coords['isNaN'] = False
        nodes_coords['nonNaNneighbours'] = 0
        # Loop over nodes and find the nodes with issues
        for node_idx in nodes_coords.index:
            if data[str(node_idx)].isnull().values.any():
                nodes_coords.loc[node_idx, 'isNaN'] = True
                node_lat = float(round(nodes_coords.loc[node_idx, 'lat'], 3))
                node_lon = float(round(nodes_coords.loc[node_idx, 'lon'], 3))
                multiplier = 1.5
                ds_slice = ds_regrided.sel(
                    lat = slice(node_lat+multiplier*resolution, node_lat-multiplier*resolution), 
                    lon = slice(node_lon-multiplier*resolution, node_lon+multiplier*resolution)
                    )
                slice_df = ds_slice.to_dataframe(name = var_name).reset_index()
                slice_df['lat'] = slice_df['lat'].round(3)
                slice_df['lon'] = slice_df['lon'].round(3)
                slice_df['location'] = list(zip(slice_df['lat'], slice_df['lon']))
                slice_df = slice_df.pivot(index='time', columns='location', values=var_name)
                num_nan_nodes = slice_df.isnull().any(axis=0).sum()
                num_nonnan_nodes = len(slice_df.columns) - num_nan_nodes
                nodes_coords.loc[node_idx, 'nonNaNneighbours'] = num_nonnan_nodes
        nodes_coords_sorted = nodes_coords.sort_values(by = 'nonNaNneighbours', ascending = False)
        nodes_coords_sorted = nodes_coords_sorted[nodes_coords_sorted['isNaN']]
        print(f"Number of nodes with NaN values: {nodes_coords_sorted.shape[0]}")
        
        for node_idx in tqdm.tqdm(nodes_coords_sorted.index):
            node_lat, node_lon = float(round(nodes_coords.loc[node_idx, 'lat'], 3)), float(round(nodes_coords.loc[node_idx, 'lon'], 3))
            multiplier = 1.5
            ds_slice = ds_regrided.sel(
                lat = slice(node_lat+multiplier*resolution, node_lat-multiplier*resolution), 
                lon = slice(node_lon-multiplier*resolution, node_lon+multiplier*resolution)
                )
            slice_df = ds_slice.to_dataframe(name = var_name).reset_index()
            slice_df['lat'] = slice_df['lat'].round(3)
            slice_df['lon'] = slice_df['lon'].round(3)
            slice_df['location'] = list(zip(slice_df['lat'], slice_df['lon']))
            slice_df = slice_df.pivot(index='time', columns='location', values=var_name)
            slice_df.columns = list(map(str, slice_df.columns))
            num_nonnan_nodes = len(slice_df.columns) - slice_df.isnull().any(axis=0).sum()
            # print(node_idx, (node_lat, node_lon), num_nonnan_nodes)
            if num_nonnan_nodes == 9:
                replacement_values = slice_df.loc[:, f"({node_lat}, {node_lon})"]
                data.loc[:, str(node_idx)] = replacement_values
                nodes_coords_sorted.loc[node_idx, 'isNaN'] = False
            elif num_nonnan_nodes > 0:
                replacement_values = np.nanmean(slice_df, axis = 1)
                data.loc[:, str(node_idx)] = replacement_values
                ds_regrided.loc[dict(lat = node_lat, lon = node_lon)] = replacement_values
                nodes_coords_sorted.loc[node_idx, 'isNaN'] = False
        print(f"Number of nodes with NaN values: {nodes_coords_sorted['isNaN'].sum()}")
        print(issue_idx, huc, gauge_id, data.isnull().values.any())
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"))
        print("------")

Ep (Time: 1.6138 mins)


191it [02:00,  1.58it/s]


Number of catchments with issues: 6
------
0 15 15014
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00,  9.77it/s]


Number of nodes with NaN values: 0
0 15 15014 False
------
1 15 15023
Number of nodes with NaN values: 3


100%|██████████| 3/3 [00:00<00:00, 10.33it/s]


Number of nodes with NaN values: 0
1 15 15023 False
------
2 05 05013
Number of nodes with NaN values: 6


100%|██████████| 6/6 [00:00<00:00, 10.34it/s]


Number of nodes with NaN values: 0
2 05 05013 False
------
3 15 15026
Number of nodes with NaN values: 7


100%|██████████| 7/7 [00:00<00:00,  9.58it/s]


Number of nodes with NaN values: 0
3 15 15026 False
------
4 05 05028
Number of nodes with NaN values: 6


100%|██████████| 6/6 [00:00<00:00, 10.66it/s]


Number of nodes with NaN values: 0
4 05 05028 False
------
5 11 11010
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 11.43it/s]


Number of nodes with NaN values: 0
5 11 11010 False
------
SMroot (Time: 1.5371 mins)


191it [01:29,  2.15it/s]


Number of catchments with issues: 6
------
0 15 15014
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 11.50it/s]


Number of nodes with NaN values: 0
0 15 15014 False
------
1 15 15023
Number of nodes with NaN values: 3


100%|██████████| 3/3 [00:00<00:00, 10.81it/s]


Number of nodes with NaN values: 0
1 15 15023 False
------
2 05 05013
Number of nodes with NaN values: 6


100%|██████████| 6/6 [00:00<00:00, 11.27it/s]


Number of nodes with NaN values: 0
2 05 05013 False
------
3 15 15026
Number of nodes with NaN values: 7


100%|██████████| 7/7 [00:00<00:00, 10.87it/s]


Number of nodes with NaN values: 0
3 15 15026 False
------
4 05 05028
Number of nodes with NaN values: 6


100%|██████████| 6/6 [00:00<00:00, 10.94it/s]


Number of nodes with NaN values: 0
4 05 05028 False
------
5 11 11010
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 11.47it/s]


Number of nodes with NaN values: 0
5 11 11010 False
------
SMsurf (Time: 1.2104 mins)


191it [01:51,  1.72it/s]


Number of catchments with issues: 6
------
0 15 15014
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 10.10it/s]


Number of nodes with NaN values: 0
0 15 15014 False
------
1 15 15023
Number of nodes with NaN values: 3


100%|██████████| 3/3 [00:00<00:00, 10.60it/s]


Number of nodes with NaN values: 0
1 15 15023 False
------
2 05 05013
Number of nodes with NaN values: 6


100%|██████████| 6/6 [00:00<00:00, 10.27it/s]


Number of nodes with NaN values: 0
2 05 05013 False
------
3 15 15026
Number of nodes with NaN values: 7


100%|██████████| 7/7 [00:00<00:00, 10.54it/s]


Number of nodes with NaN values: 0
3 15 15026 False
------
4 05 05028
Number of nodes with NaN values: 6


100%|██████████| 6/6 [00:00<00:00, 10.39it/s]


Number of nodes with NaN values: 0
4 05 05028 False
------
5 11 11010
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 11.82it/s]


Number of nodes with NaN values: 0
5 11 11010 False
------


In [None]:
var_names = ['Ep', 'SMroot', 'SMsurf']
for var_name in var_names:
    # Loop over catchments and find ones with issues
    issues = []
    for idx, row in tqdm.tqdm(camels_graph.iterrows()):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        if data.isnull().values.any():
            issues.append([huc, gauge_id])
    issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
    print(f"Number of catchments with issues: {issues.shape[0]}")
    print("------")

    # Fix the catchments with issues
    for issue_idx, (huc, gauge_id) in enumerate(issues.values):
        print(issue_idx, huc, gauge_id)
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        nodes_coords['isNaN'] = False
        nodes_coords['nonNaNneighbours'] = 0
        # Loop over nodes and find the nodes with issues
        for node_idx in nodes_coords.index:
            if data[str(node_idx)].isnull().values.any():
                nodes_coords.loc[node_idx, 'isNaN'] = True
        print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")
        print("------")

In [None]:
var_names = ['Ep', 'SMroot', 'SMsurf']
for var_name in var_names:
    # Loop over catchments and find ones with issues
    issues = []
    for idx, row in tqdm.tqdm(camels_graph.iterrows()):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        if data.isnull().values.any():
            issues.append([huc, gauge_id])
    issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
    print(f"Number of catchments with issues: {issues.shape[0]}")
    print("------")

    # Fix the catchments with issues
    for issue_idx, (huc, gauge_id) in enumerate(issues.values):
        print(issue_idx, huc, gauge_id)
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        nodes_coords['isNaN'] = False
        # Loop over nodes and find the nodes with issues
        for node_idx in nodes_coords.index:
            if data[str(node_idx)].isnull().values.any():
                nodes_coords.loc[node_idx, 'isNaN'] = True
        print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")

        
        for node_idx in tqdm.tqdm(nodes_coords[nodes_coords['isNaN']].index):
            nodes_coords['distances'] = None
            node_lat, node_lon = float(round(nodes_coords.loc[node_idx, 'lat'], 3)), float(round(nodes_coords.loc[node_idx, 'lon'], 3))
            for node_idx2 in nodes_coords[nodes_coords['isNaN'] == False].index:
                if node_idx != node_idx2:
                    node_lat2, node_lon2 = float(round(nodes_coords.loc[node_idx2, 'lat'], 3)), float(round(nodes_coords.loc[node_idx2, 'lon'], 3))
                    distance = np.sqrt((node_lat - node_lat2)**2 + (node_lon - node_lon2)**2)
                    nodes_coords.loc[node_idx2, 'distances'] = distance
            min_distance = nodes_coords.loc[nodes_coords['distances'].idxmin(), 'distances']
            # Replace with mean of nodes having distance equal to min_distance
            replacement_nodes = nodes_coords[nodes_coords['distances'] == min_distance].index
            replacement_nodes = list(map(str, replacement_nodes))
            replacement_values = data.loc[:, replacement_nodes].mean(axis = 1)
            data.loc[:, str(node_idx)] = replacement_values
            nodes_coords.loc[node_idx, 'isNaN'] = False
        print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")
        print(issue_idx, huc, gauge_id, data.isnull().values.any())
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM', f"{var_name}.csv"))
        print("------")

## GLEAM4

In [11]:
var_names = ['Ep', 'SMrz', 'SMs', 'Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']
# var_names = ['Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']

dates = pd.date_range('1980-01-01', '2020-12-31', freq='D')
dates = dates[~((dates.month == 2) & (dates.day == 29))]
print(f"Number of dates: {len(dates)}")

def process(idx, row, var_name):
    huc, gauge_id = row['huc_02'], row.name
    nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    data = pd.DataFrame(index = dates, columns = nodes_coords.index)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic'), exist_ok = True)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4'), exist_ok = True)
    data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"))

for var_name in var_names:
    print(var_name)
    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row, var_name) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

Number of dates: 14965
Ep


100%|██████████| 191/191 [00:04<00:00, 44.83it/s]


SMrz


100%|██████████| 191/191 [00:02<00:00, 91.29it/s] 


SMs


100%|██████████| 191/191 [00:02<00:00, 90.51it/s] 


Eb


100%|██████████| 191/191 [00:05<00:00, 35.50it/s]


Ei


100%|██████████| 191/191 [00:02<00:00, 90.50it/s] 


Es


100%|██████████| 191/191 [00:02<00:00, 89.90it/s] 


Et


100%|██████████| 191/191 [00:02<00:00, 92.81it/s] 


Ew


100%|██████████| 191/191 [00:02<00:00, 91.77it/s] 


S


100%|██████████| 191/191 [00:02<00:00, 91.70it/s] 


H


100%|██████████| 191/191 [00:02<00:00, 92.03it/s] 


In [12]:
for var_name in itertools.islice(var_names, 0, None, 1):
    print(var_name)
    ds = xr.open_mfdataset(os.path.join(PATHS['GLEAM'], 'GLEAM4.2a', var_name, f"*.nc"), combine='by_coords')
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )

    if os.path.exists(os.path.join(PATHS['Assets'], 'regridder', 'regridder_gleam4_to_glofas_03min_IND.nc')):
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=True, 
            filename = os.path.join(PATHS['Assets'], 'regridder', 'regridder_gleam4_to_glofas_03min_IND.nc')
        )
    else:
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=False
        )
        regridder.to_netcdf(os.path.join(PATHS['Assets'], 'regridder', 'regridder_gleam4_to_glofas_03min_IND.nc'))
    
    ds_regrided = regridder(ds)
    ds.close()
    start_time = time.time()
    ds_regrided.load()
    end_time = time.time()
    print(f'Time: {((end_time - start_time) / 60):.4f} mins')
    
    def process(idx, row):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds_regrided.sel(lat = lat, lon = lon, method = 'nearest')
            data.loc[:, str(node_idx)] = ds_window_loc.values
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"))

    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

    ds.close()
    del ds
    gc.collect()

Ep
Time: 1.8767 mins


100%|██████████| 191/191 [01:46<00:00,  1.79it/s]


SMrz
Time: 1.2634 mins


100%|██████████| 191/191 [01:49<00:00,  1.74it/s]


SMs
Time: 1.2806 mins


100%|██████████| 191/191 [01:48<00:00,  1.76it/s]


Eb
Time: 1.3667 mins


100%|██████████| 191/191 [01:47<00:00,  1.78it/s]


Ei
Time: 1.2192 mins


100%|██████████| 191/191 [01:18<00:00,  2.44it/s]


Es
Time: 0.7418 mins


100%|██████████| 191/191 [00:34<00:00,  5.52it/s]


Et
Time: 1.3326 mins


100%|██████████| 191/191 [01:49<00:00,  1.74it/s]


Ew
Time: 1.3358 mins


100%|██████████| 191/191 [01:51<00:00,  1.71it/s]


S
Time: 1.3433 mins


100%|██████████| 191/191 [01:44<00:00,  1.82it/s]


H
Time: 1.3044 mins


100%|██████████| 191/191 [01:46<00:00,  1.79it/s]


In [11]:
var_names = ['Ep', 'SMrz', 'SMs', 'Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']
# var_names = ['Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']
for var_name in var_names:
    ds = xr.open_mfdataset(os.path.join(PATHS['GLEAM'], 'GLEAM4.2a', var_name, f"*.nc"), combine='by_coords')
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    if os.path.exists(os.path.join(PATHS['Assets'], 'regridder', 'regridder_gleam4_to_glofas_03min_IND.nc')):
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=True, 
            filename = os.path.join(PATHS['Assets'], 'regridder', 'regridder_gleam4_to_glofas_03min_IND.nc')
        )
    else:
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=False
        )
        regridder.to_netcdf(os.path.join(PATHS['Assets'], 'regridder', 'regridder_gleam4_to_glofas_03min_IND.nc'))
    ds_regrided = regridder(ds)
    ds.close()
    start_time = time.time()
    ds_regrided.load()
    end_time = time.time()
    print(f'{var_name} (Time: {((end_time - start_time) / 60):.4f} mins)')

    # Loop over catchments and find ones with issues
    issues = []
    for idx, row in tqdm.tqdm(camels_graph.iterrows()):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        if data.isnull().values.any():
            issues.append([huc, gauge_id])
    issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
    print(f"Number of catchments with issues: {issues.shape[0]}")
    print("------")

    # Fix the catchments with issues
    for issue_idx, (huc, gauge_id) in enumerate(issues.values):
        print(issue_idx, huc, gauge_id)
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        nodes_coords['isNaN'] = False
        nodes_coords['nonNaNneighbours'] = 0
        # Loop over nodes and find the nodes with issues
        for node_idx in nodes_coords.index:
            if data[str(node_idx)].isnull().values.any():
                nodes_coords.loc[node_idx, 'isNaN'] = True
                node_lat = float(round(nodes_coords.loc[node_idx, 'lat'], 3))
                node_lon = float(round(nodes_coords.loc[node_idx, 'lon'], 3))
                multiplier = 1.5
                ds_slice = ds_regrided.sel(
                    lat = slice(node_lat+multiplier*resolution, node_lat-multiplier*resolution), 
                    lon = slice(node_lon-multiplier*resolution, node_lon+multiplier*resolution)
                    )
                slice_df = ds_slice.to_dataframe(name = var_name).reset_index()
                slice_df['lat'] = slice_df['lat'].round(3)
                slice_df['lon'] = slice_df['lon'].round(3)
                slice_df['location'] = list(zip(slice_df['lat'], slice_df['lon']))
                slice_df = slice_df.pivot(index='time', columns='location', values=var_name)
                num_nan_nodes = slice_df.isnull().any(axis=0).sum()
                num_nonnan_nodes = len(slice_df.columns) - num_nan_nodes
                nodes_coords.loc[node_idx, 'nonNaNneighbours'] = num_nonnan_nodes
        nodes_coords_sorted = nodes_coords.sort_values(by = 'nonNaNneighbours', ascending = False)
        nodes_coords_sorted = nodes_coords_sorted[nodes_coords_sorted['isNaN']]
        print(f"Number of nodes with NaN values: {nodes_coords_sorted.shape[0]}")
        
        for node_idx in tqdm.tqdm(nodes_coords_sorted.index):
            node_lat, node_lon = float(round(nodes_coords.loc[node_idx, 'lat'], 3)), float(round(nodes_coords.loc[node_idx, 'lon'], 3))
            multiplier = 1.5
            ds_slice = ds_regrided.sel(
                lat = slice(node_lat+multiplier*resolution, node_lat-multiplier*resolution), 
                lon = slice(node_lon-multiplier*resolution, node_lon+multiplier*resolution)
                )
            slice_df = ds_slice.to_dataframe(name = var_name).reset_index()
            slice_df['lat'] = slice_df['lat'].round(3)
            slice_df['lon'] = slice_df['lon'].round(3)
            slice_df['location'] = list(zip(slice_df['lat'], slice_df['lon']))
            slice_df = slice_df.pivot(index='time', columns='location', values=var_name)
            slice_df.columns = list(map(str, slice_df.columns))
            num_nonnan_nodes = len(slice_df.columns) - slice_df.isnull().any(axis=0).sum()
            # print(node_idx, (node_lat, node_lon), num_nonnan_nodes)
            if num_nonnan_nodes == 9:
                replacement_values = slice_df.loc[:, f"({node_lat}, {node_lon})"]
                data.loc[:, str(node_idx)] = replacement_values
                nodes_coords_sorted.loc[node_idx, 'isNaN'] = False
            elif num_nonnan_nodes > 0:
                replacement_values = np.nanmean(slice_df, axis = 1)
                data.loc[:, str(node_idx)] = replacement_values
                ds_regrided.loc[dict(lat = node_lat, lon = node_lon)] = replacement_values
                nodes_coords_sorted.loc[node_idx, 'isNaN'] = False
        print(f"Number of nodes with NaN values: {nodes_coords_sorted['isNaN'].sum()}")
        print(issue_idx, huc, gauge_id, data.isnull().values.any())
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"))
        print("------")

Ep (Time: 1.2970 mins)


191it [01:08,  2.81it/s]


Number of catchments with issues: 0
------
SMrz (Time: 1.2745 mins)


191it [01:09,  2.75it/s]


Number of catchments with issues: 10
------
0 15 15006
Number of nodes with NaN values: 1


100%|██████████| 1/1 [00:00<00:00, 11.19it/s]


Number of nodes with NaN values: 0
0 15 15006 False
------
1 15 15014
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 11.52it/s]


Number of nodes with NaN values: 0
1 15 15014 False
------
2 14 14005
Number of nodes with NaN values: 1


100%|██████████| 1/1 [00:00<00:00, 11.54it/s]


Number of nodes with NaN values: 0
2 14 14005 False
------
3 15 15023
Number of nodes with NaN values: 12


100%|██████████| 12/12 [00:01<00:00, 11.61it/s]


Number of nodes with NaN values: 0
3 15 15023 False
------
4 05 05013
Number of nodes with NaN values: 11


100%|██████████| 11/11 [00:00<00:00, 11.57it/s]


Number of nodes with NaN values: 0
4 05 05013 False
------
5 15 15026
Number of nodes with NaN values: 12


100%|██████████| 12/12 [00:01<00:00, 11.52it/s]


Number of nodes with NaN values: 1
5 15 15026 True
------
6 14 14003
Number of nodes with NaN values: 3


100%|██████████| 3/3 [00:00<00:00, 11.57it/s]


Number of nodes with NaN values: 0
6 14 14003 False
------
7 05 05028
Number of nodes with NaN values: 11


100%|██████████| 11/11 [00:00<00:00, 11.56it/s]


Number of nodes with NaN values: 0
7 05 05028 False
------
8 16 16014
Number of nodes with NaN values: 2


100%|██████████| 2/2 [00:00<00:00, 11.35it/s]

Number of nodes with NaN values: 0
8 16 16014 False





------
9 11 11010
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 11.35it/s]


Number of nodes with NaN values: 0
9 11 11010 False
------
SMs (Time: 1.2240 mins)


191it [01:07,  2.83it/s]


Number of catchments with issues: 10
------
0 15 15006
Number of nodes with NaN values: 1


100%|██████████| 1/1 [00:00<00:00, 11.01it/s]


Number of nodes with NaN values: 0
0 15 15006 False
------
1 15 15014
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 11.31it/s]


Number of nodes with NaN values: 0
1 15 15014 False
------
2 14 14005
Number of nodes with NaN values: 1


100%|██████████| 1/1 [00:00<00:00, 11.38it/s]


Number of nodes with NaN values: 0
2 14 14005 False
------
3 15 15023
Number of nodes with NaN values: 12


100%|██████████| 12/12 [00:01<00:00, 11.24it/s]


Number of nodes with NaN values: 0
3 15 15023 False
------
4 05 05013
Number of nodes with NaN values: 11


100%|██████████| 11/11 [00:00<00:00, 11.26it/s]


Number of nodes with NaN values: 0
4 05 05013 False
------
5 15 15026
Number of nodes with NaN values: 12


100%|██████████| 12/12 [00:01<00:00, 11.19it/s]


Number of nodes with NaN values: 1
5 15 15026 True
------
6 14 14003
Number of nodes with NaN values: 3


100%|██████████| 3/3 [00:00<00:00, 11.24it/s]


Number of nodes with NaN values: 0
6 14 14003 False
------
7 05 05028
Number of nodes with NaN values: 11


100%|██████████| 11/11 [00:00<00:00, 11.19it/s]


Number of nodes with NaN values: 0
7 05 05028 False
------
8 16 16014
Number of nodes with NaN values: 2


100%|██████████| 2/2 [00:00<00:00, 11.07it/s]

Number of nodes with NaN values: 0
8 16 16014 False





------
9 11 11010
Number of nodes with NaN values: 4


100%|██████████| 4/4 [00:00<00:00, 11.10it/s]


Number of nodes with NaN values: 0
9 11 11010 False
------
Eb (Time: 1.3143 mins)


191it [01:11,  2.68it/s]


Number of catchments with issues: 0
------
Ei (Time: 1.1882 mins)


191it [00:55,  3.44it/s]


Number of catchments with issues: 0
------
Es (Time: 0.7126 mins)


191it [00:23,  8.09it/s]


Number of catchments with issues: 0
------
Et (Time: 1.3268 mins)


191it [01:12,  2.63it/s]


Number of catchments with issues: 0
------
Ew (Time: 1.3153 mins)


191it [01:14,  2.57it/s]


Number of catchments with issues: 0
------
S (Time: 1.3380 mins)


191it [01:11,  2.69it/s]


Number of catchments with issues: 0
------
H (Time: 1.3112 mins)


191it [01:10,  2.70it/s]

Number of catchments with issues: 0
------





In [None]:
var_names = ['Ep', 'SMrz', 'SMs', 'Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']
# var_names = ['Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']
for var_name in var_names:
    # Loop over catchments and find ones with issues
    issues = []
    for idx, row in tqdm.tqdm(camels_graph.iterrows()):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        if data.isnull().values.any():
            issues.append([huc, gauge_id])
    issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
    print(f"Number of catchments with issues: {issues.shape[0]}")
    print("------")

    # Fix the catchments with issues
    for issue_idx, (huc, gauge_id) in enumerate(issues.values):
        print(issue_idx, huc, gauge_id)
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        nodes_coords['isNaN'] = False
        nodes_coords['nonNaNneighbours'] = 0
        # Loop over nodes and find the nodes with issues
        for node_idx in nodes_coords.index:
            if data[str(node_idx)].isnull().values.any():
                nodes_coords.loc[node_idx, 'isNaN'] = True
        print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")
        print("------")

In [12]:
var_names = ['Ep', 'SMrz', 'SMs', 'Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']
# var_names = ['Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']
for var_name in var_names:
    # Loop over catchments and find ones with issues
    issues = []
    for idx, row in tqdm.tqdm(camels_graph.iterrows()):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        if data.isnull().values.any():
            issues.append([huc, gauge_id])
    issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
    print(f"Number of catchments with issues: {issues.shape[0]}")
    print("------")

    # Fix the catchments with issues
    for issue_idx, (huc, gauge_id) in enumerate(issues.values):
        print(issue_idx, huc, gauge_id)
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        nodes_coords['isNaN'] = False
        # Loop over nodes and find the nodes with issues
        for node_idx in nodes_coords.index:
            if data[str(node_idx)].isnull().values.any():
                nodes_coords.loc[node_idx, 'isNaN'] = True
        print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")

        
        for node_idx in tqdm.tqdm(nodes_coords[nodes_coords['isNaN']].index):
            nodes_coords['distances'] = None
            node_lat, node_lon = float(round(nodes_coords.loc[node_idx, 'lat'], 3)), float(round(nodes_coords.loc[node_idx, 'lon'], 3))
            for node_idx2 in nodes_coords[nodes_coords['isNaN'] == False].index:
                if node_idx != node_idx2:
                    node_lat2, node_lon2 = float(round(nodes_coords.loc[node_idx2, 'lat'], 3)), float(round(nodes_coords.loc[node_idx2, 'lon'], 3))
                    distance = np.sqrt((node_lat - node_lat2)**2 + (node_lon - node_lon2)**2)
                    nodes_coords.loc[node_idx2, 'distances'] = distance
            min_distance = nodes_coords.loc[nodes_coords['distances'].idxmin(), 'distances']
            # Replace with mean of nodes having distance equal to min_distance
            replacement_nodes = nodes_coords[nodes_coords['distances'] == min_distance].index
            replacement_nodes = list(map(str, replacement_nodes))
            replacement_values = data.loc[:, replacement_nodes].mean(axis = 1)
            data.loc[:, str(node_idx)] = replacement_values
            nodes_coords.loc[node_idx, 'isNaN'] = False
        print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")
        print(issue_idx, huc, gauge_id, data.isnull().values.any())
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GLEAM4', f"{var_name}.csv"))
        print("------")

191it [01:02,  3.06it/s]


Number of catchments with issues: 0
------


191it [01:03,  2.99it/s]


Number of catchments with issues: 1
------
0 15 15026
Number of nodes with NaN values: 1


100%|██████████| 1/1 [00:00<00:00, 87.95it/s]


Number of nodes with NaN values: 0
0 15 15026 False
------


191it [01:04,  2.95it/s]


Number of catchments with issues: 1
------
0 15 15026
Number of nodes with NaN values: 1


100%|██████████| 1/1 [00:00<00:00, 89.84it/s]


Number of nodes with NaN values: 0
0 15 15026 False
------


191it [01:04,  2.97it/s]


Number of catchments with issues: 0
------


191it [00:45,  4.15it/s]


Number of catchments with issues: 0
------


191it [00:21,  8.73it/s]


Number of catchments with issues: 0
------


191it [01:06,  2.89it/s]


Number of catchments with issues: 0
------


191it [01:05,  2.91it/s]


Number of catchments with issues: 0
------


191it [01:00,  3.13it/s]


Number of catchments with issues: 0
------


191it [01:02,  3.03it/s]

Number of catchments with issues: 0
------





## GPM

In [11]:
var_names = ['precipitation']
# var_names = ['Eb', 'Ei', 'Es', 'Et', 'Ew', 'S', 'H']

dates = pd.date_range('1980-01-01', '2020-12-31', freq='D')
dates = dates[~((dates.month == 2) & (dates.day == 29))]
print(f"Number of dates: {len(dates)}")

def process(idx, row, var_name):
    huc, gauge_id = row['huc_02'], row.name
    nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    data = pd.DataFrame(index = dates, columns = nodes_coords.index)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic'), exist_ok = True)
    os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GPM'), exist_ok = True)
    data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GPM', f"{var_name}.csv"))

for var_name in var_names:
    print(var_name)
    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row, var_name) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

Number of dates: 14965
precipitation


100%|██████████| 191/191 [00:06<00:00, 28.18it/s]


In [13]:
var_names = ['precipitation']
for var_name in itertools.islice(var_names, 0, None, 1):
    print(var_name)
    ds = xr.open_zarr(os.path.join(PATHS['GPM'], 'GPM_1998_2020.zarr'), consolidated=True)
    ds_var_name = list(ds.data_vars)[0]
    ds = ds[ds_var_name]
    ds = ds.sel(time=~((ds['time.month'] == 2) & (ds['time.day'] == 29)))
    ds = ds.sel(
        lat = slice(region_bounds['miny'], region_bounds['maxy']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )

    missing_date = np.datetime64('1999-09-03')
    # Reindex the dataset to include the missing date (it will be filled with NaN)
    new_times = np.sort(np.append(ds.time.values, missing_date))
    ds = ds.reindex(time=new_times).chunk({'time': -1})
    ds = ds.sortby('time')
    # # Interpolate in time to fill NaN values using linear interpolation
    # ds = ds.interpolate_na(dim='time', method='linear')
    # # Chunk in a way that makes it faster
    ds = ds.chunk({'time': 1, 'lat': -1, 'lon': -1})

    if os.path.exists(os.path.join(PATHS['Assets'], 'regridder', 'regridder_gpm_to_glofas_03min_IND.nc')):
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=True, 
            filename = os.path.join(PATHS['Assets'], 'regridder', 'regridder_gpm_to_glofas_03min_IND.nc')
        )
    else:
        regridder = xe.Regridder(
            ds, 
            ds_grid, 
            'bilinear', 
            reuse_weights=False
        )
        regridder.to_netcdf(os.path.join(PATHS['Assets'], 'regridder', 'regridder_gpm_to_glofas_03min_IND.nc'))
    
    ds_regrided = regridder(ds)
    ds.close()

    for start_year in range(1998, 2020+1, 5):
        start_date = f"{start_year}-01-01"
        end_date = f"{min(start_year+4,2020)}-12-31"
        ds_window = ds_regrided.sel(time = slice(start_date, end_date)).copy()
        start_time = time.time()
        ds_window.load()
        end_time = time.time()
        print(start_date, end_date, f"Time: {(end_time - start_time)/60:.2f} mins")
    
        def process(idx, row):
            huc, gauge_id = row['huc_02'], row.name
            nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
            data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GPM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
            for node_idx, node_row in nodes_coords.iterrows():
                lat, lon = node_row['lat'], node_row['lon']
                ds_window_loc = ds_window.sel(lat = lat, lon = lon, method = 'nearest')
                # data.loc[:, str(node_idx)] = ds_window_loc.values
                data.loc[start_date:end_date, str(node_idx)] = ds_window_loc.values
            data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GPM', f"{var_name}.csv"))

        with Parallel(n_jobs = 8, verbose = 0) as parallel:
            _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

        # ds.close()
        # del ds
        # gc.collect()
    ds_regrided.close()
    del ds_regrided, ds
    gc.collect()

precipitation
1998-01-01 2002-12-31 Time: 8.52 mins


100%|██████████| 191/191 [00:24<00:00,  7.75it/s]


2003-01-01 2007-12-31 Time: 3.47 mins


100%|██████████| 191/191 [00:27<00:00,  6.83it/s]


2008-01-01 2012-12-31 Time: 3.44 mins


100%|██████████| 191/191 [00:33<00:00,  5.66it/s]


2013-01-01 2017-12-31 Time: 3.27 mins


100%|██████████| 191/191 [00:39<00:00,  4.82it/s]


2018-01-01 2020-12-31 Time: 2.76 mins


100%|██████████| 191/191 [00:42<00:00,  4.54it/s]


In [14]:
def process(idx, row):
    huc, gauge_id = row['huc_02'], row.name
    nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GPM', f"{var_name}.csv"), index_col = 0, parse_dates = True)

    # Interpolate using a window of 15 days, centered
    missing_date = np.datetime64('1999-09-03')
    window = pd.Timedelta(days=1)
    start_date = missing_date - window
    end_date = missing_date + window
    data_window = data.loc[start_date:end_date]
    data.loc[missing_date] = data_window.mean(axis=0)

    data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GPM', f"{var_name}.csv"))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

100%|██████████| 191/191 [00:38<00:00,  5.01it/s]


In [15]:
var_names = ['precipitation']
for var_name in var_names:
    # Loop over catchments and find ones with issues
    issues = []
    for idx, row in tqdm.tqdm(camels_graph.iterrows()):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GPM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        # Only consider from 1998 onwards
        data = data.loc['1998-01-01':]
        if data.isnull().values.any():
            issues.append([huc, gauge_id])
    issues = pd.DataFrame(issues, columns = ['huc_02', 'gauge_id'])
    print(f"Number of catchments with issues: {issues.shape[0]}")
    print("------")

    # Fix the catchments with issues
    for issue_idx, (huc, gauge_id) in enumerate(issues.values):
        print(issue_idx, huc, gauge_id)
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.read_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'dynamic', 'GPM', f"{var_name}.csv"), index_col = 0, parse_dates = True)
        data = data.loc['1998-01-01':]
        nodes_coords['isNaN'] = False
        nodes_coords['nonNaNneighbours'] = 0
        # Loop over nodes and find the nodes with issues
        for node_idx in nodes_coords.index:
            if data[str(node_idx)].isnull().values.any():
                nodes_coords.loc[node_idx, 'isNaN'] = True
        print(f"Number of nodes with NaN values: {nodes_coords['isNaN'].sum()}")
        print("------")

191it [00:32,  5.86it/s]

Number of catchments with issues: 0
------





## Solar Insolation

In [13]:
def solar_insolation(lat, lon, start_date, end_date):
    # Constants
    Sc = 1361  # Solar constant (W/m^2)
    
    # Convert dates to datetime objects
    start_date = pd.to_datetime(start_date)
    end_date = pd.to_datetime(end_date)
    
    # Generate date range
    dates = pd.date_range(start=start_date, end=end_date, freq='D')
    dates = dates[~((dates.month == 2) & (dates.day == 29))]
    
    # Function to calculate solar declination
    def solar_declination(n):
        return 23.45 * np.sin(np.radians((360 / 365) * (n - 81)))

    # Function to calculate cos(theta_z) for solar zenith angle
    def cos_theta_z(lat, decl, hour_angle):
        lat_rad = np.radians(lat)
        decl_rad = np.radians(decl)
        return (np.sin(lat_rad) * np.sin(decl_rad) + 
                np.cos(lat_rad) * np.cos(decl_rad) * np.cos(np.radians(hour_angle)))
    
    # Function to calculate the hour angle
    def hour_angle(lon, date):
        # Assuming solar noon (local solar time = 12 hours)
        return 0  # hour angle at solar noon
    
    # Calculate solar insolation for each day
    insolation_values = []
    for date in dates:
        day_of_year = date.day_of_year
        declination = solar_declination(day_of_year)
        h = hour_angle(lon, date)
        cos_zenith_angle = cos_theta_z(lat, declination, h)
        
        # Insolation formula
        insolation = Sc * (1 + 0.033 * np.cos(np.radians(360 * day_of_year / 365))) * cos_zenith_angle
        
        # Make sure insolation is non-negative
        insolation = max(insolation, 0)
        insolation_values.append(insolation)
    
    # Create pandas Series
    insolation_series = pd.Series(insolation_values, index=dates, name='Solar Insolation (kW/m²)')
    insolation_series = insolation_series / 1000  # Convert to kW/m²
    
    return insolation_series

In [14]:
dates = pd.date_range('1980-01-01', '2020-12-31', freq='D')
dates = dates[~((dates.month == 2) & (dates.day == 29))]

def process(idx, row):
    huc, gauge_id = row['huc_02'], row.name
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)

    data = pd.DataFrame(columns = nodes_coords.index, index = dates)
    for node_idx, node_row in nodes_coords.iterrows():
        lat, lon = node_row['lat'], node_row['lon']
        ds_window_loc = solar_insolation(lat, lon, '1980-01-01', '2020-12-31')
        data.loc[:, node_idx] = ds_window_loc.values

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'solar_insolation.csv'))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

100%|██████████| 191/191 [09:17<00:00,  2.92s/it]


## Time Encodings

In [15]:
def sine_time_encoding(start_date, end_date):
    # (a) Create a date_range and remove leap days
    dates = pd.date_range(start=start_date, end=end_date, freq='D')
    dates = dates[~((dates.month == 2) & (dates.day == 29))]  # Remove February 29 (leap days)
    
    # (b) Create a dataframe with 'month', 'weekofyear', 'dayofyear' columns
    df = pd.DataFrame(index=dates)
    df['month'] = df.index.month
    df['weekofyear'] = df.index.isocalendar().week
    df['dayofyear'] = df.index.dayofyear
    
    # (c) Define lambda transformations for sine encoding
    # For day of year (range 1-365), week of year (range 1-52), and month (range 1-12)
    sine_transform = lambda x, max_val: np.sin(2 * np.pi * x / max_val)
    
    # (d) Apply sine transformation and add transformed columns
    df['sine_month'] = df['month'].apply(sine_transform, max_val=12)
    df['sine_weekofyear'] = df['weekofyear'].apply(sine_transform, max_val=52)
    df['sine_dayofyear'] = df['dayofyear'].apply(sine_transform, max_val=365)
    
    # return df[['sine_month', 'sine_weekofyear', 'sine_dayofyear']]
    return df

In [16]:
df_encoded = sine_time_encoding('1980-01-01', '2020-12-31')

def process(idx, row):
    huc, gauge_id = row['huc_02'], row.name

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    df_encoded.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'time_encodings.csv'))

# with Parallel(n_jobs = 8, verbose = 0) as parallel:
    # _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)):
    process(idx, row)

100%|██████████| 191/191 [00:18<00:00, 10.18it/s]


## Terrain Attributes

In [17]:
from shapely.geometry import Polygon
import rioxarray

def coords_to_polygon(lon, lat, resolution):
    half_res = resolution / 2
    return Polygon([
        (round(lon - half_res,3), round(lat - half_res,3)),
        (round(lon - half_res,3), round(lat + half_res,3)),
        (round(lon + half_res,3), round(lat + half_res,3)),
        (round(lon + half_res,3), round(lat - half_res,3))
    ])
def tile_filename_to_coords(filename):
    # format: n/s{dd}e/w{ddd}_elv.tif
    # n/e: positive, s/w: negative
    n_s, lat, e_w, lon = filename[0], int(filename[1:3]), filename[3], int(filename[4:7])
    lat = lat if n_s == 'n' else -lat
    lon = lon if e_w == 'e' else -lon
    return (lon, lat)

In [None]:
import itertools
var_names = ['elv', 'slope_percentage', 'slope_riserun', 'slope_degrees', 'slope_radians', 'aspect', 'curvature', 'planform_curvature', 'profile_curvature', 'upa', 'wth']
# valid_tiles = ['n30w150', 'n30w120', 'n30w090']

issues = []
for var_name in itertools.islice(var_names,0,None,1):
    print(var_name)
    tiles_paths = sorted(glob.glob(os.path.join(PATHS['MERIT-Hydro'], var_name, '**', '*.tif'), recursive=True))
    # tiles_paths = [tile for tile in tiles_paths if os.path.basename(os.path.dirname(tile)).split('_')[-1] in valid_tiles]
    tiles_filenames = [os.path.basename(tile) for tile in tiles_paths]
    tiles_names = [tile.split('_')[0] for tile in tiles_filenames]
    tiles_lower_left_corner = [tile_filename_to_coords(tile) for tile in tiles_filenames]
    tiles_polygons = [Polygon([(lon, lat), (lon + 5, lat), (lon + 5, lat + 5), (lon, lat + 5)]) for lon, lat in tiles_lower_left_corner]

    def process(idx, row):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(SAVE_PATH, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = ['mean', 'std', '25%', '50%', '75%'])
        cell_polygons = [coords_to_polygon(row['lon'], row['lat'], resolution) for _, row in nodes_coords.iterrows()]
        catmt_polygon = cell_polygons[0]
        for polygon in cell_polygons[1:]:
            catmt_polygon = catmt_polygon.union(polygon)
        intersected_tiles = []
        for tile_polygon, tile_path in zip(tiles_polygons, tiles_paths):
            if tile_polygon.intersects(catmt_polygon):
                intersected_tiles.append(tile_path)
        ds = rioxarray.open_rasterio(intersected_tiles[0])
        for tile in intersected_tiles[1:]:
            ds = ds.combine_first(rioxarray.open_rasterio(tile))
        ds = ds.sel(band=1)
        # Sort the x and y coordinates to be ascending
        ds = ds.sortby('x', ascending=True)
        ds = ds.sortby('y', ascending=True)
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            # ds_node = ds.rio.clip_box(lon - resolution/2, lat - resolution/2, lon + resolution/2, lat + resolution/2)
            ds_node = ds.sel(x = slice(lon - resolution/2, lon + resolution/2), y = slice(lat - resolution/2, lat + resolution/2))
            ds_node = ds_node.where(ds_node != ds.rio.nodata)
            ds_node_values = ds_node.values.flatten()
            mean = np.nanmean(ds_node_values)
            std = np.nanstd(ds_node_values)
            q25 = np.nanquantile(ds_node_values, 0.25)
            q50 = np.nanquantile(ds_node_values, 0.50)
            q75 = np.nanquantile(ds_node_values, 0.75)
            data.loc['mean', node_idx] = mean
            data.loc['std', node_idx] = std
            data.loc['25%', node_idx] = q25
            data.loc['50%', node_idx] = q50
            data.loc['75%', node_idx] = q75
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static'), exist_ok = True)
        os.makedirs(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'MERIT-Hydro'), exist_ok = True)
        data.to_csv(os.path.join(SAVE_PATH, "graph_features", huc, gauge_id, 'static', 'MERIT-Hydro', f"{var_name}.csv"))

        ds.close()
        del ds
        gc.collect()

    for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)):
        try:
            process(idx, row)
        except Exception as e:
            issues.append(f"{var_name}-{row['huc_02']}-{row.name}")
            print(f"Error: {var_name}-{row['huc_02']}-{row.name}. {e}")

elv


100%|██████████| 191/191 [04:01<00:00,  1.26s/it]


In [25]:
len(issues)

0

In [20]:
issues_df = [entry.split('-') for entry in issues]
issues_df = pd.DataFrame(issues_df, columns = ['var_name', 'huc_02', 'gauge_id'])
issues_df

Unnamed: 0,var_name,huc_02,gauge_id


In [21]:
issues_df[issues_df['var_name'] == 'elv']

Unnamed: 0,var_name,huc_02,gauge_id


## Spatial Encodings

In [22]:
def process(idx, row):
    # lon: -180 to 180; lat: -60 to 90
    lon_transform = lambda x: np.sin(2 * np.pi * (x+180) / 360)
    lat_transform = lambda x: (x - (-60))/(90 - (-60))

    huc, gauge_id = row['huc_02'], row.name
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)

    data = pd.DataFrame(columns = nodes_coords.index, index = ['lon_transformed', 'lat_transformed'])
    for node_idx, node_row in nodes_coords.iterrows():
        lat, lon = node_row['lat'], node_row['lon']
        data.loc['lon_transformed', node_idx] = lon_transform(lon)
        data.loc['lat_transformed', node_idx] = lat_transform(lat)

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'spatial_encodings.csv'))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

100%|██████████| 191/191 [00:00<00:00, 299.81it/s]


## uparea

In [23]:
uparea = xr.open_dataset(os.path.join(PATHS['gis_ldd'], 'GloFAS_03min/upstream_area_km2.nc'))
ds_varname = list(uparea.data_vars)[0]
uparea = uparea[ds_varname]
uparea = uparea.sel(
    lat = slice(region_bounds['maxy'], region_bounds['miny']), 
    lon = slice(region_bounds['minx'], region_bounds['maxx'])
)
uparea.load()

def process(idx, row):
    huc, gauge_id = row['huc_02'], row.name
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)

    data = pd.DataFrame(columns = nodes_coords.index, index = [0])
    for node_idx, node_row in nodes_coords.iterrows():
        lat, lon = node_row['lat'], node_row['lon']
        data.loc[0, node_idx] = uparea.sel(lat = lat, lon = lon, method = 'nearest').values.item()

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'uparea.csv'))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

uparea.close()
del uparea
gc.collect()

100%|██████████| 191/191 [00:01<00:00, 143.28it/s]


60

## IndiaWRIS

In [37]:
def process(idx, row):
    huc, gauge_id = row['huc_02'], row.name
    uparea = row['ghi_area']

    streamflow = pd.read_csv(os.path.join(PATHS['CAMELS'], 'CAMELS-IND', 'CAMELS_IND_Catchments_Streamflow_Sufficient', 'streamflow_timeseries', 'streamflow_observed.csv'))
    streamflow['date'] = pd.to_datetime(streamflow[['year', 'month', 'day']])
    streamflow = streamflow.set_index('date')
    streamflow = streamflow.drop(columns=['year', 'month', 'day'])
    streamflow = streamflow[~((streamflow.index.month == 2) & (streamflow.index.day == 29))]
    streamflow.columns = streamflow.columns.astype(str)
    streamflow.columns = streamflow.columns.str.zfill(5)
    streamflow = streamflow[[gauge_id]]
    streamflow.columns = ['Q_m3s']
    streamflow['Q_mm'] = (streamflow['Q_m3s'] / (uparea * 1e6)) * (3600*24*1000)

    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id), exist_ok = True)
    streamflow.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, f'IndiaWRIS.csv'))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

100%|██████████| 191/191 [00:06<00:00, 28.11it/s]


## GloFAS Parameter Maps

In [26]:
# "Catchment_morphology_and_river_network" (14 surface fields)
# - chanbnkf_Global_03min.nc (channel bankfull depth, m);
# - chanflpn_Global_03min.nc (width of the floodplain, m);
# - changrad_Global_03min.nc (channel longitudinal gradient, m/m);
# - chanlength_Global_03min.nc (channel length within a pixel, m);
# - chanman_Global_03min.nc (channel Manning's roughness coefficient, m^(1/3)s^(-1));
# - chans_Global_03min.nc (channel side slope, m/m);
# - chanbw_Global_03min.nc (channel bottom width, m):

# "Land_use" (7 surface fields)
# - fracforest_Global_03min.nc (fraction of forest for each grid-cell, -);
# - fracirrigated_Global_03min.nc (fraction of irrigated crops [except rice] for each grid-cell, -);
# - fracrice_Global_03min.nc (fraction of rice crops for each grid-cell, -);
# - fracsealed_Global_03min.nc (fraction of urban area for each grid-cell, -);
# - fracwater_Global_03min.nc (fraction of inland water for each grid-cell, -);
# - fracother_Global_03min.nc (fraction of other land cover for each grid-cell, -);
Parameter_Maps = os.path.join(PATHS['GloFAS'], 'LISFLOOD_Parameter_Maps')

var_names = ['chanbnkf', 'chanflpn', 'changrad', 'chanlength', 'chanman', 'chans', 'chanbw']
for var_name in var_names:
    print(var_name)
    ds = xr.open_dataset(os.path.join(Parameter_Maps, 'Catchments_morphology_and_river_network', f"{var_name}_Global_03min.nc"))['Band1']
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    ds.load()

    def process(idx, row):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = [0])
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds.sel(lat = lat, lon = lon, method = 'nearest')
            data.loc[0, node_idx] = ds_window_loc.values.item()
        os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS'), exist_ok = True)
        data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS', f"{var_name}.csv"))

    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

    ds.close()
    del ds
    gc.collect()

var_names = ['fracforest', 'fracirrigated', 'fracrice', 'fracsealed', 'fracwater', 'fracother']
for var_name in var_names:
    print(var_name)
    ds = xr.open_dataset(os.path.join(Parameter_Maps, 'Land_use', f"{var_name}_Global_03min.nc"))['Band1']
    ds = ds.sel(
        lat = slice(region_bounds['maxy'], region_bounds['miny']), 
        lon = slice(region_bounds['minx'], region_bounds['maxx'])
    )
    ds.load()

    def process(idx, row):
        huc, gauge_id = row['huc_02'], row.name
        nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
        data = pd.DataFrame(columns = nodes_coords.index, index = [0])
        for node_idx, node_row in nodes_coords.iterrows():
            lat, lon = node_row['lat'], node_row['lon']
            ds_window_loc = ds.sel(lat = lat, lon = lon, method = 'nearest')
            data.loc[0, node_idx] = ds_window_loc.values.item()
        os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS'), exist_ok = True)
        data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS', f"{var_name}.csv"))

    with Parallel(n_jobs = 8, verbose = 0) as parallel:
        _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

chanbnkf


100%|██████████| 191/191 [00:01<00:00, 146.69it/s]


chanflpn


100%|██████████| 191/191 [00:00<00:00, 360.25it/s]


changrad


100%|██████████| 191/191 [00:00<00:00, 356.51it/s]


chanlength


100%|██████████| 191/191 [00:00<00:00, 353.10it/s]


chanman


100%|██████████| 191/191 [00:00<00:00, 359.40it/s]


chans


100%|██████████| 191/191 [00:00<00:00, 200.18it/s]


chanbw


100%|██████████| 191/191 [00:00<00:00, 361.42it/s]


fracforest


100%|██████████| 191/191 [00:00<00:00, 359.86it/s]


fracirrigated


100%|██████████| 191/191 [00:00<00:00, 354.52it/s]


fracrice


100%|██████████| 191/191 [00:00<00:00, 349.19it/s]


fracsealed


100%|██████████| 191/191 [00:00<00:00, 351.04it/s]


fracwater


100%|██████████| 191/191 [00:00<00:00, 350.31it/s]


fracother


100%|██████████| 191/191 [00:00<00:00, 357.28it/s]


## Cell Area

In [27]:
Parameter_Maps = os.path.join(PATHS['GloFAS'], 'LISFLOOD_Parameter_Maps')
ds = xr.open_dataset(os.path.join(Parameter_Maps, 'Main', 'pixarea_Global_03min.nc'))['Band1'] / 1e6
ds = ds.sel(
    lat = slice(region_bounds['maxy'], region_bounds['miny']), 
    lon = slice(region_bounds['minx'], region_bounds['maxx'])
)
ds.load()
var_name = 'cellarea_km2'
def process(idx, row):
    huc, gauge_id = row['huc_02'], row.name
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    data = pd.DataFrame(columns = nodes_coords.index, index = [0])
    for node_idx, node_row in nodes_coords.iterrows():
        lat, lon = node_row['lat'], node_row['lon']
        ds_window_loc = ds.sel(lat = lat, lon = lon, method = 'nearest')
        data.loc[0, node_idx] = ds_window_loc.values.item()
    os.makedirs(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS'), exist_ok = True)
    data.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS', f"{var_name}.csv"))

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

ds.close()
del ds
gc.collect()

100%|██████████| 191/191 [00:00<00:00, 336.03it/s]


52

## GloFAS Discharge in mm

In [None]:
def process(idx, row):
    huc, gauge_id = row['huc_02'], row.name

    glofas_filepath = os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'dynamic', 'GloFAS', 'discharge.csv')
    glofas_data = pd.read_csv(glofas_filepath, index_col = 0, parse_dates = True)
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    uparea = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'uparea.csv'), index_col = 0)
    glofas_Q_mm = glofas_data.copy()
    for node_idx, node_row in nodes_coords.iterrows():
        uparea_node = uparea.loc[0, str(node_idx)] * 1e6
        glofas_Q_mm[str(node_idx)] = (glofas_data[str(node_idx)] / uparea_node) * (3600*24*1000)
    glofas_Q_mm.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'dynamic', 'GloFAS', 'discharge_mm.csv'), index = True)

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

100%|██████████| 191/191 [01:22<00:00,  2.31it/s]


## GloFAS Runoff in m3s and mm

In [42]:
def process(idx, row):
    huc, gauge_id = row['huc_02'], row.name

    glofas_filepath = os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'dynamic', 'GloFAS', 'runoff_water_equivalent.csv')
    glofas_data = pd.read_csv(glofas_filepath, index_col = 0, parse_dates = True)
    nodes_coords = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_files', huc, gauge_id, 'nodes_coords.csv'), index_col = 0)
    cellarea = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'static', 'GloFAS', f"cellarea_km2.csv"), index_col = 0)
    uparea = pd.read_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'uparea.csv'), index_col = 0)
    glofas_Q_m3s = glofas_data.copy()
    glofas_Q_mm = glofas_data.copy()
    for node_idx, node_row in nodes_coords.iterrows():
        cellarea_node = cellarea.loc[0, str(node_idx)]
        uparea_node = cellarea.loc[0, str(node_idx)]
        glofas_Q_m3s[str(node_idx)] = (glofas_data[str(node_idx)] * cellarea_node) / 86.4
        glofas_Q_mm[str(node_idx)] = (glofas_data[str(node_idx)] * cellarea_node) / uparea_node
    glofas_Q_m3s.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'dynamic', 'GloFAS', 'runoff_water_equivalent_m3s.csv'), index = True)
    glofas_Q_mm.to_csv(os.path.join(PATHS['devp_datasets'], DIRNAME, 'graph_features', huc, gauge_id, 'dynamic', 'GloFAS', 'runoff_water_equivalent_mm.csv'), index = True)

with Parallel(n_jobs = 8, verbose = 0) as parallel:
    _ = parallel(delayed(process)(idx, row) for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)))

100%|██████████| 191/191 [03:11<00:00,  1.00s/it]
