# Figure - RMSE Stats

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels

# from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)

from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM_*.nc"


def preprocess(ds):
    # subset time
    # ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    # ds = ds.coarsen({"lon": 5, "lat": 5}).mean()

    ds = ds.resample(time="1D").mean()
    return ds


ds_field = xr.open_mfdataset(url, preprocess=preprocess)

ds_field

### Example Results

In [None]:
def post_process(ds, variable):
    # subset temporal space
    ds = ds.sel(time=slice(np.datetime64("2012-10-22"), np.datetime64("2012-12-02")))

    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    ds = correct_longitude_domain(ds)

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
ds_field = post_process(ds_field, "sossheig")
ds_field

```bash
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_MIOST_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_DUACS_swot_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_en_j1_tpn_g2.nc
!wget https://ige-meom-opendap.univ-grenoble-alpes.fr/thredds/fileServer/meomopendap/extract/ocean-data-challenges/dc_data1/dc_mapping/2020a_SSH_mapping_NATL60_BFN_Steady_State_QG1L_swot_en_j1_tpn_g2.nc
```

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/

In [None]:
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict

## PostProcessing

### Cleaning

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid

logger.info("Dataset I - MIOST")
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "gssh")

ds_field["ssh_miost"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset II - DUACS")
url = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
ds_predict = xr.open_dataset(url)

ds_predict = post_process(ds_predict, "gssh")

ds_field["ssh_duacs"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

logger.info("Dataset II - SIREN")
# url = "/Users/eman/code_projects/logs/saved_data/test_res_nadir4_lb.nc"
url = "/Users/eman/code_projects/logs/saved_data/test_res_nadir4_jz.nc"
ds_predict = xr.open_dataset(url)


ds_predict = post_process(ds_predict, "ssh_model_predict")

ds_field["ssh_nerf"] = oi_regrid(ds_predict["ssh"], ds_field["ssh"])

## Metric I - RMSE

In [None]:
from inr4ssh._src.metrics.field.stats import nrmse_spacetime, rmse_space, nrmse_time

In [None]:
nrmse_xyt = nrmse_spacetime(ds_field["ssh"], ds_field["ssh_duacs"]).values
print(f"Leaderboard SSH RMSE score [DUACS] =  {nrmse_xyt:.2f}")
nrmse_xyt = nrmse_spacetime(ds_field["ssh"], ds_field["ssh_miost"]).values
print(f"Leaderboard SSH RMSE score [MIOST] =  {nrmse_xyt:.2f}")
nrmse_xyt = nrmse_spacetime(ds_field["ssh"], ds_field["ssh_nerf"]).values
print(f"Leaderboard SSH RMSE score [NerF] =  {nrmse_xyt:.2f}")

#### Error Variability (Temporal)

In [None]:
rmse_t = {}
rmse_t["duacs"] = nrmse_time(ds_field["ssh_duacs"], ds_field["ssh"])
err_var_time = rmse_t["duacs"].std().values
print(f"Error Variability [DUACS] =  {err_var_time:.2f}")
rmse_t["miost"] = nrmse_time(ds_field["ssh_miost"], ds_field["ssh"])
err_var_time = rmse_t["miost"].std().values
print(f"Error Variability [MIOST] =  {err_var_time:.2f}")
rmse_t["nerf"] = nrmse_time(ds_field["ssh_nerf"], ds_field["ssh"])
err_var_time = rmse_t["nerf"].std().values
print(f"Error Variability [NerF] =  {err_var_time:.2f}")

In [None]:
fig, ax = plt.subplots()


rmse_t["duacs"].plot(ax=ax, label="DUACS", color="tab:green")
rmse_t["miost"].plot(ax=ax, label="MIOST", color="tab:blue")
rmse_t["nerf"].plot(ax=ax, label="NerF", color="tab:red")

ax.set(xlabel="Time", ylabel="nRMSE")
ax.set_ylim((0.75, 1.0))
plt.legend()
plt.tight_layout()
fig.savefig(Path(root).joinpath(f"figures/osse_2020a_stats_nrsme.png"))
plt.show()

#### Error Variability (Spatial)

In [None]:
rmse_xy = rmse_space(ds_field["ssh_predict"], ds_field["ssh"])
err_var_space = rmse_xy.std().values
print(f"Error Variability (Spatial)=  {err_var_space:.2f}")

In [None]:
fig, ax = plt.subplots()

rmse_xy.T.plot.imshow(ax=ax)

plt.tight_layout()
plt.show()

## Multivariate Statistics

In [None]:
from hyppo.independence import RV
from hyppo.d_variate import dHsic
from hyppo.ksample import Energy

In [None]:
from tqdm.notebook import tqdm

times = []
stats_dict = {
    "rv": list(),
    "rvd": list(),
    "hsic": list(),
    "energy": list(),
}
stats = {
    "duacs": {
        "rv": list(),
        "rvd": list(),
        "hsic": list(),
        "energy": list(),
    },
    "miost": {
        "rv": list(),
        "rvd": list(),
        "hsic": list(),
        "energy": list(),
    },
    "nerf": {
        "rv": list(),
        "rvd": list(),
        "hsic": list(),
        "energy": list(),
    },
}

with tqdm(ds_field.groupby("time")) as pbar:
    for idata in pbar:
        # do statistic
        pbar.set_description(f"RV Coeff [DUACS]...")
        stats["duacs"]["rv"].append(
            RV().statistic(
                idata[1]["ssh"].values.flatten()[:, None],
                idata[1]["ssh_duacs"].values.flatten()[:, None],
            )
        )
        pbar.set_description(f"RV Coeff [DUACS]...")
        stats["miost"]["rv"].append(
            RV().statistic(
                idata[1]["ssh"].values.flatten()[:, None],
                idata[1]["ssh_nerf"].values.flatten()[:, None],
            )
        )
        pbar.set_description(f"RV Coeff [NerF]...")
        stats["nerf"]["rv"].append(
            RV().statistic(
                idata[1]["ssh"].values.flatten()[:, None],
                idata[1]["ssh_nerf"].values.flatten()[:, None],
            )
        )

        # do statistic
        pbar.set_description(f"Spatial RV Coeff [DUACS]...")
        stats["duacs"]["rvd"].append(
            RV().statistic(idata[1]["ssh"].values, idata[1]["ssh_duacs"].values)
        )
        pbar.set_description(f"Spatial RV Coeff [MIOST]...")
        stats["miost"]["rvd"].append(
            RV().statistic(idata[1]["ssh"].values, idata[1]["ssh_miost"].values)
        )
        pbar.set_description(f"Spatial RV Coeff [NerF]...")
        stats["nerf"]["rvd"].append(
            RV().statistic(idata[1]["ssh"].values, idata[1]["ssh_nerf"].values)
        )
        # stats["energy"].append(
        #     Energy().statistic(idata[1]["ssh"].values, idata[1]["ssh_predict"].values)
        # )
        pbar.set_description(f"nHSIC [DUACS]...")
        stats["duacs"]["hsic"].append(
            dHsic().statistic(idata[1]["ssh"].values, idata[1]["ssh_duacs"].values)
            / (
                np.sqrt(
                    dHsic().statistic(idata[1]["ssh"].values, idata[1]["ssh"].values)
                )
                * np.sqrt(
                    dHsic().statistic(
                        idata[1]["ssh_duacs"].values, idata[1]["ssh_duacs"].values
                    )
                )
            )
        )
        pbar.set_description(f"nHSIC [MIOST]...")
        stats["miost"]["hsic"].append(
            dHsic().statistic(idata[1]["ssh"].values, idata[1]["ssh_miost"].values)
            / (
                np.sqrt(
                    dHsic().statistic(idata[1]["ssh"].values, idata[1]["ssh"].values)
                )
                * np.sqrt(
                    dHsic().statistic(
                        idata[1]["ssh_miost"].values, idata[1]["ssh_miost"].values
                    )
                )
            )
        )
        pbar.set_description(f"nHSIC [NerF]...")
        stats["nerf"]["hsic"].append(
            dHsic().statistic(idata[1]["ssh"].values, idata[1]["ssh_nerf"].values)
            / (
                np.sqrt(
                    dHsic().statistic(idata[1]["ssh"].values, idata[1]["ssh"].values)
                )
                * np.sqrt(
                    dHsic().statistic(
                        idata[1]["ssh_nerf"].values, idata[1]["ssh_nerf"].values
                    )
                )
            )
        )
        times.append(idata[0])

In [None]:
fig, ax = plt.subplots()

ax.plot(ds_field.time.values, stats["duacs"]["rv"], label="DUACS", color="tab:green")
ax.plot(ds_field.time.values, stats["miost"]["rv"], label="MIOST", color="tab:blue")
ax.plot(ds_field.time.values, stats["nerf"]["rv"], label="NerF", color="tab:red")

ax.set_ylim((0.75, 1.0))
plt.xticks(rotation=45)
# ax.set_yscale("log")
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(times, stats["duacs"]["rvd"], label="DUACS", color="tab:green")
ax.plot(times, stats["miost"]["rvd"], label="MIOST", color="tab:blue")
ax.plot(times, stats["nerf"]["rvd"], label="NerF", color="tab:red")
ax.set_ylim((0.90, 1.0))
plt.xticks(rotation=45)
# ax.set_yscale("log")
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots()

ax.plot(times, stats["duacs"]["hsic"], label="DUACS", color="tab:green")
ax.plot(times, stats["miost"]["hsic"], label="MIOST", color="tab:blue")
ax.plot(times, stats["nerf"]["hsic"], label="NerF", color="tab:red")
ax.set_ylim((0.75, 1.1))
plt.xticks(rotation=45)
# ax.set_yscale("log")
plt.tight_layout()
plt.show()