# Sea Surface Height - Derived Quantities

In this notebook, we will be looking at how we can derive quantities from sea surface height. Namely we will look at:

* U,V - Velocity components
* RV - relative vorticity

In [None]:
import sys, os

import ml_collections
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])


# append to path
sys.path.append(str(root))

In [None]:
from pathlib import Path
import numpy as np
from pathlib import Path
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt

import seaborn as sns

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)
import scienceplots

# plt.style.use("science")

import hvplot.xarray
import hvplot.pandas

from inr4ssh._src.operators.finite_diff import calculate_gradient, calculate_laplacian
from inr4ssh._src.preprocess.subset import temporal_subset, spatial_subset
from inr4ssh._src.preprocess.coords import (
    correct_coordinate_labels,
    correct_longitude_domain,
)
from inr4ssh._src.data.ssh_obs import load_ssh_altimetry_data_train

from inr4ssh._src.preprocess.coords import correct_longitude_domain
from inr4ssh._src.preprocess.obs import bin_observations_xr, bin_observations_coords
from inr4ssh._src.preprocess.grid import create_spatiotemporal_grid
from inr4ssh._src.viz.movie import create_movie
from inr4ssh._src.metrics.psd import psd_isotropic
from inr4ssh._src.viz.psd.isotropic import plot_psd_isotropic
from inr4ssh._src.viz.obs import plot_obs_demo
from inr4ssh._src.metrics.psd import psd_spacetime, psd_spacetime_dask
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_wavelength,
    plot_psd_spacetime_wavenumber,
)
from inr4ssh._src.viz.psd.spacetime import (
    plot_psd_spacetime_score_wavelength,
    plot_psd_spacetime_score_wavenumber,
)
from loguru import logger

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Data

In [None]:
from inr4ssh._src.preprocess.spatial import convert_lon_360_180, convert_lon_180_360


def post_process(ds, variable):

    # correct coordinate labels
    logger.info("Fixing coordinate labels...")
    ds = correct_coordinate_labels(ds)

    # correct labels
    logger.info("Fixing labels")
    ds = ds.rename({variable: "ssh"})

    # correct longitude domain
    logger.info("Fixing longitude domain")
    from inr4ssh._src.preprocess.spatial import convert_lon_360_180

    ds["longitude"] = convert_lon_180_360(ds.longitude)
    # ds["longitude"] = convert_lon_360_180(ds.longitude)

    # # subset temporal space
    # ds = ds.sel(time=slice(np.datetime64("2017-02-01"), np.datetime64("2017-03-01")))

    # # subset spatial space
    # ds = ds.sel(
    #     longitude=slice(-75.0, -45.0),
    #     latitude=slice(33.0, 53.0)
    # )

    ds = ds.transpose("time", "latitude", "longitude")

    # regrid data
    return ds

In [None]:
fig_path = Path(root).joinpath("figures/dc21a")

## Reference Grid

In [None]:
!ls /Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/

In [None]:
from inr4ssh._src.preprocess.regrid import oi_regrid

logger.info("Dataset I - DUACS")
url = "/Volumes/EMANS_HDD/data/dc21b_ose/test_2/results/OSE_ssh_mapping_MIOST.nc"
url = "/Volumes/EMANS_HDD/data/dc21b/results/siren_136.nc"
ds_field = xr.open_dataset(url)

ds_field = post_process(ds_field, "ssh")
ds_field

In [None]:
# from ml_collections import config_dict
#
# # create configuration
# def get_lowres_config():
#     config = config_dict.ConfigDict()
#
#     config.lon_min = -65  # -75.0
#     config.lon_max = -55.0  # -45.0
#     config.dlon = 0.1
#     config.lat_min = 33.0
#     config.lat_max = 43.0  # 53.0
#     config.dlat = 0.1
#     config.time_min = np.datetime64("2017-02-01")
#     config.time_max = np.datetime64("2017-03-01")
#     config.dt_freq = 1
#     config.dt_unit = "D"
#     config.dtime = "1_D"  # np.timedelta64(1, "D")
#     config.time_buffer = np.timedelta64(1, "D")
#     return config
#
#
# def get_hires_config():
#     config = get_lowres_config()
#     config.dlon = 0.05
#     config.dlat = 0.05
#     config.dtime = "12_h"
#     return config
#
#
# def get_superres_config():
#     config = get_lowres_config()
#     config.dlon = 0.01
#     config.dlat = 0.01
#     config.dtime = "6_h"
#     return config

In [None]:
# import pyinterp
# from einops import rearrange
# from inr4ssh._src.preprocess.regrid import (
#     create_pyinterp_grid_2dt,
#     regrid_2dt_from_grid,
#     regrid_2dt_from_da,
# )
# from inr4ssh._src.interp import interp_2dt
# from inr4ssh._src.preprocess.coords import Bounds2DT

In [None]:
# # init config
# config = get_hires_config()
#
# # create target grid
# grid_target = Bounds2DT.init_from_config(config).create_coordinates().create_grid()
#
# # regrid
# ds_field_hires = regrid_2dt_from_grid(
#     ds_field.ssh,
#     grid_target,
#     is_circle=True,
# )
#
# # fill gaps (around edges)
# ds_field_hires = interp_2dt(ds_field_hires, is_circle=True, method="gauss_seidel")
#
# ds_field_hires = xr.Dataset({"ssh": ds_field_hires})
#
# ds_field.ssh.shape, ds_field_hires.ssh.shape

In [None]:
vmin = np.min([ds_field.ssh.values])
vmax = np.max([ds_field.ssh.values])

# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_field.ssh.sel(time="2017-02-01").plot(
    ax=ax, cmap="viridis", robust=True, cbar_kwargs={"label": ""}, vmin=vmin, vmax=vmax
)
ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

In [None]:
from inr4ssh._src.operators.ssh import (
    ssh2uv_ds_2dt,
    kinetic_energy,
    enstropy,
    ssh2rv_ds_2dt,
)


# calculate UV components
ds_field = ssh2uv_ds_2dt(ds_field)

# calculate kinetic energy
ds_field["ke"] = kinetic_energy(ds_field.u, ds_field.v)

# calculate relative vorticity
ds_field = ssh2rv_ds_2dt(ds_field)

# calculate enstropy
ds_field["ens"] = enstropy(ds_field.rv)

In [None]:
# with plt.style.context('science'):
vmin = np.min([ds_field.u.values, ds_field.v.values])
vmax = np.max([ds_field.u.values, ds_field.v.values])
fig, ax = plt.subplots(ncols=2, figsize=(15, 5))

ds_field.u.sel(time="2017-02-01").plot(
    ax=ax[0],
    cmap="coolwarm",
    robust=True,
    cbar_kwargs={"label": ""},
    vmin=vmin,
    vmax=vmax,
)
ds_field.v.sel(time="2017-02-01").plot(
    ax=ax[1],
    cmap="coolwarm",
    robust=True,
    cbar_kwargs={"label": ""},
    vmin=vmin,
    vmax=vmax,
)
ax[0].set(xlabel="Longitude", ylabel="Latitude", title="")
ax[1].set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

In [None]:
# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_field.ke.sel(time="2017-02-01").plot(
    ax=ax,
    cmap="coolwarm",
    robust=True,
    cbar_kwargs={"label": ""},
    # vmin=vmin, vmax=vmax
)
ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

In [None]:
# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_field.rv.sel(time="2017-02-01").plot(
    ax=ax,
    cmap="coolwarm",
    robust=True,
    cbar_kwargs={"label": ""},
    # vmin=vmin, vmax=vmax
)
ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()

In [None]:
# with plt.style.context('science'):
fig, ax = plt.subplots()

ds_field.ens.sel(time="2017-02-01").plot(
    ax=ax,
    cmap="coolwarm",
    robust=True,
    cbar_kwargs={"label": ""},
    # vmin=vmin, vmax=vmax
)
ax.set(xlabel="Longitude", ylabel="Latitude", title="")
plt.show()