# Bakaano-Hydro Full Workflow (Colab)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/confidence-duku/bakaano-hydro/blob/main/Bakaano-Hydro%20on%20Google%20Colab.ipynb)

This notebook runs the end-to-end workflow:
1. Input data preprocessing
2. Runoff computation and routing
3. Interactive exploration
4. Training and evaluation
5. Simulation/inference

Before running, upload your basin shapefile and (optionally) GRDC/CSV station data to Google Drive.

**Required user inputs before running:**
- `study_area`: path to your basin shapefile (`.shp`)
- choose one observed-data mode:
  - `grdc_netcdf` path, or
  - `csv_dir` + `lookup_csv` paths


## Set Colab Runtime to GPU

Before running install/training cells:
1. Go to **Runtime -> Change runtime type**.
2. Set **Hardware accelerator** to **GPU**.
3. Click **Save**.

Then run the next GPU check cell. If it reports no GPU, restart runtime and try again.


For this notebook, we use TensorFlow GPU, so the install cell removes `torch` packages from the current runtime before installing Bakaano.

If you need PyTorch later, use a separate Colab runtime/session for PyTorch workloads.


In [None]:
!pip -q uninstall -y torch torchvision torchaudio
!pip -q install "bakaano-hydro[gpu] @ git+https://github.com/confidence-duku/bakaano-hydro.git"
!pip -q install h5netcdf


In [None]:
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
print('GPUs:', gpus)
if not gpus:
    raise RuntimeError('No GPU detected. In Colab, set Runtime -> Change runtime type -> GPU.')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# User Inputs (edit this cell only)
# ----------------------------------------------------------------------------
# 1) Workspace and basin shapefile
folder_name = 'bakaano_workflow'
shapefile_name = 'your_basin.shp'  # must exist in MyDrive/<folder_name>/shapes/

# 2) Observed-data mode (choose one): 'GRDC' or 'CSV'
OBSERVED_DATA_MODE = 'GRDC'

# GRDC mode input (used when OBSERVED_DATA_MODE='GRDC')
grdc_filename = 'your_grdc_data.nc'  # file in MyDrive/<folder_name>/data/

# CSV mode inputs (used when OBSERVED_DATA_MODE='CSV')
csv_dir_name = 'csv_timeseries'      # folder in MyDrive/<folder_name>/data/
lookup_csv_name = 'station_lookup.csv'  # file in MyDrive/<folder_name>/data/

# 3) Global model settings
climate_data_source = 'ERA5'  # ERA5, CHIRPS, CHELSA
routing_method = 'mfd'        # mfd, d8, dinf

# 4) Run toggles
RUN_PREPROCESS = True
RUN_RUNOFF_ROUTING = True
RUN_TRAIN = True
RUN_EVAL = True
RUN_SIM_GRDC = True
RUN_SIM_POINTS = True

# 5) Dates (single shared window for beginner workflow)
# These are reused by preprocessing, runoff routing, interactive exploration, and AlphaEarth.
WORKFLOW_START_DATE = '2001-01-01'
WORKFLOW_END_DATE = '2010-12-31'

# Derived dates (kept explicit so downstream cells stay readable)
TRAIN_START_DATE = '2001-01-01'
TRAIN_END_DATE = '2020-12-31'

EVAL_START_DATE = '2001-01-01'
EVAL_END_DATE = '2020-12-31'




In [None]:
# Optional Advanced Settings (optional; beginners can skip editing this cell)
batch_size = 32
num_epochs = 300
learning_rate = 1e-3
loss_function = 'mse'
area_normalize = True
model_overwrite = True

LR_SCHEDULE = 'cosine'
WARMUP_EPOCHS = 5
MIN_LEARNING_RATE = 1e-5

CSV_ID_COL = 'id'
CSV_LAT_COL = 'latitude'
CSV_LON_COL = 'longitude'
CSV_DATE_COL = 'date'
CSV_DISCHARGE_COL = 'discharge'
CSV_FILE_PATTERN = '{id}.csv'


In [None]:
from pathlib import Path
import shutil

working_dir_drive = Path('/content/drive/MyDrive') / folder_name
working_dir_local = Path('/content') / folder_name
working_dir = working_dir_drive
study_area = working_dir_drive / 'shapes' / shapefile_name

if OBSERVED_DATA_MODE == 'GRDC':
    grdc_netcdf = working_dir_drive / 'data' / grdc_filename
    csv_dir = None
    lookup_csv = None
elif OBSERVED_DATA_MODE == 'CSV':
    grdc_netcdf = None
    csv_dir = working_dir_drive / 'data' / csv_dir_name
    lookup_csv = working_dir_drive / 'data' / lookup_csv_name
else:
    raise ValueError("OBSERVED_DATA_MODE must be 'GRDC' or 'CSV'.")

working_dir_drive.mkdir(parents=True, exist_ok=True)
(working_dir_drive / 'shapes').mkdir(parents=True, exist_ok=True)
(working_dir_drive / 'data').mkdir(parents=True, exist_ok=True)
working_dir_local.mkdir(parents=True, exist_ok=True)
(working_dir_local / 'shapes').mkdir(parents=True, exist_ok=True)
(working_dir_local / 'data').mkdir(parents=True, exist_ok=True)

print('working_dir (drive):', working_dir_drive)
print('working_dir (local):', working_dir_local)
print('study_area:', study_area)
print('mode:', OBSERVED_DATA_MODE)
print('grdc_netcdf:', grdc_netcdf)
print('csv_dir:', csv_dir)
print('lookup_csv:', lookup_csv)


## Configuration

Beginner run order:
1. Edit the **User Inputs** cell.
2. (Optional) edit **Optional Advanced Settings**.
3. Run setup/validation, then run sections top-to-bottom.

Observed-data mode:
- `OBSERVED_DATA_MODE = 'GRDC'`: set `grdc_filename`
- `OBSERVED_DATA_MODE = 'CSV'`: set `csv_dir_name` and `lookup_csv_name`


In [None]:
if not study_area.exists():
    raise FileNotFoundError(f'study_area not found: {study_area}')

grdc_mode = grdc_netcdf is not None
csv_mode = (csv_dir is not None) and (lookup_csv is not None)

if grdc_mode == csv_mode:
    raise ValueError(
        'Choose exactly one observed-data mode:\n'
        "  1) GRDC mode: set grdc_netcdf=Path(...), csv_dir=None, lookup_csv=None\n"
        "  2) CSV mode: set grdc_netcdf=None, csv_dir=Path(...), lookup_csv=Path(...)"
    )

grdc_netcdf_runtime = None

def open_grdc_with_fallback(nc_path):
    import xarray as xr
    errors = []
    for engine in [None, 'h5netcdf']:
        try:
            if engine is None:
                return xr.open_dataset(nc_path), 'netcdf4(default)'
            return xr.open_dataset(nc_path, engine=engine), engine
        except Exception as e:
            name = 'netcdf4(default)' if engine is None else engine
            errors.append(f'{name}: {e}')
    raise RuntimeError('Unable to open GRDC NetCDF with available backends:\n' + '\n'.join(errors))

if grdc_mode:
    if not grdc_netcdf.exists():
        raise FileNotFoundError(f'grdc_netcdf not found: {grdc_netcdf}')

    src_size = grdc_netcdf.stat().st_size
    if src_size <= 0:
        raise RuntimeError(f'GRDC file is empty: {grdc_netcdf}')

    usage = shutil.disk_usage('/content')
    if usage.free < src_size + 200 * 1024 * 1024:
        raise RuntimeError(
            'Not enough free space in /content to stage GRDC NetCDF. '
            f'Need at least {src_size + 200 * 1024 * 1024} bytes.'
        )

    grdc_netcdf_runtime = working_dir_local / 'grdc' / grdc_netcdf.name
    grdc_netcdf_runtime.parent.mkdir(parents=True, exist_ok=True)
    shutil.copy2(grdc_netcdf, grdc_netcdf_runtime)

    dst_size = grdc_netcdf_runtime.stat().st_size
    if dst_size != src_size:
        raise RuntimeError(
            'GRDC file copy size mismatch; copy may be incomplete. '
            f'source={src_size} bytes, copied={dst_size} bytes'
        )

    try:
        ds_chk, engine_used = open_grdc_with_fallback(grdc_netcdf_runtime)
        with ds_chk:
            _ = tuple(ds_chk.dims.keys())
        print('GRDC runtime copy ready:', grdc_netcdf_runtime)
        print('GRDC backend:', engine_used)
    except Exception as e:
        raise RuntimeError(
            f'Failed to open GRDC NetCDF after local copy: {grdc_netcdf_runtime}\n'
            'If this file opens on HPC but not on Colab, install h5netcdf and retry:\n'
            '  !pip -q install h5netcdf\n'
            f'Details: {e}'
        ) from e

if csv_mode:
    if not csv_dir.exists():
        raise FileNotFoundError(f'csv_dir not found: {csv_dir}')
    if not lookup_csv.exists():
        raise FileNotFoundError(f'lookup_csv not found: {lookup_csv}')

print('Input validation passed.')
print('Mode:', 'GRDC' if grdc_mode else 'CSV')


## 1) Download and preprocess input data

Can run independently: **Yes**, after setup/configuration cells are run.

Required user data:
- `study_area` shapefile in `working_dir_drive/shapes/`

Required configuration:
- `RUN_PREPROCESS=True`
- `climate_data_source` (`ERA5`, `CHIRPS`, or `CHELSA`)
- preprocessing dates in the code cell below

Recommended order used here:
1. DEM first (creates reference grid `elevation/dem_clipped.tif`)
2. Tree cover
3. NDVI
4. Soil
5. AlphaEarth
6. Meteorology

Dataset availability:
- Tree cover (MODIS VCF): 2001 onward
- NDVI (MODIS 16-day): 2001 onward
- AlphaEarth embeddings: 2017 onward

Expected outputs:
- `elevation/`, `vcf/`, `ndvi/`, `soil/`, `alpha_earth/`
- climate NetCDF files under `working_dir_drive/{climate_data_source}/`

Resume behavior:
- Reruns reuse existing outputs and process/download only missing pieces where supported.


In [None]:
import ee
ee.Authenticate(auth_mode="notebook")   # forces link + paste code flow
ee.Initialize()

In [None]:
if RUN_PREPROCESS:
    from bakaano.dem import DEM
    from bakaano.tree_cover import TreeCover
    from bakaano.ndvi import NDVI
    from bakaano.soil import Soil
    from bakaano.alpha_earth import AlphaEarth
    from bakaano.meteo import Meteo

    dd = DEM(
        working_dir=str(working_dir),
        study_area=str(study_area),
        local_data=False,
        local_data_path=None,
    )
    dd.get_dem_data()

    vf = TreeCover(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date=WORKFLOW_START_DATE,
        end_date=WORKFLOW_END_DATE,
    )
    vf.get_tree_cover_data()

    nd = NDVI(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date=WORKFLOW_START_DATE,
        end_date=WORKFLOW_END_DATE,
    )
    nd.get_ndvi_data()

    sgd = Soil(
        working_dir=str(working_dir),
        study_area=str(study_area),
    )
    sgd.get_soil_data()

    ae = AlphaEarth(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date=WORKFLOW_START_DATE,
        end_date=WORKFLOW_END_DATE,
    )
    ae.get_alpha_earth()

    cd = Meteo(
        working_dir=str(working_dir_local),
        study_area=str(study_area),
        start_date=WORKFLOW_START_DATE,
        end_date=WORKFLOW_END_DATE,
        local_data=False,
        data_source=climate_data_source,
    )
    cd.get_meteo_data()
    src_climate = working_dir_local / climate_data_source
    dst_climate = working_dir_drive / climate_data_source
    if src_climate.exists():
        shutil.copytree(src_climate, dst_climate, dirs_exist_ok=True)
        print(f'Synced climate outputs to Drive: {dst_climate}')

    print('Preprocessing complete.')
else:
    print('Skipping preprocessing step.')


### Optional quick plots

Run any of these in a new code cell if needed:
- `vf.plot_tree_cover(variable='tree_cover')`
- `nd.plot_ndvi(interval_num=10)`
- `dd.plot_dem()`
- `sgd.plot_soil(variable='wilting_point')`
- `ae.plot_alpha_earth('A35')`
- `cd.plot_meteo(variable='tasmin', date='2006-12-01')`


## 2) Compute runoff and route to river network

Can run independently: **Yes**, if preprocessing outputs already exist in Drive.

Required user data:
- basin shapefile in `working_dir_drive/shapes/`

Required existing preprocessed inputs in Drive:
- `elevation/`, `soil/`, `ndvi/`, `vcf/`, `alpha_earth/`, and `{climate_data_source}/`

Required configuration:
- `RUN_RUNOFF_ROUTING=True`
- `routing_method` (`mfd`, `d8`, or `dinf`)
- runoff dates in the code cell below

Execution behavior:
- Inputs are staged to local runtime (`/content/...`) for speed
- VegET runs locally
- `runoff_output/` and `catchment/` are synced back to Drive

Expected outputs:
- `runoff_output/wacc_sparse_arrays.pkl`
- routed runoff outputs under `runoff_output/`

Resume behavior:
- If interrupted, rerun this section; checkpoint files allow continuation.


In [None]:

if RUN_RUNOFF_ROUTING:
    from bakaano.veget import VegET

    for folder in ['shapes', 'elevation', 'soil', 'ndvi', 'vcf', 'alpha_earth', climate_data_source]:
        src = working_dir_drive / folder
        dst = working_dir_local / folder
        if src.exists():
            shutil.copytree(src, dst, dirs_exist_ok=True)
    local_study_area = working_dir_local / 'shapes' / shapefile_name

    vg = VegET(
        working_dir=str(working_dir_local),
        study_area=str(local_study_area),
        start_date=WORKFLOW_START_DATE,
        end_date=WORKFLOW_END_DATE,
        climate_data_source=climate_data_source,
        routing_method=routing_method,
    )
    vg.compute_veget_runoff_route_flow()
    for folder in ['runoff_output', 'catchment']:
        src = working_dir_local / folder
        dst = working_dir_drive / folder
        if src.exists():
            shutil.copytree(src, dst, dirs_exist_ok=True)
    print('Runoff and routing complete.')
else:
    print('Skipping runoff/routing step.')


In [None]:
from bakaano.plot_runoff import RoutedRunoff

rr = RoutedRunoff(
    working_dir=str(working_dir),
    study_area=str(study_area),
)

rr.map_routed_runoff(date='2003-07-07', vmax=7)


## 3) Interactive exploration

Can run independently: **Yes**, if required raster/runoff outputs already exist.

Required user data:
- `study_area` shapefile
- optional GRDC NetCDF (for station overlays)

Required existing outputs:
- `elevation/dem_clipped.tif`
- `soil/`, `vcf/`, `elevation/slope_clipped.tif`
- runoff/catchment outputs from Section 2

Required configuration:
- `grdc_netcdf_runtime` available for GRDC station overlays
- map date ranges in the code cell below

Typical checks:
- inspect DEM/tree cover/soil/river network layers
- verify station coverage and missingness
- confirm runoff outputs look reasonable before training


In [None]:
from IPython.display import display
from bakaano.runner import BakaanoHydro

bk = BakaanoHydro(
    working_dir=str(working_dir),
    study_area=str(study_area),
    climate_data_source=climate_data_source,
)

if grdc_netcdf_runtime is not None and grdc_netcdf_runtime.exists():
    explore_map = bk.explore_data_interactively(
        start_date=WORKFLOW_START_DATE,
        end_date=WORKFLOW_END_DATE,
        grdc_netcdf=str(grdc_netcdf_runtime),
    )
    display(explore_map)
else:
    print('Skipping explore_data_interactively (no GRDC NetCDF path set).')


### Routed runoff timeseries

This section is fully interactive.
It lists available station IDs in your study area and prompts you to enter one.


In [None]:
if grdc_netcdf_runtime is not None and grdc_netcdf_runtime.exists():
    rr.interactive_plot_routed_runoff_timeseries(
        start_date=WORKFLOW_START_DATE,
        end_date=WORKFLOW_END_DATE,
        grdc_netcdf=str(grdc_netcdf_runtime),
    )
elif lookup_csv is not None and lookup_csv.exists():
    rr.interactive_plot_routed_runoff_timeseries(
        start_date=WORKFLOW_START_DATE,
        end_date=WORKFLOW_END_DATE,
        lookup_csv=str(lookup_csv),
        id_col=CSV_ID_COL,
        lat_col=CSV_LAT_COL,
        lon_col=CSV_LON_COL,
    )
else:
    print('Set grdc_netcdf or lookup_csv before running interactive routed runoff timeseries.')


## 4) Train model

Can run independently: **Yes**, if runoff/catchment outputs and observed data are ready.

Required user data (choose one mode):
- GRDC mode: `grdc_netcdf` file
- CSV mode: `csv_dir` + `lookup_csv`

Required existing outputs:
- Section 2 outputs in `runoff_output/` and `catchment/`
- predictor rasters from preprocessing

Required configuration:
- `RUN_TRAIN=True`
- `batch_size`, `num_epochs`, `learning_rate`, `loss_function`
- `routing_method`, `area_normalize`, train date range
- `model_overwrite`

Training control:
- `model_overwrite=True`: train fresh model
- `model_overwrite=False`: continue from existing model if present

Expected output:
- `models/bakaano_model.keras`

GRDC backend note:
- Notebook stages GRDC NetCDF to local runtime and uses backend fallback (`netcdf4`/`h5netcdf`).


In [None]:
if RUN_TRAIN:
    if grdc_netcdf_runtime is not None and grdc_netcdf_runtime.exists():
        bk.train_streamflow_model(
            train_start=TRAIN_START_DATE,
            train_end=TRAIN_END_DATE,
            grdc_netcdf=str(grdc_netcdf_runtime),
            batch_size=batch_size,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            loss_function=loss_function,
            lr_schedule=LR_SCHEDULE,
            warmup_epochs=WARMUP_EPOCHS,
            min_learning_rate=MIN_LEARNING_RATE,
            routing_method=routing_method,
            area_normalize=area_normalize,
            model_overwrite=model_overwrite,
        )
    else:
        bk.train_streamflow_model(
            train_start=TRAIN_START_DATE,
            train_end=TRAIN_END_DATE,
            grdc_netcdf=None,
            batch_size=batch_size,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            loss_function=loss_function,
            lr_schedule=LR_SCHEDULE,
            warmup_epochs=WARMUP_EPOCHS,
            min_learning_rate=MIN_LEARNING_RATE,
            routing_method=routing_method,
            area_normalize=area_normalize,
            model_overwrite=model_overwrite,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col=CSV_ID_COL,
            lat_col=CSV_LAT_COL,
            lon_col=CSV_LON_COL,
            date_col=CSV_DATE_COL,
            discharge_col=CSV_DISCHARGE_COL,
            file_pattern=CSV_FILE_PATTERN,
        )
else:
    print('Skipping training.')


## 5) Evaluate model

Can run independently: **Yes**, if a trained model file exists.

Required user data (choose one mode):
- GRDC mode: `grdc_netcdf` file
- CSV mode: `csv_dir` + `lookup_csv`

Required existing outputs:
- `models/bakaano_model.keras`
- runoff/catchment/predictor data used by the model

Required configuration:
- `RUN_EVAL=True`
- validation period (`val_start`, `val_end`)
- `routing_method`, `area_normalize`

Expected outputs:
- evaluation plots and metrics for observed vs predicted flow


In [None]:
model_path = working_dir / 'models' / 'bakaano_model.keras'
print('model_path exists:', model_path.exists())

if RUN_EVAL and model_path.exists():
    if grdc_netcdf_runtime is not None and grdc_netcdf_runtime.exists():
        bk.evaluate_streamflow_model_interactively(
            model_path=str(model_path),
            val_start=EVAL_START_DATE,
            val_end=EVAL_END_DATE,
            grdc_netcdf=str(grdc_netcdf_runtime),
            routing_method=routing_method,
            area_normalize=area_normalize,
        )
    else:
        bk.evaluate_streamflow_model_interactively(
            model_path=str(model_path),
            val_start=EVAL_START_DATE,
            val_end=EVAL_END_DATE,
            grdc_netcdf=None,
            routing_method=routing_method,
            area_normalize=area_normalize,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col=CSV_ID_COL,
            lat_col=CSV_LAT_COL,
            lon_col=CSV_LON_COL,
            date_col=CSV_DATE_COL,
            discharge_col=CSV_DISCHARGE_COL,
            file_pattern=CSV_FILE_PATTERN,
        )
else:
    print('Skipping evaluation.')


## 6) Simulate streamflow

Can run independently: **Yes**, if a trained model file exists.

Required user data:
- for station simulation: GRDC NetCDF or CSV station data
- for point simulation: coordinate list in code cell

Required existing outputs:
- `models/bakaano_model.keras`
- runoff/catchment/predictor data

Required configuration:
- `RUN_SIM_GRDC` and/or `RUN_SIM_POINTS`
- simulation period (`sim_start`, `sim_end`)
- `routing_method`, `area_normalize`

Expected outputs:
- simulated discharge series and prediction artifacts under `predicted_streamflow_data/`


In [None]:
if model_path.exists() and RUN_SIM_GRDC:
    if grdc_netcdf_runtime is not None and grdc_netcdf_runtime.exists():
        bk.simulate_grdc_csv_stations(
            model_path=str(model_path),
            sim_start=WORKFLOW_START_DATE,
            sim_end=WORKFLOW_END_DATE,
            grdc_netcdf=str(grdc_netcdf_runtime),
            routing_method=routing_method,
            area_normalize=area_normalize,
        )
    else:
        bk.simulate_grdc_csv_stations(
            model_path=str(model_path),
            sim_start=WORKFLOW_START_DATE,
            sim_end=WORKFLOW_END_DATE,
            grdc_netcdf=None,
            routing_method=routing_method,
            area_normalize=area_normalize,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col=CSV_ID_COL,
            lat_col=CSV_LAT_COL,
            lon_col=CSV_LON_COL,
            date_col=CSV_DATE_COL,
            discharge_col=CSV_DISCHARGE_COL,
            file_pattern=CSV_FILE_PATTERN,
        )
else:
    print('Skipping station simulation.')


In [None]:
SIM_POINT_LATLIST = [13.8, 13.9] #user provided latitudes for point simulation
SIM_POINT_LONLIST = [3.0, 4.0] #user provided longitudes for point simulation

if model_path.exists() and RUN_SIM_POINTS:
    bk.simulate_streamflow(
        model_path=str(model_path),
        sim_start=WORKFLOW_START_DATE,
        sim_end=WORKFLOW_END_DATE,
        latlist=SIM_POINT_LATLIST,
        lonlist=SIM_POINT_LONLIST,
        routing_method=routing_method,
        area_normalize=area_normalize,
    )
else:
    print('Skipping point simulation.')


In [None]:
import glob
import pandas as pd

pred_files = sorted(glob.glob(str(working_dir / 'predicted_streamflow_data' / '*.csv')))
print('Prediction files:', len(pred_files))
if pred_files:
    print('Example:', pred_files[0])
    df = pd.read_csv(pred_files[0])
    display(df.head())