# Hands-on: Training the AIFS-ENS with Anemoi

In this tutorial we will learn how to train the AIFS-ENS (ensemble) model using the anemoi packages. We'll focus on the CRPS (Continuous Ranked Probability Score) based training approach, which is specifically designed for ensemble weather forecasting.

**Learning Objectives**

By the end of this tutorial, you will:
- Understand the key differences between deterministic and ensemble CRPS training
- Learn how to configure the anemoi training pipeline for ensemble models
- Build a minimal training configuration step-by-step
- Execute a short training run to verify everything works


**Resources**

- [Anemoi docu: CRPS-based training](https://anemoi.readthedocs.io/projects/training/en/latest/user-guide/kcrps-set-up.html)
- [Anemoi Documentation](https://anemoi.readthedocs.io/projects/training/en/latest/)
- [Lang et al. 2024](http://arxiv.org/abs/2412.15832)

## Background: What is CRPS Training?

The **Continuous Ranked Probability Score (CRPS)** is a proper scoring rule for evaluating probabilistic forecasts. In the context of ensemble weather forecasting, CRPS training allows us to train models that produce multiple ensemble members, each representing a different possible future state of the atmosphere.

<img src="_resources/aifs-crps_sketch.png" alt="CRPS Sketch" width="900">

### Why Ensemble Training?

- **Uncertainty Quantification**: Each ensemble member represents a different possible future
- **Probabilistic Forecasting**: Provides uncertainty estimates alongside predictions
- **Better Skill Scores**: Often outperforms deterministic models in terms of skill metrics
- **Operational Use**: Essential for weather services that need to communicate forecast uncertainty




## CRPS Training in Anemoi

### Key Differences: Deterministic vs CRPS Training

The main components of the training pipeline need to be modified when switching from deterministic to ensemble CRPS training:

| Component | Deterministic | CRPS |
|-----------|---------------|------|
| **Forecaster** | `GraphForecaster` | `GraphEnsForecaster` |
| **Strategy** | `DDPGroupStrategy` | `DDPEnsGroupStrategy` |
| **Training Loss** | `WeightedMSELoss` | `AlmostFairKernelCRPS` |
| **Model** | `AnemoiModelEncProcDec` | `AnemoiEnsModelEncProcDec` |
| **Datamodule** | `AnemoiDatasetsDataModule` | `AnemoiEnsDatasetsDataModule` |


#### The AlmostFairKernelCRPS Loss

The training uses the AlmostFairKernelCRPS loss function, which combines the traditional CRPS with a "fair" version:

$$\text{afCRPS}_\alpha := \alpha\text{fCRPS} + (1-\alpha)\text{CRPS}$$

Where $\alpha$ is a trade-off parameter between the CRPS and the fair CRPS.


## Building Our Training Config

Now we'll examine our training configuration step-by-step, highlighting the key differences from deterministic training. We have a minimal configuration file ready that we'll load and examine section by section.

In [None]:
# Let's start by importing the necessary modules
import yaml
from pathlib import Path

# Load our minimal configuration file
config_path = Path("configs/aifs_ens_minimal.yaml")
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

### Step 1: Hardware Configuration

The hardware configuration needs to specify the number of GPUs per ensemble, which is crucial for the ensemble training strategy.

**Key Points:**
- `num_gpus_per_ensemble`: Number of GPUs to use per ensemble (typically 1 for small setups)
- `num_gpus_per_model`: Number of GPUs per model instance
- Total ensemble members = `ensemble_size_per_device` × `num_gpus_per_ensemble` (we'll set this later)


In [None]:
# Display the hardware configuration section
print("Hardware Configuration:")
print("=" * 50)
print(yaml.dump(config['hardware'], default_flow_style=False))


### Step 2: Datamodule Configuration

For ensemble training, we need to use the `AnemoiEnsDatasetsDataModule` instead of the regular datamodule. This handles ensemble data loading and can work with either:
- Single initial conditions for all ensemble members
- Perturbed initial conditions (if available in your dataset)


In [None]:
print("Datamodule configuration:")
print("=" * 50)
print(yaml.dump(config['datamodule'], default_flow_style=False))


### Step 3: Model Configuration

Key model changes for CRPS-based training are:

1. **Ensemble Model Class**: Uses `AnemoiEnsModelEncProcDec` instead of `AnemoiModelEncProcDec`

2. **Noise Injector**: Each ensemble member samples random noise at every time step:
   ```yaml
   noise_injector:
     _target_: anemoi.models.layers.ensemble.NoiseConditioning
     noise_std: 1
     noise_channels_dim: 4
     noise_mlp_hidden_dim: 32
     inject_noise: True
   ```

3. **Conditional Layer Normalization**: The processor uses `ConditionalLayerNorm` instead of regular `LayerNorm` to condition the latent space on the noise:
   ```yaml
   processor:
     layer_kernels:
       LayerNorm:
         _target_: anemoi.models.layers.normalization.ConditionalLayerNorm
         normalized_shape: ${model.num_channels}
         condition_shape: ${model.noise_injector.noise_channels_dim}
   ```

   Unlike standard layer normalization that normalizes features independently, conditional layer normalization allows the normalization to be conditioned on additional information (in this case, noise vectors).
    - Each ensemble member gets a unique noise vector at every time step
    - This noise is embedded and used to condition the layer normalization in the processor
    - The conditioning allows the same model weights to produce different outputs for different ensemble members
    - This creates diversity in the ensemble predictions while sharing computational resources


This noise injection and conditioning is what allows each ensemble member to produce different predictions while sharing the same model weights.

In [None]:
# Display the model configuration section
print("Model Configuration:")
print("=" * 30)
print(yaml.dump(config['model'], default_flow_style=False))


### Step 4: Training Configuration

Now we configure the training parameters, strategy, and loss function for ensemble training.

**Key Training Parameters:**

1. **Model Task**: Set to `GraphEnsForecaster` (handled by the `ensemble` training default)
2. **Ensemble Size**: `ensemble_size_per_device: 2` means 2 ensemble members per device
3. **Total Ensemble Members**: `ensemble_size_per_device` × `num_gpus_per_ensemble` = 2 × 1 = 2 members
4. **Strategy**: Uses `DDPEnsGroupStrategy` (handled by the `ensemble` training default)
5. **Loss Function**: `AlmostFairKernelCRPS` with `alpha=1.0` (pure fair CRPS)


In [None]:
# Display the training configuration section
print("Training Configuration:")
print("=" * 30)
print(yaml.dump(config['training'], default_flow_style=False))


## Training Execution

Now that we have our configuration ready, let's execute the training. We'll run a short training session to verify everything works correctly by running,

```bash
anemoi-training train --config-path . --config-name aifs_ens_minimal_config
```

In [None]:
import subprocess

# Execute the training using subprocess
print("Starting AIFS-ENS training...")
print("=" * 50)

# Run the training command using subprocess
result = subprocess.run([
    "anemoi-training", "train", 
    "--config-path", "/home/ecm1922/Projects/ml-training-course/2025-ml-training/6-Anemoi/configs", 
    "--config-name", "aifs_ens_minimal.yaml"
    ], 
    check=True,  # Raise exception if command fails
    capture_output=True,  # Capture output for display
    text=True  # Return output as text
)
print(result.stdout)
    
print("\n" + "=" * 50)
print("\n✓ Training completed successfully!")

## Monitoring and Results

### Key Metrics to Monitor

- **Training Loss**: The AlmostFairKernelCRPS loss should decrease over time
- **Validation Metrics**: Similar to training loss, calculated on validation data
- **Learning Rate**: Should follow the warmup schedule
- **Memory Usage**: Monitor GPU memory usage
- **Training Speed**: Steps per second

### Output Files

The training will create several output files:

- **Checkpoints**: Model weights saved periodically (`.ckpt` files)
- **Logs**: Training logs and metrics
- **MLflow Artifacts**: If MLflow is enabled, check the tracking UI
- **Plots**: Diagnostic plots (if enabled)


### Follow-up Exercises

1. **Run the Training**: Uncomment the training code and execute a real training run
2. **Experiment with Parameters**: Try different values for:
   - `ensemble_size_per_device` (e.g., 4, 8)
   - `alpha` parameter in the loss function (e.g., 0.5, 0.8)
   - Generating ensemble forecasts