---
title: Gradient Considerations
date: 2023-04-01
authors:
  - name: J. Emmanuel Johnson
    affiliations:
      - MEOM Lab
    roles:
      - Primary Programmer
    email: jemanjohnson34@gmail.com
license: CC-BY-4.0
keywords: NerFs, Images
---

In [1]:
# import sys, os

# # spyder up to find the root
# oceanbench_root = "/gpfswork/rech/cli/uvo53rl/projects/oceanbench"

# # append to path
# sys.path.append(str(oceanbench_root))

In [2]:
import autoroot
import typing as tp
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from omegaconf import OmegaConf
import hydra
import metpy


sns.reset_defaults()
sns.set_context(context="poster", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Processing Chain

**Part I**:

* Open Dataset
* Validate Coordinates + Variables
* Decode Time
* Select Region
* Sortby Time

**Part II**: Regrid

**Part III**:

* Interpolate Nans
* Add Units
* Spatial Rescale
* Time Rescale

**Part IV**: Metrics

*

## Data

In [3]:
# !wget wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc

In [4]:
# !cat configs/postprocess.yaml

In [5]:
# # load config
# config_dm = OmegaConf.load('./configs/postprocess.yaml')

# # instantiate
# ds = hydra.utils.instantiate(config_dm.NATL60_GF_1Y1D)
# ds

## Reference Dataset

For the reference dataset, we will look at the NEMO simulation of the Gulfstream.

In [6]:
%%writefile configs/natl60.yaml

domain:
  lat: {_target_: "builtins.slice", _args_: [32., 44.]}
  lon: {_target_: "builtins.slice", _args_: [-66., -54.]}
  time: {_target_: "builtins.slice", _args_: ["2012-10-22", "2012-12-02"]}

select:
    _target_: "xarray.Dataset.sel"
    _partial_: True
    indexers: "${domain}"

# NATL60 GULFSTREAM SIMULATION - REDUCED VERSION
NATL60_GF_1Y1D:
  _target_: "oceanbench._src.data.pipe"
  inp: "/gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/natl60/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"
  fns:
    - {_target_: "xarray.open_dataset", decode_times: False, _partial_: True}
    # VALIDATE COORDINATES
    - {_target_: "oceanbench._src.geoprocessing.validation.validate_latlon", _partial_: True}
    - {_target_: "oceanbench._src.geoprocessing.validation.validate_time", _partial_: True}
    # RESCALE TIME
    - {_target_: "oceanbench._src.geoprocessing.validation.decode_cf_time", units: "seconds since 2012-10-01", _partial_: true}
    # SELECT REGION
    - "${select}"
    - {_target_: "xarray.Dataset.sortby", variables: "time", _partial_: True}

Overwriting configs/natl60.yaml


In [7]:
%%time

# load config
config_dm = OmegaConf.load('./configs/natl60.yaml')

# instantiate
ds_natl60 = hydra.utils.instantiate(config_dm.NATL60_GF_1Y1D).compute()
ds_natl60

CPU times: user 221 ms, sys: 54.1 ms, total: 275 ms
Wall time: 367 ms


### Prediction Datasets - NADIR

In [8]:
!ls /gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/results/
!ls /gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/results/4DVarNet

4DVarNet  DUACS   leaderboard.csv  NerF  results.csv	   results_nerf.csv
BFNQG	  DYMOST  MIOST		   OI	 results_demo.csv
2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc


In [9]:
# load config
results_config = OmegaConf.load(f'./configs/results.yaml')

# instantiate
ds_duacs = hydra.utils.instantiate(results_config.DUACS_NADIR.data).compute()
ds_miost = hydra.utils.instantiate(results_config.MIOST_NADIR.data).compute()
ds_bfnqg = hydra.utils.instantiate(results_config.BFNQG_NADIR.data).compute()
ds_4dvarnet = hydra.utils.instantiate(results_config.FourDVARNET_NADIR.data).compute()
# ds_nerf_siren = hydra.utils.instantiate(results_config.NERF_SIREN_NADIR.data).compute()
# ds_nerf_ffn = hydra.utils.instantiate(results_config.NERF_FFN_NADIR.data).compute()
# ds_nerf_mlp = hydra.utils.instantiate(results_config.NERF_MLP_NADIR.data).compute()

## Post Processing Chain

* Coordinate Change
* PSD Metrics

In [10]:
from oceanbench._src.geoprocessing.gridding import grid_to_regular_grid, coord_based_to_grid
from oceanbench._src.geoprocessing.interpolate import fillnan_gauss_seidel
from oceanbench._src.geoprocessing import geostrophic as geocalc
from oceanbench._src.geoprocessing.spatial import latlon_deg2m
from oceanbench._src.geoprocessing.temporal import time_rescale
from metpy.units import units
import pint_xarray

def postprocess_fn(ds, ds_reference):
    
    # resample
    ds = ds.resample(time="1D").mean()

    # regrid
    ds = grid_to_regular_grid(
        src_grid_ds=ds.pint.dequantify(),
        tgt_grid_ds=ds_reference.pint.dequantify(), keep_attrs=False
    )
    
    # fill nans
    ds = fillnan_gauss_seidel(ds, variable="ssh")
    
    return ds

In [11]:
ds_natl60_ = ds_natl60.pipe(lambda x: postprocess_fn(x, ds_natl60))
ds_duacs = ds_duacs.pipe(lambda x: postprocess_fn(x, ds_natl60))
ds_miost = ds_miost.pipe(lambda x: postprocess_fn(x, ds_natl60))
ds_bfnqg = ds_bfnqg.pipe(lambda x: postprocess_fn(x, ds_natl60))
ds_4dvarnet = ds_4dvarnet.pipe(lambda x: postprocess_fn(x, ds_natl60))

## Normalized RMSE

In [12]:
from oceanbench._src.metrics.stats import nrmse_da, rmse_da
from matplotlib import ticker

In [13]:
ds_duacs_nrmse = nrmse_da(ds_duacs, ds_natl60_, "ssh", dim=["lat", "lon"])
ds_miost_nrmse = nrmse_da(ds_miost, ds_natl60_, "ssh", dim=["lat", "lon"])
ds_bfnqg_nrmse = nrmse_da(ds_bfnqg, ds_natl60_, "ssh", dim=["lat", "lon"])
ds_4dvarnet_nrmse = nrmse_da(ds_4dvarnet, ds_natl60_, "ssh", dim=["lat", "lon"])

In [20]:
fig, ax = plt.subplots(figsize=(7,6))

ds_duacs_nrmse.plot(ax=ax, label="DUACS", color="tab:green")
ds_miost_nrmse.plot(ax=ax, label="MIOST", color="tab:red")
ds_bfnqg_nrmse.plot(ax=ax, label="BFNQG", color="tab:blue")
ds_4dvarnet_nrmse.plot(ax=ax, label="4DVarNet", color="tab:olive")

ax.set(
    ylim=[0.80, 1.0],
    ylabel="Normalized RMSE",
    xlabel="Date"
)
ax.grid("on", which="both", axis="both", alpha=0.5)

# tick format
ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.1f}"))

# tick locator
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.05))
ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.01))
plt.tight_layout()
plt.legend()
fig.savefig(f"./figures/stats/nrmse_space.png")
plt.close()

## RMSE

In [15]:
ds_duacs_rmse = rmse_da(ds_duacs, ds_natl60_, "ssh", dim=["lat", "lon"])
ds_miost_rmse = rmse_da(ds_miost, ds_natl60_, "ssh", dim=["lat", "lon"])
ds_bfnqg_rmse = rmse_da(ds_bfnqg, ds_natl60_, "ssh", dim=["lat", "lon"])
ds_4dvarnet_rmse = rmse_da(ds_4dvarnet, ds_natl60_, "ssh", dim=["lat", "lon"])

In [16]:
fig, ax = plt.subplots(figsize=(9,6))

(100*ds_duacs_rmse).plot(ax=ax, label="DUACS", color="tab:green")
(100*ds_miost_rmse).plot(ax=ax, label="MIOST", color="tab:red")
(100*ds_bfnqg_rmse).plot(ax=ax, label="BFNQG", color="tab:blue")
(100*ds_4dvarnet_rmse).plot(ax=ax, label="4DVarNet", color="tab:olive")

ax.set(
    # ylim=[0.01, 0.1],
    ylabel="RMSE [cm]",
    xlabel="Date"
)
ax.grid("on", which="both", axis="both", alpha=0.5)

# tick format
ax.yaxis.set_major_formatter(ticker.StrMethodFormatter("{x:.1f}"))

# tick locator
ax.yaxis.set_major_locator(ticker.MultipleLocator(0.5))
# ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1))
plt.tight_layout()
plt.legend()
fig.savefig(f"./figures/stats/rmse_space.png")
plt.close()