---
title: Gradient Considerations
date: 2023-04-01
authors:
  - name: J. Emmanuel Johnson
    affiliations:
      - MEOM Lab
    roles:
      - Primary Programmer
    email: jemanjohnson34@gmail.com
license: CC-BY-4.0
keywords: NerFs, Images
---

In [1]:
# import sys, os

# # spyder up to find the root
# oceanbench_root = "/gpfswork/rech/cli/uvo53rl/projects/oceanbench"

# # append to path
# sys.path.append(str(oceanbench_root))

In [2]:
import autoroot
import typing as tp
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
from omegaconf import OmegaConf
import hydra
import metpy


sns.reset_defaults()
sns.set_context(context="poster", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Processing Chain

**Part I**:

* Open Dataset
* Validate Coordinates + Variables
* Decode Time
* Select Region
* Sortby Time

**Part II**: Regrid

**Part III**:

* Interpolate Nans
* Add Units
* Spatial Rescale
* Time Rescale

**Part IV**: Metrics

*

## Data

In [3]:
# !wget wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc

In [4]:
# !cat configs/postprocess.yaml

In [5]:
# # load config
# config_dm = OmegaConf.load('./configs/postprocess.yaml')

# # instantiate
# ds = hydra.utils.instantiate(config_dm.NATL60_GF_1Y1D)
# ds

## Reference Dataset

For the reference dataset, we will look at the NEMO simulation of the Gulfstream.

In [6]:
%%writefile configs/natl60.yaml

domain:
  lat: {_target_: "builtins.slice", _args_: [32., 44.]}
  lon: {_target_: "builtins.slice", _args_: [-66., -54.]}
  time: {_target_: "builtins.slice", _args_: ["2012-10-22", "2012-12-02"]}

select:
    _target_: "xarray.Dataset.sel"
    _partial_: True
    indexers: "${domain}"

# NATL60 GULFSTREAM SIMULATION - REDUCED VERSION
NATL60_GF_1Y1D:
  _target_: "oceanbench._src.data.pipe"
  inp: "/gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/natl60/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"
  fns:
    - {_target_: "xarray.open_dataset", decode_times: False, _partial_: True}
    # VALIDATE COORDINATES
    - {_target_: "oceanbench._src.geoprocessing.validation.validate_latlon", _partial_: True}
    - {_target_: "oceanbench._src.geoprocessing.validation.validate_time", _partial_: True}
    # RESCALE TIME
    - {_target_: "oceanbench._src.geoprocessing.validation.decode_cf_time", units: "seconds since 2012-10-01", _partial_: true}
    # SELECT REGION
    - "${select}"
    - {_target_: "xarray.Dataset.sortby", variables: "time", _partial_: True}

Overwriting configs/natl60.yaml


In [7]:
%%time

# load config
config_dm = OmegaConf.load('./configs/natl60.yaml')

# instantiate
ds_natl60 = hydra.utils.instantiate(config_dm.NATL60_GF_1Y1D).compute()
ds_natl60

CPU times: user 209 ms, sys: 55 ms, total: 264 ms
Wall time: 265 ms


### Prediction Datasets - NADIR

In [8]:
!ls /gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/results/
!ls /gpfswork/rech/yrf/commun/data_challenges/dc20a_osse/staging/results/4DVarNet

4DVarNet  DUACS   leaderboard.csv  NerF  results.csv	   results_nerf.csv
BFNQG	  DYMOST  MIOST		   OI	 results_demo.csv
2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadir_GF_GF.nc
2020a_SSH_mapping_NATL60_4DVarNet_v2022_nadirswot_GF_GF.nc


In [9]:
# load config
results_config = OmegaConf.load(f'./configs/results.yaml')

# instantiate
ds_duacs = hydra.utils.instantiate(results_config.DUACS_NADIR.data).compute()
ds_miost = hydra.utils.instantiate(results_config.MIOST_NADIR.data).compute()
ds_bfnqg = hydra.utils.instantiate(results_config.BFNQG_NADIR.data).compute()
ds_4dvarnet = hydra.utils.instantiate(results_config.FourDVARNET_NADIR.data).compute()
# ds_nerf_siren = hydra.utils.instantiate(results_config.NERF_SIREN_NADIR.data).compute()
# ds_nerf_ffn = hydra.utils.instantiate(results_config.NERF_FFN_NADIR.data).compute()
# ds_nerf_mlp = hydra.utils.instantiate(results_config.NERF_MLP_NADIR.data).compute()

## Post Processing Chain

* Coordinate Change
* PSD Metrics

In [10]:
from oceanbench._src.geoprocessing.gridding import grid_to_regular_grid, coord_based_to_grid
from oceanbench._src.geoprocessing.interpolate import fillnan_gauss_seidel
from oceanbench._src.geoprocessing import geostrophic as geocalc
from oceanbench._src.geoprocessing.spatial import latlon_deg2m
from oceanbench._src.geoprocessing.temporal import time_rescale
from metpy.units import units
import pint_xarray

def postprocess_fn(ds, ds_reference):
    
    # resample
    ds = ds.resample(time="1D").mean()

    # regrid
    ds = grid_to_regular_grid(
        src_grid_ds=ds.pint.dequantify(),
        tgt_grid_ds=ds_reference.pint.dequantify(), keep_attrs=False
    )
    
    # fill nans
    ds = fillnan_gauss_seidel(ds, variable="ssh")
    
    # coordinate change
    ds = latlon_deg2m(ds, mean=True)
    ds = time_rescale(ds, t0="2012-10-22", freq_dt=1, freq_unit="D")
    
    
    return ds

In [11]:
ds_natl60_ = ds_natl60.pipe(lambda x: postprocess_fn(x, ds_natl60))
ds_duacs = ds_duacs.pipe(lambda x: postprocess_fn(x, ds_natl60))
ds_miost = ds_miost.pipe(lambda x: postprocess_fn(x, ds_natl60))
ds_bfnqg = ds_bfnqg.pipe(lambda x: postprocess_fn(x, ds_natl60))
ds_4dvarnet = ds_4dvarnet.pipe(lambda x: postprocess_fn(x, ds_natl60))

## Power Spectrum (Isotropic)

In [18]:
%%writefile configs/metrics.yaml

fill_nans:
    _target_: "oceanbench._src.geoprocessing.interpolate.fillnan_gauss_seidel"
    _partial_: True
    variable: "ssh"
    
spatial_rescale:
    _target_: "oceanbench._src.geoprocessing.spatial.latlon_deg2m"
    _partial_: True
    mean: True

temporal_rescale:
    _target_: "oceanbench._src.geoprocessing.temporal.time_rescale"
    _partial_: True
    t0: "2012-10-22"
    freq_dt: 1
    freq_unit: "D"
    
psd_isotropic:
    _target_: "oceanbench._src.metrics.power_spectrum.psd_isotropic"
    _partial_: True
    variable: "ssh"
    dims: ["lon", "lat"]
    detrend: "constant"
    window: "tukey"
    nfactor: 2
    window_correction: True
    true_amplitude: True
    truncate: True
    
psd_isotropic_score:
    _target_: "oceanbench._src.metrics.power_spectrum.psd_isotropic_score"
    _partial_: True
    variable: "ssh"
    psd_dims: ["lon", "lat"]
    avg_dims: ["time"]
    detrend: "constant"
    window: "tukey"
    nfactor: 2
    window_correction: True
    true_amplitude: True
    truncate: True
    
psd_spacetime_score:
    _target_: "oceanbench._src.metrics.power_spectrum.psd_spacetime_score"
    _partial_: True
    variable: "ssh"
    psd_dims: ["time", "lon"]
    avg_dims: ["lat"]
    detrend: "constant"
    window: "tukey"
    nfactor: 2
    window_correction: True
    true_amplitude: True
    truncate: True
    
psd_spacetime:
    _target_: "oceanbench._src.metrics.power_spectrum.psd_spacetime"
    _partial_: True
    variable: "ssh"
    dims: ["time", "lon"]
    detrend: "constant"
    window: "tukey"
    nfactor: 2
    window_correction: True
    true_amplitude: True
    truncate: True
    
psd_isotropic_avg:
    _target_: "oceanbench._src.preprocessing.mean.xr_cond_average"
    _partial_: True
    dims: ["time"]
    drop: True
    
psd_spacetime_avg:
    _target_: "oceanbench._src.preprocessing.mean.xr_cond_average"
    _partial_: True
    dims: ["lat"]
    drop: True
    
    
psd_preprocess_chain:
    _target_: "oceanbench._src.data.pipe"
    _partial_: true
    fns:
        - "${fill_nans}" # FILL NANs
        - "${spatial_rescale}" # RESCALE LATLON DEGREEs -> METERS
        - "${temporal_rescale}" # RESCALE TIME -> DAYS


psd_isotropic_chain:
    _target_: "oceanbench._src.data.pipe"
    _partial_: true
    fns:
        - "${psd_isotropic}" # ISOTROPIC POWER SPECTRUM
        - "${psd_isotropic_avg}" # AVERAGE TIME DIMENSIONS
        
psd_spacetime_chain:
    _target_: "oceanbench._src.data.pipe"
    _partial_: true
    fns:
        - "${psd_spacetime}" # ISOTROPIC POWER SPECTRUM
        - "${psd_spacetime_avg}" # AVERAGE LATITUDE DIMENSIONS

Overwriting configs/metrics.yaml


In [33]:
# load config
metrics_config = OmegaConf.load('./configs/metrics.yaml')

ds_natl60_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_chain)(ds_natl60_)
ds_duacs_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_chain)(ds_duacs)
ds_miost_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_chain)(ds_miost)
ds_bfnqg_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_chain)(ds_bfnqg)
ds_4dvarnet_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_chain)(ds_4dvarnet)


In [34]:
from utils import PlotPSDIsotropic, PlotPSDScoreIsotropic

In [35]:
psd_iso_plot = PlotPSDIsotropic()
psd_iso_plot.init_fig(figsize=(8,7))
psd_iso_plot.plot_both(
    ds_natl60_psd.ssh,
    freq_scale=1e3, 
    units="km",
    label="NATL60",
    color="black",
)
psd_iso_plot.plot_both(
    ds_duacs_psd.ssh,
    freq_scale=1e3, 
    units="km",
    label="DUACS",
    color="tab:green",
)
psd_iso_plot.plot_both(
    ds_miost_psd.ssh,
    freq_scale=1e3, 
    units="km",
    label="MIOST",
    color="tab:red",
)
psd_iso_plot.plot_both(
    ds_bfnqg_psd.ssh,
    freq_scale=1e3, 
    units="km",
    label="BFN-QG",
    color="tab:blue",
)
psd_iso_plot.plot_both(
    ds_4dvarnet_psd.ssh,
    freq_scale=1e3, 
    units="km",
    label="4DVarNet",
    color="tab:olive",
)
# psd_iso_plot.plot_both(
#     ds_nerf_mlp_psd.ssh,
#     freq_scale=1e3, 
#     units="km",
#     label="NERF (MLP)",
#     color="tab:cyan",
# )

# set custom bounds
psd_iso_plot.ax.set_xlim((10**(-3) - 0.00025, 10**(-1) +0.025))
psd_iso_plot.ax.set_ylabel("PSD [SSH]")
plt.tight_layout()
plt.gcf().savefig("./figures/psd_iso/dc20a_psd_iso_ssh.png")
plt.close()

### PSD Isotropic Score

In [30]:
# load config
metrics_config = OmegaConf.load('./configs/metrics.yaml')

ds_duacs_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_score)(ds_duacs, ds_natl60_)
ds_miost_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_score)(ds_miost, ds_natl60_)
ds_bfnqg_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_score)(ds_bfnqg, ds_natl60_)
ds_4dvarnet_psd = hydra.utils.instantiate(metrics_config.psd_isotropic_score)(ds_4dvarnet, ds_natl60_)


In [31]:
from oceanbench._src.metrics.utils import find_intercept_1D, find_intercept_2D

In [32]:
psd_iso_plot = PlotPSDScoreIsotropic()
psd_iso_plot.init_fig(figsize=(8,7))
psd_iso_plot.plot_score(
    ds_duacs_psd.ssh,
    freq_scale=1e3, 
    units="km",
    name="DUACS",
    color="green",
    threshhold=0.50,
    threshhold_color="tab:green"
    
)
psd_iso_plot.plot_score(
    ds_miost_psd.ssh,
    freq_scale=1e3, 
    units="km",
    name="MIOST",
    color="red",
    threshhold=0.50,
    threshhold_color="tab:red"
    
)
psd_iso_plot.plot_score(
    ds_bfnqg_psd.ssh,
    freq_scale=1e3, 
    units="km",
    name="BFN-QG",
    color="blue",
    threshhold=0.50,
    threshhold_color="tab:blue"
    
)

psd_iso_plot.plot_score(
    ds_4dvarnet_psd.ssh,
    freq_scale=1e3, 
    units="km",
    name="4DVarNet",
    color="olive",
    threshhold=0.50,
    threshhold_color="tab:olive"
    
)

# set custom bounds
psd_iso_plot.ax.set_xlim((10**(-3) - 0.00025, 10**(-1) +0.025))
psd_iso_plot.ax.set_ylabel("PSD Score [SSH]")
plt.legend()
plt.tight_layout()
# plt.gcf().savefig("./figures/dc20a/psd_score/isotropic/dc20a_psd_isotropic_score_nadir.png")
plt.gcf().savefig("./figures/psd_iso/dc20a_psd_score_iso_ssh.png")
plt.close()

No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
