In [1]:
#%% Setting Up
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import networkx as nx
import rioxarray as rxr

import geopandas as gpd
import seaborn as sns
import matplotlib.pyplot as plt

from shapely.geometry import Point
from shapely.geometry import Polygon

import glob
import os
import itertools
import tqdm
import gc
import time
import pickle

from joblib import Parallel, delayed

import rioxarray as rxr

import configparser
cfg = configparser.ConfigParser()
cfg.optionxform = str
cfg.read('/home/sarth/rootdir/datadir/assets/defaults.ini')
cfg = {s: dict(cfg.items(s)) for s in cfg.sections()}
PATHS = cfg['PATHS']

print("Setting up...")

Setting up...


In [2]:
#%% Region-Specific: CAMELS-IND
DIRNAME = '03min_GloFAS_CAMELS-IND'
SAVE_PATH = os.path.join(PATHS['devp_datasets'], DIRNAME)
resolution = 0.05
lon_360_180 = lambda x: (x + 180) % 360 - 180 # convert 0-360 to -180-180
lon_180_360 = lambda x: x % 360 # convert -180-180 to 0-360
region_bounds = {
    'minx': 66,
    'miny': 5,
    'maxx': 100,
    'maxy': 30
}

camels_graph = pd.read_csv(os.path.join(SAVE_PATH, 'nested_gauges', 'graph_attributes_with_nesting.csv'), index_col=0)
camels_graph.index = camels_graph.index.map(lambda x: str(x).zfill(5))
camels_graph['huc_02'] = camels_graph['huc_02'].map(lambda x: str(x).zfill(2))
# camels_graph = camels_graph[camels_graph['nesting'].isin(['not_nested', 'nested_downstream'])]
camels_graph = camels_graph.reset_index()
print(f"Number of catmt's with nesting: {len(camels_graph)}")
camels_graph

Number of catmt's with nesting: 191


Unnamed: 0,gauge_id,huc_02,gauge_lon,gauge_lat,ghi_area,cwc_lon,cwc_lat,cwc_area,cwc_site_name,ghi_stn_id,...,cwc_river,flow_availability,snapped_lon,snapped_lat,snapped_uparea,snapped_iou,area_percent_difference,num_nodes,num_edges,nesting
0,14015,14,73.11033,18.73540,125.7,73.1108,18.7367,125.0,Pen,wfrn_penxx,...,Bhogeswari,31.36,73.125,18.725,116.744995,0.650870,7.124106,4.0,3.0,not_nested
1,15006,15,74.88124,13.51876,299.6,74.8800,13.5214,253.0,Avershe,wfrs_avers,...,Seetha,39.00,74.925,13.475,329.444850,0.607418,9.961565,11.0,10.0,not_nested
2,05025,05,78.05617,11.93959,356.0,78.0572,11.9383,362.0,Thoppur,cauv_thopp,...,Cauvery/Thoppaiyar,31.52,78.125,11.975,331.356480,0.634196,6.922338,11.0,10.0,not_nested
3,15032,15,74.98123,13.29791,356.9,74.9806,13.2942,327.0,Yennehole,wfrs_yenne,...,Swarna,70.67,74.975,13.275,329.717100,0.755086,7.616389,11.0,10.0,not_nested
4,15007,15,76.84792,8.71458,555.8,76.8500,8.7150,540.0,Ayilam,wfrs_ayila,...,Vamanapuram,93.18,76.875,8.725,578.080800,0.773097,4.008784,19.0,18.0,not_nested
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
186,04038,04,75.13955,17.97291,22326.1,75.1392,17.9722,22856.0,Narsingpur,kris_narsi,...,Krishna/Bhima,95.30,75.125,17.975,22149.584000,0.953356,0.790624,758.0,757.0,nested_downstream
187,04057,04,76.96458,15.66043,23527.1,76.9628,15.6611,23500.0,T. Ramapuram (Seasonal),kris_trama,...,Krishna/Tungabhadra/Hagari,84.89,76.975,15.625,23463.662000,0.941489,0.269636,786.0,785.0,nested_downstream
188,03057,03,80.68958,19.29779,24041.7,80.6892,19.2989,24210.0,Mirdapalli,goda_mirda,...,Godavari/Pranhita/Indravati,56.17,80.675,19.275,23877.793000,0.940053,0.681758,821.0,820.0,nested_downstream
189,12004,12,79.01458,23.03126,26297.6,79.0156,23.0308,26453.0,Barmanghat,narm_barma,...,Narmada,91.85,79.025,23.025,26175.932000,0.943601,0.462658,921.0,920.0,nested_downstream


In [3]:
def idx_to_map(ds, var_name):
    lats = ds.lat.values
    lons = ds.lon.values
    catmt_var_map = xr.DataArray(
        np.zeros((len(lats), len(lons)), dtype = np.float32)*np.nan,
        dims = ['lat', 'lon'],
        coords = {'lat': lats, 'lon': lons}
    )
    for idx in ds.idx.values:
        lat, lon = ds['idx2lat'].sel(idx = idx).values, ds['idx2lon'].sel(idx = idx).values
        catmt_var_map.loc[dict(lat = lat, lon = lon)] = ds[var_name].sel(idx = idx).values
    return catmt_var_map

START_DATE = pd.Timestamp('1998-01-01')
END_DATE = pd.Timestamp('2022-12-31')

# Helper Functions

In [4]:
def get_fao_pet(
    surface_pressure_mean: pd.Series,
    temperature_2m_mean: pd.Series,
    dewpoint_temperature_2m_mean: pd.Series,
    u_component_of_wind_10m_mean: pd.Series,
    v_component_of_wind_10m_mean: pd.Series,
    surface_net_solar_radiation_mean: pd.Series,
    surface_net_thermal_radiation_mean: pd.Series,
) -> pd.Series:
    # windspeed2m_m_s
    temp_windspeed10m_m_s = np.sqrt(u_component_of_wind_10m_mean**2 + v_component_of_wind_10m_mean**2)
    windspeed2m_m_s = temp_windspeed10m_m_s * 4.87 / (np.log(67.8 * 10 - 5.42))

    # net_radiation_MJ_m2
    net_radiation_mj_m2 = ((surface_net_solar_radiation_mean + surface_net_thermal_radiation_mean) * 24 / 1e6)

    # Convert units
    surface_pressure_kpa = surface_pressure_mean / 1e3  # convert Pa to kPa
    temperature2m_c = temperature_2m_mean - 273.15  # convert K to C
    dewpoint2m_c = dewpoint_temperature_2m_mean - 273.15  # convert K to C

    # Constants.
    lmbda = 2.45  # Latent heat of vaporization [MJ kg -1] (simplification in the FAO PenMon (latent heat of about 20°C)
    cp = 1.013e-3  # Specific heat at constant pressure [MJ kg-1 °C-1]
    eps = 0.622  # Ratio molecular weight of water vapour/dry air

    # Soil heat flux density [MJ m-2 day-1] - set to 0 following eq 42 in FAO
    soil_heat_flux = np.zeros_like(surface_pressure_kpa)

    # Atmospheric pressure [kPa] eq 7 in FAO.
    P_kpa = surface_pressure_kpa

    # Psychrometric constant (gamma symbol in FAO) eq 8 in FAO.
    psychometric_kpa_c = cp * P_kpa / (eps * lmbda)

    # Saturation vapour pressure, eq 11 in FAO.
    svp_kpa = 0.6108 * np.exp((17.27 * temperature2m_c) / (temperature2m_c + 237.3))

    # Delta (slope of saturation vapour pressure curve) eq 13 in FAO.
    delta_kpa_c = 4098.0 * svp_kpa / (temperature2m_c + 237.3)**2

    # Actual vapour pressure, eq 14 in FAO.
    avp_kpa = 0.6108 * np.exp((17.27 * dewpoint2m_c) / (dewpoint2m_c + 237.3))

    # Saturation vapour pressure deficit.
    svpdeficit_kpa = svp_kpa - avp_kpa

    # Calculate ET0, equation 6 in FAO
    numerator = (0.408 * delta_kpa_c * (net_radiation_mj_m2 - soil_heat_flux) + psychometric_kpa_c *
                 (900 / (temperature2m_c + 273)) * windspeed2m_m_s * svpdeficit_kpa)
    denominator = delta_kpa_c + psychometric_kpa_c * (1 + 0.34 * windspeed2m_m_s)

    ET0_mm_day = numerator / denominator

    return ET0_mm_day

In [5]:
from typing import Dict, List
from numba import njit

@njit
def _split_list(a_list: List) -> List:
    """Splits list into list of lists, where each list contains subsequent numbers."""
    new_list = []
    start = 0
    for index, value in enumerate(a_list):
        if index < len(a_list) - 1:
            if a_list[index + 1] > value + 1:
                end = index + 1
                new_list.append(a_list[start:end])
                start = end
        else:
            new_list.append(a_list[start:len(a_list)])
    return new_list

def _get_moisture_and_seasonality_index(precipitation, pet) -> tuple[float, float]:

    mean_monthly_precip = precipitation.groupby(precipitation.index.month).mean()
    mean_monthly_pet = pet.groupby(pet.index.month).mean()

    # Average annual moisture index (see Knoben)
    p_gt_et = 1 - mean_monthly_pet.loc[mean_monthly_precip > mean_monthly_pet] / mean_monthly_precip.loc[
        mean_monthly_precip > mean_monthly_pet]
    srs = pd.Series(np.zeros((12), dtype=np.float32), index=mean_monthly_pet.index, name='dummy')
    p_eq_et = srs.loc[mean_monthly_precip == mean_monthly_pet]
    p_lt_et = mean_monthly_precip.loc[mean_monthly_precip < mean_monthly_pet] / mean_monthly_pet.loc[
        mean_monthly_precip < mean_monthly_pet] - 1
    monthly_moisture_index = pd.concat([p_gt_et, p_eq_et, p_lt_et])

    annual_moisture_index = monthly_moisture_index.mean()

    # Seasonality (see Knoben)
    seasonality = monthly_moisture_index.max() - monthly_moisture_index.min()

    return annual_moisture_index, seasonality

def get_climate_indices(df: pd.DataFrame) -> Dict[str, float]:
    # Convert units
    df['temperature_2m_mean'] = df['temperature_2m_mean'] - 273.15  # Convert K to C

    # Mean daily precip
    p_mean = df["total_precipitation_sum"].mean()
    # Mean daily PET
    pet_mean_era5 = df["potential_evaporation_sum_ERA5_LAND"].mean()
    pet_mean_fao = df["potential_evaporation_sum_FAO_PENMAN_MONTEITH"].mean()

    # Aridity index
    aridity_era5 = pet_mean_era5 / p_mean
    aridity_fao = pet_mean_fao / p_mean

    # Compute moistuer and seasonality index once with ERA5 PET and once with FAO PM PET
    annual_moisture_index_era5, seasonality_era5 = _get_moisture_and_seasonality_index(
        precipitation=df["total_precipitation_sum"], pet=df["potential_evaporation_sum_ERA5_LAND"])
    annual_moisture_index_fao, seasonality_fao = _get_moisture_and_seasonality_index(
        precipitation=df["total_precipitation_sum"], pet=df["potential_evaporation_sum_FAO_PENMAN_MONTEITH"])
    
    # Fraction of mean monthly precipipitation falling as snow (see Knoben)
    mean_monthly_precip = df["total_precipitation_sum"].groupby(df.index.month).mean()
    mean_monthly_temp = df["temperature_2m_mean"].groupby(df.index.month).mean()
    frac_snow = mean_monthly_precip.loc[mean_monthly_temp < 0].sum() / mean_monthly_precip.sum()

    high_prec_freq = len(df.loc[df["total_precipitation_sum"] >= 5 * p_mean]) / len(df)
    low_prec_freq = len(df.loc[df["total_precipitation_sum"] < 1]) / len(df)

    precip = df["total_precipitation_sum"].values
    idx = np.where(precip < 1)[0]
    groups = _split_list(idx)
    if groups:
        low_precip_dur = np.mean(np.array([len(p) for p in groups]))
    else:
        low_precip_dur = 0.0

    idx = np.where(precip >= 5 * p_mean)[0]
    groups = _split_list(idx)
    if groups:
        high_prec_dur = np.mean(np.array([len(p) for p in groups]))
    else:
        high_prec_dur = 0.0

    # climate_indices = {
    #     'p_mean': p_mean,
    #     'pet_mean_ERA5_LAND': pet_mean_era5,
    #     'pet_mean_FAO_PM': pet_mean_fao,
    #     'aridity_ERA5_LAND': aridity_era5,
    #     'aridity_FAO_PM': aridity_fao,
    #     'frac_snow': frac_snow,
    #     'moisture_index_ERA5_LAND': annual_moisture_index_era5,
    #     'seasonality_ERA5_LAND': seasonality_era5,
    #     'moisture_index_FAO_PM': annual_moisture_index_fao,
    #     'seasonality_FAO_PM': seasonality_fao,
    #     'high_prec_freq': high_prec_freq,
    #     'high_prec_dur': high_prec_dur,
    #     'low_prec_freq': low_prec_freq,
    #     'low_prec_dur': low_precip_dur
    # }
    climate_indices_dict = {
        'p_mean': p_mean,
        'pet_mean_ERA5': pet_mean_era5,
        'pet_mean_FAO_PM': pet_mean_fao,
        'aridity_ERA5': aridity_era5,
        'aridity_FAO_PM': aridity_fao,
        'frac_snow': frac_snow,
        'moisture_index_ERA5': annual_moisture_index_era5,
        'seasonality_ERA5': seasonality_era5,
        'moisture_index_FAO_PM': annual_moisture_index_fao,
        'seasonality_FAO_PM': seasonality_fao,
        'high_prec_freq': high_prec_freq,
        'high_prec_dur': high_prec_dur,
        'low_prec_freq': low_prec_freq,
        'low_prec_dur': low_precip_dur
    }

    return climate_indices_dict

# Execution

In [6]:
def process_catmt(huc, gauge_id):
    catmt = xr.open_zarr(os.path.join(SAVE_PATH, 'inventory', huc, f'{gauge_id}.zarr'))
    catmt = catmt[sorted(catmt.data_vars)]

    START_DATE = pd.Timestamp('1999-01-01')
    END_DATE = pd.Timestamp('2009-12-31')
    catmt_slice = catmt.sel(time=slice(START_DATE, END_DATE)).copy()

    ERA5_var_names = [
        'dynamic_ERA5_total_precipitation',
        'dynamic_ERA5_potential_evaporation',
        'dynamic_ERA5_2m_temperature',
        'dynamic_ERA5_surface_pressure',
        'dynamic_ERA5_2m_dewpoint_temperature',
        'dynamic_ERA5_10m_u_component_of_wind',
        'dynamic_ERA5_10m_v_component_of_wind',
        'dynamic_ERA5_surface_net_solar_radiation',
        'dynamic_ERA5_surface_net_thermal_radiation'
    ]

    catmt_slice = catmt_slice[ERA5_var_names]
    for varname in ['dynamic_ERA5_surface_net_solar_radiation', 'dynamic_ERA5_surface_net_thermal_radiation']:
        catmt_slice[varname] = catmt_slice[varname] / 24

    rename_columns={
        'dynamic_ERA5_surface_pressure': 'surface_pressure_mean',
        'dynamic_ERA5_2m_temperature': 'temperature_2m_mean',
        'dynamic_ERA5_2m_dewpoint_temperature': 'dewpoint_temperature_2m_mean',
        'dynamic_ERA5_10m_u_component_of_wind': 'u_component_of_wind_10m_mean',
        'dynamic_ERA5_10m_v_component_of_wind': 'v_component_of_wind_10m_mean',
        'dynamic_ERA5_surface_net_solar_radiation': 'surface_net_solar_radiation_mean',
        'dynamic_ERA5_surface_net_thermal_radiation': 'surface_net_thermal_radiation_mean',
        'dynamic_ERA5_total_precipitation': 'total_precipitation_sum',
        'dynamic_ERA5_potential_evaporation': 'potential_evaporation_sum_ERA5_LAND',
    }
    catmt_slice = catmt_slice.rename(rename_columns)
    catmt_slice = catmt_slice[sorted(catmt_slice.data_vars)]

    catmt_slice['potential_evaporation_sum_ERA5_LAND'] = catmt_slice['potential_evaporation_sum_ERA5_LAND'] * 1e3 * -1
    catmt_slice['total_precipitation_sum'] = catmt_slice['total_precipitation_sum'] * 1e3  # Convert from m to mm

    def process_catmt_idx(idx):
        catmt_slice_idx = catmt_slice.sel(idx=idx).drop_vars(['idx']).to_dataframe()
        ET0 = get_fao_pet(
            surface_pressure_mean=catmt_slice_idx['surface_pressure_mean'],
            temperature_2m_mean=catmt_slice_idx['temperature_2m_mean'],
            dewpoint_temperature_2m_mean=catmt_slice_idx['dewpoint_temperature_2m_mean'],
            u_component_of_wind_10m_mean=catmt_slice_idx['u_component_of_wind_10m_mean'],
            v_component_of_wind_10m_mean=catmt_slice_idx['v_component_of_wind_10m_mean'],
            surface_net_solar_radiation_mean=catmt_slice_idx['surface_net_solar_radiation_mean'],
            surface_net_thermal_radiation_mean=catmt_slice_idx['surface_net_thermal_radiation_mean']
        )
        catmt_slice_idx['potential_evaporation_sum_FAO_PENMAN_MONTEITH'] = ET0
        climate_indices_idx = get_climate_indices(catmt_slice_idx)
        return pd.Series(climate_indices_idx, name=idx)

    with Parallel(n_jobs=32, verbose=0) as parallel:
        climate_indices = parallel(delayed(process_catmt_idx)(idx) for idx in catmt_slice.idx.values)

    climate_indices = pd.DataFrame(climate_indices).T

    for climidx in climate_indices.index:
        catmt[f"static_ClimSumm_{climidx}"] = xr.DataArray(
            climate_indices.loc[climidx].values,
            dims=['idx'],
            coords={'idx': catmt.idx.values}
        )
    
    del climate_indices
    gc.collect()

    catmt.to_zarr(os.path.join(SAVE_PATH, 'inventory', huc, f'{gauge_id}.zarr'), mode = 'a')

In [None]:
# 0, 25, 1
# 25, 50, 1 | stopped at 35/191
# Resume with lower threads | 35, 50, 1
for idx, row in itertools.islice(tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)), 0, None, 1):
    huc, gauge_id = row['huc_02'], row['gauge_id']
    process_catmt(huc, gauge_id)

100%|██████████| 191/191 [35:25<00:00, 11.13s/it]


In [8]:
row = camels_graph.iloc[0]
huc, gauge_id = row['huc_02'], row['gauge_id']
catmt = xr.open_zarr(os.path.join(SAVE_PATH, 'inventory', huc, f'{gauge_id}.zarr'))
data_vars = list(sorted(catmt.data_vars))
# print(f"Length of idx vars: {len(catmt.idx.values)}")
catmt.close()
print(f"Length of data_vars: {len(data_vars)}")
data_vars

Length of data_vars: 145


['dynamic_ERA5-Land_dewpoint_temperature_2m_max',
 'dynamic_ERA5-Land_dewpoint_temperature_2m_min',
 'dynamic_ERA5-Land_leaf_area_index_high_vegetation',
 'dynamic_ERA5-Land_leaf_area_index_low_vegetation',
 'dynamic_ERA5-Land_potential_evaporation_sum',
 'dynamic_ERA5-Land_runoff_sum',
 'dynamic_ERA5-Land_snow_cover',
 'dynamic_ERA5-Land_snow_depth',
 'dynamic_ERA5-Land_snowfall_sum',
 'dynamic_ERA5-Land_snowmelt_sum',
 'dynamic_ERA5-Land_sub_surface_runoff_sum',
 'dynamic_ERA5-Land_surface_net_solar_radiation_sum',
 'dynamic_ERA5-Land_surface_net_thermal_radiation_sum',
 'dynamic_ERA5-Land_surface_pressure',
 'dynamic_ERA5-Land_surface_runoff_sum',
 'dynamic_ERA5-Land_temperature_2m_max',
 'dynamic_ERA5-Land_temperature_2m_min',
 'dynamic_ERA5-Land_total_evaporation_sum',
 'dynamic_ERA5-Land_total_precipitation_sum',
 'dynamic_ERA5-Land_u_component_of_wind_10m',
 'dynamic_ERA5-Land_v_component_of_wind_10m',
 'dynamic_ERA5-Land_volumetric_soil_water_layer_1',
 'dynamic_ERA5-Land_volum

In [8]:
131+14

145

# Target Related Variables

In [6]:
def process_related_vars(huc, gauge_id):
    catmt = xr.open_zarr(os.path.join(SAVE_PATH, 'lumped_inventory', huc, f'{gauge_id}.zarr'))
    catmt = catmt[sorted(catmt.data_vars)]

    START_DATE = pd.Timestamp('1999-01-01')
    END_DATE = pd.Timestamp('2019-12-31')
    catmt_slice = catmt.sel(time=slice(START_DATE, END_DATE)).copy()

    ERA5_var_names = [
        'dynamic_ERA5_total_precipitation',
        'dynamic_ERA5_potential_evaporation',
        'dynamic_ERA5_2m_temperature',
        'dynamic_ERA5_surface_pressure',
        'dynamic_ERA5_2m_dewpoint_temperature',
        'dynamic_ERA5_10m_u_component_of_wind',
        'dynamic_ERA5_10m_v_component_of_wind',
        'dynamic_ERA5_surface_net_solar_radiation',
        'dynamic_ERA5_surface_net_thermal_radiation'
    ]

    catmt_slice = catmt_slice[ERA5_var_names]
    for varname in ['dynamic_ERA5_surface_net_solar_radiation', 'dynamic_ERA5_surface_net_thermal_radiation']:
        catmt_slice[varname] = catmt_slice[varname] / 24

    rename_columns={
        'dynamic_ERA5_surface_pressure': 'surface_pressure_mean',
        'dynamic_ERA5_2m_temperature': 'temperature_2m_mean',
        'dynamic_ERA5_2m_dewpoint_temperature': 'dewpoint_temperature_2m_mean',
        'dynamic_ERA5_10m_u_component_of_wind': 'u_component_of_wind_10m_mean',
        'dynamic_ERA5_10m_v_component_of_wind': 'v_component_of_wind_10m_mean',
        'dynamic_ERA5_surface_net_solar_radiation': 'surface_net_solar_radiation_mean',
        'dynamic_ERA5_surface_net_thermal_radiation': 'surface_net_thermal_radiation_mean',
        'dynamic_ERA5_total_precipitation': 'total_precipitation_sum',
        'dynamic_ERA5_potential_evaporation': 'potential_evaporation_sum_ERA5_LAND',
    }
    catmt_slice = catmt_slice.rename(rename_columns)
    catmt_slice = catmt_slice[sorted(catmt_slice.data_vars)]

    catmt_slice['potential_evaporation_sum_ERA5_LAND'] = catmt_slice['potential_evaporation_sum_ERA5_LAND'] * 1e3 * -1
    catmt_slice['total_precipitation_sum'] = catmt_slice['total_precipitation_sum'] * 1e3  # Convert from m to mm

    PET = get_fao_pet(
        surface_pressure_mean=catmt_slice['surface_pressure_mean'],
        temperature_2m_mean=catmt_slice['temperature_2m_mean'],
        dewpoint_temperature_2m_mean=catmt_slice['dewpoint_temperature_2m_mean'],
        u_component_of_wind_10m_mean=catmt_slice['u_component_of_wind_10m_mean'],
        v_component_of_wind_10m_mean=catmt_slice['v_component_of_wind_10m_mean'],
        surface_net_solar_radiation_mean=catmt_slice['surface_net_solar_radiation_mean'],
        surface_net_thermal_radiation_mean=catmt_slice['surface_net_thermal_radiation_mean']
    )

    Prcp = catmt_slice['total_precipitation_sum']

    return PET.values, Prcp.values

PET_list = []
Prcp_list = []
for idx, row in tqdm.tqdm(camels_graph.iterrows(), total=len(camels_graph)):
    huc, gauge_id = row['huc_02'], row['gauge_id']
    PET, Prcp = process_related_vars(huc, gauge_id)
    PET_list.append(PET)
    Prcp_list.append(Prcp)

PET_array = np.array(PET_list)
Prcp_array = np.array(Prcp_list)

100%|██████████| 191/191 [01:37<00:00,  1.96it/s]


In [7]:
import torch

PET_tensor = torch.tensor(PET_array, dtype=torch.float32)
Prcp_tensor = torch.tensor(Prcp_array, dtype=torch.float32)
print(PET_tensor.shape, Prcp_tensor.shape)

torch.Size([191, 7665]) torch.Size([191, 7665])


In [8]:
timestamps = torch.load('/home/sarth/rootdir/datadir/data/datasets/batched_catchments/CAMELS-IND_HUCAll_lumped/timestamps.pt')
timestamps = pd.DataFrame(index=timestamps)
timestamps.index.name = 'timestamp'
timestamps

1998-01-02
1998-01-03
1998-01-04
1998-01-05
1998-01-06
...
2020-01-06
2020-01-07
2020-01-08
2020-01-09
2020-01-10


In [9]:
start_idx = timestamps.index.get_loc(pd.Timestamp('1999-01-01'))
end_idx = timestamps.index.get_loc(pd.Timestamp('2019-12-31'))
start_idx, end_idx

(364, 8028)

In [10]:
# Create tensor same length as timestamps and fill the values from start_idx to end_idx
PET = torch.zeros((PET_tensor.shape[0], len(timestamps)), dtype=torch.float32)
PET[:, start_idx:end_idx+1] = PET_tensor
Prcp = torch.zeros((Prcp_tensor.shape[0], len(timestamps)), dtype=torch.float32)
Prcp[:, start_idx:end_idx+1] = Prcp_tensor
print(PET.shape, Prcp.shape)

torch.Size([191, 8039]) torch.Size([191, 8039])


In [11]:
SAVE_PATH = os.path.join(PATHS['datasets'], 'batched_catchments', 'CAMELS-IND_HUCAll_lumped')
torch.save(PET, os.path.join(SAVE_PATH, 'PET.pt'))
torch.save(Prcp, os.path.join(SAVE_PATH, 'Prcp.pt'))