# Multi-Run AIFS ENS v1 - Running 50 Ensemble Members

This notebook runs ECMWF's aifs-ens-v1 data-driven model for multiple ensemble members (1-50), using ECMWF's [open data](https://www.ecmwf.int/en/forecasts/datasets/open-data) dataset and the [anemoi-inference](https://anemoi-inference.readthedocs.io/en/latest/apis/level1.html) package.

aifs-ens-v1 is designed to be an inherently uncertain model, meaning that for the same initial conditions, different noise is applied within the model and a different forecast will be provided. This notebook runs all 50 perturbed ensemble members and saves each as a separate GRIB file.

# 1. Install Required Packages and Imports

In [None]:
## Uncomment the lines below to install the required packages
#!pip install torch==2.5.0 anemoi-inference[huggingface]==0.6.0 anemoi-models==0.6.0 anemoi-graphs==0.6.0 anemoi-datasets==0.5.23
#!pip install earthkit-regrid==0.4.0 'ecmwf-opendata>=0.3.19'
#!pip install flash_attn

In [None]:
import datetime
from collections import defaultdict
import os
import time

import numpy as np
import earthkit.data as ekd
import earthkit.regrid as ekr

from anemoi.inference.runners.simple import SimpleRunner
from anemoi.inference.outputs.printer import print_state
from anemoi.inference.outputs.gribfile import GribFileOutput
from anemoi.inference.context import Context

from ecmwf.opendata import Client as OpendataClient

# 2. Configuration and Setup
## List of parameters to retrieve from ECMWF open data

In [None]:
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw"]
PARAM_SFC_FC = ["lsm", "z", "slor", "sdor"]
PARAM_SOIL = ["sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1, 2]

# Configuration for multi-run
ENSEMBLE_MEMBERS = list(range(1, 51))  # Members 1-50
LEAD_TIME = 72  # Hours
OUTPUT_DIR = "ensemble_outputs"

## Select a date and create output directory

In [None]:
DATE = OpendataClient("ecmwf").latest()
print("Initial date is", DATE)

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
date_str = DATE.strftime("%Y%m%d_%H%M")

## Define Data Retrieval Function

In [None]:
def get_open_data(param, levelist=[], number=None):
    fields = defaultdict(list)
    # Get the data for the current date and the previous date
    for date in [DATE - datetime.timedelta(hours=6), DATE]:
        if number is None:
            data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
        else:
            data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist, 
                                 number=[number], stream='enfo')
        
        for f in data:
            # Open data is between -180 and 180, we need to shift it to 0-360
            assert f.to_numpy().shape == (721, 1440)
            values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
            # Interpolate the data from 0.25 to N320
            values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
            # Add the values to the list
            name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
            fields[name].append(values)

    # Create a single matrix for each parameter
    for param, values in fields.items():
        fields[param] = np.stack(values)

    return fields

## Define Function to Get Input Fields for a Given Ensemble Member

In [None]:
def get_input_fields(number):
    """Get input fields for a specific ensemble member."""
    fields = {}
    
    # Add single level fields
    fields.update(get_open_data(param=PARAM_SFC, number=number))
    fields.update(get_open_data(param=PARAM_SFC_FC))  # Constant fields
    
    # Add soil fields
    soil = get_open_data(param=PARAM_SOIL, levelist=SOIL_LEVELS, number=number)
    
    # Rename soil parameters
    mapping = {'sot_1': 'stl1', 'sot_2': 'stl2',
               'vsw_1': 'swvl1', 'vsw_2': 'swvl2'}
    for k, v in soil.items():
        fields[mapping[k]] = v
    
    # Add pressure level fields
    fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS, number=number))
    
    # Convert geopotential height to geopotential
    for level in LEVELS:
        gh = fields.pop(f"gh_{level}")
        fields[f"z_{level}"] = gh * 9.80665
    
    return fields

# 3. Load the Model

In [None]:
# Optional: Set environment variables to reduce memory usage
# import os
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' 
# os.environ['ANEMOI_INFERENCE_NUM_CHUNKS'] = '16'

checkpoint = {"huggingface": "ecmwf/aifs-ens-1.0"}
runner = SimpleRunner(checkpoint, device="cuda")
print("Model loaded successfully!")

# 4. Run Forecasts for All Ensemble Members

In [None]:
# Track processing times
processing_times = []

for member in ENSEMBLE_MEMBERS:
    start_time = time.time()
    print(f"\n{'='*60}")
    print(f"Processing ensemble member {member}/{len(ENSEMBLE_MEMBERS)}")
    print(f"{'='*60}")
    
    try:
        # Get input fields for this member
        print(f"Retrieving initial conditions for member {member}...")
        fields = get_input_fields(member)
        input_state = dict(date=DATE, fields=fields)
        
        # Setup output file
        grib_file = f"{OUTPUT_DIR}/aifs_ens_forecast_{date_str}_member{member:03d}.grib"
        
        # Create context for outputs
        context = Context()
        context.time_step = 6  # 6-hour time step
        context.lead_time = LEAD_TIME
        context.reference_date = DATE
        
        # Initialize GRIB output
        grib_output = GribFileOutput(context, path=grib_file)
        
        # Run forecast
        print(f"Running forecast for member {member}...")
        outputs_initialized = False
        step_count = 0
        
        for state in runner.run(input_state=input_state, lead_time=LEAD_TIME):
            # Initialize output on first state
            if not outputs_initialized:
                grib_output.open(state)
                outputs_initialized = True
            
            # Write to output
            grib_output.write_step(state)
            step_count += 1
            
            # Print progress every 4 steps (24 hours)
            if step_count % 4 == 0:
                print(f"  Member {member}: {step_count * 6} hours completed")
        
        # Close output
        grib_output.close()
        
        # Verify file
        if os.path.exists(grib_file):
            file_size = os.path.getsize(grib_file) / (1024 * 1024)  # Size in MB
            print(f"✓ Member {member} completed: {grib_file} ({file_size:.2f} MB)")
        else:
            print(f"✗ Error: Output file not created for member {member}")
        
        # Track time
        member_time = time.time() - start_time
        processing_times.append(member_time)
        print(f"Member {member} processing time: {member_time:.2f} seconds")
        
        # Estimate remaining time
        if len(processing_times) > 0:
            avg_time = sum(processing_times) / len(processing_times)
            remaining_members = len(ENSEMBLE_MEMBERS) - member
            est_remaining = avg_time * remaining_members
            print(f"Estimated time remaining: {est_remaining/60:.1f} minutes")
        
    except Exception as e:
        print(f"✗ Error processing member {member}: {str(e)}")
        continue

print(f"\n{'='*60}")
print("All ensemble members processed!")
print(f"Total processing time: {sum(processing_times)/60:.1f} minutes")
print(f"Average time per member: {sum(processing_times)/len(processing_times):.2f} seconds")

# 5. Verify Output Files

In [None]:
# List all generated files
grib_files = sorted([f for f in os.listdir(OUTPUT_DIR) if f.endswith('.grib')])
print(f"\nGenerated {len(grib_files)} GRIB files:")

total_size = 0
for f in grib_files:
    file_path = os.path.join(OUTPUT_DIR, f)
    file_size = os.path.getsize(file_path) / (1024 * 1024)  # MB
    total_size += file_size
    print(f"  {f}: {file_size:.2f} MB")

print(f"\nTotal disk space used: {total_size:.2f} MB ({total_size/1024:.2f} GB)")

# 6. Optional: Quick Verification of a Sample File

In [None]:
# Verify one of the output files
if grib_files:
    sample_file = os.path.join(OUTPUT_DIR, grib_files[0])
    print(f"\nChecking sample file: {grib_files[0]}")
    
    try:
        grib_data = ekd.from_source("file", sample_file)
        print(f"File contains {len(grib_data)} fields")
        
        # Show first few fields
        print("\nFirst 5 fields:")
        for i, field in enumerate(grib_data[:5]):
            meta = field.metadata()
            print(f"  {i+1}: {meta.get('param')} at {meta.get('levelist', 'surface')} - "
                  f"step: {meta.get('step')} hours")
    except Exception as e:
        print(f"Error reading file: {e}")