In [1]:
import jax.numpy as jnp
from jax import Array
from pyutils.jax.differential import compute_forward_derivative, kron, speye

In [2]:
compute_forward_derivative(jnp.array([1.0, 2.0, 3.0, 5.0]))

BCOO(float32[4, 4], nse=7)

In [5]:
x = jnp.linspace(0, 10, 100)
y = jnp.linspace(0, 5, 101)
z = jnp.linspace(-2, 2, 40)

In [6]:
def build_differential(axis:list[Array]):
    if len(axis) == 1:
        return compute_forward_derivative(axis[0])
    elif len(axis) == 2:
        return (kron(compute_forward_derivative(axis[0]), speye(axis[1].shape[0])),
                kron(speye(axis[0].shape[0]), compute_forward_derivative(axis[1])))
    elif len(axis) == 3:
        return (kron(kron(compute_forward_derivative(axis[0]), speye(axis[1].shape[0])), speye(axis[2].shape[0])),
                kron(kron(speye(axis[0].shape[0]), compute_forward_derivative(axis[1])), speye(axis[2].shape[0])),
                kron(kron(speye(axis[0].shape[0]), speye(axis[1].shape[0])), compute_forward_derivative(axis[2])))
    else:   
        raise NotImplementedError("Differential operator for more than 3 dimensions is not implemented yet.")


build_differential([x, y, z])

(BCOO(float32[404000, 404000], nse=803960),
 BCOO(float32[404000, 404000], nse=804000),
 BCOO(float32[404000, 404000], nse=797900))

## 1D

## 2D

In [32]:
x = jnp.linspace(0, 10, 100)
y = jnp.linspace(0, 5, 101)
def f(x, y):
    return jnp.sin(x) * jnp.cos(y)

xx, yy = jnp.meshgrid(x, y, indexing='ij')
f_grid = f(xx, yy)

D_x, D_y = build_differential([x, y])

Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)
Dyf = (D_y @ f_grid.flatten()).reshape(f_grid.shape)
Dxf.shape

(100, 101)

In [27]:
import jax
D_x = build_differential([x])
Dxf_vmap = jax.vmap(lambda f: D_x@f, in_axes = 1, out_axes = 1)(f_grid)

In [30]:
%%timeit
Dxf_vmap = jax.vmap(lambda f: D_x@f, in_axes = 1, out_axes = 1)(f_grid)

559 μs ± 8.11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [33]:
%%timeit
Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)

224 μs ± 10.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## 3D

In [34]:
x = jnp.linspace(0, 10, 100)
y = jnp.linspace(0, 5, 101)
z = jnp.linspace(-2, 2, 40)

def f(x, y, z):
    return jnp.sin(x) * jnp.cos(y)+ z**2

xx, yy, zz = jnp.meshgrid(x, y, z, indexing='ij')
f_grid = f(xx, yy, zz)

D_x, D_y, D_z = build_differential([x, y, z])

Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)
Dyf = (D_y @ f_grid.flatten()).reshape(f_grid.shape)
Dzf = (D_z @ f_grid.flatten()).reshape(f_grid.shape)

In [59]:
%%timeit
Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)

664 μs ± 4.06 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [57]:
D_x, D_y, D_z = build_differential([x, y, z])
@jax.jit
def foo(f):
    return D_x @ f

In [58]:
%%timeit
Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)

668 μs ± 11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [40]:
D_x = build_differential([x])

In [50]:
@jax.jit
def foo(f):
    return D_x@f

In [56]:
%%timeit
Dxf_vmap = jax.vmap(jax.vmap(foo, in_axes = 1, out_axes = 1), in_axes = 1, out_axes = 1)(f_grid)

448 μs ± 13.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [53]:
%%timeit
Dxf_vmap = jax.vmap(jax.vmap(lambda f: D_x@f, in_axes = 1, out_axes = 1), in_axes = 1, out_axes = 1)(f_grid)

1.11 ms ± 11.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
