This notebook creates gauge groups for different types of experiments for a given provider. Gauge groups are stored as text files and run commands are generated for each gauge group.

The experiments currently supported are:
* Gauged
* Ungauged
* Random k-fold cross-validation
* Hydrologically-separated cross-validation

# Imports

In [1]:
import json

In [2]:
import experiment_definition as utils
import uruguay as provider

# Gauge Groups

## Global Gauge Group

In [3]:
global_gauge_group = utils.get_global_model_training_gauge_group_without_provider_country(
    gauge_group_file_path=provider.GLOBAL_GAUGE_GROUP_FILE)
print(f'There are {len(global_gauge_group)} gauges in the filtered global gauge group.')

There are 15277 gauges in the filtered global gauge group.


## Provider Gauge Group

In [4]:
print(f'There are {len(provider.TRAIN_GAUGES)} gauges in the provider train gauge group.')
print(f'There are {len(provider.TEST_GAUGES)} gauges in the provider test gauge group.')

There are 19 gauges in the provider train gauge group.
There are 19 gauges in the provider test gauge group.


# Experiment Gauge Groups

In [5]:
experiment_train_gauge_groups = {}
experiment_test_gauge_groups = {}

## Fully Gauged Run

In [6]:
experiment = 'gauged'
experiment_train_gauge_groups[experiment], experiment_test_gauge_groups[experiment] = utils.create_gauged_run_gauge_groups(
    global_gauges=global_gauge_group,
    train_gauges=provider.TRAIN_GAUGES,
    test_gauges=provider.TEST_GAUGES,
)
print(f'There are {len(experiment_train_gauge_groups[experiment])} gauges in the "{experiment}" train gauge group.')

There are 15296 gauges in the "gauged" train gauge group.


## Full Ungauged

In [7]:
experiment = 'ungauged'
experiment_train_gauge_groups[experiment], experiment_test_gauge_groups[experiment] = utils.create_ungauged_run_gauge_groups(
    global_gauges=global_gauge_group,
    test_gauges=provider.TEST_GAUGES,
)
print(f'There are {len(experiment_train_gauge_groups[experiment])} gauges in the "{experiment}" train gauge group.')

There are 15277 gauges in the "ungauged" train gauge group.


## Random Cross Validation

In [8]:
experiment = 'random_cross_validation'
experiment_train_gauge_groups[experiment], experiment_test_gauge_groups[experiment] = utils.create_random_cross_validation_gauge_groups(
    global_gauges=global_gauge_group,
    train_gauges=provider.TRAIN_GAUGES,
    num_cross_validation_folds=provider.NUM_CROSS_VALIDATION_FOLDS,
)
for fold in experiment_train_gauge_groups[experiment]:
  print(f'There are {len(experiment_train_gauge_groups[experiment][fold])} gauges in the "{experiment}_fold_{fold}" train gauge group.')

There are 15292 gauges in the "random_cross_validation_fold_fold_0" train gauge group.
There are 15293 gauges in the "random_cross_validation_fold_fold_1" train gauge group.
There are 15293 gauges in the "random_cross_validation_fold_fold_2" train gauge group.
There are 15293 gauges in the "random_cross_validation_fold_fold_3" train gauge group.
There are 15293 gauges in the "random_cross_validation_fold_fold_4" train gauge group.
There are 15293 gauges in the "random_cross_validation_fold_fold_5" train gauge group.


## Hydrologically Separated Cross Validation

In [9]:
experiment = 'hydrologically_separated_cross_validation'
experiment_train_gauge_groups[experiment], experiment_test_gauge_groups[experiment] = utils.create_hydrography_separated_cross_validation_gauge_groups(
    global_gauges=global_gauge_group,
    train_gauges=provider.TRAIN_GAUGES,
    graph_crossval_split_to_gauge_mapping=provider.GRAPH_CROSSVAL_SPLIT_TO_GAUGE_MAPPING
)
for fold in experiment_train_gauge_groups[experiment]:
  print(f'There are {len(experiment_train_gauge_groups[experiment][fold])} gauges in the "{experiment}_{fold}" train gauge group.')

There are 15294 gauges in the "hydrologically_separated_cross_validation_fold_52" train gauge group.
There are 15294 gauges in the "hydrologically_separated_cross_validation_fold_46_1" train gauge group.
There are 15292 gauges in the "hydrologically_separated_cross_validation_fold_44" train gauge group.
There are 15294 gauges in the "hydrologically_separated_cross_validation_fold_28" train gauge group.
There are 15293 gauges in the "hydrologically_separated_cross_validation_fold_2215" train gauge group.
There are 15294 gauges in the "hydrologically_separated_cross_validation_fold_14" train gauge group.
There are 15295 gauges in the "hydrologically_separated_cross_validation_fold_71_0" train gauge group.
There are 15295 gauges in the "hydrologically_separated_cross_validation_fold_97" train gauge group.
There are 15295 gauges in the "hydrologically_separated_cross_validation_fold_10_1" train gauge group.
There are 15295 gauges in the "hydrologically_separated_cross_validation_fold_176_1

# Save Gauge Groups

In [10]:
# Please don't overwrite the gauge groups that we actually used for the experiments!!!!
OVERWRITE_EXISTING_GAUGE_GROUPS = False

# if os.path.exists(provider.CNS_GAUGE_GROUP_PATH) and OVERWRITE_EXISTING_GAUGE_GROUPS:
#   gfile.DeleteRecursively(provider.CNS_GAUGE_GROUP_PATH)

# if not os.path.exists(provider.CNS_GAUGE_GROUP_PATH):
#   gfile.MakeDirs(provider.CNS_GAUGE_GROUP_PATH)

#   # Save train and test gauge groups.
#   utils.save_all_experiment_gauge_groups_as_text_files(
#       experiment_train_gauge_groups=experiment_train_gauge_groups,
#       provider=provider.PREFIX,
#       base_path=provider.CNS_GAUGE_GROUP_PATH,
#   )
#   utils.save_test_gauge_group_as_text_file(
#       test_gauges=provider.TEST_GAUGES,
#       provider=provider.PREFIX,
#       base_path=provider.CNS_GAUGE_GROUP_PATH,
#   )

# Cross Validation Mapping

In [11]:
# Create a mapping from gauges to model runs, so that we can later pull the
gauge_to_model_path = utils.create_test_gauge_to_model_path_mapping(
    all_test_gauges=provider.TEST_GAUGES,
    experiment_test_gauge_groups=experiment_test_gauge_groups,
    base_path=provider.MODEL_RUN_DIRECTORY,
)

# Save mapping.
with open(provider.GAUGE_TO_MODEL_PATH_MAPPING_FILE, 'w') as f:
  json.dump(gauge_to_model_path, f)

# Run Commands
This creates run commands for launching the model training and testing pipeline.

This is only relevant internally at Google.

In [12]:
template_run_command = 'blaze run --run_under="cd $PWD &&" -c \
      opt intelligence/flood_forecasting/hydro_model/training_pipeline:xmanager_launch \
      -- --xm_resource_pool=research-training \
      --xm_resource_alloc=research-training/karmel-tpu \
      --setup_name=lstm_mean_embedding_ungauged_compatible \
      --keep_split_margin=True \
      --run_inference_on_training_dataset=False  \
      --xmanager_config=global_model \
      --xm_enable_build_isolation \
      --gfs_user=flood-forecasting-dev \
      --flags="target_gauge_group={train_gauge_group}, ungauged_gauge_group={test_gauge_group}" \
      --output_path={output_path}'

def create_run_command(
    experiment: str,
    split: str | None = None,
) -> str:
  train_gauge_group = utils.get_gauge_group_name(
      experiment=experiment,
      provider=provider.PREFIX,
      split=split,
  )
  train_gauge_group = f'wmo_pilot_only_{train_gauge_group}'
  test_gauge_group = utils.get_gauge_group_name(
      experiment=utils.TEST_GAUGE_GROUP_NAME,
      provider=provider.PREFIX,
  )
  test_gauge_group = f'wmo_pilot_only_{test_gauge_group}'
  output_path = utils.model_path_for_gauge_group(
      experiment=experiment,
      base_model_run_directory=provider.MODEL_RUN_DIRECTORY,
      split=split
  )
  return template_run_command.format(
      train_gauge_group=train_gauge_group,
      test_gauge_group=test_gauge_group,
      output_path=output_path,
  )

In [13]:
for experiment, gauges_or_split in experiment_train_gauge_groups.items():
  if experiment == 'leave_one_out_cross_validation':
    continue
  if type(gauges_or_split) == dict:
    for split, gauges in gauges_or_split.items():
      run_command = create_run_command(
          experiment=experiment,
          split=split,
      )
      print("'" + run_command + "'")
  else:
    run_command = create_run_command(experiment=experiment)
    print("'" + run_command + "'")

'blaze run --run_under="cd $PWD &&" -c       opt intelligence/flood_forecasting/hydro_model/training_pipeline:xmanager_launch       -- --xm_resource_pool=research-training       --xm_resource_alloc=research-training/karmel-tpu       --setup_name=lstm_mean_embedding_ungauged_compatible       --keep_split_margin=True       --run_inference_on_training_dataset=False        --xmanager_config=global_model       --xm_enable_build_isolation       --gfs_user=flood-forecasting-dev       --flags="target_gauge_group=wmo_pilot_only_DNAUY_gauged, ungauged_gauge_group=wmo_pilot_only_DNAUY_all_gauges"       --output_path=/home/gsnearing/data/model_runs/gauged'
'blaze run --run_under="cd $PWD &&" -c       opt intelligence/flood_forecasting/hydro_model/training_pipeline:xmanager_launch       -- --xm_resource_pool=research-training       --xm_resource_alloc=research-training/karmel-tpu       --setup_name=lstm_mean_embedding_ungauged_compatible       --keep_split_margin=True       --run_inference_on_train