In [None]:
import numpy as np
import os

os.environ["KERAS_BACKEND"] = "jax"
import keras
import xarray as xr
import jax.numpy as jnp
from jaxtyping import Array, Float
import matplotlib.pyplot as plt
from jax import config
import metpy

config.update("jax_enable_x64", True)

In [None]:
import toolz


def validate_lon(ds):
    new_ds = ds.copy()

    new_ds["lon"] = (ds.lon + 180) % 360 - 180

    attrs = toolz.dicttoolz.merge(
        ds.lon.attrs,
        dict(
            units="degrees_east",
            standard_name="longitude",
            long_name="Longitude",
            actual_range=[new_ds.lon.values.min(), new_ds.lon.values.max()],
        ),
    )
    new_ds["lon"] = new_ds.lon.assign_attrs(attrs)

    new_ds = new_ds.sortby("lon")
    return new_ds


def validate_lat(ds):
    new_ds = ds.copy()
    new_ds["lat"] = (ds.lat + 90) % 180 - 90

    attrs = toolz.dicttoolz.merge(
        ds.lat.attrs,
        dict(
            units="degrees_north",
            standard_name="latitude",
            long_name="Latitude",
            actual_range=[new_ds.lat.values.min(), new_ds.lat.values.max()],
        ),
    )

    new_ds["lat"] = new_ds.lat.assign_attrs(attrs)

    new_ds = new_ds.sortby("lat")

    return new_ds


def validate_sst(ds: xr.Dataset, variable: str = "sst") -> xr.Dataset:
    """Assign sst attributes to variable"""
    ds = ds.copy()
    attrs = toolz.dicttoolz.merge(
        ds[variable].attrs,
        dict(
            units="degC",
            standard_name="sea_surface_temperature",
            long_name="Sea Surface Temperature",
        ),
    )
    ds[variable] = ds[variable].assign_attrs(attrs)

    return ds

# Data

In [None]:
ds = xr.tutorial.open_dataset("ersstv5")
ds = validate_sst(validate_lat(validate_lon(ds)))
da = ds["sst"]
# da

In [None]:
da

In [None]:
# Mean of equitorial and polar radius
EARTH_RADIUS_M = 1000 * (6357 + 6378) / 2
METERS_PER_DEGREE = 2 * np.pi * EARTH_RADIUS_M / 360

In [None]:
import finitediffx as fdx

In [None]:
subset: Float[Array, "Dx Dy"] = da.transpose("time", "lon", "lat").isel(time=0)
dx: Float[Array, "Dx"] = da.coords["lon"]

assert dx.shape[0] == subset.shape[0]

In [None]:
subset.shape, dx.shape

In [None]:
fd_out = fdx.difference(
    subset.values, axis=0, step_size=(dx[1] - dx[0]).values, method="central"
)
# fd_out = fdx.difference(subset, axis=0, method="central") / (dx[1]-dx[0])
# fd_out = jnp.gradient(subset, dx[:,None], axis=0)

In [None]:
xr_out = da.transpose("time", "lon", "lat").isel(time=0).differentiate("lon")

In [None]:
# np.testing.assert_array_almost_equal(fd_out, xr_out.values, decimal=0)

In [None]:
fig, ax = plt.subplots()

ax.imshow(xr_out.T, origin="lower")

plt.show()

In [None]:
xr_out.min(), xr_out.max(), fd_out.min(), fd_out.max()

In [None]:
fig, ax = plt.subplots()

ax.imshow(fd_out.T, origin="lower")

plt.show()

In [None]:
fig, ax = plt.subplots()

pts = ax.imshow(np.abs(fd_out.T - xr_out.T), origin="lower", cmap="Reds")

plt.colorbar(pts)

plt.show()

In [None]:
xr_out.values.shape, fd_out.shape

In [None]:
import typing as tp


def derivative_longitude(field: xr.DataArray) -> xr.DataArray:
    latitude = field.coords["lat"]
    cos_theta = np.cos(np.deg2rad(latitude))
    # TODO(shoyer): use a custom calculation with roll() instead of
    # differentiate() to calculate rolling over 360 to 0 degrees properly.
    return field.differentiate("lon") / cos_theta / METERS_PER_DEGREE


def derivative_latitude(field: xr.DataArray) -> xr.DataArray:
    return field.differentiate("lat") / METERS_PER_DEGREE


def derivative_radius(field: xr.DataArray) -> xr.DataArray:
    return field.differentiate("height") / METERS_PER_DEGREE**2


def divergence(u: xr.DataArray, v: xr.DataArray) -> xr.DataArray:
    return derivative_longitude(u) + derivative_latitude(v)


def curl_k(u: xr.DataArray, v: xr.DataArray) -> xr.DataArray:
    return derivative_longitude(v) - derivative_latitude(u)


def _geostrophic_wind(
    geopotential: xr.DataArray,
) -> tp.Tuple[xr.DataArray, xr.DataArray]:
    omega = 7.292e-5  # radians / second
    coriolis_parameter = 2 * omega * np.sin(np.deg2rad(geopotential.coords["latitude"]))
    # Geostrophic wind is inf on the equator. We don't clip it to ensure that the
    # user makes an intentional choice about how handle these invalid values
    # (e.g., by evaluating over a region).
    return (
        -derivative_latitude(geopotential) / coriolis_parameter,
        +derivative_longitude(geopotential) / coriolis_parameter,
    )

In [None]:
class SphericalDomainCartesian2D(tp.NamedTuple):
    lon_coords: Float[Array, "Dx"]
    lat_coords: Float[Array, "Dy"]
    dlon: Float[Array, ""]
    dlat: Float[Array, ""]

    @classmethod
    def from_xarray(cls, ds: xr.DataArray):
        lon_coords = ds.lon.values
        lat_coords = ds.lat.values
        dlon = np.mean(lon_coords[:-1] - lon_coords[1:])
        dlat = np.mean(lat_coords[:-1] - lat_coords[1:])
        return cls(lon_coords, lat_coords, dlon, dlat)


class SphericalDomainRectilinear2D(tp.NamedTuple):
    lon_coords: Float[Array, "Dx"]
    lat_coords: Float[Array, "Dy"]
    dlon: Float[Array, ""]
    dlat: Float[Array, ""]

    @classmethod
    def from_xarray(cls, ds: xr.DataArray):
        lon_coords = ds.lon.values
        lat_coords = ds.lat.values
        dlon = np.mean(lon_coords[:-1] - lon_coords[1:])
        dlat = np.mean(lat_coords[:-1] - lat_coords[1:])
        return cls(lon_coords, lat_coords, dlon, dlat)

In [None]:
# import ee
# import xarray
# ee.Authenticate()
# ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com')

* `ECMWF/ERA5/DAILY` | "ee://NASA/GDDP-CMIP6"
* `HYCOM/sea_surface_elevation`
* `HYCOM/sea_temp_salinity`
* `HYCOM/sea_water_velocity`

In [None]:
# ic = ee.ImageCollection('NASA/GDDP-CMIP6').filterDate('2022-01-05', '2022-03-31')
# ds = xarray.open_dataset(ic, engine='ee', crs='EPSG:4326', scale=0.25)
# ds

In [None]:
domain = SphericalDomain2D.from_xarray(da)

In [None]:
def derivative_longitude_array_2D(
    u: Float[Array, "Dx Dy"], domain: SphericalDomain2D, *args, **kwargs
):
    latitude = domain.lat_coords
    cos_theta = jnp.cos(jnp.deg2rad(latitude))
    # TODO(shoyer): use a custom calculation with roll() instead of
    # differentiate() to calculate rolling over 360 to 0 degrees properly.
    dlon = fdx.difference(domain.lon_coords, axis=0, *args, **kwargs)

    dlon = dlon[..., None]
    du_dlon = fdx.difference(u, axis=0, step_size=dlon, *args, **kwargs)
    return du_dlon / cos_theta / METERS_PER_DEGREE

In [None]:
# tmp = jnp.pad(subset.values, pad_width=((1,1),(0,0)), mode="wrap")
dt_dx_jax = derivative_longitude_array_2D(subset.values, domain, method="central")

In [None]:
dt_dx_xr = derivative_longitude(da.transpose("time", "lon", "lat").isel(time=0))
dt_dx_xr

In [None]:
dt_dx_metpy = metpy.calc.geospatial_gradient(
    subset, latitude=subset.lat, longitude=subset.lon, return_only="df/dx"
)

```python
geospatial_derivative = lambda: ...
derivative_cartesian

```

In [None]:
fig, ax = plt.subplots()

pts = ax.imshow(
    np.abs(dt_dx_metpy.magnitude.T - dt_dx_jax.T), origin="lower", cmap="Reds"
)

plt.colorbar(pts)

plt.show()

In [None]:
fig, ax = plt.subplots()

pts = ax.imshow(
    np.abs(dt_dx_metpy.magnitude.T - dt_dx_xr.values.T), origin="lower", cmap="Reds"
)

plt.colorbar(pts)

plt.show()

In [None]:
dt_dx_xr.T.plot.pcolormesh()

In [None]:
np.gradient()

In [None]:
da.coords["lat"]

In [None]:
out = _d_dx(da)
out.isel(time=0).plot.imshow()