---
title: Gradient Considerations
date: 2023-04-01
authors:
  - name: J. Emmanuel Johnson
    affiliations:
      - MEOM Lab
    roles:
      - Primary Programmer
    email: jemanjohnson34@gmail.com
license: CC-BY-4.0
keywords: NerFs, Images
---

> In this notebook, we take a look at some of the derived quantities for sea surface height (SSH). These are physically meaningful quantities like velocity and vorticity. From these, we can visually assess how well our reconstruction methods perform.

In [None]:
import autoroot
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jrandom
import numpy as np
import numba as nb
import equinox as eqx
import kernex as kex
import finitediffx as fdx
import diffrax as dfx
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm, trange
from jaxtyping import Float, Array, PyTree, ArrayLike
import wandb
from omegaconf import OmegaConf
import hydra
import metpy
from sklearn.pipeline import Pipeline
from jejeqx._src.transforms.dataframe.spatial import Spherical2Cartesian
from jejeqx._src.transforms.dataframe.temporal import TimeDelta
from jejeqx._src.transforms.dataframe.scaling import MinMaxDF


sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
jax.config.update("jax_enable_x64", False)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Recap Formulation

We are interested in learning non-linear functions $\boldsymbol{f}$.

$$
\begin{aligned}
\boldsymbol{f}(\mathbf{x}) &=
\mathbf{w}^\top\boldsymbol{\phi}(\mathbf{x})+\mathbf{b}
\end{aligned}
$$

where the $\boldsymbol{\phi}(\cdot)$ is a basis function. Neural Fields typically try to learn this basis funciton via a series of composite functions of the form

$$
\boldsymbol{\phi}(\mathbf{x}) =
\boldsymbol{\phi}_L\circ\boldsymbol{\phi}_{L-1}
\circ\cdots\circ
\boldsymbol{\phi}_2\circ\boldsymbol{\phi}_{1}(\mathbf{x})
$$

## Problems

Here, we will demonstrate a problem that a naive network has.

## Data

In [None]:
# !wget wget -nc https://s3.us-east-1.wasabisys.com/melody/osse_data/ref/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc

In [None]:
from pathlib import Path

In [None]:
Path(
    "/gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/natl60/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"
).is_file()

In [None]:
from dataclasses import dataclass, field
from typing import List, Dict


@dataclass
class Subset:
    _target_: str = "builtins.slice"
    _args_: List = field(default_factory=lambda: ["2013-01-01", "2013-01-01"])


@dataclass
class SSHDM:
    _target_: str = "jejeqx._src.datamodules.coords.AlongTrackDM"
    batch_size: int = 10_000
    shuffle: bool = False
    train_size: float = 0.80
    subset_size: float = 0.40
    decode_times: bool = False
    spatial_coords: List = field(default_factory=lambda: ["lat", "lon"])
    temporal_coords: List = field(default_factory=lambda: ["time"])
    variables: List = field(default_factory=lambda: ["ssh"])
    paths: str = "/gpfswork/rech/cli/uvo53rl/projects/jejeqx/data/natl60/NATL60-CJM165_GULFSTREAM_ssh_y2013*"


# spatial transform
spatial_transforms = Pipeline(
    [
        ("cartesian3d", Spherical2Cartesian(radius=1.0, units="degrees")),
        ("spatialminmax", MinMaxDF(["x", "y", "z"], -1, 1)),
    ]
)

temporal_transforms = Pipeline(
    [
        ("timedelta", TimeDelta("2012-10-01", 1, "s")),
        ("timeminmax", MinMaxDF(["time"], -1, 1)),
    ]
)

In [None]:
select = {"time": slice("2013-01-01", "2013-06-01")}

config_dm = OmegaConf.structured(SSHDM())

dm = hydra.utils.instantiate(
    config_dm,
    select=select,
    spatial_transform=spatial_transforms,
    temporal_transform=temporal_transforms,
)

dm.setup()


init = dm.ds_train[:32]
x_init, t_init, y_init = init["spatial"], init["temporal"], init["data"]
x_init.min(), x_init.max(), x_init.shape, t_init.min(), t_init.max(), t_init.shape

In [None]:
xrda = dm.load_xrds()
xrda

In [None]:
import jejeqx._src.transforms.xarray.geostrophic as geocalc
import jejeqx._src.viz.geostrophic as geoplot
from jejeqx._src.viz.utils import get_cbar_label

In [None]:
def calculate_physical_quantities(da):
    # SSH
    ds = geocalc.get_ssh_dataset(da)

    # Stream Function
    ds = geocalc.calculate_streamfunction(ds, "ssh")

    # U,V Velocities
    ds = geocalc.calculate_velocities_sf(ds, "psi")

    # Kinetic Energy
    ds = geocalc.calculate_kinetic_energy(ds, ["u", "v"])

    # Relative Vorticity
    ds = geocalc.calculate_relative_vorticity_uv(ds, ["u", "v"], normalized=True)

    # Strain
    ds = geocalc.calculate_strain_magnitude(ds, ["u", "v"], normalized=True)

    return ds

In [None]:
ds_natl60 = calculate_physical_quantities(xrda.ssh)

In [None]:
import cmocean as cmo


def plot_analysis_vars(ds):
    fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(14, 7))

    # SSH
    cbar_kwargs = {"label": get_cbar_label(ds.ssh)}
    ds.ssh.plot.pcolormesh(ax=ax[0, 0], cmap="viridis", cbar_kwargs=cbar_kwargs)
    plt.tight_layout()

    # U
    cbar_kwargs = {"label": get_cbar_label(ds.u)}
    ds.u.plot.pcolormesh(ax=ax[0, 1], cmap="gray", cbar_kwargs=cbar_kwargs)
    plt.tight_layout()

    # v
    cbar_kwargs = {"label": get_cbar_label(ds.v)}
    ds.v.plot.pcolormesh(ax=ax[0, 2], cmap="gray", cbar_kwargs=cbar_kwargs)
    plt.tight_layout()

    # Kinetic Energy
    cbar_kwargs = {"label": get_cbar_label(ds.ke)}
    ds.ke.plot.pcolormesh(ax=ax[1, 0], cmap="YlGnBu_r", cbar_kwargs=cbar_kwargs)
    plt.tight_layout()

    # Relative Vorticity
    cbar_kwargs = {"label": get_cbar_label(ds.vort_r)}
    ds.vort_r.plot.pcolormesh(ax=ax[1, 1], cmap="RdBu_r", cbar_kwargs=cbar_kwargs)
    plt.tight_layout()

    # STRAIN
    cbar_kwargs = {"label": get_cbar_label(ds.strain)}
    ds.strain.plot.pcolormesh(ax=ax[1, 2], cmap=cmo.cm.speed, cbar_kwargs=cbar_kwargs)
    plt.tight_layout()
    return fig, ax

In [None]:
fig, ax = plot_analysis_vars(ds_natl60.isel(time=0))
plt.show()

In [None]:
from jejeqx._src.transforms.xarray.grid import latlon_deg2m, time_rescale

ds_psd_natl60 = latlon_deg2m(ds_natl60, mean=True)
ds_psd_natl60 = time_rescale(ds_psd_natl60, 1, "D")
ds_psd_natl60

## IsoTropic PSD

In [None]:
from jejeqx._src.transforms.xarray.psd import (
    psd_spacetime,
    psd_isotropic,
    psd_average_freq,
)


def calculate_isotropic_psd(ds, freq_dt=1, freq_unit="D"):
    ds = latlon_deg2m(ds, mean=True)
    ds = time_rescale(ds, freq_dt, freq_unit)

    # calculate isotropic PSDs
    ds_psd = xr.Dataset()
    ds_psd["ssh"] = psd_average_freq(psd_isotropic(ds.ssh, ["lat", "lon"]))
    ds_psd["u"] = psd_average_freq(psd_isotropic(ds.u, ["lat", "lon"]))
    ds_psd["v"] = psd_average_freq(psd_isotropic(ds.v, ["lat", "lon"]))
    ds_psd["ke"] = psd_average_freq(psd_isotropic(ds.ke, ["lat", "lon"]))
    ds_psd["vort_r"] = psd_average_freq(psd_isotropic(ds.vort_r, ["lat", "lon"]))
    ds_psd["strain"] = psd_average_freq(psd_isotropic(ds.strain, ["lat", "lon"]))

    return ds_psd

In [None]:
ds_psd_natl60 = calculate_isotropic_psd(ds_natl60)

In [None]:
ds_psd_natl60.ssh

In [None]:
import cmocean as cmo
from jejeqx._src.viz.xarray.psd import plot_psd_isotropic, plot_psd_spacetime_wavenumber


def plot_analysis_psd_iso(ds):
    fig, ax = plt.subplots(ncols=3, nrows=2, figsize=(14, 7))

    # SSH
    scale = "km"
    units = "$m^{2}$/cycles/m"
    ax[0, 0] = plot_psd_isotropic(ds.ssh, units=units, scale=scale, ax=ax[0, 0])

    # U
    scale = "km"
    units = "U-Velocity"  # "$m^{2}$/cycles/m"
    ax[0, 1] = plot_psd_isotropic(ds.u, units=units, scale=scale, ax=ax[0, 1])

    # v
    scale = "km"
    units = "V-Velocity"  # "$m^{2}$/cycles/m"
    ax[0, 2] = plot_psd_isotropic(ds.v, units=units, scale=scale, ax=ax[0, 2])

    # Kinetic Energy
    scale = "km"
    units = "Kinetic Energy"  # "$m^{2}$/cycles/m"
    ax[1, 0] = plot_psd_isotropic(ds.ke, units=units, scale=scale, ax=ax[1, 0])

    # Relative Vorticity
    scale = "km"
    units = "U-Velocity"  # "$m^{2}$/cycles/m"
    ax[1, 1] = plot_psd_isotropic(ds.vort_r, units=units, scale=scale, ax=ax[1, 1])

    # STRAIN
    scale = "km"
    units = "Strain"  # "$m^{2}$/cycles/m"
    ax[1, 2] = plot_psd_isotropic(ds.u, units=units, scale=scale, ax=ax[1, 2])

    plt.tight_layout()
    return fig, ax

In [None]:
fig, ax = plot_analysis_psd_iso(ds_psd_natl60)
plt.show()

In [None]:
from jejeqx._src.transforms.xarray.psd import (
    psd_spacetime,
    psd_isotropic,
    psd_average_freq,
)


def calculate_spacetime_psd(ds, freq_dt=1, freq_unit="D"):
    ds = latlon_deg2m(ds, mean=True)
    ds = time_rescale(ds, freq_dt, freq_unit)

    # calculate isotropic PSDs
    ds_psd = xr.Dataset()
    ds_psd["ssh"] = psd_average_freq(psd_spacetime(ds.ssh, ["time", "lon"]))
    ds_psd["u"] = psd_average_freq(psd_spacetime(ds.u, ["time", "lon"]))
    ds_psd["v"] = psd_average_freq(psd_spacetime(ds.v, ["time", "lon"]))
    ds_psd["ke"] = psd_average_freq(psd_spacetime(ds.ke, ["time", "lon"]))
    ds_psd["vort_r"] = psd_average_freq(psd_spacetime(ds.vort_r, ["time", "lon"]))
    ds_psd["strain"] = psd_average_freq(psd_spacetime(ds.strain, ["time", "lon"]))

    return ds_psd

In [None]:
ds_psd_natl60 = calculate_spacetime_psd(ds_natl60)
ds_psd_natl60

In [None]:
ds_psd_natl60 = calculate_spacetime_psd(ds_natl60)
ds_psd_natl60

In [None]:
import cmocean as cmo
from jejeqx._src.viz.xarray.psd import plot_psd_isotropic, plot_psd_spacetime_wavenumber


def plot_analysis_psd_spacetime(ds):
    fig, ax = plt.subplots(ncols=3, nrows=2, figsize=(14, 7))

    # SSH
    scale = "km"
    units = "SSH"  # "$m^{2}$/cycles/m"
    _, ax[0, 0], _ = plot_psd_spacetime_wavelength(
        ds.ssh, space_scale=scale, psd_units=units, ax=ax[0, 0]
    )

    # U
    scale = "km"
    units = "U-Velocity"  # "$m^{2}$/cycles/m"
    ax[0, 1] = plot_psd_spacetime_wavelength(
        ds.u, space_scale=scale, psd_units=units, ax=ax[0, 1]
    )

    # v
    scale = "km"
    units = "V-Velocity"  # "$m^{2}$/cycles/m"
    ax[0, 2] = plot_psd_spacetime_wavelength(
        ds.v, space_scale=scale, psd_units=units, ax=ax[0, 2]
    )

    # Kinetic Energy
    scale = "km"
    units = "Kinetic Energy"  # "$m^{2}$/cycles/m"
    ax[1, 0] = plot_psd_spacetime_wavelength(
        ds.ke, space_scale=scale, psd_units=units, ax=ax[1, 0]
    )

    # Relative Vorticity
    scale = "km"
    units = "Relative Vorticity"  # "$m^{2}$/cycles/m"
    ax[1, 1] = plot_psd_spacetime_wavelength(
        ds.vort_r, space_scale=scale, psd_units=units, ax=ax[1, 1]
    )

    # STRAIN
    scale = "km"
    units = "Strain"  # "$m^{2}$/cycles/m"
    ax[1, 2] = plot_psd_spacetime_wavelength(
        ds.strain, space_scale=scale, psd_units=units, ax=ax[1, 2]
    )

    plt.tight_layout()
    return fig, ax

In [None]:
fig, ax = plot_analysis_psd_spacetime(ds_psd_natl60)
plt.show()

In [None]:
out = latlon_deg2m(ds_natl60.ssh, mean=True)
out = time_rescale(out, 1, "D")
out = psd_spacetime(out, ["time", "lon"])

out = psd_average_freq(out)


fig, ax, _ = plot_psd_spacetime_wavelength(out, "km", "SSH")
# fig, ax, _ = plot_psd_spacetime_wavenumber(out, "km")

plt.show()

In [None]:
from jejeqx._src.viz.xarray.psd import (
    plot_psd_spacetime_wavenumber,
    plot_psd_spacetime_wavelength,
)

In [None]:
fig, ax, _ = plot_psd_spacetime_wavenumber(ds_psd_natl60.ssh)

plt.show()

In [None]:
fig, ax, _ = plot_psd_spacetime_wavelength(ds_psd_natl60.ssh)

plt.show()

In [None]:
import cmocean as cmo


def plot_analysis_psd_iso(ds):
    fig, ax = plt.subplots(ncols=3, nrows=2, figsize=(14, 7))

    # SSH
    scale = "km"
    units = "$m^{2}$/cycles/m"
    ax[0, 0] = plot_psd_spacetime(ds.ssh, units=units, scale=scale, ax=ax[0, 0])

    # U
    scale = "km"
    units = "U-Velocity"  # "$m^{2}$/cycles/m"
    ax[0, 1] = plot_psd_spacetime(ds.u, units=units, scale=scale, ax=ax[0, 1])

    # v
    scale = "km"
    units = "V-Velocity"  # "$m^{2}$/cycles/m"
    ax[0, 2] = plot_psd_spacetime(ds.v, units=units, scale=scale, ax=ax[0, 2])

    # Kinetic Energy
    scale = "km"
    units = "Kinetic Energy"  # "$m^{2}$/cycles/m"
    ax[1, 0] = plot_psd_spacetime(ds.ke, units=units, scale=scale, ax=ax[1, 0])

    # Relative Vorticity
    scale = "km"
    units = "U-Velocity"  # "$m^{2}$/cycles/m"
    ax[1, 1] = plot_psd_spacetime(ds.vort_r, units=units, scale=scale, ax=ax[1, 1])

    # STRAIN
    scale = "km"
    units = "Strain"  # "$m^{2}$/cycles/m"
    ax[1, 2] = plot_psd_spacetime(ds.u, units=units, scale=scale, ax=ax[1, 2])

    plt.tight_layout()
    return fig, ax

In [None]:
import xrft
from jejeqx._src.transforms.xarray.psd import (
    psd_spacetime,
    psd_isotropic,
    psd_average_freq,
)

In [None]:
ds_psd = psd_isotropic(out_ds.ssh, ["lat", "lon"])
ds_psd_avg = psd_average_freq(ds_psd)
ds_psd_avg

In [None]:
from jejeqx._src.viz.xarray.psd import (
    plot_psd_isotropic_wavenumber,
    plot_psd_isotropic_wavelength,
    plot_psd_isotropic,
)

In [None]:
scale = "km"
units = "$m^{2}$/cycles/m"
fig, ax = plt.subplots(ncols=2, nrows=3, figsize=(10, 12))

ax[1, 1] = plot_psd_isotropic(ds_psd_avg, units=units, scale=scale, ax=ax[1, 1])


plt.tight_layout()
plt.show()

In [None]:
plot_psd_isotropi

In [None]:
import matplotlib.colors as colors
import matplotlib.ticker as ticker


def plot_psd_spacetime_wavenumber(freq_x, freq_y, psd):
    fig, ax = plt.subplots()

    locator = ticker.LogLocator()
    norm = colors.LogNorm()

    pts = ax.contourf(
        freq_x, freq_y, psd, norm=norm, locator=locator, cmap="RdYlGn", extend="both"
    )

    ax.set(
        yscale="log",
        xscale="log",
        xlabel="Wavenumber [cycles/km]",
        ylabel="Frequency [cycles/days]",
    )
    # colorbar
    fmt = ticker.LogFormatterMathtext(base=10)
    cbar = fig.colorbar(
        pts,
        pad=0.02,
        format=fmt,
    )
    cbar.ax.set_ylabel(r"PSD [m$^{2}$/cycles/m]")

    plt.grid(which="both", linestyle="--", linewidth=1, color="black", alpha=0.2)

    return fig, ax, cbar

In [None]:
# average over latitude
mean_psd_signal = psd_signal.mean(dim="latitude").where(
    (psd_signal.freq_longitude > 0.0) & (psd_signal.freq_time > 0.0), drop=True
)

In [None]:
np.sum(out[0])

In [None]:
np.sum(dx), np.sum(dy)

In [None]:
np.mean(dx), np.mean(dy)

In [None]:
np.mean(dx) * len(dx), np.mean(dy) * len(dx)

In [None]:
out = metpy.calc.lat_lon_grid_deltas(ds.lon, ds.lat)
out