# Finite Difference

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".local"])

# append to path
sys.path.append(str(root))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

## Data

In [None]:
import xarray as xr
import pandas as pd

path_data = "/Volumes/EMANS_HDD/data/osse_oceanix/raw/sim/NATL60-CJM165_GULFSTREAM_ssh_y2013.1y.nc"
path_data = "/Volumes/EMANS_HDD/data/dc22b_osse/raw/dc_qg_train/dc_qg_train_y2013*.nc"

ds = xr.open_mfdataset(path_data)
# ds = xr.open_dataset(path_data, engine="netcdf4", decode_times=False).assign_coords(time=lambda ds: pd.to_datetime(ds.time))

In [None]:
ds

In [None]:
ds.lon

In [None]:
da = xr.DataArray(
    ds.ssh.values,
    coords={"time": ds.time.values, "lat": ds.nav_lat.values, "lon": ds.nav_lon.values},
)

In [None]:
da

In [None]:
ds

In [None]:
ds.isel(time=200).ssh

In [None]:
fig, ax = plt.subplots()
da.isel(time=200).plot(ax=ax, cmap="viridis")
plt.show()

In [None]:
import imageio


def load_fox():
    # FOX
    image_url = "https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg"
    img = imageio.imread(image_url)[..., :3] / 255.0
    c = [img.shape[0] // 2, img.shape[1] // 2]
    r = 256
    img = img[c[0] - r : c[0] + r, c[1] - r : c[1] + r]
    return img


def load_earth():
    # EARTH
    image_url = "https://i0.wp.com/thepythoncodingbook.com/wp-content/uploads/2021/08/Earth.png?w=301&ssl=1"
    img = imageio.imread(image_url)[..., :3] / 255.0
    # TODO: crop the image slightly
    return img


import skimage


def load_cameraman():
    img = skimage.data.camera() / 255.0
    return img[..., None]

In [None]:
img = load_fox()
img.shape

$$
\mathbf{x} \in \mathbb{R}
$$

In [None]:
x, y = [jnp.linspace(-1, 1, 50)] * 2
dx, dy = [x[1] - x[0]] * 2

X, Y = jnp.meshgrid(x, y, indexing="ij")

F1 = X**2 + Y**3  # -Y
F2 = X**4 + Y**3  # +X
F = jnp.stack([F1, F2], axis=0)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

ax.quiver(X, Y, F1, F2, color="k", alpha=0.5)

ax.set_aspect("equal", "box")
plt.show()

In [None]:
# x, y, z = [jnp.linspace(0,1,100)] * 3
# dx, dy, dz = x[1]-x[0], y[1]-y[0], z[1]-z[0]
# X, Y, Z = jnp.meshgrid(x,y,z,indexing="ij")
# F1 = X**2 + Y**3
# F2 = X**4

In [None]:
import jax.numpy as jnp

For a 2D function, $\boldsymbol{f}(x,y)$, the partial derivative is:

$$
\partial_x \boldsymbol{f}(\mathbf{x},\mathbf{y}) =
\lim_{\epsilon\rightarrow 0} \frac{\boldsymbol{f}(\mathbf{x}+\epsilon,\mathbf{y})-
\boldsymbol{f}(\mathbf{x},\mathbf{y})}{\epsilon}
$$

For discrete data, we can approximate this using finite differences:

$$
\partial_x f(x,y) \approx \frac{f(x+1,y)-f(x,y)}{1}
$$

## Laplacian

### Finite Difference

In [None]:
import serket as sk

# F1/dx: differentiate F1 wrt x
dF1dx = sk.fd.difference(F1, axis=0, step_size=dx, accuracy=10)
# np.testing.assert_allclose(dF1dx, 2*X.squeeze(), atol=1e-7)

dF2dy = sk.fd.difference(F2, axis=1, step_size=dy, accuracy=6)

dF1dx.shape, dF2dy.shape

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

ax.quiver(X, Y, dF1dx, dF2dy, color="k", alpha=0.5, label="FD")
ax.quiver(X, Y, 2 * X, 3 * Y**2, color="red", alpha=0.5, label="True")

plt.legend()
ax.set_aspect("equal", "box")
plt.show()

In [None]:
F.shape

In [None]:
dF = sk.fd.gradient(F1, step_size=(dx, dy), accuracy=6)


d2F = sk.fd.laplacian(F1, step_size=(dx, dy), accuracy=6)
dF.shape, d2F.shape

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

ax.quiver(X, Y, dF[0], dF[1], color="k", alpha=0.5, label="FD")
ax.quiver(X, Y, 2 * X, 3 * Y**2, color="red", alpha=0.5, label="True")

plt.legend()
ax.set_aspect("equal", "box")
plt.show()

In [None]:
dF1dx.shape, F1.shape

### Convolution

In [None]:
import jax


def gradient_1D(array, step=1.0):
    kernel = jnp.array([0.5, 0, 0.5])

    lhs = array[np.newaxis, np.newaxis, Ellipsis]
    rhs = kernel[np.newaxis, np.newaxis, Ellipsis] / step**2

    result = jax.lax.conv(lhs, rhs, window_strides=(1,) * array.ndim, padding="SAME")
    squeezed = np.squeeze(result, axis=(0, 1))

    return squeezed


def gradient_2D(array, step=1.0):

    kernel = jnp.array([[0, 0, 0], [0, -1, 0], [0, 1, 0]])

    lhs = array[np.newaxis, np.newaxis, Ellipsis]
    rhs = kernel[np.newaxis, np.newaxis, Ellipsis] / step**2

    result = jax.lax.conv(lhs, rhs, window_strides=(1,) * array.ndim, padding="SAME")
    squeezed = np.squeeze(result, axis=(0, 1))

    return squeezed


def laplacian_1D(array, step=1.0):
    kernel = jnp.array([1, -2, 1])

    lhs = array[np.newaxis, np.newaxis, Ellipsis]
    rhs = kernel[np.newaxis, np.newaxis, Ellipsis] / step**2

    result = jax.lax.conv(lhs, rhs, window_strides=(1,) * array.ndim, padding="SAME")
    squeezed = np.squeeze(result, axis=(0, 1))

    return squeezed


def laplacian_2D(array, step=1.0):

    kernel = jnp.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]])

    lhs = array[np.newaxis, np.newaxis, Ellipsis]
    rhs = kernel[np.newaxis, np.newaxis, Ellipsis] / step**2

    result = jax.lax.conv(lhs, rhs, window_strides=(1,) * array.ndim, padding="SAME")
    squeezed = np.squeeze(result, axis=(0, 1))

    return squeezed

In [None]:
img.shape

In [None]:
d2F1dx2 = gradient_2D(F1)
d2F2dy2 = gradient_2D(F2)

d2F1dx2.shape, d2F2dy2.shape

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

ax.quiver(X, Y, d2F1dx2, d2F2dy2, color="k", alpha=0.5, label="FD")
# ax.quiver(X, Y, 2*X, 3*Y**2, color="red", alpha=0.5, label="True")

plt.legend()
ax.set_aspect("equal", "box")
plt.show()

### Kernex

First Order (One-Side, Backwards), 2-Point Stencil:

$$
f'=\frac{f_i-f_{i-1}}{\Delta x}
$$

---

First Order (One-Side, Forward), 2-Point Stencil:

$$
f' = \frac{f_{i+1}-f_i}{\Delta x}
$$

---

Second Order (Two-Sided, Cenetered), 3-Point Stencil:

$$
f' = \frac{f_{i+1}-f_{i-1}}{2\Delta x}
$$

---


Fourth-Order (2-Sided,Centered), 5-Point Stencil:

$$
f' = \frac{-f_{i+2}+8f_{i+1}-8f_{i-1}+f_{i-2}}{12\Delta x}
$$

5-Point Stencil

$$
(x-2h,x-h,x,x+h,x+2h)
$$

In [None]:
@kex.kmap(kernel_size=(3, 3), padding="valid", relative=True)
def fd_forward(x):
    return x[0, 0] - x[0, 1]


@kex.kmap(kernel_size=(3, 3), padding="valid", relative=True)
def stencil(x):
    return 0.25 * (x[0, 1] + x[1, 0] + x[0, -1] + x[-1, 0])


@kex.kmap(kernel_size=(3, 3), padding="valid", relative=True)
def sobel_x(x):
    return (
        1 * x[1, -1]
        + 0 * x[1, 0]
        + -1 * x[1, 1]
        + 2 * x[0, -1]
        + 0 * x[0, 0]
        + -2 * x[0, 1]
        + 1 * x[-1, 1]
        + 0 * x[-1, 0]
        + -1 * x[-1, 1]
    )


@kex.kmap(kernel_size=(3, 3), padding="valid", relative=True)
def sobel_y(x):
    return (
        1 * x[1, -1]
        + 2 * x[1, 0]
        + 1 * x[1, 1]
        + 0 * x[0, -1]
        + 0 * x[0, 0]
        + 0 * x[0, 1]
        + -1 * x[-1, 1]
        + -2 * x[-1, 0]
        + -1 * x[-1, 1]
    )

5-Point Stencil, laplacian

$$
\frac{f(x-h,y)+f(x+h,y)+f(x,y-h)+f(x,y+h)-4f(x,y)}{h^2}
$$

In [None]:
def laplacian_1d(window_size):
    filter_1d = jnp.ones(window_size)
    filter_1d = filter_1d.at[window_size // 2].set(1 - window_size)
    return filter_1d

In [None]:
lap_filter = laplacian_1d(3)
lap_filter

In [None]:
import kernex as kex


@kex.kmap(kernel_size=(3, 3), padding="valid", relative=True)
def laplacian(x):
    return (
        0 * x[1, -1]
        + 1 * x[1, 0]
        + 0 * x[1, 1]
        + 1 * x[0, -1]
        + -4 * x[0, 0]
        + 1 * x[0, 1]
        + 0 * x[-1, 1]
        + 1 * x[-1, 0]
        + 0 * x[-1, 1]
    )


array_ones = jnp.ones([10, 10])

laplacian(array_ones)