# Create Cross Validation Configuration Files

This notebook creates files containing per-basin time splits for a set of temporal cross-validation model runs. Each cross validaiton split uses ony calendar year (Jan 1 through Dec 31) for testing. The remaining years of data are used for training, with the exception that one year of data is withheld from both training and testing on either side of the test year. This is because the model uses 365 days of data for spinup which makes it necessary to use a 1-year buffer between the train and test periods so that no input data is used for both training and testing. The dates in the train/test splits listed below refer to dates of target (streamflow) data, not dates of input data.

Data are available from 1980-2023, inclusive, however ECMWF HRES data and GraphCast data are only available from 2016-2023. These are the two main forecast data sources, so any time period outside of this range will not accurately represent model performance in a real-world setting. We can train on data prior to 2016, but we will not test on data prior to 2016.

The train/test cross validation splits are:
1) Test: 2016; Train: 1980-2014 & 2018-2023
2) Test: 2017; Train: 1980-2015 & 2019-2023
3) Test: 2018; Train: 1980-2016 & 2020-2023
4) Test: 2019; Train: 1980-2017 & 2021-2023
5) Test: 2020; Train: 1980-2018 & 2022-2023
6) Test: 2021; Train: 1980-2019 & 2023
7) Test: 2022; Train: 1980-2020
8) Test: 2023; Train: 1980-2021

## Notebook Setup

In [None]:
from datetime import datetime
import itertools
import os
import pandas as pd
from pathlib import Path
import re
import shutil
from tqdm.notebook import tqdm

In [2]:
from googlehydrology.utils.config import Config
from googlehydrology.datautils.utils import load_basin_file
from googlehydrology.datasetzoo.caravan import load_caravan_attributes

## User-Defined Variables

* `BASE_CONFIG_PATH`: Path to a config file with all model, data, and training settings except basin lists and time periods.
* `EXPERIMENT_NAME`: Name of experiment to use for config files, etc.
* `BASIN_LIST_DIRECTORY`: Directory where basin lists for this cross-validation experiment are stored. These basin list files must be generated separately.
* `CONFIG_FILE_DIRECTORY`: Directory where the config files for this cross-validation experiment are written to by running this notebook. This is the output directory of this notebook and will be overwritten.
* `TEST_YEARS`: List of years to use as individual test periods (MultiMet is 2016 - 2023).
* `FIRST_DATE`: First date of valid data in the dataset (MultiMet is 1980)
* `LAST_DATE`: Last date of valid data in the dataset (MultiMet is 2023)

In [3]:
BASE_CONFIG_PATH = '/home/gsnearing/ecmwf/configs/template_config.yml'
EXPERIMENT_NAME = 'gauged-caravan'
CONFIG_FILE_DIRECTORY = f'/home/gsnearing/ecmwf/configs/{EXPERIMENT_NAME}/'
BASIN_LIST_DIRECTORY = f'/home/gsnearing/ecmwf/basin_lists/{EXPERIMENT_NAME}/'

TEST_YEARS = list(range(2016, 2024))
FIRST_DATE = '01/01/1980'
LAST_DATE = '31/12/2023'

In [4]:
# Turn the user-supplied path strings into pathlib.Path objects.
base_config_path = Path(BASE_CONFIG_PATH)
basin_list_dir = Path(BASIN_LIST_DIRECTORY)
config_file_dir = Path(CONFIG_FILE_DIRECTORY)

## Construct Time Splits
Constructs time splits where each year is withheld for testing exactly once.

In [5]:
date_format = '%d/%m/%Y'
splits = {
    year: 
    {
        'train': [[FIRST_DATE, f'31/12/{year-2}'], [f'01/01/{year+2}', LAST_DATE]],
        'test':  [[f'01/01/{year}', f'31/12/{year}']],
    } for year in TEST_YEARS
}

def _compare_dates_gt(datestr1: str, datestr2: str) -> bool:
    return pd.to_datetime(datestr1, dayfirst=True) > pd.to_datetime(datestr2, dayfirst=True)

# Remove periods that are outside of the date range.
def _filter_date_tuples(data: dict) -> dict:
    """
    Filters a nested dictionary to remove date tuples where the first date is later than the second.
    """
    filtered_data = {}
    for outer_key, inner_dict in data.items():
        filtered_inner_dict = {}
        for inner_key, date_tuples in inner_dict.items():
            # Use a list comprehension to keep only valid tuples
            filtered_tuples = [
                (d1, d2) for d1, d2 in date_tuples if not _compare_dates_gt(d1, d2)
            ]
            filtered_inner_dict[inner_key] = filtered_tuples
        filtered_data[outer_key] = filtered_inner_dict
    return filtered_data               

In [6]:
time_splits = _filter_date_tuples(splits)
time_splits

{2016: {'train': [('01/01/1980', '31/12/2014'), ('01/01/2018', '31/12/2023')],
  'test': [('01/01/2016', '31/12/2016')]},
 2017: {'train': [('01/01/1980', '31/12/2015'), ('01/01/2019', '31/12/2023')],
  'test': [('01/01/2017', '31/12/2017')]},
 2018: {'train': [('01/01/1980', '31/12/2016'), ('01/01/2020', '31/12/2023')],
  'test': [('01/01/2018', '31/12/2018')]},
 2019: {'train': [('01/01/1980', '31/12/2017'), ('01/01/2021', '31/12/2023')],
  'test': [('01/01/2019', '31/12/2019')]},
 2020: {'train': [('01/01/1980', '31/12/2018'), ('01/01/2022', '31/12/2023')],
  'test': [('01/01/2020', '31/12/2020')]},
 2021: {'train': [('01/01/1980', '31/12/2019'), ('01/01/2023', '31/12/2023')],
  'test': [('01/01/2021', '31/12/2021')]},
 2022: {'train': [('01/01/1980', '31/12/2020')],
  'test': [('01/01/2022', '31/12/2022')]},
 2023: {'train': [('01/01/1980', '31/12/2021')],
  'test': [('01/01/2023', '31/12/2023')]}}

## Locate Basin List Files
Creates a data structure containing the train/test file pairs in a given basin list directory.

In [7]:
def _organize_basin_files(basin_list_dir: Path) -> dict:
    """
    Searches a directory for 'prefix_train.txt' and 'prefix_test.txt' files,
    and organizes them into a nested dictionary.

    Args:
        basin_list_dir: A pathlib.Path object representing the directory to search.

    Returns:
        A nested dictionary with the organized file paths.
    """
    organized_files = {}
    
    # Updated regex to capture the prefix and type (train/test)
    pattern = re.compile(r'(?:(.+)_)?(train|test)\.txt')

    for file_path in basin_list_dir.iterdir():
        if file_path.is_file() and file_path.suffix == '.txt':
            match = pattern.match(file_path.name)
            if match:
                prefix = match.group(1)     # e.g., 'basin_a', 'another_basin'
                file_type = match.group(2)  # 'train' or 'test'

                if prefix not in organized_files:
                    organized_files[prefix] = {}
                
                # Store the Path object
                organized_files[prefix][file_type] = file_path
                
    return organized_files

In [8]:
basin_lists = _organize_basin_files(basin_list_dir)
basin_lists

{None: {'test': PosixPath('/home/gsnearing/ecmwf/basin_lists/gauged-caravan/test.txt'),
  'train': PosixPath('/home/gsnearing/ecmwf/basin_lists/gauged-caravan/train.txt')}}

## Create Configs
Creates separate config files for each run in the cross-product of basin and time splits.

In [9]:
def delete_all_contents(directory_path: Path):
    """
    Deletes all files and subdirectories within the specified directory.
    """
    if not directory_path.is_dir():
        print(f"Error: {directory_path} is not a directory.")
        return

    for item in directory_path.iterdir():
        if item.is_dir():
            shutil.rmtree(item)
        else:
            item.unlink()

In [None]:
def create_cross_validation_config_files(
    base_config_path: Path, 
    output_dir: Path,
    basin_lists: dict[str, dict[str, Path]] | None = None,
    time_splits: dict[str, dict[str, list[str]]] | None = None
):
    """Create configs for spatiotemporal cross validation splits.

    Optionally supply a set of basins lists and/or a set of time splits, and this
    function will create config files that implement a cross product of the basin (space)
    and time splits. These config files are stored in `output_directory` and the
    `run_scheduler.py` script in this library can be used to run the whole set.
    
    Parameters
    ----------
    base_config_path : Path
        Path to a base config file (.yml)
    output_dir : Path 
        Path to a folder where the generated configs will be stored
    basin_splits: dict[str, dict[str, list[str]]] | None
    
    time_splits: dict[str, dict[str, list[str]]] | None
        
    """
    # since the output directory is intended to represent one cross-validation experiment, 
    # delete any configs that already exist.
    if output_dir.is_dir():
        delete_all_contents(output_dir)

    if not output_dir.is_dir():
        output_dir.mkdir(parents=True)

    if basin_lists is None and time_splits is None:
        raise ValueError('Must supply either a time split or a basin split.')

    if basin_lists is None:
        basin_lists = {'dummy': []}

    if time_splits is None:
        time_splits = {'dummy': []}
        
    # load base config as dictionary
    base_config = Config(base_config_path)

    # keep a list of all configs generated
    all_configs = []

    # iterate over each possible combination of basin list and time split
    for basin_split in basin_lists:
        
        # update basin list files
        if basin_split != 'dummy':
            
            update_config = basin_lists[basin_split]
            
            if 'train' in update_config:
                update_config['train_basin_file'] = update_config.pop('train')
            if 'validation' in update_config:            
                update_config['validation_basin_file'] = update_config.pop('validation')
            if 'test' in update_config:
                update_config['test_basin_file'] = update_config.pop('test')
            if any([bf not in update_config for bf in ['train_basin_file', 'test_basin_file']]):
                raise ValueError('Basin list inner dictionary must contain files for train and test. Validation is optional.')
            
            base_config.update_config(update_config)

        for time_split in time_splits:

            # update experiment name with split IDs
            split_name = f'basin_split_{basin_split}_time_split_{time_split}'
            name = f'{EXPERIMENT_NAME}_{split_name}'
            base_config.update_config({"experiment_name": name})

            # update period dates
            if time_split != 'dummy':
                
                update_config = time_splits[time_split]
                
                if 'train' in update_config:
                    update_config['train_start_date'] = [dates[0] for dates in update_config['train']]
                    update_config['train_end_date'] = [dates[1] for dates in update_config.pop('train')]
                if 'test' in update_config:
                    update_config['test_start_date'] = [dates[0] for dates in update_config['test']]
                    update_config['test_end_date'] = [dates[1] for dates in update_config.pop('test')]                 
                if any([tp not in update_config for tp in ['train_start_date', 'train_end_date', 'test_start_date', 'test_end_date']]):
                    raise ValueError('Time split inner dictionary must contain start and end dates for train and test. Validation is optional.')
      
                if 'validation' in update_config:            
                    update_config['validation_start_date'] = [dates[0] for dates in update_config['validation']]
                    update_config['validation_end_date'] = [dates[1] for dates in update_config.pop('validation')]
                elif 'validation_start_date' not in update_config:
                    update_config['validation_start_date'] = update_config['test_start_date']
                    update_config['validation_end_date'] = update_config['test_end_date']
                
                base_config.update_config(update_config)

            # write new config to output directory
            base_config.dump_config(output_dir, f"{name}.yml")
            all_configs.append(output_dir/f"{name}.yml")
        
    return all_configs

    print(f"Finished. Configs are stored in {output_dir}")

In [11]:
all_configs = create_cross_validation_config_files(
    base_config_path=base_config_path, 
    output_dir=config_file_dir,
    basin_lists=basin_lists,
    time_splits=time_splits
)
print(f'{len(all_configs)} configs were generated.')

8 configs were generated.


## Remove Zero-Variance Static Features
Static features that are constant across all basins in the train set will have 0-variance scalers and cause NaN's in training. It is therefore necessary to remove them.

In [12]:
# Preload all attributes.
cfg = Config(all_configs[0])
attributes_df = load_caravan_attributes(data_dir=Path(cfg.statics_data_dir))[cfg.static_attributes].to_dataframe()

In [13]:
# Find any attribute that is unsafe in any split.
unsafe_attributes = []
for config_path in tqdm(all_configs):
    cfg = Config(config_path)
    basins = load_basin_file(cfg.train_basin_file)
    basin_attributes_df = attributes_df.loc[basins]
    variances = basin_attributes_df.var()    
    unsafe_attributes.extend(variances[variances == 0].index.tolist())

  0%|          | 0/8 [00:00<?, ?it/s]

In [14]:
# Combine unsafe attributes and create list of safe attributes.
unsafe_attributes = set(unsafe_attributes)
safe_attributes = set(attributes_df.columns) - unsafe_attributes
print(f'There are {len(cfg.static_attributes)} attributes total.')
print(f'There are {len(safe_attributes)} attributes with non-zero variance.')
print('Attributes with at least one nonzero variance:', unsafe_attributes)

There are 86 attributes total.
There are 85 attributes with non-zero variance.
Attributes with at least one nonzero variance: {'glc_pc_s05'}


In [15]:
# Overwrite config files with new static attributes list.
for config_path in tqdm(all_configs):
    cfg = Config(config_path)
    config_path.unlink()
    cfg.update_config({'static_attributes': safe_attributes})
    cfg.dump_config(folder=config_path.parent, filename=config_path.name)

  0%|          | 0/8 [00:00<?, ?it/s]