# Figure - SpatioTemporal PSD

In [None]:
import sys, os
from pathlib import Path

import ml_collections
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])
exp = Path(root).joinpath("experiments/dc20a")

# append to path
sys.path.append(str(root))
sys.path.append(str(exp))

In [None]:
root, exp

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_coordinate_labels

# from inr4ssh._src.preprocess.obs import bin_observations
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)

from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

### Evaluation Field

**Cleaning**

* Evaluation Period
* Lat/Lon Labels
* Longitude Range
* Regridding

In [None]:
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/
!ls /Volumes/EMANS_HDD/data/dc20a_osse/results/swot1nadir5/
!ls /Users/eman/code_projects/logs/saved_data/

In [None]:
from ml_collections import config_dict

config = config_dict.ConfigDict()

# reference stuff
config.reference = reference = config_dict.ConfigDict()
reference.path = (
    "/Volumes/EMANS_HDD/data/dc20a_osse/raw/dc_ref/NATL60-CJM165_GULFSTREAM_*.nc"
)
reference.var_name = "sossheig"
reference.model_name = "natl60"

# model stuff
config.study = study = config_dict.ConfigDict()
study.path = "/Users/eman/code_projects/logs/saved_data/test_res_nadir4_jz_v40.nc"
study.var_name = "ssh_model_predict"
study.model_name = "siren"

config_study = config_dict.ConfigDict()
config_study.miost = miost = config_dict.ConfigDict()
miost.path = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_MIOST_en_j1_tpn_g2.nc"
miost.var_name = "gssh"
miost.model_name = "miost"
config_study.duacs = duacs = config_dict.ConfigDict()
duacs.path = "/Volumes/EMANS_HDD/data/dc20a_osse/results/nadir4/2020a_SSH_mapping_NATL60_DUACS_en_j1_tpn_g2.nc"
duacs.var_name = "gssh"
duacs.model_name = "duacs"

# data stuff
config.psd = psd = config_dict.ConfigDict()
psd.factor_time = [1, "D"]
psd.factor_space = 1
psd.units = "degrees"

# Figures stuff
config.figure = figure = config_dict.ConfigDict()
figure.save_path = Path(root).joinpath("figures/dc20a/")
figure.save_name = study.model_name

In [None]:
from evaluation import Evaluation


eval_obj = Evaluation(config=config)

eval_obj.load_reference()
eval_obj.add_model()
eval_obj.add_model(config=miost)
eval_obj.add_model(config=duacs)
eval_obj.add_ke()
eval_obj.add_rv()

eval_obj.ds_field

In [None]:
eval_obj.models, eval_obj.model_ref

## Map I - Space Time PSD

In [None]:
from evaluation import plot_psd_spacetime_all

# plot_psd_spacetime_all(dict_psd, config)

### Map II - Space Time PSD Score

In [None]:
from evaluation import PSDSTScoreEval

score_obj = PSDSTScoreEval(eval_obj=eval_obj)
score_obj.standardize_coords()
score_obj.calculate_psd_score()

In [None]:
# score_obj.plot("ssh")
# score_obj.plot_all()
score_obj.stats("ssh", "siren")
score_obj.stats("ssh", "miost")
score_obj.stats("ssh", "duacs")
score_obj.stats_all()

In [None]:
score_obj