# Figure - Pixel Densities

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels

from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
def remove_nans(x):
    return x[~np.nan(x)]

In [None]:
def post_process(ds, variable):

    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    from inr4ssh._src.preprocess.spatial import convert_lon_360_180

    ds["longitude"] = convert_lon_360_180(ds.longitude)

    # subset temporal space
    # ds = ds.sel(time=slice(np.datetime64("2017-01-01"), np.datetime64("2017-12-31")))
    ds = ds.sel(time=slice(np.datetime64("2017-02-01"), np.datetime64("2017-03-01")))
    # ds = ds.sel(
    #     longitude=slice(-75.0, -45.0),
    #     latitude=slice(33.0, 53.0),
    # )

    # # subset spatial space
    # ds = ds.where(
    #     (ds["longitude"] >= -75.0)
    #     & (ds["longitude"] <= -45.0)
    #     & (ds["latitude"] >= 33.0)
    #     & (ds["latitude"] <= 53.0),
    #     drop=True,
    # )

    # # subset spatial space (evaluation)
    # ds = ds.where(
    #     (ds["longitude"] >= -65.0)
    #     & (ds["longitude"] <= -55.0)
    #     & (ds["latitude"] >= 33.0)
    #     & (ds["latitude"] <= 43.0),
    #     drop=True,
    # )

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
logger.info("Dataset I - DUACS")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_DUACS.nc"
ds_field = xr.open_dataset(url)

ds_field = post_process(ds_field, "ssh")

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid

logger.info("Dataset II - BASELINE")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_BASELINE.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh")

ds_field["ssh_oi"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset III - MIOST")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_MIOST.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh")

ds_field["ssh_miost"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset IV - NerF")
url = "/Users/eman/code_projects/logs/saved_data/test_dc21b_feb_pretrain.nc"
# url = "/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc"
# url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_4dvarNet_2022.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh")

ds_field["ssh_nerf"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

In [None]:
ds_field

In [None]:
from tqdm.notebook import tqdm

variables = ["ssh", "ssh_oi", "ssh_miost", "ssh_nerf"]

logger.info("Calculating Kinetic Energy...")
for ivar in tqdm(variables):
    ds_field[f"{ivar}_ke"] = calculate_gradient(ds_field[ivar], "longitude", "latitude")

logger.info("Calculating Enstropy...")
for ivar in tqdm(variables):
    ds_field[f"{ivar}_ens"] = (
        0.5 * calculate_laplacian(ds_field[ivar], "longitude", "latitude") ** 2
    )

## Metrics II - Pixel-Wise Dist

In [None]:
fig_path = Path(root).joinpath("figures/dc21a")

### Density (Sea Surface Height)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_oi.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="MIOST",
    color="blue",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_density_ssh"))
plt.show()

### Cumulative Density (Sea Surface Height)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_oi.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_cdensity_ssh"))
plt.show()

### Density (Kinetic Energy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_oi_ke.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_density_ke"))
plt.show()

### Log Density (Kinetic Energy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_oi_ke.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_density_ke_log"))
plt.show()

### Cumulative Density (Kinetic Energy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_oi_ke.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_cdensity_ke"))
plt.show()

### Log Cumulative Density (Kinetic Energy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_oi_ke.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_cdensity_ke_log"))
plt.show()

## Enstropy

### Density (Enstropy)

In [None]:
def remove_nans(x):
    return x[~np.isnan(x)]

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=remove_nans(ds_field.ssh_oi_ens.values.flatten()) + 1e-10,
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_ens.values.flatten()) + 1e-10,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_miost_ens.values.flatten()) + 1e-10,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_nerf_ens.values.flatten()) + 1e-10,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
ax.set_xlabel(r"Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_density_ens"))
plt.show()

### Log Density (Enstropy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=remove_nans(ds_field.ssh_oi_ens.values.flatten()) + 1e-10,
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_ens.values.flatten()) + 1e-10,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_miost_ens.values.flatten()) + 1e-15,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_nerf_ens.values.flatten()) + 1e-10,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_density_ens_log"))
plt.show()

### Cumulative Density (Enstropy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_oi_ens.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_ens.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ens.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ens.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
ax.set_xlabel(r"Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_cdensity_ens"))
plt.show()

### Log Cumulative Density (Enstropy)

In [None]:
import numpy.ma as ma

fig, ax = plt.subplots()
sns.histplot(
    data=remove_nans(ds_field.ssh_oi_ens.values.flatten()) + 1e-10,
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="OI",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)

sns.histplot(
    data=remove_nans(ds_field.ssh_ens.values.flatten()) + 1e-10,
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)

sns.histplot(
    data=remove_nans(ds_field.ssh_miost_ens.values.flatten()) + 1e-10,
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)

sns.histplot(
    data=remove_nans(ds_field.ssh_nerf_ens.values.flatten()) + 1e-10,
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(fig_path).joinpath(f"dc21a_cdensity_ens_log"))
plt.show()