In [None]:
import xarray as xr
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_c
import matplotlib.gridspec as mpl_gs
import matplotlib.dates as mpl_dates
import cartopy.crs as ccrs
import cartopy
import cmocean

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

# Load the data

In [None]:
ds_nextsim = xr.open_dataset("data/consistency_nextsim.nc")
ds_deterministic = xr.open_dataset("data/consistency_deterministic.nc")
ds_diffusion = xr.open_dataset("data/consistency_diffusion.nc")

In [None]:
idx_ens = 0
last_t = 100

In [None]:
fig, ax = plt.subplots(
    ncols=3, figsize=(4, 4), dpi=120,
    nrows=3
)
for axi in ax:
    for axij in axi:
        axij.set_facecolor(cmocean.cm.ice(0.))
        axij.xaxis.set_visible(False)
        axij.spines.left.set_visible(False)
        axij.spines.right.set_visible(False)
        axij.spines.bottom.set_visible(False)

ax[0, 0].set_title("+ 12 hours")
ax[0, 1].set_title("+ 10 days")
ax[0, 2].set_title(f"+ {last_t/2:.0f} days")

ax[0, 0].set_ylabel("neXtSIM")
ax[1, 0].set_ylabel("Deterministic")
ax[2, 0].set_ylabel("ResDiffusion")

for i, t in enumerate([1, 20, last_t]):
    cf = ax[0, i].pcolormesh(
        np.arange(64), np.arange(64),
        ds_nextsim["sit"].isel(time=t),
        cmap="cmo.ice", vmin=0, vmax=2.5,
        shading="nearest"
    )
    ax[0, i].text(0.02, 0.98, f"({chr(97+i):s})", transform=ax[0, i].transAxes, fontweight="heavy", ha="left", va="top", c="white")
    ax[1, i].pcolormesh(
        np.arange(64), np.arange(64),
        ds_deterministic["sit"].isel(time=t),
        cmap="cmo.ice", vmin=0, vmax=2.5,
        shading="nearest"
    )
    ax[1, i].text(0.02, 0.98, f"({chr(97+i+3):s})", transform=ax[1, i].transAxes, ha="left", va="top", c="white")
    cf = ax[2, i].pcolormesh(
        np.arange(64), np.arange(64),
        ds_diffusion["sit"].isel(time=t, ens=idx_ens),
        cmap="cmo.ice", vmin=0, vmax=2.5,
        shading="nearest"
    )
    ax[2, i].text(0.02, 0.98, f"({chr(97+i+6):s})", transform=ax[2, i].transAxes, ha="left", va="top", c="white")
for axi in ax:
    for axij in axi:
        axij.set_yticks([])

cax = fig.add_axes([0.92, 0.2, 0.02, 0.6])
cbar = fig.colorbar(cf, cax) 
cbar.set_label("Sea-ice thickness (m)")
fig.savefig("figures/fig_05_spatial_sit.png", dpi=300)