# Bakaano-Hydro Full Workflow (Colab)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/confidence-duku/bakaano-hydro/blob/main/colab_full_workflow.ipynb)

This notebook runs the end-to-end workflow:
1. Input data preprocessing
2. Runoff computation and routing
3. Interactive exploration
4. Training and evaluation
5. Simulation/inference

Before running, upload your basin shapefile and (optionally) GRDC/CSV station data to Google Drive.

**Required user inputs before running:**
- `study_area`: path to your basin shapefile (`.shp`)
- choose one observed-data mode:
  - `grdc_netcdf` path, or
  - `csv_dir` + `lookup_csv` paths


In [None]:
# Install package from GitHub
!pip -q install "bakaano-hydro[gpu] @ git+https://github.com/confidence-duku/bakaano-hydro.git"

In [None]:
# Verify GPU runtime in Colab
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
print('GPUs:', gpus)
if not gpus:
    raise RuntimeError('No GPU detected. In Colab, set Runtime -> Change runtime type -> GPU.')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from pathlib import Path

# ------------------------------
# User configuration
# ------------------------------
working_dir = Path('/content/drive/MyDrive/bakaano_workflow')
study_area = Path('/content/drive/MyDrive/bakaano_workflow/shapes/study_area.shp')  # REQUIRED: user must set this

# Optional observed streamflow inputs
grdc_netcdf = Path('/content/drive/MyDrive/bakaano_workflow/data/GRDC.nc')  # REQUIRED for GRDC mode; set None for CSV mode
csv_dir = None      # REQUIRED for CSV mode: Path('/content/drive/MyDrive/bakaano_workflow/data/station_csvs')
lookup_csv = None   # REQUIRED for CSV mode: Path('/content/drive/MyDrive/bakaano_workflow/data/station_lookup.csv')

# Core time windows
prep_start = '2001-01-01'
prep_end = '2010-12-31'
train_start = '1981-01-01'
train_end = '2020-12-31'
val_start = '2001-01-01'
val_end = '2010-12-31'
sim_start = '1981-01-01'
sim_end = '2020-12-31'

# Model/data options
climate_data_source = 'ERA5'   # ERA5, CHIRPS, or CHELSA
routing_method = 'mfd'         # mfd, d8, or dinf

# Training options
batch_size = 32
num_epochs = 300
learning_rate = 1e-3
loss_function = 'mse'
area_normalize = True

# Toggle heavy steps as needed
RUN_PREPROCESS = True
RUN_RUNOFF_ROUTING = True
RUN_TRAIN = True
RUN_EVAL = True
RUN_SIM_GRDC = True
RUN_SIM_POINTS = True

working_dir.mkdir(parents=True, exist_ok=True)
print('working_dir:', working_dir)
print('study_area exists:', study_area.exists())
print('grdc_netcdf exists:', grdc_netcdf.exists() if grdc_netcdf is not None else False)

## Observed Data Mode (Required)

Choose **exactly one** mode before running heavy steps:

1. **GRDC mode**: set `grdc_netcdf=Path(...)`, and keep `csv_dir=None`, `lookup_csv=None`.
2. **CSV mode**: set `grdc_netcdf=None`, and set both `csv_dir=Path(...)` and `lookup_csv=Path(...)`.


In [None]:
# Validate required inputs and enforce exactly one observed-data mode
if not study_area.exists():
    raise FileNotFoundError(f'study_area not found: {study_area}')

grdc_mode = grdc_netcdf is not None
csv_mode = (csv_dir is not None) and (lookup_csv is not None)

if grdc_mode == csv_mode:
    raise ValueError(
        'Choose exactly one observed-data mode:\n'
        "  1) GRDC mode: set grdc_netcdf=Path(...), csv_dir=None, lookup_csv=None\n"
        "  2) CSV mode: set grdc_netcdf=None, csv_dir=Path(...), lookup_csv=Path(...)"
    )

if grdc_mode and not grdc_netcdf.exists():
    raise FileNotFoundError(f'grdc_netcdf not found: {grdc_netcdf}')

if csv_mode:
    if not csv_dir.exists():
        raise FileNotFoundError(f'csv_dir not found: {csv_dir}')
    if not lookup_csv.exists():
        raise FileNotFoundError(f'lookup_csv not found: {lookup_csv}')

print('Input validation passed.')
print('Mode:', 'GRDC' if grdc_mode else 'CSV')


## 1) Download and preprocess input data

In [None]:
if RUN_PREPROCESS:
    from bakaano.tree_cover import TreeCover
    from bakaano.ndvi import NDVI
    from bakaano.dem import DEM
    from bakaano.soil import Soil
    from bakaano.alpha_earth import AlphaEarth
    from bakaano.meteo import Meteo

    # Tree cover
    vf = TreeCover(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date=prep_start,
        end_date=prep_end,
    )
    vf.get_tree_cover_data()

    # NDVI
    nd = NDVI(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date=prep_start,
        end_date=prep_end,
    )
    nd.get_ndvi_data()

    # DEM
    dd = DEM(
        working_dir=str(working_dir),
        study_area=str(study_area),
        local_data=False,
        local_data_path=None,
    )
    dd.get_dem_data()

    # Soil
    sgd = Soil(
        working_dir=str(working_dir),
        study_area=str(study_area),
    )
    sgd.get_soil_data()

    # AlphaEarth
    ae = AlphaEarth(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date='2013-01-01',
        end_date='2024-01-01',
    )
    ae.get_alpha_earth()

    # Meteorology
    cd = Meteo(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date=prep_start,
        end_date=prep_end,
        local_data=False,
        data_source=climate_data_source,
    )
    cd.get_meteo_data()

    print('Preprocessing complete.')
else:
    print('Skipping preprocessing step.')

In [None]:
# Optional quick plots after preprocessing
# vf.plot_tree_cover(variable='tree_cover')
# nd.plot_ndvi(interval_num=10)
# dd.plot_dem()
# sgd.plot_soil(variable='wilting_point')
# ae.plot_alpha_earth('A35')
# cd.plot_meteo(variable='tasmin', date='2006-12-01')

## 2) Compute runoff and route to river network

In [None]:
if RUN_RUNOFF_ROUTING:
    from bakaano.veget import VegET

    vg = VegET(
        working_dir=str(working_dir),
        study_area=str(study_area),
        start_date=prep_start,
        end_date=prep_end,
        climate_data_source=climate_data_source,
        routing_method=routing_method,
    )
    vg.compute_veget_runoff_route_flow()
    print('Runoff and routing complete.')
else:
    print('Skipping runoff/routing step.')

In [None]:
from bakaano.plot_runoff import RoutedRunoff

rr = RoutedRunoff(
    working_dir=str(working_dir),
    study_area=str(study_area),
)

# Update date to one available in your runoff output
rr.map_routed_runoff(date='2010-01-03', vmax=7)

## 3) Interactive exploration

In [None]:
from bakaano.runner import BakaanoHydro

bk = BakaanoHydro(
    working_dir=str(working_dir),
    study_area=str(study_area),
    climate_data_source=climate_data_source,
)

if grdc_netcdf is not None and grdc_netcdf.exists():
    bk.explore_data_interactively(
        start_date='1989-01-01',
        end_date='1989-12-31',
        grdc_netcdf=str(grdc_netcdf),
    )
else:
    print('Skipping explore_data_interactively (no GRDC NetCDF path set).')

## 4) Train model

In [None]:
if RUN_TRAIN:
    if grdc_netcdf is not None and grdc_netcdf.exists():
        bk.train_streamflow_model(
            train_start=train_start,
            train_end=train_end,
            grdc_netcdf=str(grdc_netcdf),
            batch_size=batch_size,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            loss_function=loss_function,
            lr_schedule='cosine',
            warmup_epochs=5,
            min_learning_rate=1e-5,
            routing_method=routing_method,
            area_normalize=area_normalize,
        )
    else:
        # CSV-based training path
        bk.train_streamflow_model(
            train_start=train_start,
            train_end=train_end,
            grdc_netcdf=None,
            batch_size=batch_size,
            num_epochs=num_epochs,
            learning_rate=learning_rate,
            loss_function=loss_function,
            lr_schedule='cosine',
            warmup_epochs=5,
            min_learning_rate=1e-5,
            routing_method=routing_method,
            area_normalize=area_normalize,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col='id',
            lat_col='latitude',
            lon_col='longitude',
            date_col='date',
            discharge_col='discharge',
            file_pattern='{id}.csv',
        )
else:
    print('Skipping training.')

## 5) Evaluate model

In [None]:
model_path = working_dir / 'models' / 'bakaano_model.keras'
print('model_path exists:', model_path.exists())

if RUN_EVAL and model_path.exists():
    if grdc_netcdf is not None and grdc_netcdf.exists():
        bk.evaluate_streamflow_model_interactively(
            model_path=str(model_path),
            val_start=val_start,
            val_end=val_end,
            grdc_netcdf=str(grdc_netcdf),
            routing_method=routing_method,
            area_normalize=area_normalize,
        )
    else:
        bk.evaluate_streamflow_model_interactively(
            model_path=str(model_path),
            val_start=val_start,
            val_end=val_end,
            grdc_netcdf=None,
            routing_method=routing_method,
            area_normalize=area_normalize,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col='id',
            lat_col='latitude',
            lon_col='longitude',
            date_col='date',
            discharge_col='discharge',
            file_pattern='{id}.csv',
        )
else:
    print('Skipping evaluation.')

## 6) Simulate streamflow

In [None]:
if model_path.exists() and RUN_SIM_GRDC:
    if grdc_netcdf is not None and grdc_netcdf.exists():
        bk.simulate_grdc_csv_stations(
            model_path=str(model_path),
            sim_start=sim_start,
            sim_end=sim_end,
            grdc_netcdf=str(grdc_netcdf),
            routing_method=routing_method,
            area_normalize=area_normalize,
        )
    else:
        bk.simulate_grdc_csv_stations(
            model_path=str(model_path),
            sim_start=sim_start,
            sim_end=sim_end,
            grdc_netcdf=None,
            routing_method=routing_method,
            area_normalize=area_normalize,
            csv_dir=str(csv_dir),
            lookup_csv=str(lookup_csv),
            id_col='id',
            lat_col='latitude',
            lon_col='longitude',
            date_col='date',
            discharge_col='discharge',
            file_pattern='{id}.csv',
        )
else:
    print('Skipping station simulation.')

In [None]:
# Arbitrary coordinate simulation
if model_path.exists() and RUN_SIM_POINTS:
    bk.simulate_streamflow(
        model_path=str(model_path),
        sim_start='1981-01-01',
        sim_end='1990-12-31',
        latlist=[13.8, 13.9],
        lonlist=[3.0, 4.0],
        routing_method=routing_method,
        area_normalize=area_normalize,
    )
else:
    print('Skipping point simulation.')

In [None]:
# Inspect generated prediction files
import glob
import pandas as pd

pred_files = sorted(glob.glob(str(working_dir / 'predicted_streamflow_data' / '*.csv')))
print('Prediction files:', len(pred_files))
if pred_files:
    print('Example:', pred_files[0])
    df = pd.read_csv(pred_files[0])
    display(df.head())