# Figure - Isotropic PSD

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
from pathlib import Path
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

# import seaborn as sns
#
# sns.reset_defaults()
# sns.set_context(context="talk", font_scale=0.7)
import scienceplots

plt.style.use("science")

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_longitude_domain

# from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)
from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM_*.nc"


def preprocess(ds):
    # subset time
    # ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    # ds = ds.coarsen({"lon": 5, "lat": 5}).mean()

    ds = ds.resample(time="1D").mean()
    return ds


ds_field = xr.open_mfdataset(url, preprocess=preprocess)

ds_field

In [None]:
def post_process(ds, variable):
    # subset temporal space
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    ds = correct_longitude_domain(ds)

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
ds_field = post_process(ds_field, "sossheig")
ds_field

### Example Results

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/swot1nadir5/

In [None]:
!ls /Users/eman/code_projects/logs/saved_data

In [None]:
# url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
# # url =
# ds_predict = xr.open_dataset(url)
#
# ds_predict

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
experiment = "swot1nadir5"

if experiment == "nadir4":
    url_miost = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc"
    url_duacs = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
    url_nerf = "/Users/eman/code_projects/logs/saved_data/test_res_nadir4_jz.nc"
elif experiment == "swot1nadir5":
    url_miost = "/Volumes/EMANS_HDD/data/dc20a_osse/results/swot1nadir5/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc"
    url_duacs = "/Volumes/EMANS_HDD/data/dc20a_osse/results/swot1nadir5/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc"
    url_nerf = "/Users/eman/code_projects/logs/saved_data/test_res_swot1nadir5_jz.nc"
else:
    raise ValueError(f"Unrecognized exp: {experiment}")

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid


logger.info("Dataset I - MIOST")

ds_predict = xr.open_dataset(url_miost)
ds_predict = post_process(ds_predict, "gssh")
ds_field["ssh_miost"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset II - DUACS")
ds_predict = xr.open_dataset(url_duacs)
ds_predict = post_process(ds_predict, "gssh")
ds_field["ssh_duacs"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset II - SIREN")
ds_predict = xr.open_dataset(url_nerf)
ds_predict = post_process(ds_predict, "ssh_model_predict")
ds_field["ssh_nerf"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

In [None]:
from tqdm.notebook import tqdm
from inr4ssh._src.operators.ssh import (
    ssh2uv_ds_2dt,
    ssh2uv_da_2dt,
    kinetic_energy,
    enstropy,
    ssh2rv_ds_2dt,
    ssh2rv_da_2dt,
)

variables = ["ssh", "ssh_duacs", "ssh_miost", "ssh_nerf"]

logger.info("Calculating Kinetic Energy...")
for ivar in tqdm(variables):
    # calculate UV components
    u, v = ssh2uv_da_2dt(ds_field[ivar])
    ke = kinetic_energy(u, v)

    ds_field[f"{ivar}_ke"] = (("time", "latitude", "longitude"), ke)

logger.info("Calculating Enstropy...")
for ivar in tqdm(variables):
    # calculate UV components
    rv = ssh2rv_da_2dt(ds_field[ivar])

    ds_field[f"{ivar}_ens"] = (("time", "latitude", "longitude"), rv)

In [None]:
ds_field

## Metrics - Isotropic PSD

In [None]:
ds_field["longitude"] = ds_field.longitude * 111e3
ds_field["latitude"] = ds_field.latitude * 111e3

### Sea Surface Height

In [None]:
# calculate
ds_field_psd_natl60 = psd_isotropic(ds_field.ssh)
ds_field_psd_duacs = psd_isotropic(ds_field.ssh_duacs)
ds_field_psd_miost = psd_isotropic(ds_field.ssh_miost)
ds_field_psd_nerf = psd_isotropic(ds_field.ssh_nerf)

In [None]:
import seaborn as sns

# with plt.style.context(["science"]):
sns.set_context(context="talk", font_scale=0.7)
fig, ax, secax = plot_psd_isotropic(
    ds_field_psd_natl60.freq_r.values * 1e3, ds_field_psd_natl60.values, color="black"
)

ax.plot(
    ds_field_psd_duacs.freq_r.values * 1e3, ds_field_psd_duacs.values, color="tab:green"
)
ax.plot(
    ds_field_psd_miost.freq_r.values * 1e3, ds_field_psd_miost.values, color="tab:blue"
)
ax.plot(
    ds_field_psd_nerf.freq_r.values * 1e3, ds_field_psd_nerf.values, color="tab:red"
)
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(ds_field_psd_natl60.freq_r.values * 1e3)),
        0.25 * np.ma.max(np.ma.masked_invalid(ds_field_psd_natl60.freq_r.values * 1e3)),
    )
)
plt.legend(["NATL60", "DUACS", "MIOST", "NerF"])
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/osse_2020a_psd_iso.png"))
plt.show()

### Kinetic Energy

In [None]:
# calculate
ds_field_psd_natl60_ke = psd_isotropic(ds_field.ssh_ke)
ds_field_psd_duacs_ke = psd_isotropic(ds_field.ssh_duacs_ke)
ds_field_psd_miost_ke = psd_isotropic(ds_field.ssh_miost_ke)
ds_field_psd_nerf_ke = psd_isotropic(ds_field.ssh_nerf_ke)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_field_psd_natl60_ke.freq_r.values * 1e3,
    ds_field_psd_natl60_ke.values,
    color="black",
)

ax.plot(
    ds_field_psd_duacs_ke.freq_r.values * 1e3,
    ds_field_psd_duacs_ke.values,
    color="tab:green",
)
ax.plot(
    ds_field_psd_miost_ke.freq_r.values * 1e3,
    ds_field_psd_miost_ke.values,
    color="tab:blue",
)
ax.plot(
    ds_field_psd_nerf_ke.freq_r.values * 1e3,
    ds_field_psd_nerf_ke.values,
    color="tab:red",
)
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(ds_field_psd_natl60_ke.freq_r.values * 1e3)),
        0.25
        * np.ma.max(np.ma.masked_invalid(ds_field_psd_natl60_ke.freq_r.values * 1e3)),
    )
)
ax.set_ylabel(f"PSD [m$^2$s$^{-2}$/cycles/m")
plt.legend(["NATL60", "DUACS", "MIOST", "NerF"])
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/osse_2020a_psd_iso_ke.png"))
plt.show()

### Enstropy

In [None]:
# calculate
ds_field_psd_natl60_ens = psd_isotropic(ds_field.ssh_ens)
ds_field_psd_duacs_ens = psd_isotropic(ds_field.ssh_duacs_ens)
ds_field_psd_miost_ens = psd_isotropic(ds_field.ssh_miost_ens)
ds_field_psd_nerf_ens = psd_isotropic(ds_field.ssh_nerf_ens)

In [None]:
sns.set_context(context="talk", font_scale=0.7)
fig, ax, secax = plot_psd_isotropic(
    ds_field_psd_natl60_ens.freq_r.values * 1e3,
    ds_field_psd_natl60_ens.values,
    color="black",
)

ax.plot(
    ds_field_psd_duacs_ens.freq_r.values * 1e3,
    ds_field_psd_duacs_ens.values,
    color="tab:green",
)
ax.plot(
    ds_field_psd_miost_ens.freq_r.values * 1e3,
    ds_field_psd_miost_ens.values,
    color="tab:blue",
)
ax.plot(
    ds_field_psd_nerf_ens.freq_r.values * 1e3,
    ds_field_psd_nerf_ens.values,
    color="tab:red",
)
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(ds_field_psd_natl60_ens.freq_r.values * 1e3)),
        0.25
        * np.ma.max(np.ma.masked_invalid(ds_field_psd_natl60_ens.freq_r.values * 1e3)),
    )
)
ax.set_ylabel(f"PSD [s$^{-1}$/cycles/m")
plt.legend(["NATL60", "DUACS", "MIOST", "SIREN"])
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/osse_2020a_psd_iso_ens.png"))
plt.show()

## Metrics - Isotropic PSD Score

### Sea Surface Height

In [None]:
from inr4ssh._src.metrics.psd import psd_isotropic_score, wavelength_resolved_isotropic
from tqdm.notebook import tqdm

sns.set_context(context="talk", font_scale=0.7)
variables = ["ssh_duacs", "ssh_miost", "ssh_nerf"]
colours = ["tab:green", "tab:blue", "tab:red"]
linestyle = ["-", "--", "-."]

ax = None

for ivariable, icolour, ilinestyle in tqdm(zip(variables, colours, linestyle)):
    psd_iso_score = psd_isotropic_score(ds_field["ssh"], ds_field[ivariable])

    space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
    iname = ivariable.split("_")[1].upper()
    print(
        f"Shortest Spatial Wavelength Resolved [{iname}] = {space_iso_resolved/1e3:.2f} (km)"
    )

    fig, ax, secax = plot_psd_isotropic(
        psd_iso_score.freq_r.values * 1e3,
        psd_iso_score.values,
        ax=ax,
        color=icolour,
        linestyle="-",
    )

    ax.set(ylabel="PSD Score", yscale="linear")
    plt.ylim((0, 1.0))
    plt.xlim(
        (
            np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
            0.25 * np.ma.max(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        )
    )

    # plot the graph point
    resolved_scale = 1 / (space_iso_resolved * 1e-3)
    ax.vlines(
        x=resolved_scale, ymin=0, ymax=0.5, color=icolour, linewidth=2, linestyle="--"
    )
    ax.hlines(
        y=0.5,
        xmin=np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        xmax=resolved_scale,
        color=icolour,
        linewidth=2,
        linestyle="--",
    )
    ax.set_aspect("equal", "box")

    label = f"{iname} - $\lambda$ > {int(space_iso_resolved*1e-3)} km"
    plt.scatter(
        resolved_scale, 0.5, color=icolour, marker=".", linewidth=5, label=label
    )

plt.legend()
plt.tight_layout()

fig.savefig(Path(root).joinpath("figures/osse_2020a_psd_score_iso_ssh.png"))
plt.show()

### Kinetic Energy

In [None]:
variables = ["ssh_duacs_ke", "ssh_miost_ke", "ssh_nerf_ke"]
colours = ["tab:green", "tab:blue", "tab:red"]
linestyle = ["-", "--", "-."]
sns.set_context(context="talk", font_scale=0.7)
ax = None

for ivariable, icolour, ilinestyle in tqdm(zip(variables, colours, linestyle)):
    psd_iso_score = psd_isotropic_score(ds_field["ssh_ke"], ds_field[ivariable])

    space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
    iname = ivariable.split("_")[1].upper()
    print(
        f"Shortest Spatial Wavelength Resolved [{iname}] = {space_iso_resolved/1e3:.2f} (km$^2$s$^{-2})"
    )

    fig, ax, secax = plot_psd_isotropic(
        psd_iso_score.freq_r.values * 1e3,
        psd_iso_score.values,
        ax=ax,
        color=icolour,
        linestyle="-",
    )

    ax.set(ylabel="PSD Score", yscale="linear")
    plt.ylim((0, 1.0))
    plt.xlim(
        (
            np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
            0.25 * np.ma.max(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        )
    )

    # plot the graph point
    resolved_scale = 1 / (space_iso_resolved * 1e-3)
    ax.vlines(
        x=resolved_scale, ymin=0, ymax=0.5, color=icolour, linewidth=2, linestyle="--"
    )
    ax.hlines(
        y=0.5,
        xmin=np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        xmax=resolved_scale,
        color=icolour,
        linewidth=2,
        linestyle="--",
    )
    ax.set_aspect("equal", "box")

    label = f"{iname} - $\lambda$ > {int(space_iso_resolved*1e-3)} km$^2$s$^{-2}$"
    plt.scatter(
        resolved_scale, 0.5, color=icolour, marker=".", linewidth=5, label=label
    )

plt.legend()
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/osse_2020a_psd_score_iso_ke.png"))
plt.show()

### Enstropy

In [None]:
variables = ["ssh_duacs_ens", "ssh_miost_ens", "ssh_nerf_ens"]
colours = ["tab:green", "tab:blue", "tab:red"]
linestyle = ["-", "--", "-."]
sns.set_context(context="talk", font_scale=0.7)
ax = None

for ivariable, icolour, ilinestyle in tqdm(zip(variables, colours, linestyle)):
    psd_iso_score = psd_isotropic_score(ds_field["ssh_ens"], ds_field[ivariable])

    space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
    iname = ivariable.split("_")[1].upper()
    print(
        f"Shortest Spatial Wavelength Resolved [{iname}] = {space_iso_resolved/1e3:.2f} (s$^{-1}$)"
    )

    fig, ax, secax = plot_psd_isotropic(
        psd_iso_score.freq_r.values * 1e3,
        psd_iso_score.values,
        ax=ax,
        color=icolour,
        linestyle="-",
    )
    ax.set_aspect("equal", "box")

    ax.set(ylabel="PSD Score", yscale="linear")
    plt.ylim((0, 1.0))
    plt.xlim(
        (
            np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
            0.25 * np.ma.max(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        )
    )

    # plot the graph point
    resolved_scale = 1 / (space_iso_resolved * 1e-3)
    ax.vlines(
        x=resolved_scale, ymin=0, ymax=0.5, color=icolour, linewidth=2, linestyle="--"
    )
    ax.hlines(
        y=0.5,
        xmin=np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        xmax=resolved_scale,
        color=icolour,
        linewidth=2,
        linestyle="--",
    )

    label = f"{iname} - $\lambda$ > {int(space_iso_resolved*1e-3)} s$^{-1}$"
    plt.scatter(
        resolved_scale, 0.5, color=icolour, marker=".", linewidth=5, label=label
    )

plt.legend()
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/osse_2020a_psd_score_iso_ens.png"))
plt.show()