# Figure - SpatioTemporal PSD

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels

# from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)

from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM_*.nc"


def preprocess(ds):

    # subset time
    # ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    # ds = ds.coarsen({"lon": 5, "lat": 5}).mean()

    ds = ds.resample(time="1D").mean()
    return ds


ds_field = xr.open_mfdataset(url, preprocess=preprocess)

ds_field

In [None]:
def post_process(ds, variable):

    # subset temporal space
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    ds = correct_longitude_domain(ds)

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
ds_field = post_process(ds_field, "sossheig")
ds_field

### Example Results

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid

logger.info("Dataset I - MIOST")
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "gssh")

ds_field["ssh_miost"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset II - DUACS")
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "gssh")

ds_field["ssh_duacs"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset II - SIREN")
url = "/Users/eman/code_projects/logs/saved_data/test_res_nadir4_lb.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "ssh_model_predict")

ds_field["ssh_nerf"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

In [None]:
from tqdm.notebook import tqdm
from inr4ssh._src.operators.ssh import (
    ssh2uv_ds_2dt,
    ssh2uv_da_2dt,
    kinetic_energy,
    enstropy,
    ssh2rv_ds_2dt,
    ssh2rv_da_2dt,
)

variables = ["ssh", "ssh_duacs", "ssh_miost", "ssh_nerf"]

logger.info("Calculating Kinetic Energy...")
for ivar in tqdm(variables):
    # calculate UV components
    u, v = ssh2uv_da_2dt(ds_field[ivar])
    ke = kinetic_energy(u, v)

    ds_field[f"{ivar}_ke"] = (("time", "latitude", "longitude"), ke)

logger.info("Calculating Enstropy...")
for ivar in tqdm(variables):
    # calculate UV components
    rv = ssh2rv_da_2dt(ds_field[ivar])

    ds_field[f"{ivar}_ens"] = (("time", "latitude", "longitude"), rv)

In [None]:
ds_field

## Metrics II - Space-Time PSD

#### Absolute Values

In [None]:
from inr4ssh._src.metrics.psd import (
    psd_isotropic_score,
    psd_spacetime_score,
    wavelength_resolved_spacetime,
    wavelength_resolved_isotropic,
)

In [None]:
time_norm = np.timedelta64(1, "D")
# mean psd of signal
ds_field["time"] = (ds_field.time - ds_field.time[0]) / time_norm

#### Degrees

In [None]:
# Time-Longitude (Lat avg) PSD Score
ds_field_ = ds_field.chunk(
    {
        "time": 1,
        "longitude": ds_field["longitude"].size,
        "latitude": ds_field["latitude"].size,
    }
).compute()

ds_field_psd = psd_spacetime(ds_field_["ssh"])
ds_predict_psd = psd_spacetime(ds_field_["ssh_predict"])

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_field_psd.freq_longitude,
    ds_field_psd.freq_time,
    ds_field_psd,
)
# ax.set_xlim((1000, 10))
ax.set_xlabel("Wavelength [degrees]")
# cbat.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/degree]")
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/degree]")
# cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

plt.tight_layout()
plt.show()

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_predict_psd.freq_longitude,
    ds_predict_psd.freq_time,
    ds_predict_psd,
)
# ax.set_xlim((1000, 10))
ax.set_xlabel("Wavelength [degrees]")
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/degree]")
# cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

plt.tight_layout()
plt.show()

### PSD Score

In [None]:
psd_score = psd_spacetime_score(ds_field_["ssh_predict"], ds_field_["ssh"])

In [None]:
fig, ax, cbar = plot_psd_spacetime_score_wavelength(
    psd_score.freq_longitude,
    psd_score.freq_time,
    psd_score,
)

ax.set_xlabel("Wavelength [degrees]")

plt.tight_layout()
plt.show()

In [None]:
fig, ax, cbar = plot_psd_spacetime_score_wavenumber(
    psd_score.freq_longitude,
    psd_score.freq_time,
    psd_score,
)

ax.set_xlabel("Wavenumber [cycles/degrees]")

plt.tight_layout()
plt.show()

In [None]:
spatial_resolved, time_resolved = wavelength_resolved_spacetime(psd_score)

In [None]:
print(f"Shortest Spatial Wavelength Resolved = {spatial_resolved:.2f} (degree lon)")
print(f"Shortest Temporal Wavelength Resolved = {time_resolved:.2f} (days)")

#### Kilometers

In [None]:
ds_field["longitude"] = ds_field.longitude * 111e3
ds_field["latitude"] = ds_field.latitude * 111e3

# Time-Longitude (Lat avg) PSD Score
ds_field = ds_field.chunk(
    {
        "time": 1,
        "longitude": ds_field["longitude"].size,
        "latitude": ds_field["latitude"].size,
    }
).compute()

In [None]:
ds_field_psd_ssh = psd_spacetime(ds_field["ssh"])
ds_predict_psd_ssh_duacs = psd_spacetime(ds_field["ssh_duacs"])
ds_predict_psd_ssh_miost = psd_spacetime(ds_field["ssh_miost"])
ds_predict_psd_ssh_nerf = psd_spacetime(ds_field["ssh_nerf"])

In [None]:
from pathlib import Path

data = [
    ds_field_psd_ssh,
    ds_predict_psd_ssh_duacs,
    ds_predict_psd_ssh_miost,
    ds_predict_psd_ssh_nerf,
]
names = ["natl60", "duacs", "miost", "nef"]

for idata, iname in zip(data, names):

    fig, ax, cbar = plot_psd_spacetime_wavelength(
        idata.freq_longitude * 1e3,
        idata.freq_time,
        idata,
    )
    # ax.set_xlim((1000, 10))
    ax.set_xlabel("Wavelength [km]")
    # cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
    # cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

    plt.tight_layout()
    fig.savefig(Path(root).joinpath(f"figures/osse_2020a_psd_{iname.lower()}.png"))
    plt.show()

In [None]:
from pathlib import Path

ds_field_psd_ssh = psd_spacetime(ds_field["ssh_ke"])
ds_predict_psd_ssh_duacs = psd_spacetime(ds_field["ssh_duacs_ke"])
ds_predict_psd_ssh_miost = psd_spacetime(ds_field["ssh_miost_ke"])
ds_predict_psd_ssh_nerf = psd_spacetime(ds_field["ssh_nerf_ke"])

data = [
    ds_field_psd_ssh,
    ds_predict_psd_ssh_duacs,
    ds_predict_psd_ssh_miost,
    ds_predict_psd_ssh_nerf,
]
names = ["natl60", "duacs", "miost", "nef"]

for idata, iname in zip(data, names):

    fig, ax, cbar = plot_psd_spacetime_wavelength(
        idata.freq_longitude * 1e3,
        idata.freq_time,
        idata,
    )
    # ax.set_xlim((1000, 10))
    # ax.set_xlabel("Wavelength [km]")
    cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
    # cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

    plt.tight_layout()
    fig.savefig(Path(root).joinpath(f"figures/osse_2020a_psd_{iname.lower()}_ke.png"))
    plt.show()

In [None]:
from pathlib import Path

ds_field_psd_ssh = psd_spacetime(ds_field["ssh_ens"])
ds_predict_psd_ssh_duacs = psd_spacetime(ds_field["ssh_duacs_ens"])
ds_predict_psd_ssh_miost = psd_spacetime(ds_field["ssh_miost_ens"])
ds_predict_psd_ssh_nerf = psd_spacetime(ds_field["ssh_nerf_ens"])

data = [
    ds_field_psd_ssh,
    ds_predict_psd_ssh_duacs,
    ds_predict_psd_ssh_miost,
    ds_predict_psd_ssh_nerf,
]
names = ["natl60", "duacs", "miost", "nef"]

for idata, iname in zip(data, names):

    fig, ax, cbar = plot_psd_spacetime_wavelength(
        idata.freq_longitude * 1e3,
        idata.freq_time,
        idata,
    )
    # ax.set_xlim((1000, 10))
    # ax.set_xlabel("Wavelength [km]")
    # cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
    cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

    plt.tight_layout()
    fig.savefig(Path(root).joinpath(f"figures/osse_2020a_psd_{iname.lower()}_ens.png"))
    plt.show()

### PSD Score

In [None]:
# # grab ssh
# ds_field_psd = ds_field.ssh
# ds_predict_psd = ds_field.ssh_predict
#
# # correct units, degrees -> meters
# ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
# ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3
# ds_predict_psd["longitude"] = ds_predict_psd.longitude * 111e3
# ds_predict_psd["latitude"] = ds_predict_psd.latitude * 111e3
#
# # Time-Longitude (Lat avg) PSD Score
# ds_field_psd = ds_field_psd.chunk(
#     {
#         "time": 1,
#         "longitude": ds_field_psd["longitude"].size,
#         "latitude": ds_field_psd["latitude"].size,
#     }
# ).compute()
# ds_predict_psd = ds_predict_psd.chunk(
#     {
#         "time": 1,
#         "longitude": ds_predict_psd["longitude"].size,
#         "latitude": ds_predict_psd["latitude"].size,
#     }
# ).compute()
#
#
# psd_score = psd_spacetime_score(ds_predict_psd, ds_field_psd)

In [None]:
# psd_score_duacs = psd_spacetime_score(ds_field["ssh_duacs"], ds_field["ssh"])
# psd_score_miost = psd_spacetime_score(ds_field["ssh_miost"], ds_field["ssh"])
# psd_score_nerf = psd_spacetime_score(ds_field["ssh_nerf"], ds_field["ssh"])

In [None]:
# fig, ax, cbar = plot_psd_spacetime_score_wavenumber(
#     psd_score.freq_longitude * 1e3,
#     psd_score.freq_time,
#     psd_score,
# )
#
# plt.tight_layout()
# plt.show()

### Sea Surface Height

In [None]:
psd_score_duacs = psd_spacetime_score(ds_field["ssh_duacs"], ds_field["ssh"])
psd_score_miost = psd_spacetime_score(ds_field["ssh_miost"], ds_field["ssh"])
psd_score_nerf = psd_spacetime_score(ds_field["ssh_nerf"], ds_field["ssh"])

data = [
    psd_score_duacs,
    psd_score_miost,
    psd_score_nerf,
]
names = ["duacs", "miost", "nef"]

for idata, iname in zip(data, names):

    fig, ax, cbar = plot_psd_spacetime_score_wavelength(
        idata.freq_longitude * 1e3,
        idata.freq_time,
        idata,
    )
    # # ax.set_xlim((1000, 10))
    # # ax.set_xlabel("Wavelength [km]")
    # cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
    # # cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

    plt.tight_layout()
    fig.savefig(
        Path(root).joinpath(f"figures/osse_2020a_psd_score_{iname.lower()}.png")
    )
    plt.show()

### Kinetic Energy

In [None]:
psd_score_duacs = psd_spacetime_score(ds_field["ssh_duacs_ke"], ds_field["ssh_ke"])
psd_score_miost = psd_spacetime_score(ds_field["ssh_miost_ke"], ds_field["ssh_ke"])
psd_score_nerf = psd_spacetime_score(ds_field["ssh_nerf_ke"], ds_field["ssh_ke"])

data = [
    psd_score_duacs,
    psd_score_miost,
    psd_score_nerf,
]
names = ["duacs", "miost", "nef"]

for idata, iname in zip(data, names):

    fig, ax, cbar = plot_psd_spacetime_score_wavelength(
        idata.freq_longitude * 1e3,
        idata.freq_time,
        idata,
    )
    # # ax.set_xlim((1000, 10))
    # # ax.set_xlabel("Wavelength [km]")
    # cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
    # # cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

    plt.tight_layout()
    fig.savefig(
        Path(root).joinpath(f"figures/osse_2020a_psd_score_{iname.lower()}_ke.png")
    )
    plt.show()

### Enstropy

In [None]:
psd_score_duacs = psd_spacetime_score(ds_field["ssh_duacs_ens"], ds_field["ssh_ens"])
psd_score_miost = psd_spacetime_score(ds_field["ssh_miost_ens"], ds_field["ssh_ens"])
psd_score_nerf = psd_spacetime_score(ds_field["ssh_nerf_ens"], ds_field["ssh_ens"])

data = [
    psd_score_duacs,
    psd_score_miost,
    psd_score_nerf,
]
names = ["duacs", "miost", "nef"]

for idata, iname in zip(data, names):

    fig, ax, cbar = plot_psd_spacetime_score_wavelength(
        idata.freq_longitude * 1e3,
        idata.freq_time,
        idata,
    )
    # # ax.set_xlim((1000, 10))
    # # ax.set_xlabel("Wavelength [km]")
    # cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
    # # cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

    plt.tight_layout()
    fig.savefig(
        Path(root).joinpath(f"figures/osse_2020a_psd_score_{iname.lower()}_ens.png")
    )
    plt.show()

In [None]:
# fig, ax, cbar = plot_psd_spacetime_score_wavelength(
#     psd_score.freq_longitude * 1e3,
#     psd_score.freq_time,
#     psd_score,
# )
#
# plt.tight_layout()
# plt.show()

In [None]:
# spatial_resolved, time_resolved = wavelength_resolved_spacetime(psd_score)

In [None]:
# print(f"Shortest Spatial Wavelength Resolved = {spatial_resolved*1e-3:.2f} (km lon)")
# print(f"Shortest Temporal Wavelength Resolved = {time_resolved:.2f} (days)")