# Run Ensemble Mean upon Anemoi Models
Connecting cascade and anemoi allows for distributed and managed execution of models, and generation of products.

In [2]:
import anemoi.cascade as ac

In [3]:
import os

# Setup Environment
os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
os.environ['LD_LIBRARY_PATH'] = f"INSERT_LD_LIBRARY_PATH_HERE:{os.environ['LD_LIBRARY_PATH']}"

## Additional cascade actions
`pproc-cascade` provides an action api for advanced production generation, but depends on pproc

In [4]:
from cascade import Cascade
import earthkit.data as ekd

PPCASCADE_IMPORTED = True
try:
    from ppcascade import fluent as ppfluent
except (RuntimeError, ImportError, AttributeError) as e:
    print('pproc cascade could not be imported', e)
    PPCASCADE_IMPORTED = False

In [5]:
CKPT = {'huggingface':'ecmwf/aifs-single-1.0'}

With the environment ready, and the ckpt established, we can begin to call a prediction

In [6]:
ac.fluent.from_input?

[0;31mSignature:[0m
[0mac[0m[0;34m.[0m[0mfluent[0m[0;34m.[0m[0mfrom_input[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mckpt[0m[0;34m:[0m [0;34m'VALID_CKPT'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0minput[0m[0;34m:[0m [0;34m'str | dict[str, Any]'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdate[0m[0;34m:[0m [0;34m'str | tuple[int, int, int]'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlead_time[0m[0;34m:[0m [0;34m'Any'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mensemble_members[0m[0;34m:[0m [0;34m'int'[0m [0;34m=[0m [0;36m1[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m**[0m[0mkwargs[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m'fluent.Action'[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Run an anemoi model from a given input source

Parameters
----------
ckpt : VALID_CKPT
    Checkpoint to load
input : str | dict[str, Any]
    `anemoi.inference` input

All that is needed to run a prediction is the:
- checkpoint
- the input_type
- Start time
- Lead time

From there, the number of ensembles can be set, which is automatically added as a dimension

Note:
    Some issues do occur with running out of memory at higher ensemble numbers

In [7]:
model_action = ac.fluent.from_input(CKPT, 'mars', '2022-01-01T00:00', lead_time = '7D', ensemble_members=51)
model_action.nodes

            No post_processors defined. Accumulations will be accumulated from the beginning of the forecast.

            🚧🚧🚧 In a future release, the default will be to NOT accumulate from the beginning of the forecast. 🚧🚧🚧
            Update your config if you wish to keep accumulating from the beginning.
            https://github.com/ecmwf/anemoi-inference/issues/131
            
  from .autonotebook import tqdm as notebook_tqdm
Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 34.95it/s]
  versions[name] = str(module.__version__)
Environment validation failed. The following issues were found:
  python:
    Python version mismatch: 3.11.6 != 3.11.8
  missing:
    Missing module in inference environment: anemoi.graphs
  mismatch:
    Version of module anemoi.utils was lower in training then in inference: 0.4.1 <= 0.4.12
  critical mismatch:
    CRITICAL: Version of module anemoi.datasets was greater in training then in inference: 0.5.7 > 0.4.3
    CRITICAL: Version of module 

Now we can check the coordinates of the graph.

In [8]:
model_action.nodes.coords

Coordinates:
  * param            (param) object 240B 'cp' 'tp' '100u' '100v' ... 'u' 't' 'z'
  * ensemble_member  (ensemble_member) int64 408B 0 1 2 3 4 5 ... 46 47 48 49 50
  * date             (date) <U16 64B '2022-01-01T00:00'
  * step             (step) int64 224B 6 12 18 24 30 36 ... 144 150 156 162 168

In [None]:
if PPCASCADE_IMPORTED:
    from ppcascade.utils.window import Range
    
    interpolation = {
        "grid": "O320"
    }
    
    windows = [
                Range("0-24", [6, 12, 18, 24]), 
                Range("12-24", [12, 18, 24]), 
            ]


In [None]:
PRODUCT_CATALOG = {
    'ensemble_window_mean': lambda x: x.switch(ppfluent.Action).window_operation("mean", windows,dim="step", batch_size=2).ensemble_operation("mean", dim="ensemble_member"),
    'ensemble_mean': lambda x: x.select({'step':24}).mean(dim="member"),
}

In [None]:
graph = PRODUCT_CATALOG['ensemble_mean'](model_action)

With the ensemble mean for a bunch of the params chosen, we can visualise the graph

In [None]:
cascade_sel = Cascade.from_actions([graph])
cascade_sel.visualise("EnsembleMean.html", cdn_resources='in_line', preset='blob')

Now the fun part, execution...

In [None]:
%%time

from cascade.executors.dask import DaskLocalExecutor 

cascade_sel.executor = DaskLocalExecutor(memory_limit="24GB", n_workers=1, threads_per_worker=1)
results = cascade_sel.execute()

In [None]:
combined_data = ekd.from_source('multi', list(results.values()))

In [None]:
combined_data.ls()

## Plotting

In [None]:
from earthkit.regrid import interpolate
import earthkit.plots.quickmap as qmap

In [None]:
VAR_OF_INTEREST = '2t'

In [None]:
# the target grid is a global 5x5 degree regular latitude grid
out_grid = {"grid": [0.25,0.25]}
r = interpolate(combined_data.sel(param = VAR_OF_INTEREST), out_grid=out_grid, method="linear")

In [None]:
qmap.plot(r, colors = 'viridis')