# ERA5 Flow Matching Pipeline End-to-End

Welcome to the all-in-one WeatherFlow workflow! This notebook shows how to go from ERA5-style atmospheric data to a trained flow matching model, generate forecasts with an ODE solver, and visualise the resulting weather fields. The walkthrough is designed to surface the "aha!" moment where everything clicks together.


## Notebook roadmap

We'll work through the complete lifecycle:

1. Configure the experiment for your set of ERA5 variables and pressure levels.
2. Load data either from WeatherBench2/ERA5 or a fast synthetic fallback that mimics the structure of the real dataset.
3. Build a `WeatherFlowMatch` vector field model and review its architecture.
4. Train the model with `FlowTrainer`, monitoring flow matching and physics-aware losses.
5. Wrap the trained model with `WeatherFlowODE` for inference at arbitrary lead times.
6. Visualise true vs. predicted fields using the rich `WeatherVisualizer` toolkit.

Each section includes compact code you can re-use in scripts or other notebooks.


In [None]:
# Add the repository root to the Python path so imports resolve correctly
import os
import sys
from pathlib import Path

NOTEBOOK_DIR = Path.cwd()
REPO_ROOT = (NOTEBOOK_DIR / '..').resolve()
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))
print(f'Repository root: {REPO_ROOT}')

try:
    import weatherflow
    print(f'WeatherFlow version: {weatherflow.__version__}')
except Exception as exc:
    print(f'WeatherFlow import warning: {exc}')


In [None]:
# Install lightweight mocks for optional dependencies (cartopy, torchdiffeq) when missing
try:
    from mock_dependencies import install_all_mocks
    install_all_mocks()
except Exception as exc:
    print(f'Mock dependency installation skipped: {exc}')


In [None]:
# Core imports for the workflow
import json
import math
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

try:
    from weatherflow.data import ERA5Dataset, create_data_loaders
except Exception as exc:
    print(f'ERA5 data utilities unavailable: {exc}')
    ERA5Dataset = None
    create_data_loaders = None

from weatherflow.models import WeatherFlowMatch, WeatherFlowODE
from weatherflow.training import FlowTrainer

try:
    from weatherflow.utils import WeatherVisualizer
    HAS_VISUALIZER = True
except Exception as exc:
    print(f'WeatherVisualizer import warning: {exc}')
    WeatherVisualizer = None
    HAS_VISUALIZER = False

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_device = torch.device(device)
print(f'Using device: {device}')

torch.manual_seed(42)
np.random.seed(42)


## 1. Configure the experiment

Customise the configuration dictionary below to match the ERA5 variables, pressure levels, and training schedule you want to explore. The synthetic fallback shares the same interface so you can switch to real data without touching the rest of the notebook.


In [None]:
config = {
    'variables': ['z', 't', 'u', 'v'],
    'pressure_levels': [500, 300],
    'resolution': (32, 64),
    'train_samples': 96,
    'val_samples': 24,
    'batch_size': 8,
    'hidden_dim': 160,
    'n_layers': 4,
    'use_attention': True,
    'physics_informed': True,
    'learning_rate': 3e-4,
    'epochs': 4,
    'loss_type': 'mse',
    'use_synthetic_data': True,  # Flip to False to connect with actual ERA5/WeatherBench2 data
    'era5_data_path': None,       # Optional local path or GCS URL when use_synthetic_data=False
    'train_years': ('2016', '2016'),
    'val_years': ('2017', '2017'),
    'solver_method': 'dopri5',
    'rtol': 1e-4,
    'atol': 1e-4,
    'visualize_variable': 'z'
}
config


## 2. Load ERA5-style data

To make the notebook runnable anywhere, we provide a physically-inspired synthetic dataset that mirrors the shape and metadata of ERA5. When you have direct access to the WeatherBench2 ERA5 store (or a local subset), set `use_synthetic_data=False` and the same code path will wrap the real dataset into channel-first tensors expected by the models.


In [None]:
class SyntheticERA5Dataset(Dataset):
    """Smooth, repeatable toy dataset that mimics ERA5 tensors."""

    def __init__(self, num_steps, variables, pressure_levels, resolution=(32, 64), seed=0):
        self.num_steps = int(num_steps)
        self.variables = list(variables)
        self.pressure_levels = list(pressure_levels)
        self.resolution = tuple(resolution)
        self.channels = len(self.variables) * len(self.pressure_levels)
        self.generator = torch.Generator().manual_seed(int(seed))

        lat_radians = torch.linspace(-math.pi / 2, math.pi / 2, self.resolution[0])
        lon_radians = torch.linspace(0.0, 2 * math.pi, self.resolution[1])
        self.lat_grid, self.lon_grid = torch.meshgrid(lat_radians, lon_radians, indexing='ij')
        self.latitudes = torch.linspace(-90.0, 90.0, self.resolution[0])
        self.longitudes = torch.linspace(0.0, 360.0, self.resolution[1])

        self._states = self._generate_sequence()

    def _base_wave(self, var_idx, level_idx, phase):
        lat_component = torch.sin(self.lat_grid * (level_idx + 1))
        lon_component = torch.cos(self.lon_grid * (var_idx + 1))
        standing_wave = torch.sin(self.lon_grid * 0.5 + phase) * torch.cos(self.lat_grid + 0.5 * phase * (level_idx + 1))
        rotational = torch.sin(self.lat_grid * 0.3 + phase) * torch.sin(self.lon_grid * 0.7 - phase)
        amplitude = 1.0 + 0.25 * level_idx + 0.35 * var_idx
        return amplitude * (lat_component + lon_component) + 0.6 * standing_wave + 0.1 * rotational

    def _generate_state(self, step):
        phase = 2 * math.pi * step / max(self.num_steps, 1)
        channels = []
        for v_idx, _ in enumerate(self.variables):
            for l_idx, _ in enumerate(self.pressure_levels):
                wave = self._base_wave(v_idx, l_idx, phase)
                drift = torch.sin(self.lon_grid + phase) * torch.cos(self.lat_grid - 0.5 * phase)
                noise = 0.02 * torch.randn(self.resolution, generator=self.generator)
                channels.append(wave + 0.05 * step * drift / max(self.num_steps, 1) + noise)
        return torch.stack(channels).float()

    def _generate_sequence(self):
        states = []
        for step in range(self.num_steps + 1):
            states.append(self._generate_state(step))
        return torch.stack(states)

    def __len__(self):
        return self.num_steps

    def __getitem__(self, idx):
        input_state = self._states[idx]
        target_state = self._states[idx + 1]
        metadata = {
            't0_index': idx,
            't1_index': idx + 1,
            'variables': self.variables,
            'pressure_levels': self.pressure_levels
        }
        return {'input': input_state, 'target': target_state, 'metadata': metadata}

    def get_coords(self):
        return {
            'latitude': self.latitudes.numpy(),
            'longitude': self.longitudes.numpy()
        }


In [None]:
def channel_first_wrapper(base_dataset, variables, pressure_levels):
    """Flatten variable/level dimensions into channels for FlowTrainer."""

    class _WrappedDataset(Dataset):
        def __init__(self, dataset):
            self.dataset = dataset
            self.variables = list(variables)
            self.pressure_levels = list(pressure_levels)
            self.channels = len(self.variables) * len(self.pressure_levels)
            coords = getattr(dataset, 'get_coords', lambda: None)() if hasattr(dataset, 'get_coords') else None
            if coords:
                self.latitudes = torch.tensor(coords['latitude'], dtype=torch.float32)
                self.longitudes = torch.tensor(coords['longitude'], dtype=torch.float32)
            else:
                spatial_shape = getattr(dataset, 'shape', (len(self.variables), len(self.pressure_levels), 32, 64))
                self.latitudes = torch.linspace(-90.0, 90.0, spatial_shape[-2])
                self.longitudes = torch.linspace(0.0, 360.0, spatial_shape[-1])

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            sample = self.dataset[idx]
            x0 = sample['input']
            x1 = sample['target']
            if x0.ndim == 4:
                x0 = x0.view(-1, *x0.shape[-2:])
            if x1.ndim == 4:
                x1 = x1.view(-1, *x1.shape[-2:])
            metadata = dict(sample.get('metadata', {}))
            metadata.setdefault('variables', self.variables)
            metadata.setdefault('pressure_levels', self.pressure_levels)
            return {'input': x0.float(), 'target': x1.float(), 'metadata': metadata}

        def get_coords(self):
            return {
                'latitude': self.latitudes.numpy(),
                'longitude': self.longitudes.numpy()
            }

    return _WrappedDataset(base_dataset)


In [None]:
def collate_samples(batch):
    inputs = torch.stack([item['input'] for item in batch])
    targets = torch.stack([item['target'] for item in batch])
    metadata = {}
    for key in batch[0]['metadata']:
        metadata[key] = [item['metadata'][key] for item in batch]
    return {'input': inputs, 'target': targets, 'metadata': metadata}

if config['use_synthetic_data']:
    train_dataset = SyntheticERA5Dataset(
        num_steps=config['train_samples'],
        variables=config['variables'],
        pressure_levels=config['pressure_levels'],
        resolution=config['resolution'],
        seed=42
    )
    val_dataset = SyntheticERA5Dataset(
        num_steps=config['val_samples'],
        variables=config['variables'],
        pressure_levels=config['pressure_levels'],
        resolution=config['resolution'],
        seed=123
    )
else:
    if ERA5Dataset is None:
        raise RuntimeError('ERA5Dataset is unavailable. Install required dependencies or use synthetic data.')
    train_base = ERA5Dataset(
        variables=config['variables'],
        pressure_levels=config['pressure_levels'],
        data_path=config['era5_data_path'],
        time_slice=config['train_years'],
        normalize=True,
        cache_data=False,
        verbose=True
    )
    val_base = ERA5Dataset(
        variables=config['variables'],
        pressure_levels=config['pressure_levels'],
        data_path=config['era5_data_path'],
        time_slice=config['val_years'],
        normalize=True,
        cache_data=False,
        verbose=True
    )
    train_dataset = channel_first_wrapper(train_base, config['variables'], config['pressure_levels'])
    val_dataset = channel_first_wrapper(val_base, config['variables'], config['pressure_levels'])

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
    collate_fn=collate_samples
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
    collate_fn=collate_samples
)

coords = train_dataset.get_coords() if hasattr(train_dataset, 'get_coords') else {'latitude': train_dataset.latitudes.numpy(), 'longitude': train_dataset.longitudes.numpy()}
latitudes = torch.tensor(coords['latitude'], dtype=torch.float32)
longitudes = torch.tensor(coords['longitude'], dtype=torch.float32)

sample_batch = next(iter(train_loader))
input_channels = sample_batch["input"].shape[1]
grid_shape = sample_batch["input"].shape[2:]

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f"Batch tensor shape: {sample_batch["input"].shape}")
print(f'Channels feeding the model: {input_channels}')
print(f'Spatial grid: {grid_shape}')


The wrapped loaders now emit tensors with shape `[batch, channels, lat, lon]`, exactly what the convolutional backbone expects. Metadata (such as the time indices and pressure levels) is preserved for analysis and visualisation.


In [None]:
def unflatten_channels(tensor, variables, pressure_levels):
    vars_count = len(variables)
    levels_count = len(pressure_levels)
    if tensor.ndim == 3:
        channels, lat, lon = tensor.shape
        return tensor.view(vars_count, levels_count, lat, lon)
    elif tensor.ndim == 4:
        batch, channels, lat, lon = tensor.shape
        return tensor.view(batch, vars_count, levels_count, lat, lon)
    else:
        raise ValueError('Expected a 3D or 4D tensor.')

def describe_batch(batch):
    print('Batch keys:', list(batch.keys()))
    print('Input shape:', batch['input'].shape)
    print('Target shape:', batch['target'].shape)
    print('Metadata example:', {k: v[0] for k, v in batch['metadata'].items()})

describe_batch(sample_batch)


## 3. Build the flow matching model

`WeatherFlowMatch` provides a spatio-temporal vector field with optional attention and physics-informed divergence control. We configure it with the number of input channels discovered above.


In [None]:
model = WeatherFlowMatch(
    input_channels=input_channels,
    hidden_dim=config['hidden_dim'],
    n_layers=config['n_layers'],
    use_attention=config['use_attention'],
    physics_informed=config['physics_informed']
)
model = model.to(device)
total_params = sum(p.numel() for p in model.parameters())
print(model)
print(f'Total parameters: {total_params:,}')


## 4. Train with flow matching objectives

`FlowTrainer` handles the stochastic time sampling, automatic mixed precision (disabled here for clarity), and optional physics regularisation. The loop below records training and validation losses so you can inspect convergence.


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'])
trainer = FlowTrainer(
    model=model,
    optimizer=optimizer,
    device=device,
    use_amp=False,
    physics_regularization=config['physics_informed'],
    physics_lambda=0.1,
    loss_type=config['loss_type']
)

history = {
    'train_loss': [],
    'train_flow_loss': [],
    'val_loss': [],
    'val_flow_loss': []
}

for epoch in range(1, config['epochs'] + 1):
    print(f"\nEpoch {epoch}/{config['epochs']}")
    train_metrics = trainer.train_epoch(train_loader)
    val_metrics = trainer.validate(val_loader)

    history["train_loss"].append(train_metrics["loss"])
    history["train_flow_loss"].append(train_metrics.get("flow_loss", train_metrics["loss"]))
    history["val_loss"].append(val_metrics["loss"])
    history["val_flow_loss"].append(val_metrics.get("flow_loss", val_metrics["loss"]))

    print(f"Train loss: {train_metrics["loss"]:.4f} | Flow: {train_metrics.get("flow_loss", train_metrics["loss"]):.4f}")
    print(f"Val loss:   {val_metrics["loss"]:.4f} | Flow: {val_metrics.get("flow_loss", val_metrics["loss"]):.4f}")


In [None]:
plt.figure(figsize=(10, 5))
plt.plot(history['train_flow_loss'], label='Train flow loss')
plt.plot(history['val_flow_loss'], label='Validation flow loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Flow matching convergence')
plt.grid(True, linestyle='--', alpha=0.4)
plt.legend()
plt.show()


## 5. Run inference with `WeatherFlowODE`

Once the vector field is trained, the differentiable ODE solver can evaluate forecasts at any intermediate time. We'll request five lead times between the initial state and the six-hour target (t=1).


In [None]:
trainer.model.eval()
ode_model = WeatherFlowODE(
    flow_model=trainer.model,
    solver_method=config['solver_method'],
    rtol=config['rtol'],
    atol=config['atol']
)

val_iter = iter(val_loader)
batch_for_eval = next(val_iter)
x0 = batch_for_eval['input'].to(device)
x1 = batch_for_eval['target']
lead_times = torch.linspace(0.0, 1.0, steps=5, device=torch_device)

with torch.no_grad():
    trajectory = ode_model(x0, lead_times)

forecast = trajectory[-1].cpu()
truth = x1
mse_per_sample = torch.mean((forecast - truth) ** 2, dim=(1, 2, 3))

print(f'Trajectory shape: {trajectory.shape}')
print(f'MSE per sample: {mse_per_sample.numpy()}')


## 6. Visualise predictions versus truth

With the forecasts in hand, we can inspect individual variables and pressure levels. The `WeatherVisualizer` offers map projections and difference plots; when it isn't available the notebook falls back to standard Matplotlib heatmaps.


In [None]:
var_name = config['visualize_variable']
var_index = config['variables'].index(var_name)
level_index = 0
truth_fields = unflatten_channels(truth[0], config['variables'], config['pressure_levels'])
forecast_fields = unflatten_channels(forecast[0], config['variables'], config['pressure_levels'])
true_slice = truth_fields[var_index, level_index]
pred_slice = forecast_fields[var_index, level_index]

if HAS_VISUALIZER and WeatherVisualizer is not None:
    visualizer = WeatherVisualizer(figsize=(16, 6))
    try:
        fig, axes = visualizer.plot_comparison(
            true_data={var_name: true_slice},
            pred_data={var_name: pred_slice},
            var_name=var_name,
            level_idx=0,
            title=f"{var_name} comparison at {config['pressure_levels'][level_index]} hPa"
        )
    except Exception as exc:
        print(f'WeatherVisualizer plotting failed ({exc}). Falling back to Matplotlib.')
        HAS_VISUALIZER = False

if not HAS_VISUALIZER or WeatherVisualizer is None:
    lon_vals = longitudes.numpy() if isinstance(longitudes, torch.Tensor) else np.asarray(longitudes)
    lat_vals = latitudes.numpy() if isinstance(latitudes, torch.Tensor) else np.asarray(latitudes)
    extent = [lon_vals.min(), lon_vals.max(), lat_vals.min(), lat_vals.max()]

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    im0 = axes[0].imshow(true_slice.numpy(), extent=extent, origin='lower', aspect='auto', cmap='viridis')
    axes[0].set_title('Truth')
    fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

    im1 = axes[1].imshow(pred_slice.numpy(), extent=extent, origin='lower', aspect='auto', cmap='viridis')
    axes[1].set_title('Prediction')
    fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

    diff = (pred_slice - true_slice).numpy()
    im2 = axes[2].imshow(diff, extent=extent, origin='lower', aspect='auto', cmap='RdBu_r')
    axes[2].set_title('Difference')
    fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)

    for ax in axes:
        ax.set_xlabel('Longitude')
        ax.set_ylabel('Latitude')

    plt.suptitle(f"{var_name} comparison at {config['pressure_levels'][level_index]} hPa")
    plt.tight_layout()
    plt.show()


## 7. Where to go next

* Swap in the real ERA5 dataset by setting `use_synthetic_data=False` and pointing `era5_data_path` to a WeatherBench2 store or local Zarr.
* Increase the resolution or add humidity variables to stress-test the architecture.
* Integrate the training history with Weights & Biases or TensorBoard for richer monitoring.
* Extend the visual analysis with the WeatherFlow dashboard (`frontend/`) or the animation utilities in `weatherflow.utils.visualization`.

Happy forecasting!
