Adapted from https://huggingface.co/ecmwf/aifs-single-1.0/blob/main/run_AIFS_v1.ipynb

In [None]:
import datetime
from collections import defaultdict

import numpy as np
import earthkit.data as ekd
import earthkit.regrid as ekr

from anemoi.inference.runners.simple import SimpleRunner
from anemoi.inference.outputs.printer import print_state

from ecmwf.opendata import Client as OpendataClient

In [None]:
PARAM_SFC = ["10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "lsm", "z", "slor", "sdor"]
PARAM_SOIL =["vsw","sot"]
PARAM_PL = ["gh", "t", "u", "v", "w", "q"]
LEVELS = [1000, 925, 850, 700, 600, 500, 400, 300, 250, 200, 150, 100, 50]
SOIL_LEVELS = [1,2]

In [None]:
DATE = OpendataClient().latest()

In [None]:
print("Initial date is", DATE)

In [None]:
def get_open_data(param, levelist=[]):
    fields = defaultdict(list)
    # Get the data for the current date and the previous date
    for date in [DATE - datetime.timedelta(hours=6), DATE]:
        data = ekd.from_source("ecmwf-open-data", date=date, param=param, levelist=levelist)
        for f in data:
            # Open data is between -180 and 180, we need to shift it to 0-360
            assert f.to_numpy().shape == (721,1440)
            values = np.roll(f.to_numpy(), -f.shape[1] // 2, axis=1)
            # Interpolate the data to from 0.25 to N320
            values = ekr.interpolate(values, {"grid": (0.25, 0.25)}, {"grid": "N320"})
            # Add the values to the list
            name = f"{f.metadata('param')}_{f.metadata('levelist')}" if levelist else f.metadata("param")
            fields[name].append(values)

    # Create a single matrix for each parameter
    for param, values in fields.items():
        fields[param] = np.stack(values)

    return fields

In [None]:
fields = {}
fields.update(get_open_data(param=PARAM_SFC))

In [None]:
soil=get_open_data(param=PARAM_SOIL,levelist=SOIL_LEVELS)

In [None]:
mapping = {'sot_1': 'stl1', 'sot_2': 'stl2',
           'vsw_1': 'swvl1','vsw_2': 'swvl2'}
for k,v in soil.items():
    fields[mapping[k]]=v

In [None]:
fields.update(get_open_data(param=PARAM_PL, levelist=LEVELS))

In [None]:
# Transform GH to Z
for level in LEVELS:
    gh = fields.pop(f"gh_{level}")
    fields[f"z_{level}"] = gh * 9.80665

In [None]:
input_state = dict(date=DATE, fields=fields)

In [None]:
checkpoint = {"huggingface":"ecmwf/aifs-single-1.0"}

In [None]:
runner = SimpleRunner(checkpoint, device="cuda")

In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
os.environ['ANEMOI_INFERENCE_NUM_CHUNKS']='16'

In [None]:
for state in runner.run(input_state=input_state, lead_time=12):
    print_state(state)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.tri as tri

In [None]:
for name, a in state['fields'].items():
    print(name, a.shape, a.dtype)

In [None]:
def fix(lons):
    # Shift the longitudes from 0-360 to -180-180
    return np.where(lons > 180, lons - 360, lons)

latitudes = state["latitudes"]
longitudes = state["longitudes"]
values = state["fields"]["100u"]


triangulation = tri.Triangulation(fix(longitudes), latitudes)

fig, ax = plt.subplots(figsize=(11, 6))

contour = ax.tricontourf(triangulation, values, levels=20,  cmap="RdBu")

