# Figure - Pixel Densities

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels

from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM_*.nc"


def preprocess(ds):
    # subset time
    # ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    # ds = ds.coarsen({"lon": 5, "lat": 5}).mean()

    ds = ds.resample(time="1D").mean()
    return ds


ds_field = xr.open_mfdataset(url, preprocess=preprocess)

ds_field

### Example Results

In [None]:
def remove_nans(x):
    return x[~np.isnan(x)]


def post_process(ds, variable):
    # subset temporal space
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    ds = correct_longitude_domain(ds)

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
ds_field = post_process(ds_field, "sossheig")
ds_field

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid

logger.info("Dataset I - MIOST")
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "gssh")

ds_field["ssh_miost"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset II - DUACS")
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "gssh")

ds_field["ssh_duacs"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset II - SIREN")
# url = "/Users/eman/code_projects/logs/saved_data/test_res_nadir4_lb.nc"
url = "/Users/eman/code_projects/logs/saved_data/test_res_nadir4_jz.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh_model_predict")

ds_field["ssh_nerf"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

In [None]:
ds_field

In [None]:
from tqdm.notebook import tqdm
from inr4ssh._src.operators.ssh import (
    ssh2uv_ds_2dt,
    ssh2uv_da_2dt,
    kinetic_energy,
    enstropy,
    ssh2rv_ds_2dt,
    ssh2rv_da_2dt,
)

variables = ["ssh", "ssh_duacs", "ssh_miost", "ssh_nerf"]

logger.info("Calculating Kinetic Energy...")
for ivar in tqdm(variables):
    # calculate UV components
    u, v = ssh2uv_da_2dt(ds_field[ivar])
    ke = kinetic_energy(u, v)

    ds_field[f"{ivar}_ke"] = (("time", "latitude", "longitude"), ke)

logger.info("Calculating Enstrophy...")
for ivar in tqdm(variables):
    # calculate UV components
    rv = ssh2rv_da_2dt(ds_field[ivar])

    ds_field[f"{ivar}_ens"] = (("time", "latitude", "longitude"), rv)

## Metrics II - Pixel-Wise Dist

### Density (Sea Surface Height)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_duacs.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="MIOST",
    color="blue",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_dens_ssh.png"))
plt.show()

### Cumulative Density (Sea Surface Height)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_duacs.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_cdens_ssh.png"))
plt.show()

### Density (Kinetic Energy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_ke.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_duacs_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_dens_ke.png"))
plt.show()

### Log Density (Kinetic Energy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_ke.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_duacs_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ke.values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Log Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_dens_lke.png"))
plt.show()

### Cumulative Density (Kinetic Energy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_ke.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_duacs_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_cdens_ke.png"))
plt.show()

### Log Cumulative Density (Kinetic Energy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_ke.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_duacs_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ke.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
# ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_cdens_lke.png"))
plt.show()

## Enstropy

### Density (Enstropy)

In [None]:
ds_field.ssh_ke.to_dataframe().dropna().values.shape, ds_field.ssh_duacs_ke.to_dataframe().dropna().values.shape, ds_field.ssh_miost_ke.to_dataframe().dropna().values.shape, ds_field.ssh_nerf_ke.to_dataframe().dropna().values.shape

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_ens.to_dataframe().dropna().ssh_ens.values + 1e-10,
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_ens.to_dataframe().dropna().values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_duacs_ens.to_dataframe().dropna().values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ens.to_dataframe().dropna().values.flatten(),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ens.to_dataframe().dropna().values.flatten() + 1e-10,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
ax.set_xlabel(r"Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_dens_ens.png"))
plt.show()

### Log Density (Enstropy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=remove_nans(ds_field.ssh_ens.values.flatten()) + 1e-10,
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_duacs_ens.values.flatten()) + 1e-10,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_miost_ens.values.flatten()) + 1e-15,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=remove_nans(ds_field.ssh_nerf_ens.values.flatten()) + 1e-10,
    cumulative=False,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_dens_lens.png"))
plt.show()

### Cumulative Density (Enstropy)

In [None]:
fig, ax = plt.subplots()
sns.histplot(
    data=ds_field.ssh_ens.values.flatten(),
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_duacs_ens.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_miost_ens.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
sns.histplot(
    data=ds_field.ssh_nerf_ens.values.flatten(),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=False,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
ax.set_xlabel(r"Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_cdens_ens.png"))
plt.show()

### Log Cumulative Density (Enstropy)

In [None]:
import numpy.ma as ma

fig, ax = plt.subplots()
data = ds_field.ssh_ens.values.flatten()
data = data[~np.isnan(data)] + 1e-10
sns.histplot(
    data=data,
    # data=np.log(ds_field.ssh_grad.values.flatten()),
    # data=np.log(ds_field.ssh_lap.values.flatten()),
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NATL60",
    color="black",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
data = ds_field.ssh_duacs_ens.values.flatten()
data = data[~np.isnan(data)] + 1e-10
sns.histplot(
    data=data,
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="DUACS",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
data = ds_field.ssh_miost_ens.values.flatten()
data = data[~np.isnan(data)] + 1e-10
sns.histplot(
    data=data,
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="MIOST",
    color="green",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)

sns.histplot(
    data=remove_nans(ds_field.ssh_nerf_ens.values.flatten()) + 1e-10,
    cumulative=True,
    common_norm=True,
    stat="density",
    ax=ax,
    log_scale=True,
    label="NerF",
    color="red",
    fill=False,
    element="step",
    linewidth=3,
    alpha=0.5,
)
# ax.set_xlabel("SSH [m]")
# ax.set_xlabel(r"Log Kinetic Energy [m$^2$s$^{-2}$]")
ax.set_xlabel(r"Log Enstropy [s$^{-1}$]")
ax.set_ylabel("Cumulative Density")
plt.tight_layout()
plt.legend()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_cdens_lens.png"))
plt.show()