In [1]:
import os
import sys
import pickle
import glob
import xarray as xr
import pandas as pd
import plotly.graph_objects as go
from pathlib import Path
import torch
import matplotlib.pyplot as plt

print(f"Current working directory: {os.getcwd()}")

from neuralhydrology.utils.config import Config
from neuralhydrology.datasetzoo.onlineforecastdataset import OnlineForecastDataset
from neuralhydrology.nh_run import start_run, eval_run
from neuralhydrology.training.train import start_training

Current working directory: /home/sngrj0hn/GitHub/neuralhydrology/operational_harz/gefs_10d_sample


## Load Configuration

In [2]:
config_path = Path("basins_MDN.yml")
cfg = Config(config_path)
print(f"Loaded config for experiment: {cfg.experiment_name}")
print(f"Run directory: {cfg.run_dir}")

Loaded config for experiment: development_run_mdn
Run directory: None


## Inspect Data Loading
This step manually instantiates the `OnlineForecastDataset` to verify that:
1. The cached `train_data.zarr` can be accessed or rebuilt.
2. The scaler is correctly loaded or created.

In [3]:
# Initialize the dataset in training mode
# This will trigger _load_or_create_xarray_dataset and _ensure_scaler
try:
    print("Attempting to load OnlineForecastDataset...")
    ds = OnlineForecastDataset(cfg=cfg, is_train=True, period="train")
    print("Successfully loaded OnlineForecastDataset!")
    print(f"Number of samples: {len(ds)}")
    print(f"Scaler keys: {list(ds.scaler.keys())}")

    # Verify Zarr cache exists
    zarr_path = cfg.train_dir / "train_data.zarr"
    if zarr_path.exists():
        print(f"Zarr cache successfully created at {zarr_path}")
    else:
        print("Warning: Zarr cache not found!")

except Exception as e:
    print(f"Error loading dataset: {e}")
    raise

Attempting to load OnlineForecastDataset...


  return cls(**configuration_parsed)


100%|██████████| 5/5 [00:01<00:00,  2.65it/s]

Successfully loaded OnlineForecastDataset!
Number of samples: 5475
Scaler keys: ['xarray_feature_scale', 'xarray_feature_center']
Zarr cache successfully created at runs/train_data.zarr
Successfully loaded OnlineForecastDataset!
Number of samples: 5475
Scaler keys: ['xarray_feature_scale', 'xarray_feature_center']
Zarr cache successfully created at runs/train_data.zarr


## Train Model
If the data loading above succeeded, we can proceed to train the model.

In [4]:
# Capture the location of the Zarr cache verified in the previous cell
# We need to pass this explicitly because we are about to reset train_dir
if cfg.train_dir is not None:
    zarr_cache_path = cfg.train_dir / "train_data.zarr"
else:
    zarr_cache_path = None

# Modify config to use the Zarr cache and default run directory structure
# CRITICAL: We set "train_dir": None. This forces BaseTrainer to create a NEW train_data directory
# inside the new run directory, where the scaler will be saved.
# If we left train_dir pointing to the old location, the scaler would be saved there,
# but the validator would look for it in the new run directory, causing a FileNotFoundError.
update_dict = {
    "train_dir": None,
    "save_train_data": False
}

if zarr_cache_path and zarr_cache_path.exists():
    print(f"Using existing Zarr cache at: {zarr_cache_path}")
    update_dict["train_data_file"] = zarr_cache_path
else:
    print("Warning: Zarr cache not found. Dataset may be rebuilt.")

cfg.update_config(update_dict)

if torch.cuda.is_available():
    print(f"Starting run on GPU: {torch.cuda.get_device_name(0)}")
    cfg.device = "cuda:0"
else:
    print("Starting run on CPU")
    cfg.device = "cpu"

start_training(cfg)

Using existing Zarr cache at: runs/train_data.zarr
Starting run on GPU: NVIDIA GeForce RTX 4090
2025-12-02 12:27:59,256: Logging to /home/sngrj0hn/GitHub/neuralhydrology/operational_harz/gefs_10d_sample/runs/development_run_mdn_0212_122759/output.log initialized.
2025-12-02 12:27:59,256: ### Folder structure created at /home/sngrj0hn/GitHub/neuralhydrology/operational_harz/gefs_10d_sample/runs/development_run_mdn_0212_122759
2025-12-02 12:27:59,256: ### Run configurations for development_run_mdn
2025-12-02 12:27:59,256: experiment_name: development_run_mdn
2025-12-02 12:27:59,256: run_dir: /home/sngrj0hn/GitHub/neuralhydrology/operational_harz/gefs_10d_sample/runs/development_run_mdn_0212_122759
2025-12-02 12:27:59,256: train_basin_file: basins.txt
2025-12-02 12:27:59,257: validation_basin_file: basins.txt
2025-12-02 12:27:59,257: test_basin_file: basins.txt
2025-12-02 12:27:59,257: train_start_date: 2020-10-01 00:00:00
2025-12-02 12:27:59,257: train_end_date: 2023-09-30 00:00:00
2025-

  return cls(**configuration_parsed)


100%|██████████| 5/5 [00:01<00:00,  3.29it/s]

# Epoch 1:   5%|▍         | 1/22 [00:00<00:03,  5.58it/s, Loss: 298.4598]

  sample[f'x_h{freq_suffix}'] = torch.from_numpy(x_h)


# Epoch 1: 100%|██████████| 22/22 [00:01<00:00, 18.59it/s, Loss: 127.3802]
2025-12-02 12:28:03,427: Epoch 1 average loss: avg_loss: 170.45217, avg_total_loss: 170.45217
# Epoch 1: 100%|██████████| 22/22 [00:01<00:00, 18.59it/s, Loss: 127.3802]
2025-12-02 12:28:03,427: Epoch 1 average loss: avg_loss: 170.45217, avg_total_loss: 170.45217
# Epoch 2: 100%|██████████| 22/22 [00:01<00:00, 21.06it/s, Loss: 72.4138] 
2025-12-02 12:28:04,476: Epoch 2 average loss: avg_loss: 95.10378, avg_total_loss: 95.10378
# Epoch 2: 100%|██████████| 22/22 [00:01<00:00, 21.06it/s, Loss: 72.4138]
2025-12-02 12:28:04,476: Epoch 2 average loss: avg_loss: 95.10378, avg_total_loss: 95.10378
# Epoch 3: 100%|██████████| 22/22 [00:01<00:00, 21.31it/s, Loss: 62.5879]
2025-12-02 12:28:05,514: Epoch 3 average loss: avg_loss: 66.11217, avg_total_loss: 66.11217
# Epoch 3: 100%|██████████| 22/22 [00:01<00:00, 21.31it/s, Loss: 62.5879]
2025-12-02 12:28:05,514: Epoch 3 average loss: avg_loss: 66.11217, avg_total_loss: 66.112

  return cls(**configuration_parsed)


Metrics for 1h are calculated over last 1 elements only. Ignoring 239 predictions per sequence.
Metrics for 1h are calculated over last 1 elements only. Ignoring 239 predictions per sequence.
# Validation:  25%|██▌       | 1/4 [00:02<00:06,  2.09s/it]2025-12-02 12:28:08,659: Loading cached dataset from runs/train_data.zarr
# Validation:  25%|██▌       | 1/4 [00:02<00:06,  2.09s/it]2025-12-02 12:28:08,659: Loading cached dataset from runs/train_data.zarr
2025-12-02 12:28:08,676: Validating data availability against configured time periods...
2025-12-02 12:28:08,676:   Historical coverage: 2020-09-01 00:00:00 to 2024-04-30 23:00:00
2025-12-02 12:28:08,676:   Forecast coverage: 2020-10-01 00:00:00 to 2025-12-01 00:00:00
2025-12-02 12:28:08,676:   Checking train period: 2020-10-01 to 2023-09-30
2025-12-02 12:28:08,676:   Checking validation period: 2023-10-01 to 2024-01-31
2025-12-02 12:28:08,676:   Checking test period: 2024-02-01 to 2024-04-30
2025-12-02 12:28:08,677: ✅ Data availability

  return cls(**configuration_parsed)


# Validation:  50%|█████     | 2/4 [00:04<00:04,  2.03s/it]2025-12-02 12:28:10,641: Loading cached dataset from runs/train_data.zarr
2025-12-02 12:28:10,641: Loading cached dataset from runs/train_data.zarr
2025-12-02 12:28:10,657: Validating data availability against configured time periods...
2025-12-02 12:28:10,657:   Historical coverage: 2020-09-01 00:00:00 to 2024-04-30 23:00:00
2025-12-02 12:28:10,657:   Forecast coverage: 2020-10-01 00:00:00 to 2025-12-01 00:00:00
2025-12-02 12:28:10,658:   Checking train period: 2020-10-01 to 2023-09-30
2025-12-02 12:28:10,658:   Checking validation period: 2023-10-01 to 2024-01-31
2025-12-02 12:28:10,658:   Checking test period: 2024-02-01 to 2024-04-30
2025-12-02 12:28:10,658: ✅ Data availability validation passed
2025-12-02 12:28:10,657: Validating data availability against configured time periods...
2025-12-02 12:28:10,657:   Historical coverage: 2020-09-01 00:00:00 to 2024-04-30 23:00:00
2025-12-02 12:28:10,657:   Forecast coverage: 2020-1

  return cls(**configuration_parsed)


# Validation:  75%|███████▌  | 3/4 [00:06<00:02,  2.02s/it]2025-12-02 12:28:12,658: Loading cached dataset from runs/train_data.zarr
2025-12-02 12:28:12,658: Loading cached dataset from runs/train_data.zarr
2025-12-02 12:28:12,676: Validating data availability against configured time periods...
2025-12-02 12:28:12,676:   Historical coverage: 2020-09-01 00:00:00 to 2024-04-30 23:00:00
2025-12-02 12:28:12,676:   Forecast coverage: 2020-10-01 00:00:00 to 2025-12-01 00:00:00
2025-12-02 12:28:12,676:   Checking train period: 2020-10-01 to 2023-09-30
2025-12-02 12:28:12,676:   Checking validation period: 2023-10-01 to 2024-01-31
2025-12-02 12:28:12,676:   Checking test period: 2024-02-01 to 2024-04-30
2025-12-02 12:28:12,676: ✅ Data availability validation passed
2025-12-02 12:28:12,676: Validating data availability against configured time periods...
2025-12-02 12:28:12,676:   Historical coverage: 2020-09-01 00:00:00 to 2024-04-30 23:00:00
2025-12-02 12:28:12,676:   Forecast coverage: 2020-1

  return cls(**configuration_parsed)


# Validation: 100%|██████████| 4/4 [00:08<00:00,  2.03s/it]

2025-12-02 12:28:15,147: Epoch 4 average validation loss: 323.74667 -- Median validation metrics: avg_loss: 323.74667, NSE: 0.14392, KGE: 0.27951, Alpha-NSE: 0.52393, Beta-NSE: -0.30557
# Epoch 5:   0%|          | 0/22 [00:00<?, ?it/s]2025-12-02 12:28:15,147: Epoch 4 average validation loss: 323.74667 -- Median validation metrics: avg_loss: 323.74667, NSE: 0.14392, KGE: 0.27951, Alpha-NSE: 0.52393, Beta-NSE: -0.30557
# Epoch 5: 100%|██████████| 22/22 [00:01<00:00, 20.99it/s, Loss: 61.2099]
2025-12-02 12:28:16,198: Epoch 5 average loss: avg_loss: 33.18799, avg_total_loss: 33.18799
# Epoch 5: 100%|██████████| 22/22 [00:01<00:00, 20.99it/s, Loss: 61.2099]
2025-12-02 12:28:16,198: Epoch 5 average loss: avg_loss: 33.18799, avg_total_loss: 33.18799
# Epoch 6: 100%|██████████| 22/22 [00:01<00:00, 21.40it/s, Loss: 31.2408]
2025-12-02 12:28:17,231: Epoch 6 average loss: avg_loss: 16.33436, avg_total_loss: 16.33436
# Epoch 6: 100%|████

  return cls(**configuration_parsed)


# Validation: 100%|██████████| 4/4 [00:07<00:00,  1.82s/it]

2025-12-02 12:28:26,924: Epoch 8 average validation loss: 347.69742 -- Median validation metrics: avg_loss: 347.69742, NSE: 0.22480, KGE: 0.37555, Alpha-NSE: 0.60978, Beta-NSE: -0.13922
# Epoch 9:   0%|          | 0/22 [00:00<?, ?it/s]2025-12-02 12:28:26,924: Epoch 8 average validation loss: 347.69742 -- Median validation metrics: avg_loss: 347.69742, NSE: 0.22480, KGE: 0.37555, Alpha-NSE: 0.60978, Beta-NSE: -0.13922
# Epoch 9: 100%|██████████| 22/22 [00:01<00:00, 20.47it/s, Loss: 11.7942] 
2025-12-02 12:28:28,000: Epoch 9 average loss: avg_loss: -12.04331, avg_total_loss: -12.04331
# Epoch 9: 100%|██████████| 22/22 [00:01<00:00, 20.47it/s, Loss: 11.7942] 
2025-12-02 12:28:28,000: Epoch 9 average loss: avg_loss: -12.04331, avg_total_loss: -12.04331
# Epoch 10: 100%|██████████| 22/22 [00:00<00:00, 23.83it/s, Loss: -35.5640]
2025-12-02 12:28:28,929: Epoch 10 average loss: avg_loss: -24.09969, avg_total_loss: -24.09969
# Epoch 1

## Evaluation
Now that the model is trained, we evaluate it on the test set to generate predictions.

In [None]:
# Find the latest run directory
run_dirs = glob.glob("runs/development_run_mdn*")
if not run_dirs:
    raise FileNotFoundError("No run directories found in runs/")
# Sort by modification time to ensure we get the actual latest run
# String sorting fails because the date format is DDMM_HHMMSS (Day first)
run_dirs.sort(key=os.path.getmtime)
latest_run_dir = Path(run_dirs[-1])
print(f"Evaluating run at: {latest_run_dir}")

# Run evaluation on test and train sets
print("Starting Test Evaluation...")
eval_run(run_dir=latest_run_dir, period="test")

print("Starting Train Evaluation...")
eval_run(run_dir=latest_run_dir, period="train")

In [None]:
with open(latest_run_dir / "test" / "model_epoch080" / "test_results.p", "rb") as fp:
    results = pickle.load(fp)
    
results.keys()

In [None]:
results['DE4']['1h']['xr']['discharge_vol_sim']

In [None]:
# Set the basin variable for the plots below
# This should match one of the keys returned by results.keys()
basin = 'DE2'
print(f"Basin set to: {basin}")

In [None]:
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import pickle

# --- Training Period: Day 1 Forecast Plot ---

# 1. Load Training Results
# We look for the train_results.p file in the run directory
train_results_files = list(latest_run_dir.glob("train/*/train_results.p"))
if not train_results_files:
    print("No training results found. Make sure eval_run(..., period='train') has completed.")
else:
    train_results_file = train_results_files[0]
    print(f"Loading training results from: {train_results_file}")
    
    with open(train_results_file, "rb") as fp:
        train_results = pickle.load(fp)

    if 'basin' in locals():
        ds_train = train_results[basin]['1h']['xr']
        
        # 2. Subsample to avoid overlaps (Daily Stride)
        dates_train = pd.to_datetime(ds_train['date'].values)
        time_diff_train = dates_train[1] - dates_train[0]
        
        if time_diff_train < pd.Timedelta('24h'):
            stride_train = int(pd.Timedelta('24h') / time_diff_train)
            ds_train_subset = ds_train.isel(date=slice(0, None, stride_train))
        else:
            ds_train_subset = ds_train

        # 3. Select Day 1 (First 24 hours)
        ds_train_day1 = ds_train_subset.isel(time_step=slice(0, 24))

        # 4. Stack
        ds_train_flat = ds_train_day1.stack(combined=('date', 'time_step'))

        # 5. Extract
        if 'samples' in ds_train_flat['discharge_vol_sim'].dims:
            qsim_train = ds_train_flat['discharge_vol_sim'].transpose('combined', 'samples').values
        else:
            qsim_train = ds_train_flat['discharge_vol_sim'].values
        
        qobs_train = ds_train_flat['discharge_vol_obs'].values

        # 6. Reconstruct Dates
        dates_train_flat = pd.to_datetime(ds_train_flat['date'].values) + pd.to_timedelta(ds_train_flat['time_step'].values, unit='h')
        
        # 7. Sort
        sort_idx_train = np.argsort(dates_train_flat)
        dates_train_flat = dates_train_flat[sort_idx_train]
        qsim_train = qsim_train[sort_idx_train]
        qobs_train = qobs_train[sort_idx_train]

        # 8. Handle NaNs
        nan_mask_train = np.all(np.isnan(qsim_train), axis=-1) if qsim_train.ndim == 2 else np.isnan(qsim_train)
        if np.any(nan_mask_train):
            valid_mask_train = ~nan_mask_train
            dates_train_flat = dates_train_flat[valid_mask_train]
            qsim_train = qsim_train[valid_mask_train]
            qobs_train = qobs_train[valid_mask_train]

        # 9. Calculate Percentiles
        if qsim_train.ndim == 2:
            y_median_train = np.nanmedian(qsim_train, axis=-1)
            y_05_train = np.nanpercentile(qsim_train, 5, axis=-1)
            y_95_train = np.nanpercentile(qsim_train, 95, axis=-1)
            y_25_train = np.nanpercentile(qsim_train, 25, axis=-1)
            y_75_train = np.nanpercentile(qsim_train, 75, axis=-1)
        else:
            y_median_train = qsim_train
            y_05_train = y_95_train = y_25_train = y_75_train = qsim_train

        # 10. Plot
        fig = go.Figure()

        # 90% CI
        fig.add_trace(go.Scatter(
            x=np.concatenate([dates_train_flat, dates_train_flat[::-1]]),
            y=np.concatenate([y_95_train, y_05_train[::-1]]),
            fill='toself',
            fillcolor='rgba(53, 183, 121, 0.5)',
            line=dict(color='rgba(255,255,255,0)'),
            name='90% CI (5-95)',
            showlegend=True
        ))

        # 50% CI
        fig.add_trace(go.Scatter(
            x=np.concatenate([dates_train_flat, dates_train_flat[::-1]]),
            y=np.concatenate([y_75_train, y_25_train[::-1]]),
            fill='toself',
            fillcolor='rgba(68, 1, 84, 0.5)',
            line=dict(color='rgba(255,255,255,0)'),
            name='50% CI (25-75)',
            showlegend=True
        ))

        # Median
        fig.add_trace(go.Scatter(
            x=dates_train_flat,
            y=y_median_train,
            mode='lines',
            line=dict(color='red', width=2),
            name='Median'
        ))

        # Observed
        fig.add_trace(go.Scatter(
            x=dates_train_flat,
            y=qobs_train,
            mode='lines',
            line=dict(color='black', width=2, dash='dash'),
            name='Observed'
        ))

        fig.update_layout(
            title='Training Period: Discharge Prediction (Day 1 Ahead)',
            xaxis_title='Date',
            yaxis_title='Discharge [m³/s]',
            template='plotly_white',
            hovermode='x unified'
        )

        fig.show()
    else:
        print("Basin variable not defined.")

In [None]:
import plotly.graph_objects as go
import pandas as pd
import numpy as np

# --- Data Preparation with Slider Support ---

if 'results' in locals() and 'basin' in locals():
    ds = results['DE2']['1h']['xr']
    
    # 1. Subsample to avoid overlaps (Daily Stride)
    dates = pd.to_datetime(ds['date'].values)
    time_diff = dates[1] - dates[0]
    if time_diff < pd.Timedelta('24h'):
        stride = int(pd.Timedelta('24h') / time_diff)
        ds_subset = ds.isel(date=slice(0, None, stride))
    else:
        ds_subset = ds

    # Determine number of days in forecast horizon
    n_steps = len(ds_subset['time_step'])
    n_days = n_steps // 24
    print(f"Generating interactive plot for {n_days} forecast days...")

    fig = go.Figure()
    steps = []

    # Loop through each day (0 to n_days-1) to generate traces
    for day_idx in range(n_days):
        # Select Day slice (e.g., 0-24, 24-48, etc.)
        start_step = day_idx * 24
        end_step = (day_idx + 1) * 24
        
        # Slice, Stack, Extract
        ds_day = ds_subset.isel(time_step=slice(start_step, end_step))
        ds_flat = ds_day.stack(combined=('date', 'time_step'))

        if 'samples' in ds_flat['discharge_vol_sim'].dims:
            qsim = ds_flat['discharge_vol_sim'].transpose('combined', 'samples').values
        else:
            qsim = ds_flat['discharge_vol_sim'].values
        
        qobs = ds_flat['discharge_vol_obs'].values

        # Reconstruct Dates
        dates_flat = pd.to_datetime(ds_flat['date'].values) + pd.to_timedelta(ds_flat['time_step'].values, unit='h')
        
        # Sort
        sort_idx = np.argsort(dates_flat)
        dates_flat = dates_flat[sort_idx]
        qsim = qsim[sort_idx]
        qobs = qobs[sort_idx]

        # Handle NaNs
        nan_mask = np.all(np.isnan(qsim), axis=-1) if qsim.ndim == 2 else np.isnan(qsim)
        if np.any(nan_mask):
            valid_mask = ~nan_mask
            dates_flat = dates_flat[valid_mask]
            qsim = qsim[valid_mask]
            qobs = qobs[valid_mask]

        # Calculate Percentiles
        if qsim.ndim == 2:
            y_median = np.nanmedian(qsim, axis=-1)
            y_05 = np.nanpercentile(qsim, 5, axis=-1)
            y_95 = np.nanpercentile(qsim, 95, axis=-1)
            y_25 = np.nanpercentile(qsim, 25, axis=-1)
            y_75 = np.nanpercentile(qsim, 75, axis=-1)
        else:
            y_median = qsim
            y_05 = y_95 = y_25 = y_75 = qsim

        # Visibility: Only Day 1 (index 0) is visible initially
        is_visible = (day_idx == 0)

        # Add Traces (4 traces per day)
        
        # 1. 90% CI
        fig.add_trace(go.Scatter(
            x=np.concatenate([dates_flat, dates_flat[::-1]]),
            y=np.concatenate([y_95, y_05[::-1]]),
            fill='toself',
            fillcolor='rgba(53, 183, 121, 0.5)',
            line=dict(color='rgba(255,255,255,0)'),
            name='90% CI (5-95)',
            visible=is_visible,
            showlegend=True
        ))

        # 2. 50% CI
        fig.add_trace(go.Scatter(
            x=np.concatenate([dates_flat, dates_flat[::-1]]),
            y=np.concatenate([y_75, y_25[::-1]]),
            fill='toself',
            fillcolor='rgba(68, 1, 84, 0.5)',
            line=dict(color='rgba(255,255,255,0)'),
            name='50% CI (25-75)',
            visible=is_visible,
            showlegend=True
        ))

        # 3. Median
        fig.add_trace(go.Scatter(
            x=dates_flat,
            y=y_median,
            mode='lines',
            line=dict(color='red', width=2),
            name='Median',
            visible=is_visible,
            showlegend=True
        ))

        # 4. Observed
        fig.add_trace(go.Scatter(
            x=dates_flat,
            y=qobs,
            mode='lines',
            line=dict(color='black', width=2, dash='dash'),
            name='Observed',
            visible=is_visible,
            showlegend=True
        ))

    # Create Slider Steps
    # Total traces = n_days * 4
    for i in range(n_days):
        step = dict(
            method="update",
            args=[{"visible": [False] * (n_days * 4)},
                  {"title": f"Discharge Prediction - Day {i + 1} Ahead"}],
            label=str(i + 1)
        )
        # Enable the 4 traces for this day
        for j in range(4):
            step["args"][0]["visible"][i * 4 + j] = True
        steps.append(step)

    sliders = [dict(
        active=0,
        currentvalue={"prefix": "Forecast Day: "},
        pad={"t": 50},
        steps=steps
    )]

    fig.update_layout(
        sliders=sliders,
        title='Discharge Prediction - Day 1 Ahead',
        xaxis_title='Date',
        yaxis_title='Discharge [m³/s]',
        template='plotly_white',
        hovermode='x unified'
    )

    fig.show()
else:
    print("Please run the previous cells to load 'results' and define 'basin'.")

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
import numpy as np
from itertools import cycle

# --- Spaghetti Plot: All Forecast Traces (Median) ---

if 'results' in locals() and 'basin' in locals():
    # Use ds_subset from previous cell if available, otherwise re-derive
    if 'ds_subset' not in locals():
        ds = results[basin]['1h']['xr']
        dates = pd.to_datetime(ds['date'].values)
        time_diff = dates[1] - dates[0]
        if time_diff < pd.Timedelta('24h'):
            stride = int(pd.Timedelta('24h') / time_diff)
            ds_subset = ds.isel(date=slice(0, None, stride))
        else:
            ds_subset = ds

    fig = go.Figure()

    # 1. Plot Individual Forecast Traces
    n_forecasts = len(ds_subset['date'])
    print(f"Plotting {n_forecasts} forecast traces...")
    
    # Use Plotly standard qualitative colors, cycling through them
    palette = cycle(px.colors.qualitative.Plotly)
    colors = [next(palette) for _ in range(n_forecasts)]

    # Pre-calculate time steps in hours
    time_steps_hours = pd.to_timedelta(ds_subset['time_step'].values, unit='h')
    
    for i, date in enumerate(ds_subset['date'].values):
        # Extract forecast for this date
        # Dimensions: (time_step, samples) -> we want median over samples
        forecast_slice = ds_subset['discharge_vol_sim'].isel(date=i)
        
        if 'samples' in forecast_slice.dims:
            # Calculate median across samples
            forecast_median = forecast_slice.median(dim='samples').values
        else:
            forecast_median = forecast_slice.values
            
        # Construct x-axis (Date + Time Steps)
        start_date = pd.to_datetime(date)
        forecast_dates = start_date + time_steps_hours
        
        # Handle NaNs
        if np.all(np.isnan(forecast_median)):
            continue

        # Add Trace with unique color from Plotly palette
        fig.add_trace(go.Scatter(
            x=forecast_dates,
            y=forecast_median,
            mode='lines',
            line=dict(width=1.5, color=colors[i]), 
            opacity=0.8,
            showlegend=False, 
            name=f'Forecast {start_date.strftime("%Y-%m-%d")}'
        ))

    # 2. Plot Observations (Ground Truth)
    # We reconstruct the observations from the first day (0-24h) of each forecast
    # to ensure we have the full continuous time series starting from the first forecast date.
    
    obs_list = []
    date_list = []
    
    for i in range(len(ds_subset['date'])):
        # Take first 24 steps (Day 1) to stitch together the continuous timeline
        obs_slice = ds_subset['discharge_vol_obs'].isel(date=i, time_step=slice(0, 24)).values
        d_slice = pd.to_datetime(ds_subset['date'].values[i]) + time_steps_hours[:24]
        
        obs_list.append(obs_slice)
        date_list.append(d_slice)
        
    flat_obs = np.concatenate(obs_list)
    flat_dates = np.concatenate(date_list)
    
    # Sort to be sure
    sort_idx = np.argsort(flat_dates)
    flat_dates = flat_dates[sort_idx]
    flat_obs = flat_obs[sort_idx]
    
    fig.add_trace(go.Scatter(
        x=flat_dates,
        y=flat_obs,
        mode='lines',
        line=dict(color='black', width=2),
        name='Observed'
    ))

    fig.update_layout(
        title='All Forecast Traces (Median) vs Observed',
        xaxis_title='Date',
        yaxis_title='Discharge [m³/s]',
        template='plotly_white',
        hovermode='x unified'
    )

    fig.show()
else:
    print("Please run the previous cells to load 'results' and define 'basin'.")