# Analysis I - Pixel-Wise Stats

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels
from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM_*.nc"


def preprocess(ds):

    # subset time
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    return ds


ds_field = xr.open_mfdataset(url, preprocess=None)

# ds_field = ds_field.sel(
#     time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02"))
# )

# ds_field = (
#     ds_field.rename({"lon": "longitude"})
#     .rename({"lat": "latitude"})
#     .rename({"sossheig": "ssh"})
# )

# ds_field = ds_field.resample(time="1D").mean()

# # ds_field = correct_coordinate_labels(ds_field)

ds_field

### Example Results

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid


def post_process(ds, ds_ref, variable_ref, variable_pred):

    # subset temporal space
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))
    ds_ref = ds_ref.sel(
        time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02"))
    )

    # resample the reference grid
    # TODO: make this optional
    ds_ref = ds_ref.resample(time="1D").mean()

    # correct coordinate labels
    ds = correct_coordinate_labels(ds)
    ds_ref = correct_coordinate_labels(ds_ref)

    # correct labels
    ds_ref = ds_ref.rename({variable_ref: "ssh"})
    ds = ds.rename({variable_pred: "ssh"})

    # correct longitude domain
    ds = correct_longitude_domain(ds)
    ds_ref = correct_longitude_domain(ds_ref)

    # regrid data
    ds_ref["ssh_predict"] = oi_regrid(ds["ssh"], ds_ref["ssh"])

    return ds_ref

In [None]:
ds_field = post_process(ds_predict, ds_field, "sossheig", "gssh")

In [None]:
ds_field

## Metrics II - Space-Time PSD

#### Absolute Values

In [None]:
from inr4ssh._src.metrics.psd import (
    psd_isotropic_score,
    psd_spacetime_score,
    wavelength_resolved_spacetime,
    wavelength_resolved_isotropic,
)

In [None]:
time_norm = np.timedelta64(1, "D")
# mean psd of signal
ds_field["time"] = (ds_field.time - ds_field.time[0]) / time_norm

#### Degrees

In [None]:
# Time-Longitude (Lat avg) PSD Score
ds_field_ = ds_field.chunk(
    {
        "time": 1,
        "longitude": ds_field["longitude"].size,
        "latitude": ds_field["latitude"].size,
    }
).compute()

ds_field_psd = psd_spacetime(ds_field_["ssh"])
ds_predict_psd = psd_spacetime(ds_field_["ssh_predict"])

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_field_psd.freq_longitude,
    ds_field_psd.freq_time,
    ds_field_psd,
)
# ax.set_xlim((1000, 10))
ax.set_xlabel("Wavelength [degrees]")
# cbat.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/degree]")
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/degree]")
# cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

plt.tight_layout()
plt.show()

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_predict_psd.freq_longitude,
    ds_predict_psd.freq_time,
    ds_predict_psd,
)
# ax.set_xlim((1000, 10))
ax.set_xlabel("Wavelength [degrees]")
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/degree]")
# cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

plt.tight_layout()
plt.show()

### PSD Score

In [None]:
psd_score = psd_spacetime_score(ds_field_["ssh_predict"], ds_field_["ssh"])

In [None]:
fig, ax, cbar = plot_psd_spacetime_score_wavelength(
    psd_score.freq_longitude,
    psd_score.freq_time,
    psd_score,
)

ax.set_xlabel("Wavelength [degrees]")

plt.tight_layout()
plt.show()

In [None]:
fig, ax, cbar = plot_psd_spacetime_score_wavenumber(
    psd_score.freq_longitude,
    psd_score.freq_time,
    psd_score,
)

ax.set_xlabel("Wavenumber [cycles/degrees]")

plt.tight_layout()
plt.show()

In [None]:
spatial_resolved, time_resolved = wavelength_resolved_spacetime(psd_score)

In [None]:
print(f"Shortest Spatial Wavelength Resolved = {spatial_resolved:.2f} (degree lon)")
print(f"Shortest Temporal Wavelength Resolved = {time_resolved:.2f} (days)")

#### Kilometers

In [None]:
# grab ssh
ds_field_psd = ds_field.ssh
ds_predict_psd = ds_field.ssh_predict

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3
ds_predict_psd["longitude"] = ds_predict_psd.longitude * 111e3
ds_predict_psd["latitude"] = ds_predict_psd.latitude * 111e3


# Time-Longitude (Lat avg) PSD Score
ds_field_psd = ds_field_psd.chunk(
    {
        "time": 1,
        "longitude": ds_field_psd["longitude"].size,
        "latitude": ds_field_psd["latitude"].size,
    }
).compute()
ds_predict_psd = ds_predict_psd.chunk(
    {
        "time": 1,
        "longitude": ds_predict_psd["longitude"].size,
        "latitude": ds_predict_psd["latitude"].size,
    }
).compute()

In [None]:
ds_field_psd = psd_spacetime(ds_field_psd)
ds_predict_psd = psd_spacetime(ds_predict_psd)

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_field_psd.freq_longitude * 1e3,
    ds_field_psd.freq_time,
    ds_field_psd,
)
# ax.set_xlim((1000, 10))
ax.set_xlabel("Wavelength [km]")
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

plt.tight_layout()
plt.show()

In [None]:
fig, ax, cbar = plot_psd_spacetime_wavelength(
    ds_predict_psd.freq_longitude * 1e3,
    ds_predict_psd.freq_time,
    ds_predict_psd,
)
# ax.set_xlim((1000, 10))
ax.set_xlabel("Wavelength [km]")
# cbar.ax.set_ylabel(r"PSD [m$^2$s$^{-2}$/cyles/m")
# cbar.ax.set_ylabel(r"PSD [s$^{-1}$/cyles/m")

plt.tight_layout()
plt.show()

### PSD Score

In [None]:
# grab ssh
ds_field_psd = ds_field.ssh
ds_predict_psd = ds_field.ssh_predict

# correct units, degrees -> meters
ds_field_psd["longitude"] = ds_field_psd.longitude * 111e3
ds_field_psd["latitude"] = ds_field_psd.latitude * 111e3
ds_predict_psd["longitude"] = ds_predict_psd.longitude * 111e3
ds_predict_psd["latitude"] = ds_predict_psd.latitude * 111e3

# Time-Longitude (Lat avg) PSD Score
ds_field_psd = ds_field_psd.chunk(
    {
        "time": 1,
        "longitude": ds_field_psd["longitude"].size,
        "latitude": ds_field_psd["latitude"].size,
    }
).compute()
ds_predict_psd = ds_predict_psd.chunk(
    {
        "time": 1,
        "longitude": ds_predict_psd["longitude"].size,
        "latitude": ds_predict_psd["latitude"].size,
    }
).compute()


psd_score = psd_spacetime_score(ds_predict_psd, ds_field_psd)

In [None]:
fig, ax, cbar = plot_psd_spacetime_score_wavenumber(
    psd_score.freq_longitude * 1e3,
    psd_score.freq_time,
    psd_score,
)

plt.tight_layout()
plt.show()

In [None]:
fig, ax, cbar = plot_psd_spacetime_score_wavelength(
    psd_score.freq_longitude * 1e3,
    psd_score.freq_time,
    psd_score,
)

plt.tight_layout()
plt.show()

In [None]:
spatial_resolved, time_resolved = wavelength_resolved_spacetime(psd_score)

In [None]:
print(f"Shortest Spatial Wavelength Resolved = {spatial_resolved*1e-3:.2f} (km lon)")
print(f"Shortest Temporal Wavelength Resolved = {time_resolved:.2f} (days)")