---
title: Gradient Considerations
date: 2023-04-01
authors:
  - name: J. Emmanuel Johnson
    affiliations:
      - MEOM Lab
    roles:
      - Primary Programmer
    email: jemanjohnson34@gmail.com
license: CC-BY-4.0
keywords: NerFs, Images
---

In [None]:
import sys, os

# spyder up to find the root
oceanbench_root = "/gpfswork/rech/cli/uvo53rl/projects/oceanbench"

# append to path
sys.path.append(str(oceanbench_root))

In [None]:
import autoroot
import typing as tp
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrandom
import numpy as np
import numba as nb
import pandas as pd
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from omegaconf import OmegaConf
import hydra
import metpy
from sklearn.pipeline import Pipeline
from jejeqx._src.transforms.dataframe.spatial import Spherical2Cartesian
from jejeqx._src.transforms.dataframe.temporal import TimeDelta
from jejeqx._src.transforms.dataframe.scaling import MinMaxDF


sns.reset_defaults()
sns.set_context(context="poster", font_scale=0.7)
jax.config.update("jax_enable_x64", False)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Processing Chain

**Part I**:

* Open Dataset
* Validate Coordinates + Variables
* Decode Time
* Select Region
* Sortby Time

**Part II**: Regrid

**Part III**:

* Interpolate Nans
* Add Units
* Spatial Rescale
* Time Rescale

**Part IV**: Metrics

*

## Data

In [None]:
# !wget wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc

In [None]:
!ls /gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/results/DUACS

In [None]:
# !cat configs/postprocess.yaml

In [None]:
# # load config
# config_dm = OmegaConf.load('./configs/postprocess.yaml')

# # instantiate
# ds = hydra.utils.instantiate(config_dm.NATL60_GF_1Y1D)
# ds

## Reference Dataset

For the reference dataset, we will look at the NEMO simulation of the Gulfstream.

In [None]:
%%time

# load config
config_dm = OmegaConf.load("./configs/postprocess.yaml")

# instantiate
ds_natl60 = hydra.utils.instantiate(config_dm.NATL60_GF_FULL).compute()
ds_natl60

In [None]:
nadir4_config = OmegaConf.load(f"./configs/natl60_obs.yaml")
ds_nadir4 = hydra.utils.instantiate(nadir4_config.ALONGTRACK_NADIR4.data).compute()
ds_swot1nadir5 = hydra.utils.instantiate(
    nadir4_config.ALONGTRACK_SWOT1NADIR5.data
).compute()
ds_swot1nadir5

## Regrdding: AlongTrack -> Uniform Grid

In [None]:
from oceanbench._src.geoprocessing.gridding import (
    grid_to_regular_grid,
    coord_based_to_grid,
)

In [None]:
%%time

ds_nadir4 = coord_based_to_grid(
    coord_based_ds=ds_nadir4,
    target_grid_ds=ds_natl60.pint.dequantify(),
)
ds_swot1nadir5 = coord_based_to_grid(
    coord_based_ds=ds_swot1nadir5,
    target_grid_ds=ds_natl60.pint.dequantify(),
)

#### AlongTrack -> Uniform Grid

In [None]:
# load config
psd_config = OmegaConf.load("./configs/metrics.yaml")

ds_natl60 = hydra.utils.instantiate(psd_config.fill_nans)(ds_natl60.pint.dequantify())

In [None]:
def correct_labels(ds):
    ds["lon"].attrs["units"] = "degrees"
    ds["lat"].attrs["units"] = "degrees"
    ds["ssh"].attrs["units"] = "m"
    ds["ssh"].attrs["standard_name"] = "sea_surface_height"
    ds["ssh"].attrs["long_name"] = "Sea Surface Height"
    ds["lon"].attrs["standard_name"] = "longitude"
    ds["lat"].attrs["standard_name"] = "latitude"
    ds["lat"].attrs["long_name"] = "Latitude"
    ds["lon"].attrs["long_name"] = "Longitude"

    return ds

In [None]:
def plot_obs(ds, variable: str = "ssh", **kwargs):
    fig, ax = plt.subplots(figsize=(7, 5.5))

    X, Y = np.meshgrid(ds[variable].lon, ds[variable].lat, indexing="ij")

    xlabel = f"{ds.lon.attrs['long_name']} [{ds.lon.attrs['units']}]"
    ylabel = f"{ds.lat.attrs['long_name']} [{ds.lat.attrs['units']}]"

    pts = ax.scatter(
        X,
        Y,
        c=np.ma.masked_invalid(ds[variable]).T,
        marker="s",
        s=0.25,
        vmin=kwargs.pop("vmin", None),
        vmax=kwargs.pop("vmax", None),
    )
    ax.set(
        xlim=kwargs.pop("xlim", None),
        ylim=kwargs.pop("ylim", None),
        xlabel=xlabel,
        ylabel=ylabel,
    )
    name = ds[variable].attrs["long_name"]
    unit = ds[variable].attrs["units"]
    label = f"{name} [{unit}]"
    plt.colorbar(pts, cmap=kwargs.pop("cmap", "viridis"), label=label)

    ax.set_title(pd.to_datetime(ds.time.values).strftime("%Y-%m-%d"))
    fig.tight_layout()

    return fig, ax

In [None]:
vmin, vmax = (
    correct_labels(ds_natl60).ssh.min().pint.dequantify(),
    correct_labels(ds_natl60).ssh.max().pint.dequantify(),
)
xlim = [ds_natl60.lon.min().values, ds_natl60.lon.max().values]
ylim = [ds_natl60.lat.min().values, ds_natl60.lat.max().values]
itime = "2012-10-27"
variable = "ssh"

# SWOT1NADIR5
fig, ax = plot_obs(
    correct_labels(ds_swot1nadir5).sel(time=itime).pint.dequantify(),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap="viridis",
)
fig.savefig(f"./figures/dc20a/maps/dc20a_ssh_swot1nadir5_{itime}.png")
plt.close()
# NADIR4
fig, ax = plot_obs(
    correct_labels(ds_nadir4).sel(time=itime).pint.dequantify(),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap="viridis",
)
fig.savefig(f"./figures/dc20a/maps/dc20a_ssh_nadir4_{itime}.png")
plt.close()

## Coarsend Versions

In [None]:
ds_natl60 = ds_natl60.coarsen({"lon": 3, "lat": 3}).mean()
ds_natl60

### Prediction Datasets 

In [None]:
%%time

# load config

experiment = "nadir"  # "swot" #
if experiment == "nadir":
    # load config
    results_config = OmegaConf.load(f"./configs/results_dc20a_nadir.yaml")

    # instantiate
    ds_duacs = hydra.utils.instantiate(results_config.DUACS_NADIR.data).compute()
    ds_miost = hydra.utils.instantiate(results_config.MIOST_NADIR.data).compute()
    ds_nerf_siren = hydra.utils.instantiate(
        results_config.NERF_SIREN_NADIR.data
    ).compute()
    ds_nerf_ffn = hydra.utils.instantiate(results_config.NERF_FFN_NADIR.data).compute()
    ds_nerf_mlp = hydra.utils.instantiate(results_config.NERF_MLP_NADIR.data).compute()
elif experiment == "swot":
    # load config
    results_config = OmegaConf.load(f"./configs/results_dc20a_swot.yaml")

    # instantiate
    ds_duacs = hydra.utils.instantiate(results_config.DUACS_SWOT.data).compute()
    ds_miost = hydra.utils.instantiate(results_config.MIOST_SWOT.data).compute()
    ds_nerf_siren = hydra.utils.instantiate(
        results_config.NERF_SIREN_SWOT.data
    ).compute()
    ds_nerf_ffn = hydra.utils.instantiate(results_config.NERF_FFN_SWOT.data).compute()
    ds_nerf_mlp = hydra.utils.instantiate(results_config.NERF_MLP_SWOT.data).compute()

In [None]:
!ls /gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/ml_ready/

## Regrdding

#### Uniform Grid --> Uniform Grid

In [None]:
%%time

ds_duacs = grid_to_regular_grid(
    src_grid_ds=ds_duacs.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_miost = grid_to_regular_grid(
    src_grid_ds=ds_miost.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_siren = grid_to_regular_grid(
    src_grid_ds=ds_nerf_siren.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_ffn = grid_to_regular_grid(
    src_grid_ds=ds_nerf_ffn.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)
ds_nerf_mlp = grid_to_regular_grid(
    src_grid_ds=ds_nerf_mlp.pint.dequantify(),
    tgt_grid_ds=ds_natl60.pint.dequantify(),
    keep_attrs=False,
)

### Preprocess Chain

In [None]:
%%time

# load config
psd_config = OmegaConf.load("./configs/metrics.yaml")

ds_duacs = hydra.utils.instantiate(psd_config.fill_nans)(ds_duacs.pint.dequantify())
ds_miost = hydra.utils.instantiate(psd_config.fill_nans)(ds_miost.pint.dequantify())
ds_nerf_siren = hydra.utils.instantiate(psd_config.fill_nans)(
    ds_nerf_siren.pint.dequantify()
)
ds_nerf_ffn = hydra.utils.instantiate(psd_config.fill_nans)(
    ds_nerf_ffn.pint.dequantify()
)
ds_nerf_mlp = hydra.utils.instantiate(psd_config.fill_nans)(
    ds_nerf_mlp.pint.dequantify()
)

## Sea Surface Height

In [None]:
def plot_map(ds, variable: str = "ssh", **kwargs):
    fig, ax = plt.subplots(figsize=(7, 5.5))
    vmin = kwargs.pop("vmin", None)
    vmax = kwargs.pop("vmax", None)
    cmap = kwargs.pop("cmap", "viridis")

    ds[variable].plot.pcolormesh(
        ax=ax,
        vmin=vmin,
        vmax=vmax,
        cmap=cmap,
        **kwargs,
    )
    ds[variable].plot.contour(
        ax=ax,
        levels=5,
        alpha=0.25,
        linewidths=1,
        cmap="black",
        vmin=vmin,
        vmax=vmax,
        **kwargs,
    )
    ax.set_title(pd.to_datetime(ds.time.values).strftime("%Y-%m-%d"))
    fig.tight_layout()

    return fig, ax

In [None]:
vmin, vmax = (
    correct_labels(ds_natl60).ssh.min().pint.dequantify(),
    correct_labels(ds_natl60).ssh.max().pint.dequantify(),
)
xlim = [ds_natl60.lon.min().values, ds_natl60.lon.max().values]
ylim = [ds_natl60.lat.min().values, ds_natl60.lat.max().values]
itime = "2012-10-27"
variable = "ssh"

# NATL60
fig, ax = plot_map(
    correct_labels(ds_natl60).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap="viridis",
)
fig.savefig(f"./figures/dc20a/maps/dc20a_ssh_natl60_{experiment}_{itime}.png")
plt.close()

# DUACS
fig, ax = plot_map(
    correct_labels(ds_duacs).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap="viridis",
)
fig.savefig(f"./figures/dc20a/maps/dc20a_ssh_duacs_{experiment}_{itime}.png")
plt.close()

# MIOST
fig, ax = plot_map(
    correct_labels(ds_miost).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap="viridis",
)
fig.savefig(f"./figures/dc20a/maps/dc20a_ssh_miost_{experiment}_{itime}.png")
plt.close()

# NERF - MLP
fig, ax = plot_map(
    correct_labels(ds_nerf_mlp).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap="viridis",
)
fig.savefig(f"./figures/dc20a/maps/dc20a_ssh_nerf_mlp_{experiment}_{itime}.png")
plt.close()

# NERF - FFN
fig, ax = plot_map(
    correct_labels(ds_nerf_ffn).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap="viridis",
)
fig.savefig(f"./figures/dc20a/maps/dc20a_ssh_nerf_ffn_{experiment}_{itime}.png")
plt.close()

# NERF - SIREN
fig, ax = plot_map(
    correct_labels(ds_nerf_siren).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap="viridis",
)
fig.savefig(f"./figures/dc20a/maps/dc20a_ssh_nerf_siren_{experiment}_{itime}.png")
plt.close()

## Kinetic Energy

In [None]:
from oceanbench._src.geoprocessing import geostrophic as geocalc
from metpy.units import units

In [None]:
def calculate_physical_quantities(da):
    da["ssh"] = da.ssh * units.meters
    da = geocalc.streamfunction(da, "ssh")
    da = geocalc.geostrophic_velocities(da, variable="psi")
    da = geocalc.kinetic_energy(da, variables=["u", "v"])
    da = geocalc.divergence(da, variables=["u", "v"])
    da = geocalc.coriolis_normalized(da, "div")
    da = geocalc.relative_vorticity(da, variables=["u", "v"])
    da = geocalc.coriolis_normalized(da, "vort_r")
    da = geocalc.strain_magnitude(da, variables=["u", "v"])
    da = geocalc.coriolis_normalized(da, variable="strain")
    return da

In [None]:
ds_natl60 = calculate_physical_quantities(ds_natl60.pint.dequantify())
ds_natl60

In [None]:
%%time

ds_natl60 = calculate_physical_quantities(ds_natl60.pint.dequantify())
ds_duacs = calculate_physical_quantities(ds_duacs.pint.dequantify())
ds_miost = calculate_physical_quantities(ds_miost.pint.dequantify())
ds_nerf_siren = calculate_physical_quantities(ds_nerf_siren.pint.dequantify())
ds_nerf_ffn = calculate_physical_quantities(ds_nerf_ffn.pint.dequantify())
ds_nerf_mlp = calculate_physical_quantities(ds_nerf_mlp.pint.dequantify())

In [None]:
variable = "ke"
itime = "2012-10-27"
cmap = "YlGnBu_r"
robust = True

vmin = (
    None  # ds_natl60[variable].sel(time=itime).min().pint.dequantify().quantile(0.05)
)
vmax = (
    None  # ds_natl60[variable].sel(time=itime).max().pint.dequantify().quantile(0.95)
)
xlim = [ds_natl60.lon.min().values, ds_natl60.lon.max().values]
ylim = [ds_natl60.lat.min().values, ds_natl60.lat.max().values]

# NATL60
fig, ax = plot_map(
    ds_natl60.sel(time=itime).pint.dequantify(),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_natl60_{experiment}_{itime}.png")
plt.close()

# DUACS
fig, ax = plot_map(
    correct_labels(ds_duacs).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_duacs_{experiment}_{itime}.png")
plt.close()

# MIOST
fig, ax = plot_map(
    correct_labels(ds_miost).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_miost_{experiment}_{itime}.png")
plt.close()

# NERF - MLP
fig, ax = plot_map(
    correct_labels(ds_nerf_mlp).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_nerf_mlp_{experiment}_{itime}.png")
plt.close()

# NERF - FFN
fig, ax = plot_map(
    correct_labels(ds_nerf_ffn).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_nerf_ffn_{experiment}_{itime}.png")
plt.close()

# NERF - SIREN
fig, ax = plot_map(
    correct_labels(ds_nerf_siren).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(
    f"./figures/dc20a/maps/dc20a_{variable}_nerf_siren_{experiment}_{itime}.png"
)
plt.close()

## Relative Vorticity

In [None]:
variable = "vort_r"
itime = "2012-10-27"
cmap = "RdBu_r"
robust = True

vmin = (
    None  # ds_natl60[variable].sel(time=itime).min().pint.dequantify().quantile(0.05)
)
vmax = (
    None  # ds_natl60[variable].sel(time=itime).max().pint.dequantify().quantile(0.95)
)
xlim = [ds_natl60.lon.min().values, ds_natl60.lon.max().values]
ylim = [ds_natl60.lat.min().values, ds_natl60.lat.max().values]


# NATL60
fig, ax = plot_map(
    ds_natl60.sel(time=itime).pint.dequantify(),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_natl60_{experiment}_{itime}.png")
plt.close()

# DUACS
fig, ax = plot_map(
    correct_labels(ds_duacs).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_duacs_{experiment}_{itime}.png")
plt.close()

# MIOST
fig, ax = plot_map(
    correct_labels(ds_miost).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_miost_{experiment}_{itime}.png")
plt.close()

# NERF - MLP
fig, ax = plot_map(
    correct_labels(ds_nerf_mlp).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_nerf_mlp_{experiment}_{itime}.png")
plt.close()

# NERF - FFN
fig, ax = plot_map(
    correct_labels(ds_nerf_ffn).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_nerf_ffn_{experiment}_{itime}.png")
plt.close()

# NERF - SIREN
fig, ax = plot_map(
    correct_labels(ds_nerf_siren).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(
    f"./figures/dc20a/maps/dc20a_{variable}_nerf_siren_{experiment}_{itime}.png"
)
plt.close()

## Divergence

In [None]:
variable = "div"
itime = "2012-10-27"
cmap = "RdBu_r"
robust = True

vmin = (
    None  # ds_natl60[variable].sel(time=itime).min().pint.dequantify().quantile(0.05)
)
vmax = (
    None  # ds_natl60[variable].sel(time=itime).max().pint.dequantify().quantile(0.95)
)
xlim = [ds_natl60.lon.min().values, ds_natl60.lon.max().values]
ylim = [ds_natl60.lat.min().values, ds_natl60.lat.max().values]


# NATL60
fig, ax = plot_map(
    ds_natl60.sel(time=itime).pint.dequantify(),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_natl60_{experiment}_{itime}.png")
plt.close()

# DUACS
fig, ax = plot_map(
    correct_labels(ds_duacs).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_duacs_{experiment}_{itime}.png")
plt.close()

# MIOST
fig, ax = plot_map(
    correct_labels(ds_miost).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_miost_{experiment}_{itime}.png")
plt.close()

# NERF - MLP
fig, ax = plot_map(
    correct_labels(ds_nerf_mlp).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_nerf_mlp_{experiment}_{itime}.png")
plt.close()

# NERF - FFN
fig, ax = plot_map(
    correct_labels(ds_nerf_ffn).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_nerf_ffn_{experiment}_{itime}.png")
plt.close()

# NERF - SIREN
fig, ax = plot_map(
    correct_labels(ds_nerf_siren).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(
    f"./figures/dc20a/maps/dc20a_{variable}_nerf_siren_{experiment}_{itime}.png"
)
plt.close()

## Strain

In [None]:
import cmocean as cmo

variable = "strain"
itime = "2012-10-27"
cmap = cmo.cm.speed
robust = True

vmin = (
    None  # ds_natl60[variable].sel(time=itime).min().pint.dequantify().quantile(0.05)
)
vmax = (
    None  # ds_natl60[variable].sel(time=itime).max().pint.dequantify().quantile(0.95)
)
xlim = [ds_natl60.lon.min().values, ds_natl60.lon.max().values]
ylim = [ds_natl60.lat.min().values, ds_natl60.lat.max().values]


# NATL60
fig, ax = plot_map(
    ds_natl60.sel(time=itime).pint.dequantify(),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_natl60_{experiment}_{itime}.png")
plt.close()

# DUACS
fig, ax = plot_map(
    correct_labels(ds_duacs).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_duacs_{experiment}_{itime}.png")
plt.close()

# MIOST
fig, ax = plot_map(
    correct_labels(ds_miost).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_miost_{experiment}_{itime}.png")
plt.close()

# NERF - MLP
fig, ax = plot_map(
    correct_labels(ds_nerf_mlp).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_nerf_mlp_{experiment}_{itime}.png")
plt.close()

# NERF - FFN
fig, ax = plot_map(
    correct_labels(ds_nerf_ffn).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(f"./figures/dc20a/maps/dc20a_{variable}_nerf_ffn_{experiment}_{itime}.png")
plt.close()

# NERF - SIREN
fig, ax = plot_map(
    correct_labels(ds_nerf_siren).sel(time=itime),
    variable,
    vmin=vmin,
    vmax=vmax,
    xlim=xlim,
    ylim=ylim,
    cmap=cmap,
    robust=robust,
)
fig.savefig(
    f"./figures/dc20a/maps/dc20a_{variable}_nerf_siren_{experiment}_{itime}.png"
)
plt.close()