# Demo notebook for predictions with GenSIM

This is a notebook showing how GenSIM and its pre-trained weights can be used for efficient Arctic-wide forecasting. In this notebook you will learn how the expected data looks like, how the model can be used for auto-regressive prediction and how to reproduce some of the results presented in the manuscript.

## Download demo dataset

In [None]:
# Uncomment to download demo file from Zenodo
# ! wget -O data/auxiliary/ds_demo.nc https://zenodo.org/records/17535317/files/ds_demo.nc?download=1

## Download model weights

In [None]:
# Uncomment to download model weights from HuggingFace
# When prompted for a password, use an access token with write permissions.
# Generate one from your settings: https://huggingface.co/settings/tokens

# ! git lfs install
# ! git clone https://huggingface.co/tobifinn/GenSIM data/models

# Load the needed libraries

These are just the minimal libraries needed

In [None]:
# For empty context manager
import contextlib
# To load the data
import xarray as xr
# To read in the config and initialize model
from hydra import compose, initialize
from hydra.utils import instantiate
# To load the network weights
from safetensors import safe_open
# To run the prediction
import torch
# For general computation
import numpy as np
# For nice plotting with timeaxis
import pandas as pd
# For metric estimation
from scipy import stats
# For progress bar
from tqdm.notebook import tqdm 

# To plot the predictions
import matplotlib.pyplot as plt
import matplotlib.gridspec as mpl_gs
import matplotlib.colors as mpl_c
import cartopy.crs as ccrs
import cartopy
import cmocean

# To define the flow matching model
from gensim.model import FlowMatchingModel
import gensim.network

# Settings

In [None]:
# Set the manual seed, used for the sampling in the prediction steps
torch.manual_seed(42)

# Number of ensemble members
n_ens = 8

# Device used for inference (use of single GPU only)
device = torch.device("cuda:0")

# To activate mixed precision with which the model was (only useable for recent Nvidia GPUs)
use_mixed = True

# Path to the model checkpoints
model_path = "data/models/model_weights_ema.safetensors"

# To override the use of flash attention (only available for recent NVIDIA GPUs)
gensim.network.USE_FLASH_ATTN = True

In [None]:
# Context manager used if mixed precision is activated
prediction_context = torch.amp.autocast("cuda", dtype=torch.bfloat16) if use_mixed else contextlib.suppress()

# Load data

## Auxiliary data

The auxiliary dataset is already included in this repository

In [None]:
ds_aux = xr.open_dataset("data/auxiliary/ds_auxiliary.nc")

### Create mesh

The mesh coordinates are given as Cartesian coordinates in the cell centers in a North Polar Stereographic projection.

In [None]:
mesh = ds_aux[["x_coord", "y_coord"]].to_dataarray("coord_names").values

# Convert from metres into kilometres
mesh = mesh / 1000

# Add ensemble dimension
mesh = mesh[None].repeat(n_ens, axis=0)

# Convert into torch tensor
mesh = torch.as_tensor(mesh, device=device)

### Create mask

The mask is set to False (0) for land and set to True (1) for open ocean. The mask will be used within the neural network for the predictions to represent interactions with land and to mask the prediction over land.

In [None]:
mask = ds_aux["mask"].values

# Add ensemble dimension
mask = mask[None, None].repeat(n_ens, axis=0)

# Convert into torch tensor
mask = torch.as_tensor(mask, device=device)

# Number of valid grid points needed for metric estimation
n_grid = ds_aux["mask"].values.sum()

## Load the dataset

The here-used demo dataset contains all sea-ice states and forcings for four days between 2018-01-01 03:00 and 2018-01-04 15:00 in 12-hour steps as extracted from the original testing dataset.

We will split the data into initial conditions, forcings, and the truth.

In [None]:
# Load the netCDF dataset
ds_demo = xr.open_dataset("data/auxiliary/ds_demo.nc")["datacube"]

# Fill cell with land (stored as NaN) with zeros
ds_demo = ds_demo.fillna(0)

# Get how many prediction steps we should do
n_steps = len(ds_demo["time"]) - 1

### Initial conditions

The initial conditions are given by the six sea-ice states (thickness, concentration, damage, x-drift, y-drift, snow-on-ice thickness) at the first time step.

The shape of the initial conditions will be (n_ens, 6, 512, 512) with dimensions (ensemble members, ariables, y, x).

In [None]:
# Slice first time step
ds_initial = ds_demo.isel(time=0)

# Get only state variables
ds_initial = ds_initial.sel(var_names=["sit", "sic", "sid", "siu", "siv", "snt"])

# Create an ensemble of initial conditions
initial_conditions = ds_initial.values[None].repeat(n_ens, axis=0)

# Convert the initial conditions into a torch tensor
initial_conditions = torch.as_tensor(initial_conditions)

In [None]:
# To show how the initial conditions look like
print(ds_initial)

### Forcings

The forcings are extracted from the ERA5 reanalysis and given by the eight variables (2-metre temperature, 2-metre specific humidity, wind rotated in x-direction, wind rotated in y-direction).

For remaining more efficient, we will leave the forcings on the CPU as numpy array and convert them into the torch tensor on the GPU during inference time.

The shape of the forcings will be (n_ens, 2, 4, 512, 512) with dimensions (ensemble members, timesteps t and t+12 h, variables, y, x).

In [None]:
# Select forcings variables
ds_forcings = ds_demo.sel(var_names=['tus', 'huss', 'uas', 'vas'])

# Create ensemble dimension
forcings = ds_forcings.values[None].repeat(n_ens, axis=0)

In [None]:
# To show how the forcings look like
print(ds_forcings)

The degree day features will be only used for the current timestep, positive degree days over 30 days, freezing degree days over 30 days, positive degree days over 366 days, freezing degree days over 366 days). The degree day features are used as proxies for a longer thermodynamical development.

In [None]:
# Select degree day variables
ds_degree_days = ds_demo.sel(var_names=['pdd_month', 'fdd_month', 'pdd_year', 'fdd_year'])

# Create ensemble dimension
degree_days = ds_degree_days.values[None].repeat(n_ens, axis=0)

In [None]:
# To show how the degree days look like
print(ds_degree_days)

## Truth

Our truth is given as trajectory of sea-ice states over all timesteps.

The truth will stay as xarray array on the CPU.

The shape of the truth will be (8, 6, 512, 512) with dimensions (all time steps, variables, y, x).

In [None]:
# Get only state variables
ds_truth = ds_demo.sel(var_names=["sit", "sic", "sid", "siu", "siv", "snt"])

# Load the model

The model is defined as all-inclusive wrapping: it includes the neural network, the encoder/decoder (scaling) from/to physical space, the second-order flow matching sampler, the domain decomposition for inference.

The configuration is performed with hydra. Based on this configuration the model will be built. After building the model, the network weights will be loaded, and the model will be potentially compiled with a just-in-time compiler for faster inference.

## Load the config

In [None]:
with initialize(version_base=None, config_path=".", job_name="prediction"):
    cfg = compose(
        config_name="config.yaml",
        overrides=[]                          # If you want to change config options on the fly
    )

# We are only interested in the surrogate model part ofthe config
cfg = cfg["surrogate"]

## Instantiate the model

The model will be instantiated from the loaded config file. The instantiation will create all needed elements for the model. The model is then copied to the selected device (likely GPU).

In [None]:
model = instantiate(cfg)

## Load the neural network weights

The neural network weights are stored as safetensor to avoid contaminated files.

In [None]:
# Open the safetensor
with safe_open(model_path, framework="pt", device="cuda") as f:
    saved_keys = list(f.keys())
    
    # Get the network state dict
    network_state_dict = model.network.state_dict()

    # Update the network state dict with the weights from the tensor
    network_state_dict.update({key: f.get_tensor(key) for key in saved_keys})

    # Load the updated state dict into the network
    model.network.load_state_dict(network_state_dict)

    # Helpful message if keys are missing
    missing_keys = [k for key in network_state_dict.keys() if key not in saved_keys]
    if missing_keys:
        print("Missing keys in loaded weights:", missing_keys)

## Compile model

The model will be moved to the right device and potentially compiled

In [None]:
# Move model components to device
model = model.to(device)

# Set the correct inference model (set compile to False to deactivate compilation of model)
model.set_inference_model(compile=True)

# Prediction

To produce the trajectory, we will loop through prediction steps. For each prediction step, the forcings will be load to GPU and the model will produce the prediction. This prediction is then stored in a trajectory list. This trajectory list will be concatenated later and converted into a xarray.DataArray like the truth.

This explicit and readable implementation as shown here is not the most efficient, e.g., the fields could be pre-stored on GPU, reducing the need to copy tensors between CPU and GPU. Hence, the time needed here is not representative.

In [None]:
# Resolution needed for prediction
resolution = torch.full((n_ens, 1), 12.5, device=device)

In [None]:
# Initialise the trajectory list
trajectory = [initial_conditions]

# Set the current conditions to the initial conditions
curr_conditions = initial_conditions.to(device)

# Loop through all forecasting steps
for step in tqdm(range(n_steps)):
    # Get current forcings as tensor
    curr_forcings = forcings[:, step:step+2]
    curr_forcings = torch.as_tensor(curr_forcings, device=device)

    # Get current degree days as tensor
    curr_degree_days = degree_days[:, step]
    curr_degree_days = torch.as_tensor(curr_degree_days, device=device)
    
    with prediction_context:
        # Make the prediction
        curr_conditions = model(
            curr_conditions[:, None],  # Add timeaxis (only current timestep)
            curr_forcings,
            resolution=resolution,
            mesh=mesh,
            mask=mask,
            degree_days=curr_degree_days
        )

    # Add the current prediction to the trajectory list
    trajectory.append(curr_conditions.cpu())

In [None]:
# Stack the trajectory into a time axis
predictions = torch.stack(trajectory, axis=0)

# Convert predictions from torch to numpy
predictions = predictions.numpy()

# Copy predictions into xarray
predictions = ds_truth.expand_dims(ensemble=np.arange(n_ens), axis=1).copy(data=predictions)

# Plotting

We will plot metrics for the quality of the predictions as well as individual predictions

## Reference

As simple yet hard to beat reference in sea-ice modelling, we will use a persistence forecast where the initial conditions are predicted through the whole time.

In [None]:
# Use the initial conditions for all time steps
persistence = ds_initial.expand_dims(time=ds_truth["time"], axis=0)

## Prediction error

We will estimate the prediction error for the ensemble members and the ensemble member independently. We will additionally estimate the error only over cells with ocean.

In [None]:
# Error of persistence
error_persist = persistence-ds_truth
rmse_persist = np.sqrt((error_persist**2).sum(["y", "x"])/n_grid)

# Error of ensemble members
error_members = predictions-ds_truth
rmse_members = np.sqrt((error_members**2).sum(["ensemble", "y", "x"])/n_ens/n_grid)

# Error of ensemble mean
error_mean = predictions.mean("ensemble")-ds_truth
rmse_mean = np.sqrt((error_mean**2).sum(["y", "x"])/n_grid)

In [None]:
timedelta_axis = pd.timedelta_range("0h", periods=n_steps+1, freq="12h") / pd.Timedelta("1d")

### Plot the error

In [None]:
fig, ax = plt.subplots(nrows=2, gridspec_kw={"hspace": 0.10})
ax[0].plot(
    timedelta_axis, rmse_persist.sel(var_names="sit"),
    c="0.5", ls=":", label="Persistence", lw=1
)
ax[0].plot(
    timedelta_axis, rmse_members.sel(var_names="sit"),
    c="#E62A07", ls="--", label="GenSIM member", lw=1
)
ax[0].plot(
    timedelta_axis, rmse_mean.sel(var_names="sit"),
    c="#E65007", ls="-", label="GenSIM mean", lw=1.5
)
ax[0].set_xlim(0, 3.6)
ax[0].set_xticks([0, 1, 2, 3], [])
ax[0].set_ylim(0, 0.17)
ax[0].set_ylabel("RMSE Thickness (m)")
ax[0].legend()

ax[1].plot(
    timedelta_axis, rmse_persist.sel(var_names="siu"),
    c="0.5", ls=":", label="Persistence", lw=1
)
ax[1].plot(
    timedelta_axis, rmse_members.sel(var_names="siu"),
    c="#E62A07", ls="--", label="GenSIM member", lw=1
)
ax[1].plot(
    timedelta_axis, rmse_mean.sel(var_names="siu"),
    c="#E65007", ls="-", label="GenSIM mean", lw=1.5
)
ax[1].set_xlim(0, 3.6)
ax[1].set_xticks([0, 1, 2, 3], [0, 1, 2, 3])
ax[1].set_xlabel("Lead time (days)")
ax[1].set_ylim(0, 0.08)
ax[1].set_ylabel("RMSE x-drift (m/s)")

GenSIM improves upon the persistence forecasts for all variables and all lead times

## Power spectrum

We will estimate the power spectrum to analyse the behavior of the forecasts with respect to smoothing. We will only analyse the predictions after 3.5 days of forecasting lead time as we expect the strongest smoothing there.

In [None]:
# Slice specifying central Arctic (128x128 grid points)
SPECTRUM_SLICES = (
    slice(315, 443),
    slice(180, 308)
)

# Settings needed to estimate spectrum
norm = 1 / 128 / 128
yc, xc = np.ogrid[np.s_[-64:64],np.s_[-64:64]]
radius = np.sqrt(yc**2 + xc**2).round()
r_range = np.arange(65)
freqs = np.fft.fftfreq(128, 1)[:64]

In [None]:
# Function to estimate the spectrum
def estimate_spectrum(field):
    psd = np.fft.fftshift(np.fft.fft2(field))
    psd = np.abs(psd)**2 * norm
    psd = stats.binned_statistic(
        radius.flatten(), psd.flatten(), bins=r_range-0.5
    ).statistic
    return psd

In [None]:
spectrum_truth = xr.apply_ufunc(
    estimate_spectrum,
    ds_truth.isel(time=-1, y=SPECTRUM_SLICES[0], x=SPECTRUM_SLICES[1]),   # Slice last step and central Arctic
    input_core_dims=[["y", "x"]],
    output_core_dims=[["freqs"]],
    vectorize=True,
    dask='parallelized',
    output_dtypes=[np.float32],
    dask_gufunc_kwargs=dict(
        output_sizes={"freqs": 64}
    )
).compute()

spectrum_members = xr.apply_ufunc(
    estimate_spectrum,
    predictions.isel(time=-1, y=SPECTRUM_SLICES[0], x=SPECTRUM_SLICES[1]),   # Slice last step and central Arctic
    input_core_dims=[["y", "x"]],
    output_core_dims=[["freqs"]],
    vectorize=True,
    dask='parallelized',
    output_dtypes=[np.float32],
    dask_gufunc_kwargs=dict(
        output_sizes={"freqs": 64}
    )
).compute()
# Average the spectrum across the ensemble dimension in log space
spectrum_members = np.exp(np.log(spectrum_members).mean(["ensemble"]))

spectrum_mean = xr.apply_ufunc(
    estimate_spectrum,
    predictions.mean("ensemble").isel(time=-1, y=SPECTRUM_SLICES[0], x=SPECTRUM_SLICES[1]),   # Take ensemble mean and slice last step and central Arctic
    input_core_dims=[["y", "x"]],
    output_core_dims=[["freqs"]],
    vectorize=True,
    dask='parallelized',
    output_dtypes=[np.float32],
    dask_gufunc_kwargs=dict(
        output_sizes={"freqs": 64}
    )
).compute()

### Plot the spectrum

In [None]:
fig, ax = plt.subplots(nrows=2, gridspec_kw={"hspace": 0.10})
ax[0].loglog(
    64*24/np.arange(1, 65),
    spectrum_truth.sel(var_names="sit"),
    c="black", label="Truth", lw=1.,
)
ax[0].plot(
    64*24/np.arange(1, 65),
    spectrum_mean.sel(var_names="sit"),
    c="#E65007", label="GenSIM", lw=1.5
)
ax[0].plot(
    64*24/np.arange(1, 65),
    spectrum_members.sel(var_names="sit"),
    c="#E62A07", ls="--", label="GenSIM member", lw=1
)
ax[0].set_xscale("log")
ax[0].set_yscale("log")
ax[0].set_xticks([400, 200, 100, 50, 25])
ax[0].set_xticklabels([])
ax[0].set_xticks([], minor=True)
ax[0].set_xlim(500, 23)
ax[0].set_ylim(3E-3, 1E2)
ax[0].set_ylabel("Energy (m$^{2}$)")
ax[0].legend()

ax[1].loglog(
    64*24/np.arange(1, 65),
    spectrum_truth.sel(var_names="siu"),
    c="black", label="Truth", lw=1.,
)
ax[1].plot(
    64*24/np.arange(1, 65),
    spectrum_mean.sel(var_names="siu"),
    c="#E65007", label="GenSIM", lw=1.5
)
ax[1].plot(
    64*24/np.arange(1, 65),
    spectrum_members.sel(var_names="siu"),
    c="#E62A07", ls="--", label="GenSIM member", lw=1
)
ax[1].set_xscale("log")
ax[1].set_yscale("log")
ax[1].set_xticks([400, 200, 100, 50, 25])
ax[1].set_xticklabels([400, 200, 100, 50, 25])
ax[1].set_xticks([], minor=True)
ax[1].set_xlabel("Resolution (km)")
ax[1].set_xlim(500, 23)
ax[1].set_ylim(3E-6, 2E-1)
ax[1].set_ylabel("Energy (m$^{2}$ s$^{-2}$)")

fig.align_ylabels(ax)

## Extent accuracy

Estimate the accuracy with which the sea-ice edge is predicted. The threshold is set to 0.15. The accuracy is given as matching extent between truth and forecast weighted by the area of a cell.

In [None]:
threshold = 0.15

In [None]:
extent_truth = ds_truth.sel(var_names="sic") > threshold
extent_persistence = persistence.sel(var_names="sic") > threshold
extent_members = predictions.sel(var_names="sic") > threshold
extent_mean = predictions.mean("ensemble").sel(var_names="sic") > threshold

In [None]:
# To estimate an area weighted accuracy
# The weights is the cell area divided by the total area
weights = ds_aux["cell_area"].where(ds_aux["mask"].astype(bool))
weights /= weights.sum(["x", "y"])

In [None]:
acc_persist = ((extent_persistence == extent_truth)*weights).sum(["x", "y"])
acc_members = ((extent_members == extent_truth)*weights).sum(["x", "y"]).mean("ensemble")
acc_mean = ((extent_mean == extent_truth)*weights).sum(["x", "y"])

In [None]:
fig, ax = plt.subplots()
ax.plot(
    timedelta_axis, acc_persist,
    c="0.5", ls=":", label="Persistence", lw=1
)
ax.plot(
    timedelta_axis, acc_members,
    c="#E62A07", ls="--", label="GenSIM member", lw=1
)
ax.plot(
    timedelta_axis, acc_mean,
    c="#E65007", ls="-", label="GenSIM mean", lw=1.5
)
ax.set_xlim(0, 3.6)
ax.set_xticks([0, 1, 2, 3], [0, 1, 2, 3])
ax.set_xlabel("Lead time (days)")
ax.set_ylim(0.99, 1)
ax.set_ylabel("Accuracy")
ax.legend()

# Plot single forecasts

For the single forecasts, we will imitate what was shown in Figure 2 of the paper after 2.5 and 3 days of forecasting.

## Estimate deformation

In [None]:
def estimate_deform(siu, siv, x_coord, y_coord):
    area = 0.5 * (x_coord * y_coord.roll(abcd=-1) - x_coord.roll(abcd=-1) * y_coord).sum("abcd", skipna=False)
    du_x = (0.5 * (siu.roll(abcd=-1) + siu) * (y_coord.roll(abcd=-1) - y_coord)).sum("abcd", skipna=False) / area
    du_y = -(0.5 * (siu.roll(abcd=-1) + siu) * (x_coord.roll(abcd=-1) - x_coord)).sum("abcd", skipna=False) / area
    dv_x = (0.5 * (siv.roll(abcd=-1) + siv) * (y_coord.roll(abcd=-1) - y_coord)).sum("abcd", skipna=False) / area
    dv_y = -(0.5 * (siv.roll(abcd=-1) + siv) * (x_coord.roll(abcd=-1) - x_coord)).sum("abcd", skipna=False) / area

    div = du_x + dv_y
    shear = ((du_x - dv_y)**2 + (du_y + dv_x)**2)**0.5
    total = (div**2 + shear**2)**0.5
    return xr.Dataset({
        "area": area,
        "divergence": div,
        "shear": shear,
        "total": total
    })

In [None]:
idx_mesh = np.stack(np.meshgrid(np.arange(512), np.arange(512), indexing="xy"), axis=0)
idx_rect = np.stack((
    idx_mesh[:, :-1, :-1],
    idx_mesh[:, :-1, 1:],
    idx_mesh[:, 1:, 1:],
    idx_mesh[:, 1:, :-1],
), axis=-1).reshape(2, 511, 511, 4)
idx_rect = xr.Dataset({
    "x": (("yc", "xc", "abcd"), idx_rect[0]),
    "y": (("yc", "xc", "abcd"), idx_rect[1]),
})

In [None]:
deformation_truth = estimate_deform(
    ds_truth.sel(var_names="siu").isel(y=idx_rect["y"], x=idx_rect["x"]),
    ds_truth.sel(var_names="siv").isel(y=idx_rect["y"], x=idx_rect["x"]),
    ds_aux["x_coord"].isel(y=idx_rect["y"], x=idx_rect["x"]),
    ds_aux["y_coord"].isel(y=idx_rect["y"], x=idx_rect["x"]),
)

thickness_change_truth = ds_truth.sel(var_names="sit").diff("time", 1, label="lower")

In [None]:
deformation_prediction = estimate_deform(
    predictions.sel(var_names="siu").isel(y=idx_rect["y"], x=idx_rect["x"]),
    predictions.sel(var_names="siv").isel(y=idx_rect["y"], x=idx_rect["x"]),
    ds_aux["x_coord"].isel(y=idx_rect["y"], x=idx_rect["x"]),
    ds_aux["y_coord"].isel(y=idx_rect["y"], x=idx_rect["x"]),
)

thickness_change_prediction = predictions.sel(var_names="sit").diff("time", 1, label="lower")

## Plot comparison

In [None]:
# Set default projection
projection = ccrs.NorthPolarStereo()

In [None]:
# Settings with ensemble member and time steps should be used
idx_mem = 0
step_1 = -3
step_2 = -2

In [None]:
fig = plt.figure(figsize=(10, 4), dpi=150)

# Add subplots
gs_truth = mpl_gs.GridSpec(nrows=2, ncols=3, left=0.04, right=0.48, wspace=0.05, hspace=0.1)
ax_truth = [fig.add_subplot(subgs, projection=projection) for subgs in gs_truth]
gs_prediction = mpl_gs.GridSpec(nrows=2, ncols=3, left=0.52, right=0.96, wspace=0.05, hspace=0.1)
ax_prediction = [fig.add_subplot(subgs, projection=projection) for subgs in gs_prediction]

for axi in ax_truth + ax_prediction:
    axi.xaxis.set_visible(False)
    axi.yaxis.set_visible(False)
    axi.set_yticks([])
    axi.spines.left.set_visible(False)
    axi.spines.right.set_visible(False)
    axi.spines.bottom.set_visible(False)
    axi.set_extent([-1_500_000, 1_500_000, -800_000, 2_200_000], ccrs.NorthPolarStereo())
    axi.add_feature(cartopy.feature.LAND, fc="xkcd:putty", zorder=99, rasterized=True)

# Sea-ice concentration
sic_transform = lambda x: np.power(10_000, x)
cf_sic = ax_truth[0].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    sic_transform(ds_truth.sel(var_names="sic")[step_1].values),
    cmap="cmo.dense_r", vmin=sic_transform(0.9), vmax=sic_transform(1.),
    rasterized=True
)
cf_sic = ax_truth[3].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    sic_transform(ds_truth.sel(var_names="sic")[step_2].values),
    cmap="cmo.dense_r", vmin=sic_transform(0.9), vmax=sic_transform(1.),
    rasterized=True
)

bbox = ax_truth[3].get_position()
cax = fig.add_axes([bbox.x0+0.02, bbox.y0-0.02, bbox.x1-bbox.x0-0.04, 0.01])
cbar = fig.colorbar(cf_sic, cax, orientation="horizontal", extend="min")
cbar.set_label(r"Concentration (1)", labelpad=8)
cbar.set_ticks(sic_transform([0.9, 0.95, 1]), labels=["0.9", "0.95", "1"])
for t in cbar.ax.get_yticklabels():
    t.set_horizontalalignment('center')
cax.xaxis.get_label().set_verticalalignment("baseline")

cf_sic = ax_prediction[0].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    sic_transform(predictions.sel(var_names="sic")[step_1, idx_mem].values),
    cmap="cmo.dense_r", vmin=sic_transform(0.9), vmax=sic_transform(1.),
    rasterized=True
)
cf_sic = ax_prediction[3].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    sic_transform(predictions.sel(var_names="sic")[step_2, idx_mem].values),
    cmap="cmo.dense_r", vmin=sic_transform(0.9), vmax=sic_transform(1.),
    rasterized=True
)

bbox = ax_prediction[3].get_position()
cax = fig.add_axes([bbox.x0+0.02, bbox.y0-0.02, bbox.x1-bbox.x0-0.04, 0.01])
cbar = fig.colorbar(cf_sic, cax, orientation="horizontal", extend="min")
cbar.set_label(r"Concentration (1)", labelpad=8)
cbar.set_ticks(sic_transform([0.9, 0.95, 1]), labels=["0.9", "0.95", "1"])
for t in cbar.ax.get_yticklabels():
    t.set_horizontalalignment('center')
cax.xaxis.get_label().set_verticalalignment("baseline")

# Divergence rate
cf_div = ax_truth[1].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    deformation_truth["divergence"][step_1].values * 86400,
    cmap="RdBu",
    norm=mpl_c.SymLogNorm(0.05, 1., vmin=-3, vmax=3),
    rasterized=True
)
cf_div = ax_truth[4].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    deformation_truth["divergence"][step_2].values * 86400,
    cmap="RdBu",
    norm=mpl_c.SymLogNorm(0.05, 1., vmin=-3, vmax=3),
    rasterized=True
)

bbox = ax_truth[4].get_position()
cax = fig.add_axes([bbox.x0+0.02, bbox.y0-0.02, bbox.x1-bbox.x0-0.04, 0.01])
cbar = fig.colorbar(cf_div, cax, orientation="horizontal", extend="both")
cbar.set_label(r"Rate (day$^{-1}$)", labelpad=8)
cbar.set_ticks([-1, 0, 1], labels=["-1", "0", "1"])
cbar.set_ticks([], minor=True)
for t in cbar.ax.get_yticklabels():
    t.set_horizontalalignment('center')
cax.xaxis.get_label().set_verticalalignment("baseline")

cf_div = ax_prediction[1].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    deformation_prediction["divergence"][step_1, idx_mem].values * 86400,
    cmap="RdBu",
    norm=mpl_c.SymLogNorm(0.05, 1., vmin=-3, vmax=3),
    rasterized=True
)
cf_div = ax_prediction[4].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    deformation_prediction["divergence"][step_2, idx_mem].values * 86400,
    cmap="RdBu",
    norm=mpl_c.SymLogNorm(0.05, 1., vmin=-3, vmax=3),
    rasterized=True
)

bbox = ax_prediction[4].get_position()
cax = fig.add_axes([bbox.x0+0.02, bbox.y0-0.02, bbox.x1-bbox.x0-0.04, 0.01])
cbar = fig.colorbar(cf_div, cax, orientation="horizontal", extend="both")
cbar.set_label(r"Rate (day$^{-1}$)", labelpad=8)
cbar.set_ticks([-1, 0, 1], labels=["-1", "0", "1"])
cbar.set_ticks([], minor=True)
for t in cbar.ax.get_yticklabels():
    t.set_horizontalalignment('center')
cax.xaxis.get_label().set_verticalalignment("baseline")

# Thickness change
cf_thick = ax_truth[2].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    thickness_change_truth[step_1].values,
    cmap="cmo.balance",
    vmin=-0.25, vmax=0.25,
    rasterized=True 
)
cf_thick = ax_truth[5].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    thickness_change_truth[step_2].values,
    cmap="cmo.balance",
    vmin=-0.25, vmax=0.25,
    rasterized=True 
)

bbox = ax_truth[5].get_position()
cax = fig.add_axes([bbox.x0+0.02, bbox.y0-0.02, bbox.x1-bbox.x0-0.04, 0.01])
cbar = fig.colorbar(cf_thick, cax, orientation="horizontal", extend="both")
cbar.set_label(r"Change (m/12 h)", labelpad=8)
cbar.set_ticks([], minor=True)
for t in cbar.ax.get_yticklabels():
    t.set_horizontalalignment('center')
cax.xaxis.get_label().set_verticalalignment("baseline")

cf_thick = ax_prediction[2].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    thickness_change_prediction[step_1, idx_mem].values,
    cmap="cmo.balance",
    vmin=-0.25, vmax=0.25,
    rasterized=True 
)
cf_thick = ax_prediction[5].pcolormesh(
    ds_aux["x_coord"], ds_aux["y_coord"],
    thickness_change_prediction[step_2, idx_mem].values,
    cmap="cmo.balance",
    vmin=-0.25, vmax=0.25,
    rasterized=True 
)

bbox = ax_prediction[5].get_position()
cax = fig.add_axes([bbox.x0+0.02, bbox.y0-0.02, bbox.x1-bbox.x0-0.04, 0.01])
cbar = fig.colorbar(cf_thick, cax, orientation="horizontal", extend="both")
cbar.set_label(r"Change (m/12 h)", labelpad=8)
cbar.set_ticks([], minor=True)
for t in cbar.ax.get_yticklabels():
    t.set_horizontalalignment('center')
cax.xaxis.get_label().set_verticalalignment("baseline")


ax_truth[0].set_ylabel("{0:%Y-%m-%d %H:%M}".format(ds_demo.indexes["time"][step_1]), fontsize=10)
ax_truth[3].set_ylabel("{0:%Y-%m-%d %H:%M}".format(ds_demo.indexes["time"][step_2]), fontsize=10)

ax_truth[0].set_title("Concentration", fontsize=10)
ax_truth[1].set_title("Divergence rate", fontsize=10)
ax_truth[2].set_title("$\Delta$ Thickness", fontsize=10)
ax_truth[1].text(0.5, 1.15, s="neXtSIM-OPA", ha="center", va="bottom", transform=ax_truth[1].transAxes, fontsize=14, fontweight="bold")

ax_prediction[0].set_title("Concentration", fontsize=10)
ax_prediction[1].set_title("Divergence rate", fontsize=10)
ax_prediction[2].set_title("$\Delta$ Thickness", fontsize=10)
ax_prediction[1].text(0.5, 1.15, s="GenSIM", ha="center", va="bottom", transform=ax_prediction[1].transAxes, fontsize=14, fontweight="bold")

# Congratulations

You have reached the end. Now, you know how to use GenSIM and how to reproduce some of the results in the manuscript.

Feel free to play around with the settings or to use the model in other experiments.

For other questions please contact: Tobias Finn (tobias.finn@enpc.fr)