
# Running Ensemble Inference

The following notebook demostrates how to use Earth-2 MIP's config schema and builtin
inference workflows to perform ensemmble inference of the FourCastNetv2 small (FCNv2-sm)
weather model with an intial state pulled from the Climate Data Store (CDS) and
perturbed with random noise. The ensemble output will then be loaded into an Xarray
Dataset and some sample data analysis is provided.

In summary this notebook will cover the following topics:

- Configuring and setting up FCNv2 model registry
- An ensemble configuration file
- Running ensemble inference in Earth-2 MIP to produce an xarray DataSet
- Post processing results


## Set Up
Starting off with imports, hopefully you have already installed Earth-2 MIP from this
repository. There are a few additional packages needed.



In [None]:
config = {
    "ensemble_members": 4,
    "noise_amplitude": 0.05,
    "simulation_length": 10,
    "weather_event": {
        "properties": {
            "name": "Globe",
            "start_time": "2017-08-23 12:00:00",
            #"start_time": "2022-07-01 00:00:00",
            "initial_condition_source": "era5",
        },
        "domains": [
            {
                "name": "global",
                "type": "Window",
                "diagnostics": [{"type": "raw", "channels": ["t2m", "u10m", "v10m", "u925", "v925"]}],
            }
        ],
    },
    "output_path": "outputs/01_ensemble_notebook",
    "output_frequency": 1,
    "weather_model": "fcnv2_sm",
    "seed": 12345,
    "use_cuda_graphs": False,
    "ensemble_batch_size": 1,
    "autocast_fp16": False,
    "perturbation_strategy": "correlated",
    "noise_reddening": 2.0,
}

some text

In [None]:
import os
import xarray
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import numpy as np

datafile = '/e2ws/exercises/harvey_10_days.nc'


countries = cfeature.NaturalEarthFeature(
    category="cultural",
    name="admin_0_countries",
    scale="50m",
    facecolor="none",
    edgecolor="black",
)

def open_ensemble(f, domain, chunks={"time": 1}):
    time = xarray.open_dataset(f).time
    root = xarray.open_dataset(f, decode_times=False)
    ds = xarray.open_dataset(f, chunks=chunks, group=domain)
    ds.attrs = root.attrs
    return ds.assign_coords(time=time)


output_path = config["output_path"]
domains = config["weather_event"]["domains"][0]["name"]
ensemble_members = config["ensemble_members"]
ds = open_ensemble(datafile, domains)
ds

In [None]:
# reg_ds = ds.sel(lon=list(np.arange(70,85.25,.25)), lat=list(np.arange(20,35.25,.25)))
# reg_ds = ds.sel(lat=list(np.arange(20,35.25,.25)))
reg_ds = ds.sel(lat=list(np.arange(20,35.25,.25)),
                lon=list(np.arange(360-110,360-80.25,.25)))

## Post Processing
With inference complete, now the fun part: post processing and analysis!
You can manipulate the data to your hearts content now that its in an Xarray Dataset.
Here we will demonstrate some common plotting / analysis workflows one may be
interested. Lets start off with importing all our post processing packages.

(You may need to pip install matplotlib and cartopy)



Next, lets plot some fields of surface temperature. Since we have an ensemble of
predictions, lets display the first ensemble member, which is deterministic member,
and also the last ensemble member and the ensemmble standard deviation. One or both of
the perturbed members may look a little noisy, thats because our noise amplitude is
maybe too high. Try lowering the amplitude in the config or changing pertibation type
to see what happens.



In [None]:
scale = 1      # scale factor for speeding up plotting
max_frames = 20 # maximum number of frames to plot
ensemble = 2

time_str = 'lead time:'
projection=ccrs.PlateCarree()
var_ds = np.sqrt(np.square(reg_ds.u10m) + np.square(reg_ds.v10m))
min_val = 0
max_val = float(np.max(var_ds[0,:,:,:]))

# define plots
def make_figure():
    fig = plt.figure(figsize=(11,5))
    ax = fig.add_subplot(1, 1, 1, projection=projection)

    ax.add_feature(cfeature.COASTLINE,lw=.5)
    ax.add_feature(cfeature.RIVERS,lw=.5)
    ax.add_feature(cfeature.BORDERS, linewidth=0.6, edgecolor='dimgray')

    # ax.set_xticks(np.arange(0., 18., 2.5), crs=crs.PlateCarree())
    # ax.set_yticks(np.arange(43.5, 50.,1.5), crs=crs.PlateCarree())
    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)

    return fig, ax

def make_frame(frame):
    print(f'\rprocessing frame {frame+1} of {min(max_frames, var_ds.shape[1])}', end='')
    plot_ds = var_ds[ensemble, max(frame,0), :, :]
    pc = ax.pcolormesh(reg_ds.lon[::scale], reg_ds.lat[::scale], plot_ds[::scale, ::scale], transform=projection,
                    cmap='plasma', vmin=min_val, vmax=max_val)

    if frame==-1:
        cbar = fig.colorbar(pc, extend='both', shrink=0.8, ax=ax)
        cbar.set_label('wind speed [m/s]', fontsize=12)

    header = time_str + " " + f'{frame*6}:00:00'
    ax.set_title(header, fontsize=14)

    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(-1)

# make animation
%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()
ani = animation.FuncAnimation(fig,
                              animate,
                              min(max_frames, var_ds.shape[1]),
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=.1)
plt.close('all')
ani

some text

In [None]:
nyc_lat = 40
nyc_lon = 360 - 74



def Nvidia_cmap():
    colors = ["#8946ff", "#ffffff", "#00ff00"]
    cmap = mcolors.LinearSegmentedColormap.from_list("custom_cmap", colors)
    return cmap


plt.close("all")

data = ds.u10m[0, 10,:,:]  # [ensemble, time, lat, lon]
# data = ds.u925[0, 10,:,:]  # [ensemble, time, lat, lon]

fig = plt.figure(figsize=(9, 6))
plt.rcParams["figure.dpi"] = 100
proj = ccrs.NearsidePerspective(central_longitude=nyc_lon, central_latitude=nyc_lat)

ax = fig.add_subplot(111, projection=proj)
ax.set_title("ens. mean 10 meter zonal wind [m/s]")
img = ax.pcolormesh(
    ds.lon,
    ds.lat,
    data,
    transform=ccrs.PlateCarree(),
    cmap=Nvidia_cmap(),
    # vmin=-20,
    # vmax=20,
)
ax.coastlines(linewidth=1)
ax.add_feature(countries, edgecolor="black", linewidth=0.25)
plt.colorbar(img, ax=ax, shrink=0.40, norm=mcolors.CenteredNorm(vcenter=0))
gl = ax.gridlines(draw_labels=True, linestyle="--")
plt.savefig(f"{output_path}/gloabl_mean_zonal_wind_contour.png")