# Analysis I - Pixel-Wise Stats

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels
from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM_*.nc"


def preprocess(ds):

    # subset time
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    return ds


ds_field = xr.open_mfdataset(url, preprocess=None)

# ds_field = ds_field.sel(
#     time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02"))
# )

# ds_field = (
#     ds_field.rename({"lon": "longitude"})
#     .rename({"lat": "latitude"})
#     .rename({"sossheig": "ssh"})
# )

# ds_field = ds_field.resample(time="1D").mean()

# # ds_field = correct_coordinate_labels(ds_field)

ds_field

### Example Results

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid


def post_process(ds, ds_ref, variable_ref, variable_pred):

    # subset temporal space
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))
    ds_ref = ds_ref.sel(
        time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02"))
    )

    # resample the reference grid
    # TODO: make this optional
    ds_ref = ds_ref.resample(time="1D").mean()

    # correct coordinate labels
    ds = correct_coordinate_labels(ds)
    ds_ref = correct_coordinate_labels(ds_ref)

    # correct labels
    ds_ref = ds_ref.rename({variable_ref: "ssh"})
    ds = ds.rename({variable_pred: "ssh"})

    # correct longitude domain
    ds = correct_longitude_domain(ds)
    ds_ref = correct_longitude_domain(ds_ref)

    # regrid data
    ds_ref["ssh_predict"] = oi_regrid(ds["ssh"], ds_ref["ssh"])

    return ds_ref

In [None]:
ds_field = post_process(ds_predict, ds_field, "sossheig", "gssh")

In [None]:
ds_field

## Metrics - Isotropic PSD

In [None]:
def lonlat2dxdy(lon, lat):
    dlon = np.gradient(lon)
    dlat = np.gradient(lat)
    dx = np.sqrt(
        (dlon[1] * 111000 * np.cos(np.deg2rad(lat))) ** 2 + (dlat[1] * 111000) ** 2
    )
    dy = np.sqrt(
        (dlon[0] * 111000 * np.cos(np.deg2rad(lat))) ** 2 + (dlat[0] * 111000) ** 2
    )
    dx[0, :] = dx[1, :]
    dx[-1, :] = dx[-2, :]
    dx[:, 0] = dx[:, 1]
    dx[:, -1] = dx[:, -2]
    dy[0, :] = dy[1, :]
    dy[-1, :] = dy[-2, :]
    dy[:, 0] = dy[:, 1]
    dy[:, -1] = dy[:, -2]

    return dx, dy

In [None]:
lon.shape, lat.shape, lon_grid.shape, lat_grid.shape, dx.shape, dy.shape

In [None]:
from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian

url = "/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_BASELINE.nc"
ds_baseline = xr.open_dataset(url)
ds_baseline = correct_coordinate_labels(ds_baseline)

lon, lat = ds_baseline.longitude, ds_baseline.latitude
lon_grid, lat_grid = np.meshgrid(lon, lat)

dx, dy = lonlat2dxdy(lon_grid, lat_grid)

ds_baseline["longitude"] = dy[0, :]
ds_baseline["latitude"] = dx[:, 0]

ds_baseline = ds_baseline.set_coords(["longitude", "latitude"])


# ds_baseline["ssh_grad"] = calculate_gradient(ds_baseline["ssh"], "longitude", "latitude")
# ds_baseline["ssh_lap"] = calculate_laplacian(ds_baseline["ssh"], "longitude", "latitude")

ds_baseline_psd = psd_isotropic(ds_baseline.ssh)

In [None]:
from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian

url = "/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_BASELINE.nc"
ds_baseline = xr.open_dataset(url)
ds_baseline = correct_coordinate_labels(ds_baseline)

ds_baseline["longitude"] = ds_baseline.longitude * 111e3
ds_baseline["latitude"] = ds_baseline.latitude * 111e3

ds_baseline["ssh_ke"] = calculate_gradient(ds_baseline["ssh"], "longitude", "latitude")
ds_baseline["ssh_enstrophy"] = calculate_laplacian(
    ds_baseline["ssh"], "longitude", "latitude"
)

ds_baseline_psd = psd_isotropic(ds_baseline.ssh)

In [None]:
# ds_baseline

In [None]:
url = "/Volumes/EMANS_HDD/data/dc21b/results/OSE_ssh_mapping_DUACS.nc"
ds_duacs = xr.open_dataset(url)
ds_duacs = correct_coordinate_labels(ds_duacs)

ds_duacs["longitude"] = ds_duacs.longitude * 111e3
ds_duacs["latitude"] = ds_duacs.latitude * 111e3

ds_duacs["ssh_ke"] = calculate_gradient(ds_duacs["ssh"], "longitude", "latitude")
ds_duacs["ssh_enstrophy"] = calculate_laplacian(
    ds_duacs["ssh"], "longitude", "latitude"
)

ds_duacs_psd = psd_isotropic(ds_duacs.ssh)

In [None]:
url = "/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc"
ds_siren = xr.open_dataset(url)
ds_siren = correct_coordinate_labels(ds_siren)

ds_siren["longitude"] = ds_siren.longitude * 111e3
ds_siren["latitude"] = ds_siren.latitude * 111e3

ds_siren["ssh_ke"] = calculate_gradient(ds_siren["ssh"], "longitude", "latitude")
ds_siren["ssh_enstrophy"] = calculate_laplacian(
    ds_siren["ssh"], "longitude", "latitude"
)

ds_siren_psd = psd_isotropic(ds_siren.ssh)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_duacs_psd.freq_r.values * 1e3, ds_duacs_psd.values, color="black"
)

ax.plot(ds_baseline_psd.freq_r.values * 1e3, ds_baseline_psd.values, color="red")
ax.plot(ds_siren_psd.freq_r.values * 1e3, ds_siren_psd.values, color="blue")
plt.xlim(
    (
        min(
            np.ma.min(np.ma.masked_invalid(ds_duacs_psd.freq_r.values * 1e3)),
            np.ma.min(np.ma.masked_invalid(ds_baseline_psd.freq_r.values * 1e3)),
            np.ma.min(np.ma.masked_invalid(ds_siren_psd.freq_r.values * 1e3)),
        ),
        max(
            np.ma.max(np.ma.masked_invalid(ds_duacs_psd.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(ds_baseline_psd.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(ds_siren_psd.freq_r.values * 1e3)),
        ),
    )
)
plt.legend(["DUACS", "Naive OI", "SIREN"])
plt.tight_layout()
plt.show()

In [None]:
ds_baseline_psd = psd_isotropic(ds_baseline.ssh_ke)
ds_duacs_psd = psd_isotropic(ds_duacs.ssh_ke)
ds_siren_psd = psd_isotropic(ds_siren.ssh_ke)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_duacs_psd.freq_r.values * 1e3, ds_duacs_psd.values, color="black"
)

ax.plot(ds_baseline_psd.freq_r.values * 1e3, ds_baseline_psd.values, color="red")
ax.plot(ds_siren_psd.freq_r.values * 1e3, ds_siren_psd.values, color="blue")
plt.xlim(
    (
        min(
            np.ma.min(np.ma.masked_invalid(ds_duacs_psd.freq_r.values * 1e3)),
            np.ma.min(np.ma.masked_invalid(ds_baseline_psd.freq_r.values * 1e3)),
            np.ma.min(np.ma.masked_invalid(ds_siren_psd.freq_r.values * 1e3)),
        ),
        max(
            np.ma.max(np.ma.masked_invalid(ds_duacs_psd.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(ds_baseline_psd.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(ds_siren_psd.freq_r.values * 1e3)),
        ),
    )
)
ax.set_ylabel(r"Kinetic Energy [m$^2$s$^{-2}$/cycles/m]")
plt.legend(["DUACS", "Naive OI", "SIREN"])
plt.tight_layout()
plt.show()

In [None]:
ds_baseline_psd = psd_isotropic(ds_baseline.ssh_enstrophy)
ds_duacs_psd = psd_isotropic(ds_duacs.ssh_enstrophy)
ds_siren_psd = psd_isotropic(ds_siren.ssh_enstrophy)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_duacs_psd.freq_r.values * 1e3, ds_duacs_psd.values, color="black"
)

ax.plot(ds_baseline_psd.freq_r.values * 1e3, ds_baseline_psd.values, color="red")
ax.plot(ds_siren_psd.freq_r.values * 1e3, ds_siren_psd.values, color="blue")
plt.xlim(
    (
        min(
            np.ma.min(np.ma.masked_invalid(ds_duacs_psd.freq_r.values * 1e3)),
            np.ma.min(np.ma.masked_invalid(ds_baseline_psd.freq_r.values * 1e3)),
            np.ma.min(np.ma.masked_invalid(ds_siren_psd.freq_r.values * 1e3)),
        ),
        max(
            np.ma.max(np.ma.masked_invalid(ds_duacs_psd.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(ds_baseline_psd.freq_r.values * 1e3)),
            np.ma.max(np.ma.masked_invalid(ds_siren_psd.freq_r.values * 1e3)),
        ),
    )
)
ax.set_ylabel(r"Vorticity [s$^{-1}$/cycles/m]")
plt.legend(["DUACS", "Naive OI", "SIREN"])
plt.tight_layout()
plt.show()

In [None]:
# grab ssh
ds_field_psd = ds_field.ssh
ds_predict_psd = ds_field.ssh_predict

# calculate
ds_field_psd = psd_isotropic(ds_field_psd)
ds_predict_psd = psd_isotropic(ds_predict_psd)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_field_psd.freq_r.values, ds_field_psd.values, color="black"
)

ax.plot(ds_predict_psd.freq_r.values, ds_predict_psd.values, color="red")
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(ds_predict_psd.freq_r.values)),
        np.ma.max(np.ma.masked_invalid(ds_predict_psd.freq_r.values)),
    )
)
ax.set_xlabel("Wavenumber [cyles/degrees]")
secax.set_xlabel("Wavelength [degrees]")
plt.legend(["Reference", "DUACS"])
plt.tight_layout()
plt.show()

In [None]:
from inr4ssh._src.metrics.psd import psd_isotropic_score, wavelength_resolved_isotropic

In [None]:
from inr4ssh._src.metrics.psd import psd_isotropic_score, wavelength_resolved_isotropic

# grab ssh
ds_field_psd = ds_field.ssh
ds_predict_psd = ds_field.ssh_predict

psd_iso_score = psd_isotropic_score(ds_predict_psd, ds_field_psd)

In [None]:
space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
print(f"Shortest Spatial Wavelength Resolved = {space_iso_resolved:.2f} (degrees)")

In [None]:
fig, ax, secax = plot_psd_isotropic(
    psd_iso_score.freq_r.values, psd_iso_score.values, color="black"
)

ax.set(ylabel="PSD Score", yscale="linear")
plt.ylim((0, 1.0))
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values)),
        np.ma.max(np.ma.masked_invalid(psd_iso_score.freq_r.values)),
    )
)

# plot the graph point
resolved_scale = 1 / (space_iso_resolved)
ax.vlines(
    x=resolved_scale, ymin=0, ymax=0.5, color="green", linewidth=2, linestyle="--"
)
ax.hlines(
    y=0.5,
    xmin=np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values)),
    xmax=resolved_scale,
    color="green",
    linewidth=2,
    linestyle="--",
)


label = f"Resolved Scales \n $\lambda$ > {space_iso_resolved:.2f} degrees"
plt.scatter(resolved_scale, 0.5, color="green", marker=".", linewidth=5, label=label)
ax.set_xlabel("Wavenumber [cyles/degrees]")
secax.set_xlabel("Wavelength [degrees]")
plt.legend()
plt.tight_layout()
plt.show()

#### Isotropic PSD (Kilometers)

In [None]:
# grab ssh
ds_field_psd = ds_field.ssh
ds_predict_psd = ds_field.ssh_predict

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3
ds_predict_psd["longitude"] = ds_predict_psd.longitude * 111e3
ds_predict_psd["latitude"] = ds_predict_psd.latitude * 111e3

# calculate
ds_field_psd = psd_isotropic(ds_field_psd)
ds_predict_psd = psd_isotropic(ds_predict_psd)

In [None]:
fig, ax, secax = plot_psd_isotropic(
    ds_field_psd.freq_r.values * 1e3, ds_field_psd.values, color="black"
)

ax.plot(ds_predict_psd.freq_r.values * 1e3, ds_predict_psd.values, color="red")
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(ds_predict_psd.freq_r.values * 1e3)),
        np.ma.max(np.ma.masked_invalid(ds_predict_psd.freq_r.values * 1e3)),
    )
)
plt.legend(["Reference", "DUACS"])
plt.tight_layout()
plt.show()

In [None]:
# grab ssh
ds_field_psd = ds_field.ssh
ds_predict_psd = ds_field.ssh_predict

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3
ds_predict_psd["longitude"] = ds_predict_psd.longitude * 111e3
ds_predict_psd["latitude"] = ds_predict_psd.latitude * 111e3

psd_iso_score = psd_isotropic_score(ds_predict_psd, ds_field_psd)

In [None]:
space_iso_resolved = wavelength_resolved_isotropic(psd_iso_score, level=0.5)
print(f"Shortest Spatial Wavelength Resolved = {space_iso_resolved/1e3:.2f} (km)")

In [None]:
fig, ax, secax = plot_psd_isotropic(
    psd_iso_score.freq_r.values * 1e3, psd_iso_score.values, color="black"
)

ax.set(ylabel="PSD Score", yscale="linear")
plt.ylim((0, 1.0))
plt.xlim(
    (
        np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
        np.ma.max(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
    )
)

# plot the graph point
resolved_scale = 1 / (space_iso_resolved * 1e-3)
ax.vlines(
    x=resolved_scale, ymin=0, ymax=0.5, color="green", linewidth=2, linestyle="--"
)
ax.hlines(
    y=0.5,
    xmin=np.ma.min(np.ma.masked_invalid(psd_iso_score.freq_r.values * 1e3)),
    xmax=resolved_scale,
    color="green",
    linewidth=2,
    linestyle="--",
)


label = f"Resolved Scales \n $\lambda$ > {int(space_iso_resolved*1e-3)} km"
plt.scatter(resolved_scale, 0.5, color="green", marker=".", linewidth=5, label=label)
plt.legend()
plt.tight_layout()
plt.show()