In [6]:
import xarray as xr
import glob
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

In [3]:
# little helper function
def xr_opener(file):
    ds = xr.open_dataset(file)
    if 'index' in ds.coords:
        ds = ds.rename({'index': 'date'}) 
    return ds

#### Read all files

In [4]:
geo_folder = '../geo_data/great_db'

In [7]:
files = glob.glob(f'{geo_folder}/nc_concat/*.nc')
big_file = xr.concat([xr_opener(file) for file in files],
                     dim='gauge_id')
big_file


##### Select by gauge

In [8]:
gauges = ['5746', '3001', '1001']
ds = big_file.sel(gauge_id=gauges)
# sel is possible on date coordinate as well
ds

#### Select by other variables and split data

In [9]:
predictors = ['t_max_e5l', 't_max_e5', 't_min_e5l', 't_min_e5', 'prcp_e5l',
             'prcp_e5', 'prcp_gpcp', 'prcp_imerg', 'prcp_mswep', 'Eb', 'Es',
             'Et', 'SMsurf', 'SMroot', 'Ew', 'Ei', 'S', 'E', 'Ep']
target = ['lvl_sm', 'q_cms_s', 'lvl_mbs', 'q_mm_day']

future_window = 7
past_window = 365

train_start = '01/01/2008'
train_end = '12/31/2015'

val_start = '01/01/2016'
val_end = '12/31/2018'

test_start = '01/01/2019'
test_end = '12/31/2020'

In [10]:
ds.sel(date=slice(train_start, train_end))

In [13]:
ds.sel(date=slice(train_start, train_end),
       gauge_id='5746')[target]

In [None]:
def split_ds(ds: xr.Dataset) -> dict:
    past_predictors = list()
    future_targets = list()

    for i in range(past_window, len(ds['date'])-future_window):
        
        t = ds['date'].isel(date=i).values
        t = pd.to_datetime(t).strftime('%Y-%m-%d')
        
        past_predictors.append(
            (ds.to_array().values[:, :, i-past_window:i]))
        
        future_targets.append(
            (ds[target].to_array().values[:, :, i:i+future_window]))

    past_predictors = np.array(past_predictors)
    future_targets = np.array(future_targets)
    
    data = {'past_seq': torch.tensor(past_predictors, dtype=torch.float),
            'future_seq': torch.tensor(future_targets, dtype=torch.float)}
    
    return data


In [None]:
# train
train_past, train_future = split_ds(train_ds)
# validation
val_past, val_future = split_ds(val_ds)
# test
test_past, test_future = split_ds(test_ds)