# Debug Forecast Dataset Pipeline

This notebook walks through each step of the OnlineForecastDataset creation process,
allowing you to inspect and visualize data at every stage.

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from pathlib import Path

# Change to project root
os.chdir('/home/sngrj0hn/GitHub/neuralhydrology')

from neuralhydrology.utils.config import Config
from neuralhydrology.datautils.fetch_basin_forecasts import (
    load_basin_centroids,
    fetch_forecasts_for_basins,
    interpolate_to_hourly,
)

print("✓ Imports successful")

✓ Imports successful


## 1. Load Configuration

In [None]:
# Load config
config_path = Path('operational_harz/gefs_10d_sample/config.yml')
cfg = Config(config_path)

# Load basins from basins.txt
basins_file = config_path.parent / 'basins.txt'
with open(basins_file, 'r') as f:
    basins = [line.strip() for line in f if line.strip()]

print(f"Period: train")
print(f"Basins (from {basins_file.name}): {basins}")
print(f"Forecast inputs: {cfg.forecast_inputs}")
print(f"Hindcast inputs: {cfg.hindcast_inputs}")
print(f"Target variables: {cfg.target_variables}")
print(f"\nTrain period: {cfg.train_start_date} to {cfg.train_end_date}")

Period: train


AttributeError: 'Config' object has no attribute 'basins'

## 2. Extract Base Variable Names from Config

In [None]:
# Strip quartile suffixes to get base variable names
base_vars_needed = set()
for var in cfg.forecast_inputs:
    base_var = var.replace('_q25', '').replace('_q50', '').replace('_q75', '')
    base_vars_needed.add(base_var)

print(f"Config forecast_inputs (with quartile suffixes):")
for v in sorted(cfg.forecast_inputs):
    print(f"  - {v}")

print(f"\nBase variables needed from NOAA (without suffixes):")
for v in sorted(base_vars_needed):
    print(f"  - {v}")

## 3. Connect to NOAA GEFS Dataset

In [None]:
print("Connecting to NOAA GEFS Zarr store...")
ds = xr.open_zarr(
    "https://data.dynamical.org/noaa/gefs/forecast-35-day/latest.zarr?email=optional@email.com",
    decode_timedelta=True
)

print(f"\n✓ Connected successfully")
print(f"\nDataset dimensions: {dict(ds.dims)}")
print(f"\nAvailable variables (first 20):")
for i, v in enumerate(list(ds.data_vars)[:20]):
    print(f"  {i+1}. {v}")

print(f"\nTime range: {ds.init_time.values[0]} to {ds.init_time.values[-1]}")

## 4. Apply Temporal Filtering

In [None]:
# Extract period dates
start_date = pd.to_datetime(cfg.train_start_date, format='%d/%m/%Y')
end_date = pd.to_datetime(cfg.train_end_date, format='%d/%m/%Y')

print(f"Filtering to period: {start_date.date()} to {end_date.date()}")

# Temporal slice
time_dim = 'init_time' if 'init_time' in ds.dims else 'time'
ds_filtered = ds.sel({time_dim: slice(start_date, end_date)})

print(f"\nBefore filtering: {len(ds[time_dim])} time steps")
print(f"After filtering: {len(ds_filtered[time_dim])} time steps")
print(f"Data reduction: {100 * (1 - len(ds_filtered[time_dim]) / len(ds[time_dim])):.1f}%")

## 5. Filter to Required Variables

In [None]:
# Check which base variables are available
available_base_vars = [v for v in base_vars_needed if v in ds_filtered.data_vars]
missing_vars = [v for v in base_vars_needed if v not in ds_filtered.data_vars]

print(f"Requested base variables: {len(base_vars_needed)}")
print(f"Available: {len(available_base_vars)}")
print(f"Missing: {len(missing_vars)}")

if missing_vars:
    print(f"\n⚠️ Missing variables:")
    for v in missing_vars:
        print(f"  - {v}")

print(f"\nFiltering to {len(available_base_vars)} variables...")
ds_filtered = ds_filtered[available_base_vars]

print(f"\n✓ Dataset now contains: {list(ds_filtered.data_vars)}")

## 6. Load Basin Centroids and Extract Forecasts

In [None]:
# Load centroids
basin_centroids_file = cfg.data_dir / "basin_centroids" / "basin_centroids.csv"
centroids = load_basin_centroids(basin_centroids_file)

print(f"Loaded {len(centroids)} basin centroids")
print(f"\nCentroids:")
print(centroids)

# Filter to configured basins (from basins.txt)
centroids = centroids[centroids['basin_name'].isin(basins)]
print(f"\nFiltered to {len(centroids)} configured basins")

# Extract forecasts for basin locations
print(f"\nExtracting forecasts for basin centroids...")
basin_forecasts = fetch_forecasts_for_basins(ds_filtered, centroids)

print(f"\n✓ Basin forecasts extracted")
print(f"Dimensions: {dict(basin_forecasts.dims)}")
print(f"Variables: {list(basin_forecasts.data_vars)}")

## 7. Compute Ensemble Quartiles

In [None]:
quartiles = [0.25, 0.5, 0.75]
quartile_suffixes = {0.25: '_q25', 0.5: '_q50', 0.75: '_q75'}

print(f"Computing quartiles {quartiles} from {len(basin_forecasts.ensemble_member)} ensemble members...")

new_data_vars = {}
for var_name in basin_forecasts.data_vars:
    var_data = basin_forecasts[var_name]
    var_quartiles = var_data.quantile(quartiles, dim='ensemble_member')
    
    for i, q in enumerate(quartiles):
        suffix = quartile_suffixes.get(q, f'_q{int(q*100)}')
        new_var_name = f"{var_name}{suffix}"
        quartile_data = var_quartiles.isel(quantile=i).drop('quantile')
        new_data_vars[new_var_name] = quartile_data

coords_to_keep = {k: v for k, v in basin_forecasts.coords.items() 
                 if 'ensemble_member' not in v.dims}

basin_forecasts_quartiles = xr.Dataset(
    data_vars=new_data_vars,
    coords=coords_to_keep,
    attrs=basin_forecasts.attrs.copy()
)

print(f"\n✓ Computed {len(new_data_vars)} quartile variables")
print(f"Dimensions: {dict(basin_forecasts_quartiles.dims)}")
print(f"\nQuartile variables created:")
for v in sorted(basin_forecasts_quartiles.data_vars):
    print(f"  - {v}")

## 8. Interpolate to Hourly Resolution

In [None]:
print(f"Interpolating to hourly for first 240 hours...")
print(f"\nBefore interpolation:")
print(f"  Lead times: {len(basin_forecasts_quartiles.lead_time)}")
print(f"  Lead time range: {basin_forecasts_quartiles.lead_time.values[0]} to {basin_forecasts_quartiles.lead_time.values[-1]}")

basin_forecasts_hourly = interpolate_to_hourly(basin_forecasts_quartiles, max_hours=240)

print(f"\nAfter interpolation:")
print(f"  Lead times: {len(basin_forecasts_hourly.lead_time)}")
print(f"  Lead time range: {basin_forecasts_hourly.lead_time.values[0]} to {basin_forecasts_hourly.lead_time.values[-1]}")
print(f"\n✓ Interpolation complete")

## 9. Final Variable Filtering and Validation

In [None]:
# Filter to exact config variables
forecast_vars = [var for var in basin_forecasts_hourly.data_vars 
                if var in cfg.forecast_inputs]

print(f"Config expects {len(cfg.forecast_inputs)} forecast variables")
print(f"Dataset has {len(forecast_vars)} matching variables")

missing_in_dataset = [v for v in cfg.forecast_inputs if v not in basin_forecasts_hourly.data_vars]
if missing_in_dataset:
    print(f"\n⚠️ Variables in config but missing from dataset:")
    for v in missing_in_dataset:
        print(f"  - {v}")

if not forecast_vars:
    print(f"\n❌ ERROR: No matching variables found!")
    print(f"\nConfig forecast_inputs: {cfg.forecast_inputs}")
    print(f"\nDataset variables: {list(basin_forecasts_hourly.data_vars)}")
else:
    basin_forecasts_hourly = basin_forecasts_hourly[forecast_vars]
    print(f"\n✓ Filtered to {len(forecast_vars)} variables")

## 10. Visualize Sample Forecast Data

In [None]:
if forecast_vars:
    # Pick first variable and first basin
    sample_var = forecast_vars[0]
    sample_basin = basin_forecasts_hourly.basin.values[0]
    
    print(f"Plotting: {sample_var} for basin {sample_basin}")
    
    # Select data for first 5 forecast initialization times
    sample_data = basin_forecasts_hourly[sample_var].sel(basin=sample_basin).isel(time=slice(0, 5))
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    for i, t in enumerate(sample_data.time.values[:5]):
        forecast = sample_data.sel(time=t)
        lead_hours = sample_data.lead_time.values
        ax.plot(lead_hours, forecast.values, marker='o', markersize=2, 
               label=f"Init: {pd.to_datetime(t).strftime('%Y-%m-%d %H:%M')}")
    
    ax.set_xlabel('Lead Time (hours)')
    ax.set_ylabel(sample_var)
    ax.set_title(f'{sample_var} Forecasts for Basin {sample_basin}\n(First 5 initialization times)')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"\n✓ Visualization complete")
else:
    print("Cannot visualize - no forecast variables available")

## 11. Check Data Chunking (Before Compute)

In [None]:
print("Checking chunking status...\n")
for var in list(basin_forecasts_hourly.data_vars)[:3]:  # Check first 3 variables
    data = basin_forecasts_hourly[var].data
    if hasattr(data, 'chunks'):
        print(f"{var}:")
        print(f"  Type: Dask array")
        print(f"  Chunks: {data.chunks}")
    else:
        print(f"{var}:")
        print(f"  Type: Numpy array (in-memory)")
    print()

## 12. Materialize Data (Compute)

⚠️ This step will download all data from the remote server.
Monitor network activity and RAM usage.

In [None]:
import time

# Rechunk for parallel processing
if 'basin' in basin_forecasts_hourly.dims:
    print("Rechunking for parallel computation...")
    basin_forecasts_hourly = basin_forecasts_hourly.chunk(
        {'basin': 1, 'time': -1, 'lead_time': -1}
    )
    print(f"✓ Rechunked")

print(f"\nStarting compute (materialization)...")
print(f"This will download ~{basin_forecasts_hourly.nbytes / 1e9:.2f} GB of data")

start_time = time.time()

basin_forecasts_hourly = basin_forecasts_hourly.compute(
    scheduler='threads', 
    num_workers=4
)

elapsed = time.time() - start_time

print(f"\n✓ Compute complete in {elapsed:.1f} seconds")
print(f"Download speed: ~{(basin_forecasts_hourly.nbytes / 1e6) / elapsed:.1f} MB/s")

# Verify it's now in-memory
sample_var = list(basin_forecasts_hourly.data_vars)[0]
print(f"\nVerifying data type:")
print(f"  {sample_var}: {type(basin_forecasts_hourly[sample_var].data)}")

## 13. Final Dataset Summary

In [None]:
print("=" * 60)
print("FINAL FORECAST DATASET SUMMARY")
print("=" * 60)
print(f"\nDimensions: {dict(basin_forecasts_hourly.dims)}")
print(f"\nCoordinates:")
for coord in basin_forecasts_hourly.coords:
    print(f"  - {coord}: {basin_forecasts_hourly[coord].shape}")

print(f"\nData Variables ({len(basin_forecasts_hourly.data_vars)}):")
for var in sorted(basin_forecasts_hourly.data_vars):
    shape = basin_forecasts_hourly[var].shape
    size_mb = basin_forecasts_hourly[var].nbytes / 1e6
    print(f"  - {var}: {shape} ({size_mb:.1f} MB)")

print(f"\nTotal size: {basin_forecasts_hourly.nbytes / 1e9:.2f} GB")
print(f"Memory type: In-memory numpy arrays")
print(f"\n✓ Forecast dataset ready for merging with historical data")

## 14. Save for Later Use (Optional)

In [None]:
# Uncomment to save the processed forecast dataset
# output_file = Path('operational_harz/gefs_10d_sample/forecast_debug_output.nc')
# print(f"Saving to {output_file}...")
# basin_forecasts_hourly.to_netcdf(output_file)
# print(f"✓ Saved ({output_file.stat().st_size / 1e6:.1f} MB)")

print("To save, uncomment the code above")