# WeatherFlow: Premier Flow-Matching Training Notebook

This notebook is the recommended, end-to-end workflow for training WeatherFlow flow-matching models on real ERA5 reanalysis data. It stitches together production-ready components from the repository—robust data loaders, physics-informed architectures, and research-grade trainers/solvers—without any mock data.

## What you'll accomplish
- Configure reproducible experiments for flow matching on ERA5 pressure-level data.
- Build training/validation data loaders that stream consecutive timesteps for vector-field supervision.
- Instantiate the **WeatherFlowMatch** model with physics-aware options and (optionally) spectral mixing.
- Train with **FlowTrainer** (AMP, EMA, gradient clipping, and physics regularization supported) and track metrics.
- Evaluate with the **WeatherODESolver** to roll out trajectories and compute science-relevant diagnostics.

## Prerequisites
- ERA5 pressure-level NetCDF files (downloadable via [CDS](https://cds.climate.copernicus.eu)) stored locally. Set `config["data_root"]` to the folder that contains files like `era5_2016.nc`.
- Dependencies listed in `notebooks/notebook_requirements.txt` (notably `torch`, `xarray`, `torchdiffeq`, `matplotlib`, `tqdm`).
- GPU is recommended, but the notebook will automatically fall back to CPU.

In [None]:
# Optional: Colab setup (clone repo and install deps)
import os
import sys
import subprocess
from pathlib import Path

if "google.colab" in sys.modules:
    repo_path = Path('/content/weatherflow')
    if not repo_path.exists():
        print('Cloning WeatherFlow repository for Colab...')
        subprocess.run([
            'git', 'clone', 'https://github.com/monksealseal/weatherflow.git', str(repo_path)
        ], check=True)
    else:
        print('Repository already present at', repo_path)
    os.chdir(repo_path)
    print('Installing notebook dependencies...')
    subprocess.run([
        sys.executable, '-m', 'pip', 'install', '-r', 'notebooks/notebook_requirements.txt'
    ], check=False)
    os.environ['WEATHERFLOW_REPO'] = str(repo_path)

In [None]:
# Environment & repository setup
from pathlib import Path
import os
import sys
import torch
import json
import warnings

warnings.filterwarnings("ignore", category=UserWarning)


def resolve_repo_root() -> Path:
    env_root = os.environ.get("WEATHERFLOW_REPO")
    if env_root:
        return Path(env_root).resolve()

    cwd = Path.cwd().resolve()

    if (cwd / "weatherflow").exists():
        return cwd
    if (cwd / "notebooks").exists() and (cwd.parent / "weatherflow").exists():
        return cwd.parent
    if (cwd.parent / "weatherflow").exists():
        return cwd.parent

    return cwd

REPO_ROOT = resolve_repo_root()
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# Report basic environment info
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Repository root: {REPO_ROOT}")
print(f"Using device: {device}")
print(f"Torch version: {torch.__version__}")

In [None]:
# Experiment configuration
from weatherflow.training.utils import set_global_seed

set_global_seed(42)

def default_data_root() -> str:
    # Keep raw data outside the repo by default
    candidate = Path(os.environ.get("WEATHERFLOW_DATA", "~/.cache/weatherflow/era5")).expanduser()
    candidate.mkdir(parents=True, exist_ok=True)
    return str(candidate)

config = {
    # Data
    "data_root": default_data_root(),
    "train_years": [2016, 2017],
    "val_years": [2018],
    "variables": [
        "geopotential",
        "temperature",
        "u_component_of_wind",
        "v_component_of_wind",
    ],
    "pressure_levels": [500, 700],
    "batch_size": 4,
    "num_workers": 2,
    "download_missing": False,  # Set to True to fetch data via CDSAPI

    # Model
    "hidden_dim": 192,
    "n_layers": 6,
    "use_attention": True,
    "spherical_padding": True,
    "use_graph_mp": False,
    "use_spectral_mixer": True,
    "spectral_modes": 12,
    "physics_informed": True,
    "enhanced_physics_losses": True,

    # Training
    "learning_rate": 5e-4,
    "weight_decay": 1e-2,
    "epochs": 3,
    "loss_type": "mse",
    "loss_weighting": "time",
    "use_amp": True,
    "physics_regularization": True,
    "physics_lambda": 0.1,
    "grad_clip": 1.0,
    "ema_decay": 0.999,

    # Checkpoints & logging
    "checkpoint_dir": str(REPO_ROOT / "artifacts" / "checkpoints"),
    "history_path": str(REPO_ROOT / "artifacts" / "training_history.json"),
}

Path(config["checkpoint_dir"]).mkdir(parents=True, exist_ok=True)
Path(config["history_path"]).parent.mkdir(parents=True, exist_ok=True)

print(json.dumps(config, indent=2))

## Build ERA5 training/validation loaders
We load real ERA5 pressure-level data and reshape consecutive timesteps into `(x0, x1)` pairs for vector-field supervision. The helper below respects the repository's normalization logic and keeps channel ordering consistent with the model.

In [None]:
from typing import Tuple
import torch
from torch.utils.data import Dataset, DataLoader
from weatherflow.data.era5 import ERA5Dataset

class ConsecutiveERA5Pairs(Dataset):
    # Wrap an ERA5Dataset to return consecutive normalized frames

    def __init__(self, base_dataset: ERA5Dataset):
        if len(base_dataset) < 2:
            raise ValueError("ERA5 dataset must contain at least two timesteps for pairing.")
        self.base = base_dataset

    @staticmethod
    def _flatten_channels(sample):
        # Convert [variables, levels, lat, lon] -> [channels, lat, lon]
        if sample.ndim != 4:
            raise ValueError(f"Expected 4D tensor, got {sample.shape}")
        vars_, levels, lat, lon = sample.shape
        return sample.view(vars_ * levels, lat, lon)

    def __len__(self) -> int:
        return len(self.base) - 1

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        x0 = self._flatten_channels(self.base[idx])
        x1 = self._flatten_channels(self.base[idx + 1])
        return x0, x1


def build_loaders(cfg):
    train_raw = ERA5Dataset(
        root_dir=cfg["data_root"],
        years=cfg["train_years"],
        variables=cfg["variables"],
        levels=cfg["pressure_levels"],
        download=cfg["download_missing"],
    )

    val_raw = ERA5Dataset(
        root_dir=cfg["data_root"],
        years=cfg["val_years"],
        variables=cfg["variables"],
        levels=cfg["pressure_levels"],
        download=cfg["download_missing"],
    )

    train_ds = ConsecutiveERA5Pairs(train_raw)
    val_ds = ConsecutiveERA5Pairs(val_raw)

    train_loader = DataLoader(
        train_ds,
        batch_size=cfg["batch_size"],
        shuffle=True,
        num_workers=cfg["num_workers"],
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=cfg["batch_size"],
        shuffle=False,
        num_workers=cfg["num_workers"],
        pin_memory=True,
    )

    # Inspect a single batch to derive channel/grid shapes for the model
    sample_x0, _ = next(iter(train_loader))
    data_shape = sample_x0.shape  # [batch, channels, lat, lon]

    return train_loader, val_loader, data_shape

In [None]:
train_loader, val_loader, data_shape = build_loaders(config)
print(f"Train batches: {len(train_loader)}, Validation batches: {len(val_loader)}")
print(f"Batch shape: {data_shape}")

### (Optional) quick visual check
Use the `visualize` helper to denormalize and plot a slice of the raw ERA5 dataset.

In [None]:
# Uncomment to view a single field (uses real ERA5 data)
# _ = train_loader.dataset.base.visualize(idx=0)

## Initialize the WeatherFlowMatch model
We size the network directly from the data loader: the channel dimension equals `len(variables) * len(pressure_levels)`, and the grid size comes from the ERA5 latitude/longitude resolution. Physics-aware options (divergence regularization, enhanced physics losses, spectral mixing) are enabled to match our strongest research configurations.

In [None]:
from weatherflow.models.flow_matching import WeatherFlowMatch

input_channels = int(data_shape[1])
grid_size = (int(data_shape[2]), int(data_shape[3]))

model = WeatherFlowMatch(
    input_channels=input_channels,
    hidden_dim=config["hidden_dim"],
    n_layers=config["n_layers"],
    use_attention=config["use_attention"],
    grid_size=grid_size,
    physics_informed=config["physics_informed"],
    window_size=8,
    static_channels=0,
    forcing_dim=0,
    spherical_padding=config["spherical_padding"],
    use_graph_mp=config["use_graph_mp"],
    use_spectral_mixer=config["use_spectral_mixer"],
    spectral_modes=config["spectral_modes"],
    enhanced_physics_losses=config["enhanced_physics_losses"],
)

model = model.to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params/1e6:.2f}M")
print(f"Grid size: {grid_size}, Channels: {input_channels}")

## Trainer, optimizer, and scheduler
`FlowTrainer` brings flow-matching losses, optional physics regularization, AMP, gradient clipping, EMA, and checkpointing in one place.

In [None]:
from weatherflow.training.flow_trainer import FlowTrainer

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config["learning_rate"],
    weight_decay=config["weight_decay"],
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    factor=0.5,
    patience=2,
)

trainer = FlowTrainer(
    model=model,
    optimizer=optimizer,
    device=device,
    use_amp=config["use_amp"],
    use_wandb=False,
    checkpoint_dir=config["checkpoint_dir"],
    scheduler=scheduler,
    physics_regularization=config["physics_regularization"],
    physics_lambda=config["physics_lambda"],
    loss_type=config["loss_type"],
    loss_weighting=config["loss_weighting"],
    grad_clip=config["grad_clip"],
    ema_decay=config["ema_decay"],
    seed=42,
)

## Training loop
This loop keeps a compact record of training/validation metrics, steps the scheduler, and writes checkpoints after each epoch.

In [None]:
from pathlib import Path

history = []

for epoch in range(config["epochs"]):
    trainer.current_epoch = epoch + 1

    train_metrics = trainer.train_epoch(train_loader)
    val_metrics = trainer.validate(val_loader)

    # Scheduler step (ReduceLROnPlateau expects a validation metric)
    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
        scheduler.step(val_metrics["val_loss"])

    epoch_record = {
        "epoch": epoch + 1,
        **train_metrics,
        **val_metrics,
        "lr": optimizer.param_groups[0]["lr"],
    }
    history.append(epoch_record)

    checkpoint_name = f"epoch_{epoch + 1}.pt"
    trainer.save_checkpoint(checkpoint_name)

    print(
        f"Epoch {epoch + 1}: train_loss={train_metrics['loss']:.4f}, "
        f"val_loss={val_metrics['val_loss']:.4f}, lr={epoch_record['lr']:.2e}"
    )

# Persist history for downstream analysis
with open(config["history_path"], "w", encoding="utf-8") as f:
    json.dump(history, f, indent=2)

history

## Loss curves
Plotting training and validation losses helps spot divergence or underfitting.

In [None]:
import matplotlib.pyplot as plt

train_losses = [h["loss"] for h in history]
val_losses = [h["val_loss"] for h in history]

plt.figure(figsize=(8, 4))
plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss")
plt.plot(range(1, len(val_losses) + 1), val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Flow-matching loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

## Roll out a trajectory with the Weather ODE Solver
Use the trained vector field to integrate an initial state forward in time while enforcing physics-aware constraints. Metrics compare the final rollout against the true next-step target.

In [None]:
from weatherflow.solvers.ode_solver import WeatherODESolver
from weatherflow.training.metrics import rmse, mae, energy_ratio

model.eval()
solver = WeatherODESolver(
    rtol=1e-5,
    atol=1e-5,
    physics_constraints=True,
    constraint_types=["mass", "energy", "vorticity"],
)

# Take a single validation batch
val_batch = next(iter(val_loader))
x0, x1 = (tensor.to(device) for tensor in val_batch)

# Time grid from 0 -> 1 (can be aligned with your lead time convention)
t_span = torch.linspace(0, 1, steps=5, device=device)

with torch.no_grad():
    def velocity_fn(x, t_scalar):
        # t_scalar from solver is a scalar tensor; expand to batch
        t_batch = torch.full((x.shape[0],), float(t_scalar), device=device)
        return model(x, t_batch)

    rollout, solver_stats = solver.solve(velocity_fn, x0, t_span)

pred_future = rollout[-1]

print("Solver stats:", solver_stats)
print(
    f"RMSE vs target: {rmse(pred_future, x1).item():.4f}, "
    f"MAE: {mae(pred_future, x1).item():.4f}, "
    f"Energy ratio: {energy_ratio(pred_future, x1).item():.4f}"
)

## Next steps
- Increase `epochs`, adjust `learning_rate`, or expand `train_years` for larger runs.
- Toggle `use_spectral_mixer` / `spectral_modes` and `enhanced_physics_losses` to explore ablations.
- Integrate with [Weights & Biases](https://wandb.ai/) by setting `use_wandb=True` in the `FlowTrainer` constructor for richer experiment tracking.
- Export checkpoints from `artifacts/checkpoints/` for downstream evaluation notebooks.