In [1]:
import datetime
import numpy as np
import torch
from collections import defaultdict

import earthkit.data as ekd
import earthkit.regrid as ekr

from anemoi.inference.runners.simple import SimpleRunner
from anemoi.inference.outputs.printer import print_state

from ecmwf.opendata import Client as OpendataClient

# 1. Import initial conditions from ECMWF open data    
**Parameters to retrieve from ECMWF**

In [None]:
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
# msl: Mean sea level pressure
# skt: Skin temperature
# sp: Surface pressure
# tcw: Total column vertically-integrated water vapour
# lsm: Land Sea Mask
# z: Geopotential
# slor: Slope of sub-gridscale orography (step 0)
# sdor: Standard deviation of sub-gridscale orography (step 0)
PARAM_SOIL =["vsw","sot"]
# vsw: Volumetric soil water (layers 1-4)
# sot: Soil temperature (layers 1-4)
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
# q: Specific humidity
# w: vertical velocity
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1,2]

**Date of initial conditions**

In [3]:
DATE = OpendataClient().latest()
print(f"Initial date is {DATE}")

Initial date is 2025-09-18 06:00:00


**Fetch the data using ECMWF Open Data API**

In [4]:
def get_open_data(param, levelist=[]):
    fields = defaultdict(list)
    # Get data at time t and t-1:
    for date in [DATE - datetime.timedelta(hours=6), DATE]:
        data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist) # <class 'earthkit.data.readers.grib.file.GRIBReader'>
        for f in data:  # <class 'earthkit.data.readers.grib.codes.GribField'>
            assert f.to_numpy().shape == (721,1440)
            values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
            # Interpolate the data to from 0.25°x0.25° (regular lat-lon grid, 2D) to N320 (reduced gaussian grid, 1D, see definition here: https://www.ecmwf.int/en/forecasts/documentation-and-support/gaussian_n320) 
            values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
            # Add the values to the list
            name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
            fields[name].append(values)

    # Create a single matrix for each parameter
    for param, values in fields.items():
        fields[param] = np.stack(values)

    return fields

In [5]:
fields = {}

**Single-level field**

In [6]:
fields.update(get_open_data(param=PARAM_SFC))

                                                                

By downloading data from the ECMWF open data dataset, you agree to the terms: Attribution 4.0 International (CC BY 4.0). Please attribute ECMWF when downloading this data.


                                                                

In [7]:
soil=get_open_data(param=PARAM_SOIL,levelist=SOIL_LEVELS)

# soil parameters need to be renamed to be consistent with training
mapping = {'sot_1': 'stl1', 'sot_2': 'stl2',
           'vsw_1': 'swvl1','vsw_2': 'swvl2'}
for k,v in soil.items():
    fields[mapping[k]]=v

                                                                                     

**Pressure level fields**

In [8]:
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))

                                                                

In [9]:
# Convert geopotential height into geopotential (transform GH to Z)
for level in LEVELS:
    gh = fields.pop(f"gh_{level}")
    fields[f"z_{level}"] = gh * 9.80665

In [10]:
print(fields.keys())
print(fields["10u"])    # fields["field_name"][0] array for t-1, fields["field_name"][1] array for t, both in N320 format 

dict_keys(['10u', '10v', '2d', '2t', 'msl', 'skt', 'sp', 'tcw', 'lsm', 'z', 'slor', 'sdor', 'swvl1', 'swvl2', 'stl1', 'stl2', 't_1000', 't_925', 't_850', 't_700', 't_600', 't_500', 't_400', 't_300', 't_250', 't_200', 't_150', 't_100', 't_50', 'u_1000', 'u_925', 'u_850', 'u_700', 'u_600', 'u_500', 'u_400', 'u_300', 'u_250', 'u_200', 'u_150', 'u_100', 'u_50', 'v_1000', 'v_925', 'v_850', 'v_700', 'v_600', 'v_500', 'v_400', 'v_300', 'v_250', 'v_200', 'v_150', 'v_100', 'v_50', 'w_1000', 'w_925', 'w_850', 'w_700', 'w_600', 'w_500', 'w_400', 'w_300', 'w_250', 'w_200', 'w_150', 'w_100', 'w_50', 'q_1000', 'q_925', 'q_850', 'q_700', 'q_600', 'q_500', 'q_400', 'q_300', 'q_250', 'q_200', 'q_150', 'q_100', 'q_50', 'z_1000', 'z_925', 'z_850', 'z_700', 'z_600', 'z_500', 'z_400', 'z_300', 'z_250', 'z_200', 'z_150', 'z_100', 'z_50'])
[[ 2.25539092  1.20666618  0.11760587 ... -6.39464971 -6.26019782
  -5.23836346]
 [-1.10034716 -2.92889286 -4.08517912 ... -6.46423058 -6.38355944
  -5.32138951]]


**Create initial state**

In [11]:
input_state = dict(date=DATE, fields=fields)

# 2. Load the model and run the forecast
**Load model checkpoint form huggingface and create a runner**

In [12]:
checkpoint = {"huggingface":"ecmwf/aifs-single-1.0"}

In [13]:
print(torch.cuda.is_available())

False


In [14]:
runner = SimpleRunner(checkpoint, device="cpu")

**Run the forecast**

In [15]:
for state in runner.run(input_state=input_state, lead_time=12):
    print_state(state)

  from .autonotebook import tqdm as notebook_tqdm
Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 152059.36it/s]


ModuleNotFoundError: No module named 'flash_attn'