# Figure - Isotropic PSD (Alongtrack)

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
from pathlib import Path
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
import scienceplots

# plt.style.use("science")

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_longitude_domain

# from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.preprocess.regrid import oi_regrid
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)
from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
from inr4ssh._src.preprocess.spatial import convert_lon_360_180, convert_lon_180_360


def post_process(ds, variable):
    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    from inr4ssh._src.preprocess.spatial import convert_lon_360_180, convert_lon_180_360

    ds["longitude"] = convert_lon_360_180(ds.longitude)

    # subset temporal space
    ds = ds.sel(
        time=slice(np.datetime64("2017-01-01"), np.datetime64("2017-12-31")),
        # longitude=slice(-75.0, -45.0),
        longitude=slice(convert_lon_360_180(295), convert_lon_360_180(305)),
        latitude=slice(33.0, 43.0),
    )

    ds["longitude"] = convert_lon_180_360(ds.longitude)

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
logger.info("Dataset I - DUACS")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_DUACS.nc"
ds_field = xr.open_dataset(url)

ds_field = post_process(ds_field, "ssh")

In [None]:
ds_field

### Example Results

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc21b/results

In [None]:
# url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
# # url =
# ds_predict = xr.open_dataset(url)
#
# ds_predict

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
# url_nerf = "/Users/eman/code_projects/logs/saved_data/test_dc21b_feb_pretrain.nc"
url_4dvarnet = (
    "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_4dvarNet_2022.nc"
)
url_bfn = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_BFN.nc"
url_nerf = "/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc"
url_duacs = "/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_MIOST.nc"
url_miost = "/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_MIOST.nc"

In [None]:
from inr4ssh._src.preprocess.spatial import convert_lon_360_180, convert_lon_180_360

logger.info("Dataset II - BASELINE")

ds_predict = xr.open_dataset(url_duacs)
ds_predict = post_process(ds_predict, "ssh")
ds_field["ssh_oi"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset III - MIOST")

ds_predict = xr.open_dataset(url_miost)
ds_predict = post_process(ds_predict, "ssh")
ds_field["ssh_miost"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset IV - NerF")

ds_predict = xr.open_dataset(url_nerf)
ds_predict = post_process(ds_predict, "ssh")
ds_field["ssh_nerf"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

# ds_field

In [None]:
ds_field

In [None]:
# url = "/Users/eman/code_projects/logs/saved_data/test_alongtrack.nc"
# url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_4dvarNet_2022.nc"
# filename = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/ml_ready/test.nc"
filename = "/Volumes/EMANS_HDD/data/dc21b/test/dt_gulfstream_c2_phy_l3_20161201-20180131_285-315_23-53.nc"
ds_alongtrack = xr.open_dataset(filename, engine="netcdf4")

ds_alongtrack = correct_coordinate_labels(ds_alongtrack)
# ds_alongtrack["longitude"] = convert_lon_360_180(ds_alongtrack.longitude)
ds_alongtrack["longitude"] = convert_lon_360_180(ds_alongtrack.longitude)

ds_alongtrack["ssh"] = (
    ds_alongtrack["sla_unfiltered"] + ds_alongtrack["mdt"] - ds_alongtrack["lwe"]
)
ds_alongtrack

In [None]:
# from tqdm.notebook import tqdm
# variables = [
#     "ssh",
#     "ssh_oi",
#     "ssh_miost",
#     "ssh_nerf"
# ]
#
# logger.info("Calculating Kinetic Energy...")
# for ivar in tqdm(variables):
#     ds_field[f"{ivar}_ke"] = calculate_gradient(ds_field[ivar], "longitude", "latitude")
#
# logger.info("Calculating Enstropy...")
# for ivar in tqdm(variables):
#     ds_field[f"{ivar}_ens"] = 0.5 * calculate_laplacian(ds_field[ivar], "longitude", "latitude")**2

## Metrics - Isotropic PSD

In [None]:
from inr4ssh._src.metrics.psd import compute_psd_scores, select_track_segments
from inr4ssh._src.interp import interp_on_alongtrack
from inr4ssh._src.preprocess.spatial import convert_lon_360_180, convert_lon_180_360

In [None]:
# ds_field["longitude"] = convert_lon_180_360(ds_field.longitude)
ds_alongtrack["longitude"] = convert_lon_180_360(ds_alongtrack.longitude)

In [None]:
ds_alongtrack

        time=slice(
            np.datetime64("2017-02-01"), np.datetime64("2017-03-31")
        ),
        longitude=slice(-75.0, -45.0),
        latitude=slice(33.0, 53.0)

In [None]:
from tqdm.notebook import tqdm

urls = [
    "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_DUACS.nc",
    # "/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc",
    "/Users/eman/code_projects/logs/saved_data/test_dc21b_feb_pretrain.nc",
    "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_BASELINE.nc",
    "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_MIOST.nc",
]
names = ["duacs", "nerf", "oi", "miost"]

psd_metrics = {}

for iurl, iname in tqdm(list(zip(urls, names))):
    ds_predict = xr.open_dataset(iurl, engine="netcdf4")

    ds_predict = post_process(ds_predict, "ssh")

    alongtracks = interp_on_alongtrack(
        gridded_dataset=ds_predict,
        ds_alongtrack=ds_alongtrack,
        lon_min=295,  # convert_lon_180_360(-75.0),
        lon_max=305,  # convert_lon_180_360(-45.0),
        lat_min=33.0,
        lat_max=43.0,
        time_min="2017-01-01",
        time_max="2017-12-31",
        variable="ssh",
    )

    tracks = select_track_segments(
        time_alongtrack=alongtracks.time,
        lat_alongtrack=alongtracks.lat,
        lon_alongtrack=convert_lon_360_180(alongtracks.lon),
        ssh_alongtrack=alongtracks.ssh_alongtrack,
        ssh_map_interp=alongtracks.ssh_map,
    )

    psd_metrics[iname] = compute_psd_scores(
        ssh_true=tracks.ssh_alongtrack,
        ssh_pred=tracks.ssh_map,
        delta_x=6.77 * 0.9434,
        npt=tracks.npt,
        scaling="density",
        noverlap=0,
    )

In [None]:
import seaborn as sns

sns.set_context(context="talk", font_scale=0.7)

fig, ax, secax = plot_psd_isotropic(
    psd_metrics["duacs"].wavenumber,
    psd_metrics["duacs"].psd_ref,
    color="black",
    linestyle="-",
)

# ax.plot(
#     psd_metrics["oi"].wavenumber,
#     psd_metrics["oi"].psd_study,
#     color="black",
#     linestyle="--",
# )

ax.plot(
    psd_metrics["duacs"].wavenumber,
    psd_metrics["duacs"].psd_study,
    color="tab:green",
    linestyle="-",
)

ax.plot(
    psd_metrics["miost"].wavenumber,
    psd_metrics["miost"].psd_study,
    color="tab:blue",
    linestyle="-",
)
ax.plot(
    psd_metrics["nerf"].wavenumber,
    psd_metrics["nerf"].psd_study,
    color="tab:red",
    linestyle="-",
)

ax.set_xlim((10e-4, 10e-2))
plt.legend(["Reference", "DUACS", "MIOST", "NerF"])
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/dc21a_psd_iso_alongtrack.png"))
plt.show()

In [None]:
from inr4ssh._src.viz.psd.psd import plot_psd_score

In [None]:
# plot_psd_score(
#     psd_metrics["miost"].psd_diff,
#     psd_metrics["miost"].psd_ref,
#     psd_metrics["miost"].wavenumber,
#     psd_metrics["miost"].resolved_scale
# )

In [None]:
# plot_psd_score(
#     psd_metrics["nerf"].psd_diff,
#     psd_metrics["nerf"].psd_ref,
#     psd_metrics["nerf"].wavenumber,
#     psd_metrics["nerf"].resolved_scale
# )

In [None]:
# plot_psd_score(
#     psd_metrics["oi"].psd_diff,
#     psd_metrics["oi"].psd_ref,
#     psd_metrics["oi"].wavenumber,
#     psd_metrics["oi"].resolved_scale
# )

In [None]:
variables = ["duacs", "miost", "nerf"]
colours = ["tab:green", "tab:blue", "tab:red"]
linestyle = ["-", "-", "-", "-"]


ax = None

for ivariable, icolour, ilinestyle in tqdm(list(zip(variables, colours, linestyle))):
    x = psd_metrics[ivariable].wavenumber
    y = 1.0 - (psd_metrics[ivariable].psd_diff / psd_metrics[ivariable].psd_ref)
    fig, ax, secax = plot_psd_isotropic(
        x,
        y,
        ax=ax,
        color=icolour,
        linestyle="-",
    )

    ax.set(ylabel="PSD Score", yscale="linear")
    plt.ylim((0, 1.0))
    plt.xlim((10e-4, 10e-2))

    print(
        f"Shortest Spatial Wavelength Resolved [{ivariable}] = {psd_metrics[ivariable].resolved_scale:.2f} (km)"
    )

    # plot the resolved point
    resolved_scale = 1 / psd_metrics[ivariable].resolved_scale

    ax.vlines(
        x=resolved_scale, ymin=0, ymax=0.5, color=icolour, linewidth=2, linestyle="--"
    )
    ax.hlines(
        y=0.5,
        xmin=10e-4,
        xmax=resolved_scale,
        color=icolour,
        linewidth=2,
        linestyle="--",
    )

    label = f"{ivariable.upper()} - $\lambda$ > {int(psd_metrics[ivariable].resolved_scale):.0f} km"
    # idx = np.argwhere(np.diff(np.sign(x - y))-0.5).flatten()
    # ax.plot(x[idx], y[idx])
    plt.scatter(
        resolved_scale,
        0.5,
        color=icolour,
        marker=".",
        linewidth=5,
        label=label,
        zorder=5,
    )
plt.legend()
plt.tight_layout()
fig.savefig(Path(root).joinpath("figures/dc21a_psd_iso_score_alongtrack.png"))
plt.show()