# Analysis I - Pixel-Wise Stats

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels
from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM_*.nc"


def preprocess(ds):

    # subset time
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    return ds


ds_field = xr.open_mfdataset(url, preprocess=None)

# ds_field = ds_field.sel(
#     time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02"))
# )

# ds_field = (
#     ds_field.rename({"lon": "longitude"})
#     .rename({"lat": "latitude"})
#     .rename({"sossheig": "ssh"})
# )

# ds_field = ds_field.resample(time="1D").mean()

# # ds_field = correct_coordinate_labels(ds_field)

ds_field

### Example Results

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid


def post_process(ds, ds_ref, variable_ref, variable_pred):

    # subset temporal space
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))
    ds_ref = ds_ref.sel(
        time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02"))
    )

    # resample the reference grid
    # TODO: make this optional
    ds_ref = ds_ref.resample(time="1D").mean()

    # correct coordinate labels
    ds = correct_coordinate_labels(ds)
    ds_ref = correct_coordinate_labels(ds_ref)

    # correct labels
    ds_ref = ds_ref.rename({variable_ref: "ssh"})
    ds = ds.rename({variable_pred: "ssh"})

    # correct longitude domain
    ds = correct_longitude_domain(ds)
    ds_ref = correct_longitude_domain(ds_ref)

    # regrid data
    ds_ref["ssh_predict"] = oi_regrid(ds["ssh"], ds_ref["ssh"])

    return ds_ref

In [None]:
ds_field = post_process(ds_predict, ds_field, "sossheig", "gssh")

In [None]:
ds_field

## Metric I - RMSE

In [None]:
from inr4ssh._src.metrics.field.stats import nrmse_spacetime, rmse_space, nrmse_time

In [None]:
nrmse_xyt = nrmse_spacetime(ds_field["ssh_predict"], ds_field["ssh"]).values
print(f"Leaderboard SSH RMSE score =  {nrmse_xyt:.2f}")

#### Error Variability (Temporal)

In [None]:
rmse_t = nrmse_time(ds_field["ssh_predict"], ds_field["ssh"])
err_var_time = rmse_t.std().values
print(f"Error Variability =  {err_var_time:.2f}")

In [None]:
fig, ax = plt.subplots()

rmse_t.plot(ax=ax)

ax.set(xlabel="Time", ylabel="nRMSE")
ax.set_ylim((0, 1.0))
plt.tight_layout()
plt.show()

#### Error Variability (Spatial)

In [None]:
rmse_xy = rmse_space(ds_field["ssh_predict"], ds_field["ssh"])
err_var_space = rmse_xy.std().values
print(f"Error Variability (Spatial)=  {err_var_space:.2f}")

In [None]:
fig, ax = plt.subplots()

rmse_xy.T.plot.imshow(ax=ax)

plt.tight_layout()
plt.show()

## Multivariate Statistics

In [None]:
from hyppo.independence import RV
from hyppo.d_variate import dHsic
from hyppo.ksample import Energy

In [None]:
from tqdm.notebook import tqdm

times = []
stats = {
    "rv": list(),
    "rvd": list(),
    "hsic": list(),
    "energy": list(),
}

for idata in tqdm(ds_field.groupby("time")):

    # do statistic
    stats["rv"].append(
        RV().statistic(
            idata[1]["ssh"].values.flatten()[:, None],
            idata[1]["ssh_predict"].values.flatten()[:, None],
        )
    )

    # do statistic
    stats["rvd"].append(
        RV().statistic(idata[1]["ssh"].values, idata[1]["ssh_predict"].values)
    )
    # stats["energy"].append(
    #     Energy().statistic(idata[1]["ssh"].values, idata[1]["ssh_predict"].values)
    # )
    stats["hsic"].append(
        dHsic().statistic(idata[1]["ssh"].values, idata[1]["ssh_predict"].values)
        / (
            np.sqrt(dHsic().statistic(idata[1]["ssh"].values, idata[1]["ssh"].values))
            * np.sqrt(
                dHsic().statistic(
                    idata[1]["ssh_predict"].values, idata[1]["ssh_predict"].values
                )
            )
        )
    )
    times.append(idata[0])

In [None]:
fig, ax = plt.subplots()

ax.plot(times, stats["rv"], label="RV Coeff.")
ax.set_ylim((0.75, 1.0))
# ax.set_yscale("log")
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(times, stats["rvd"], label="RV Coeff.")
ax.set_ylim((0.75, 1.0))
# ax.set_yscale("log")
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(times, stats["hsic"], label="nHSIC")
ax.set_ylim((0.75, 1.0))
# ax.set_yscale("log")
plt.tight_layout()
plt.show()