# Figure - Isotropic PSD (Grid)

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
from pathlib import Path
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
import scienceplots

# plt.style.use("science")

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_longitude_domain

# from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.preprocess.regrid import oi_regrid
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)
from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
from inr4ssh._src.preprocess.spatial import convert_lon_360_180, convert_lon_180_360

convert_lon_360_180(295), convert_lon_360_180(305)

In [None]:
def post_process(ds, variable):
    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    from inr4ssh._src.preprocess.spatial import convert_lon_360_180

    ds["longitude"] = convert_lon_360_180(ds.longitude)

    # subset temporal space
    ds = ds.sel(
        time=slice(np.datetime64("2017-01-01"), np.datetime64("2017-12-31")),
        # longitude=slice(-75.0, -45.0),
        longitude=slice(convert_lon_360_180(295), convert_lon_360_180(305)),
        latitude=slice(33.0, 53.0),
    )
    # ds = ds.sel(
    #     time=slice(np.datetime64("2017-02-01"), np.datetime64("2017-03-31")),
    #     longitude=slice(-75.0, -45.0),
    #     latitude=slice(33.0, 53.0),
    # )

    # # subset spatial space
    # ds = ds.where(
    #     (ds["longitude"] >= -75.0)
    #     & (ds["longitude"] <= -45.0)
    #     & (ds["latitude"] >= 33.0)
    #     & (ds["latitude"] <= 53.0),
    #     drop=True,
    # )

    # # subset spatial space (evaluation)
    # ds = ds.where(
    #     (ds["longitude"] >= -65.0)
    #     & (ds["longitude"] <= -55.0)
    #     & (ds["latitude"] >= 33.0)
    #     & (ds["latitude"] <= 43.0),
    #     drop=True,
    # )

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
logger.info("Dataset I - DUACS")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_DUACS.nc"
ds_field = xr.open_dataset(url)

ds_field = post_process(ds_field, "ssh")

In [None]:
ds_field.longitude.min()

### Example Results

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/

In [None]:
# url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
# # url =
# ds_predict = xr.open_dataset(url)
#
# ds_predict

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
logger.info("Dataset II - BASELINE")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_BASELINE.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh")

ds_field["ssh_oi"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset III - MIOST")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_MIOST.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh")

ds_field["ssh_miost"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset IV - NerF")
url = "/Users/eman/code_projects/logs/saved_data/test_dc21b_feb_pretrain.nc"
url = "/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc"
# url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_4dvarNet_2022.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh")

ds_field["ssh_nerf"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

ds_field

In [None]:
from tqdm.notebook import tqdm
from inr4ssh._src.operators.ssh import (
    ssh2uv_ds_2dt,
    ssh2uv_da_2dt,
    kinetic_energy,
    enstropy,
    ssh2rv_ds_2dt,
    ssh2rv_da_2dt,
)

variables = ["ssh", "ssh_oi", "ssh_miost", "ssh_nerf"]

logger.info("Calculating Kinetic Energy...")
for ivar in tqdm(variables):
    # calculate UV components
    u, v = ssh2uv_da_2dt(ds_field[ivar])
    ke = kinetic_energy(u, v)

    ds_field[f"{ivar}_ke"] = (("time", "latitude", "longitude"), ke)

logger.info("Calculating Enstropy...")
for ivar in tqdm(variables):
    # calculate UV components
    rv = ssh2rv_da_2dt(ds_field[ivar])

    ds_field[f"{ivar}_ens"] = (("time", "latitude", "longitude"), rv)

## Metrics - Isotropic PSD

In [None]:
ds_field["longitude"] = ds_field.longitude * 111e3
ds_field["latitude"] = ds_field.latitude * 111e3

### Sea Surface Height

In [None]:
# calculate
ds_field_psd_duacs = psd_isotropic(ds_field.ssh)
ds_field_psd_oi = psd_isotropic(ds_field.ssh_oi)
ds_field_psd_miost = psd_isotropic(ds_field.ssh_miost)
ds_field_psd_nerf = psd_isotropic(ds_field.ssh_nerf)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_field_psd_oi.freq_r.values * 1e3, ds_field_psd_oi.values, color="black"
)

ax.plot(
    ds_field_psd_duacs.freq_r.values * 1e3, ds_field_psd_duacs.values, color="tab:green"
)
ax.plot(
    ds_field_psd_miost.freq_r.values * 1e3, ds_field_psd_miost.values, color="tab:blue"
)
ax.plot(
    ds_field_psd_nerf.freq_r.values * 1e3, ds_field_psd_nerf.values, color="tab:red"
)
# plt.xlim(
#     (
#         np.ma.min(np.ma.masked_invalid(ds_field_psd_duacs.freq_r.values * 1e3)),
#         np.ma.max(np.ma.masked_invalid(ds_field_psd_duacs.freq_r.values * 1e3)),
#     )
# )
plt.legend(["OI", "DUACS", "MIOST", "NerF"])
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/dc21a_psd_iso.png"))
plt.show()

### Kinetic Energy

In [None]:
# calculate
ds_field_psd_duacs_ke = psd_isotropic(ds_field.ssh_ke)
ds_field_psd_oi_ke = psd_isotropic(ds_field.ssh_oi_ke)
ds_field_psd_miost_ke = psd_isotropic(ds_field.ssh_miost_ke)
ds_field_psd_nerf_ke = psd_isotropic(ds_field.ssh_nerf_ke)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_field_psd_oi_ke.freq_r.values * 1e3,
    ds_field_psd_oi_ke.values,
    color="black",
)

ax.plot(
    ds_field_psd_duacs_ke.freq_r.values * 1e3,
    ds_field_psd_duacs_ke.values,
    color="tab:green",
)
ax.plot(
    ds_field_psd_miost_ke.freq_r.values * 1e3,
    ds_field_psd_miost_ke.values,
    color="tab:blue",
)
ax.plot(
    ds_field_psd_nerf_ke.freq_r.values * 1e3,
    ds_field_psd_nerf_ke.values,
    color="tab:red",
)
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(ds_field_psd_duacs_ke.freq_r.values * 1e3)),
        np.ma.max(np.ma.masked_invalid(ds_field_psd_duacs_ke.freq_r.values * 1e3)),
    )
)
ax.set_ylabel(f"PSD [m$^2$s$^{-2}$/cycles/m")
plt.legend(["NATL60", "DUACS", "MIOST", "NerF"])
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/dc21a_psd_iso_ke.png"))
plt.show()

### Enstropy

In [None]:
ds_field

In [None]:
# calculate
ds_field_psd_duacs_ens = psd_isotropic(ds_field.ssh_ens)
ds_field_psd_oi_ens = psd_isotropic(ds_field.ssh_oi_ens)
ds_field_psd_miost_ens = psd_isotropic(ds_field.ssh_miost_ens)
ds_field_psd_nerf_ens = psd_isotropic(ds_field.ssh_nerf_ens)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_field_psd_oi_ens.freq_r.values * 1e3,
    ds_field_psd_oi_ens.values,
    color="black",
)

ax.plot(
    ds_field_psd_duacs_ens.freq_r.values * 1e3,
    ds_field_psd_duacs_ens.values,
    color="tab:green",
)
ax.plot(
    ds_field_psd_miost_ens.freq_r.values * 1e3,
    ds_field_psd_miost_ens.values,
    color="tab:blue",
)
ax.plot(
    ds_field_psd_nerf_ens.freq_r.values * 1e3,
    ds_field_psd_nerf_ens.values,
    color="tab:red",
)
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(ds_field_psd_duacs_ens.freq_r.values * 1e3)),
        np.ma.max(np.ma.masked_invalid(ds_field_psd_duacs_ens.freq_r.values * 1e3)),
    )
)
ax.set_ylabel(f"PSD [s$^{-1}$/cycles/m")
plt.legend(["NATL60", "DUACS", "MIOST", "SIREN"])
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/osse_2020a_psd_iso_ens.png"))
plt.show()

## Metrics - Isotropic PSD Score

### Sea Surface Height

In [None]:
from inr4ssh._src.metrics.psd import psd_isotropic_score, wavelength_resolved_isotropic
from tqdm.notebook import tqdm

variables = ["ssh_oi", "ssh_nerf", "ssh_miost"]
colours = ["black", "tab:red", "tab:blue"]
linestyle = ["-", "--", "-."]

ax = None

for ivariable, icolour, ilinestyle in tqdm(zip(variables, colours, linestyle)):
    psd_iso_score = psd_isotropic_score(ds_field["ssh"], ds_field[ivariable])

    space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
    if ivariable != "ssh":
        iname = ivariable.split("_")[1].upper()
    else:
        iname = "DUACS"
    print(
        f"Shortest Spatial Wavelength Resolved [{iname}] = {space_iso_resolved/1e3:.2f} (km)"
    )

    fig, ax, secax = plot_psd_isotropic(
        psd_iso_score.freq_r.values * 1e3,
        psd_iso_score.values,
        ax=ax,
        color=icolour,
        linestyle="-",
    )

    ax.set(ylabel="PSD Score", yscale="linear")
    plt.ylim((0, 1.0))
    plt.xlim(
        (
            np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        )
    )

    # plot the graph point
    resolved_scale = 1 / (space_iso_resolved * 1e-3)
    ax.vlines(
        x=resolved_scale, ymin=0, ymax=0.5, color=icolour, linewidth=2, linestyle="--"
    )
    ax.hlines(
        y=0.5,
        xmin=np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        xmax=resolved_scale,
        color=icolour,
        linewidth=2,
        linestyle="--",
    )

    label = f"{iname} - $\lambda$ > {int(space_iso_resolved*1e-3)} km"
    plt.scatter(
        resolved_scale, 0.5, color=icolour, marker=".", linewidth=5, label=label
    )

plt.legend()
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/dc21a_psd_score_iso_ssh.png"))
plt.show()

### Kinetic Energy

In [None]:
variables = ["ssh_oi_ke", "ssh_miost_ke", "ssh_nerf_ke"]
colours = ["black", "tab:blue", "tab:red"]
linestyle = ["-", "--", "-."]

ax = None

for ivariable, icolour, ilinestyle in tqdm(zip(variables, colours, linestyle)):
    psd_iso_score = psd_isotropic_score(ds_field["ssh_ke"], ds_field[ivariable])

    space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
    iname = ivariable.split("_")[1].upper()
    print(
        f"Shortest Spatial Wavelength Resolved [{iname}] = {space_iso_resolved/1e3:.2f} (km$^2$s$^{-2})"
    )

    fig, ax, secax = plot_psd_isotropic(
        psd_iso_score.freq_r.values * 1e3,
        psd_iso_score.values,
        ax=ax,
        color=icolour,
        linestyle="-",
    )

    ax.set(ylabel="PSD Score", yscale="linear")
    plt.ylim((0, 1.0))
    plt.xlim(
        (
            np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        )
    )

    # plot the graph point
    resolved_scale = 1 / (space_iso_resolved * 1e-3)
    ax.vlines(
        x=resolved_scale, ymin=0, ymax=0.5, color=icolour, linewidth=2, linestyle="--"
    )
    ax.hlines(
        y=0.5,
        xmin=np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        xmax=resolved_scale,
        color=icolour,
        linewidth=2,
        linestyle="--",
    )

    label = f"{iname} - $\lambda$ > {int(space_iso_resolved*1e-3)} km$^2$s$^{-2}$"
    plt.scatter(
        resolved_scale, 0.5, color=icolour, marker=".", linewidth=5, label=label
    )

plt.legend()
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/dc21a_psd_score_iso_ke.png"))
plt.show()

### Enstropy

In [None]:
variables = ["ssh_oi_ens", "ssh_miost_ens", "ssh_nerf_ens"]
colours = ["black", "tab:blue", "tab:red"]
linestyle = ["-", "--", "-."]

ax = None

for ivariable, icolour, ilinestyle in tqdm(zip(variables, colours, linestyle)):
    psd_iso_score = psd_isotropic_score(ds_field["ssh_ens"], ds_field[ivariable])

    space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
    iname = ivariable.split("_")[1].upper()
    print(
        f"Shortest Spatial Wavelength Resolved [{iname}] = {space_iso_resolved/1e3:.2f} (s$^{-1}$)"
    )

    fig, ax, secax = plot_psd_isotropic(
        psd_iso_score.freq_r.values * 1e3,
        psd_iso_score.values,
        ax=ax,
        color=icolour,
        linestyle="-",
    )

    ax.set(ylabel="PSD Score", yscale="linear")
    plt.ylim((0, 1.0))
    plt.xlim(
        (
            np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        )
    )

    # plot the graph point
    resolved_scale = 1 / (space_iso_resolved * 1e-3)
    ax.vlines(
        x=resolved_scale, ymin=0, ymax=0.5, color=icolour, linewidth=2, linestyle="--"
    )
    ax.hlines(
        y=0.5,
        xmin=np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        xmax=resolved_scale,
        color=icolour,
        linewidth=2,
        linestyle="--",
    )

    label = f"{iname} - $\lambda$ > {int(space_iso_resolved*1e-3)} s$^{-1}$"
    plt.scatter(
        resolved_scale, 0.5, color=icolour, marker=".", linewidth=5, label=label
    )

plt.legend()
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/dc21a_psd_score_iso_ens.png"))
plt.show()