# Bakaano-Hydro Full Workflow (Colab)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/confidence-duku/bakaano-hydro/blob/main/Bakaano-Hydro%20on%20Google%20Colab.ipynb)

This notebook runs the end-to-end workflow:
1. Input data preprocessing
2. Runoff computation and routing
3. Interactive exploration
4. Training and evaluation
5. Simulation/inference

Before running, upload your basin shapefile and (optionally) GRDC/CSV station data to Google Drive.

**Required user inputs before running:**
- `study_area`: path to your basin shapefile (`.shp`)
- choose one observed-data mode:
  - `grdc_netcdf` path, or
  - `csv_dir` + `lookup_csv` paths


## Set Colab Runtime to GPU

Before running install/training cells:
1. Go to **Runtime -> Change runtime type**.
2. Set **Hardware accelerator** to **GPU**.
3. Click **Save**.

Then run the next GPU check cell. If it reports no GPU, restart runtime and try again.


For this notebook, we use TensorFlow GPU, so the install cell removes `torch` packages from the current runtime before installing Bakaano.

If you need PyTorch later, use a separate Colab runtime/session for PyTorch workloads.


In [None]:
!pip -q uninstall -y torch torchvision torchaudio
!pip -q install "bakaano-hydro[gpu] @ git+https://github.com/confidence-duku/bakaano-hydro.git"


In [None]:
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
print('GPUs:', gpus)
if not gpus:
    raise RuntimeError('No GPU detected. In Colab, set Runtime -> Change runtime type -> GPU.')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from pathlib import Path
import shutil

shapefile_name = ' '
grdc_filename = ' '

working_dir_drive = Path('/content/drive/MyDrive/bakaano_workflow')
working_dir_local = Path('/content/bakaano_workflow')
working_dir = working_dir_drive
study_area = working_dir_drive / 'shapes' / shapefile_name

grdc_netcdf = working_dir_drive / 'data' / grdc_filename
csv_dir = None
lookup_csv = None

climate_data_source = 'ERA5'
routing_method = 'mfd'

batch_size = 32
num_epochs = 300
learning_rate = 1e-3
loss_function = 'mse'
area_normalize = True
model_overwrite = True

RUN_PREPROCESS = True
RUN_RUNOFF_ROUTING = True
RUN_TRAIN = True
RUN_EVAL = True
RUN_SIM_GRDC = True
RUN_SIM_POINTS = True

working_dir_drive.mkdir(parents=True, exist_ok=True)
(working_dir_drive / 'shapes').mkdir(parents=True, exist_ok=True)
(working_dir_drive / 'data').mkdir(parents=True, exist_ok=True)
working_dir_local.mkdir(parents=True, exist_ok=True)
(working_dir_local / 'shapes').mkdir(parents=True, exist_ok=True)
(working_dir_local / 'data').mkdir(parents=True, exist_ok=True)
print('working_dir (drive):', working_dir_drive)
print('working_dir (local):', working_dir_local)
print('study_area exists:', study_area.exists())
print('grdc_netcdf exists:', grdc_netcdf.exists() if grdc_netcdf is not None else False)


## Configuration

Set values in the next code cell:
- `shapefile_name`: shapefile filename in `.../bakaano_workflow/shapes/`
- `grdc_filename`: GRDC NetCDF filename in `.../bakaano_workflow/data/` (GRDC mode)
- `csv_dir` + `lookup_csv` (CSV mode)
- run toggles: `RUN_PREPROCESS`, `RUN_RUNOFF_ROUTING`, `RUN_TRAIN`, `RUN_EVAL`, `RUN_SIM_GRDC`, `RUN_SIM_POINTS`

## Observed Data Mode (Required)

Choose exactly one mode before running heavy steps:
1. **GRDC mode**: set `grdc_netcdf=Path(...)`, keep `csv_dir=None`, `lookup_csv=None`.
2. **CSV mode**: set `grdc_netcdf=None`, and set both `csv_dir=Path(...)` and `lookup_csv=Path(...)`.


In [None]:
if not study_area.exists():
    raise FileNotFoundError(f'study_area not found: {study_area}')

grdc_mode = grdc_netcdf is not None
csv_mode = (csv_dir is not None) and (lookup_csv is not None)

if grdc_mode == csv_mode:
    raise ValueError(
        'Choose exactly one observed-data mode:\n'
        "  1) GRDC mode: set grdc_netcdf=Path(...), csv_dir=None, lookup_csv=None\n"
        "  2) CSV mode: set grdc_netcdf=None, csv_dir=Path(...), lookup_csv=Path(...)"
    )

if grdc_mode and not grdc_netcdf.exists():
    raise FileNotFoundError(f'grdc_netcdf not found: {grdc_netcdf}')

if csv_mode:
    if not csv_dir.exists():
        raise FileNotFoundError(f'csv_dir not found: {csv_dir}')
    if not lookup_csv.exists():
        raise FileNotFoundError(f'lookup_csv not found: {lookup_csv}')

print('Input validation passed.')
print('Mode:', 'GRDC' if grdc_mode else 'CSV')


## 1) Download and preprocess input data

This stage prepares all static and climate inputs under `working_dir`. Run this once per basin (or when basin/resolution/time range changes).

Recommended order used here:
1. DEM first (creates the reference grid at `elevation/dem_clipped.tif`)
2. Tree cover
3. NDVI
4. Soil
5. AlphaEarth
6. Meteorology

Dataset availability notes:
- Tree cover (MODIS VCF): 2001 onward
- NDVI (MODIS 16-day): 2001 onward
- AlphaEarth embeddings: 2017 onward

Expected outputs:
- `elevation/`, `vcf/`, `ndvi/`, `soil/`, `alpha_earth/` folders
- climate NetCDF files under your selected meteorology source folder


Resume behavior:
- If a previous run was interrupted, rerunning this section reuses existing files and processes missing pieces before downloading more data.
- For ERA5/NDVI/Tree cover, missing-date checks are applied so only missing raw timesteps are downloaded when possible.

If Earth Engine authentication is requested, sign in and paste the code back into Colab.


In [None]:
import ee
ee.Authenticate(auth_mode="notebook")   # forces link + paste code flow
ee.Initialize()

In [None]:
if RUN_PREPROCESS:
    from bakaano.dem import DEM
    from bakaano.tree_cover import TreeCover
    from bakaano.ndvi import NDVI
    from bakaano.soil import Soil
    from bakaano.alpha_earth import AlphaEarth
    from bakaano.meteo import Meteo

    dd = DEM(
        working_dir=str(working_dir),
        study_area=str(study_area),
        local_data=False,
        local_data_path=None,
    )
    dd.get_dem_data()

    vf = TreeCover(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date='2001-01-01',
        end_date='2010-12-31',
    )
    vf.get_tree_cover_data()

    nd = NDVI(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date='2001-01-01',
        end_date='2010-12-31',
    )
    nd.get_ndvi_data()

    sgd = Soil(
        working_dir=str(working_dir),
        study_area=str(study_area),
    )
    sgd.get_soil_data()

    ae = AlphaEarth(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date='2013-01-01',
        end_date='2024-01-01',
    )
    ae.get_alpha_earth()

    cd = Meteo(
        working_dir=str(working_dir_local),
        study_area=str(study_area),
        start_date='2001-01-01',
        end_date='2010-12-31',
        local_data=False,
        data_source=climate_data_source,
    )
    cd.get_meteo_data()
    src_climate = working_dir_local / climate_data_source
    dst_climate = working_dir_drive / climate_data_source
    if src_climate.exists():
        shutil.copytree(src_climate, dst_climate, dirs_exist_ok=True)
        print(f'Synced climate outputs to Drive: {dst_climate}')

    print('Preprocessing complete.')
else:
    print('Skipping preprocessing step.')

### Optional quick plots

Run any of these in a new code cell if needed:
- `vf.plot_tree_cover(variable='tree_cover')`
- `nd.plot_ndvi(interval_num=10)`
- `dd.plot_dem()`
- `sgd.plot_soil(variable='wilting_point')`
- `ae.plot_alpha_earth('A35')`
- `cd.plot_meteo(variable='tasmin', date='2006-12-01')`


## 2) Compute runoff and route to river network

This stage runs VegET runoff and routing. For speed, it copies required inputs from Drive to local disk (`/content`), computes locally, then syncs outputs back to Drive.

Expected outputs:
- `runoff_output/wacc_sparse_arrays.pkl`
- routed runoff rasters and diagnostics in `runoff_output/`

Resume behavior:
- If interrupted, rerun this section; checkpoint files allow continuation instead of starting over.


In [None]:
if RUN_RUNOFF_ROUTING:
    from bakaano.veget import VegET

    for folder in ['shapes', 'elevation', 'soil', 'ndvi', 'vcf', 'alpha_earth', climate_data_source]:
        src = working_dir_drive / folder
        dst = working_dir_local / folder
        if src.exists():
            shutil.copytree(src, dst, dirs_exist_ok=True)
    local_study_area = working_dir_local / 'shapes' / shapefile_name

    vg = VegET(
        working_dir=str(working_dir_local),
        study_area=str(local_study_area),
        start_date='2001-01-01',
        end_date='2010-12-31',
        climate_data_source=climate_data_source,
        routing_method=routing_method,
    )
    vg.compute_veget_runoff_route_flow()
    for folder in ['runoff_output', 'catchment']:
        src = working_dir_local / folder
        dst = working_dir_drive / folder
        if src.exists():
            shutil.copytree(src, dst, dirs_exist_ok=True)
    print('Runoff and routing complete.')
else:
    print('Skipping runoff/routing step.')

In [None]:
from bakaano.plot_runoff import RoutedRunoff

rr = RoutedRunoff(
    working_dir=str(working_dir),
    study_area=str(study_area),
)

rr.map_routed_runoff(date='2010-01-03', vmax=7)

## 3) Interactive exploration

Use this section to inspect routed runoff and basin behavior before training.

Typical checks:
- visualize routed runoff time slices
- verify station alignment and basin coverage
- confirm outputs exist before model training


In [None]:
from bakaano.runner import BakaanoHydro

bk = BakaanoHydro(
    working_dir=str(working_dir),
    study_area=str(study_area),
    climate_data_source=climate_data_source,
)

if grdc_netcdf is not None and grdc_netcdf.exists():
    bk.explore_data_interactively(
        start_date='1989-01-01',
        end_date='1989-12-31',
        grdc_netcdf=str(grdc_netcdf),
    )
else:
    print('Skipping explore_data_interactively (no GRDC NetCDF path set).')

### Routed runoff timeseries

Plot routed runoff time series at one or more stations.

Instructions:
- Set `station_ids` below to IDs available in your GRDC NetCDF or lookup CSV.
- Keep the date range inside your runoff output period.
- If you do not have observed-data metadata loaded, use the lat/lon example in `rr.plot_routed_runoff_timeseries(...)`.


In [None]:
station_ids = ['6203100']  # <- replace with your station id(s)

if grdc_netcdf is not None and grdc_netcdf.exists():
    rr.plot_routed_runoff_timeseries(
        start_date='2010-01-01',
        end_date='2010-12-31',
        station_id=station_ids,
        grdc_netcdf=str(grdc_netcdf),
    )
elif lookup_csv is not None and lookup_csv.exists():
    rr.plot_routed_runoff_timeseries(
        start_date='2010-01-01',
        end_date='2010-12-31',
        station_id=station_ids,
        lookup_csv=str(lookup_csv),
    )
else:
    print('Set grdc_netcdf or lookup_csv, or call rr.plot_routed_runoff_timeseries with lat/lon directly.')


## 4) Train model

This section trains the Bakaano model using prepared runoff/features and observed streamflow data.

Training control:
- `model_overwrite=True`: start a fresh model
- `model_overwrite=False`: load existing model (if present) and continue training

Expected output:
- `models/bakaano_model.keras`


In [None]:
if RUN_TRAIN:
    if grdc_netcdf is not None and grdc_netcdf.exists():
        bk.train_streamflow_model(
            train_start='1981-01-01',
            train_end='2020-12-31',
            grdc_netcdf=str(grdc_netcdf),
            batch_size=batch_size,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            loss_function=loss_function,
            lr_schedule='cosine',
            warmup_epochs=5,
            min_learning_rate=1e-5,
            routing_method=routing_method,
            area_normalize=area_normalize,
            model_overwrite=model_overwrite,
        )
    else:
        bk.train_streamflow_model(
            train_start='1981-01-01',
            train_end='2020-12-31',
            grdc_netcdf=None,
            batch_size=batch_size,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            loss_function=loss_function,
            lr_schedule='cosine',
            warmup_epochs=5,
            min_learning_rate=1e-5,
            routing_method=routing_method,
            area_normalize=area_normalize,
            model_overwrite=model_overwrite,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col='id',
            lat_col='latitude',
            lon_col='longitude',
            date_col='date',
            discharge_col='discharge',
            file_pattern='{id}.csv',
        )
else:
    print('Skipping training.')

## 5) Evaluate model

This section evaluates a saved model against observed data for held-out periods/stations.

Expected outputs:
- evaluation metrics
- comparison plots/figures for observed vs simulated flow


In [None]:
model_path = working_dir / 'models' / 'bakaano_model.keras'
print('model_path exists:', model_path.exists())

if RUN_EVAL and model_path.exists():
    if grdc_netcdf is not None and grdc_netcdf.exists():
        bk.evaluate_streamflow_model_interactively(
            model_path=str(model_path),
            val_start='2001-01-01',
            val_end='2010-12-31',
            grdc_netcdf=str(grdc_netcdf),
            routing_method=routing_method,
            area_normalize=area_normalize,
        )
    else:
        bk.evaluate_streamflow_model_interactively(
            model_path=str(model_path),
            val_start='2001-01-01',
            val_end='2010-12-31',
            grdc_netcdf=None,
            routing_method=routing_method,
            area_normalize=area_normalize,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col='id',
            lat_col='latitude',
            lon_col='longitude',
            date_col='date',
            discharge_col='discharge',
            file_pattern='{id}.csv',
        )
else:
    print('Skipping evaluation.')

## 6) Simulate streamflow

Use a trained model to run simulation/inference for GRDC mode or point-based mode.

Run notes:
- requires `models/bakaano_model.keras`
- choose either GRDC simulation or point simulation based on your observed-data mode

Expected outputs:
- simulated discharge series and saved prediction artifacts


In [None]:
if model_path.exists() and RUN_SIM_GRDC:
    if grdc_netcdf is not None and grdc_netcdf.exists():
        bk.simulate_grdc_csv_stations(
            model_path=str(model_path),
            sim_start='1981-01-01',
            sim_end='2020-12-31',
            grdc_netcdf=str(grdc_netcdf),
            routing_method=routing_method,
            area_normalize=area_normalize,
        )
    else:
        bk.simulate_grdc_csv_stations(
            model_path=str(model_path),
            sim_start='1981-01-01',
            sim_end='2020-12-31',
            grdc_netcdf=None,
            routing_method=routing_method,
            area_normalize=area_normalize,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col='id',
            lat_col='latitude',
            lon_col='longitude',
            date_col='date',
            discharge_col='discharge',
            file_pattern='{id}.csv',
        )
else:
    print('Skipping station simulation.')

In [None]:
if model_path.exists() and RUN_SIM_POINTS:
    bk.simulate_streamflow(
        model_path=str(model_path),
        sim_start='1981-01-01',
        sim_end='1990-12-31',
        latlist=[13.8, 13.9],
        lonlist=[3.0, 4.0],
        routing_method=routing_method,
        area_normalize=area_normalize,
    )
else:
    print('Skipping point simulation.')

In [None]:
import glob
import pandas as pd

pred_files = sorted(glob.glob(str(working_dir / 'predicted_streamflow_data' / '*.csv')))
print('Prediction files:', len(pred_files))
if pred_files:
    print('Example:', pred_files[0])
    df = pd.read_csv(pred_files[0])
    display(df.head())