In [4]:
import jax.numpy as jnp
from jax import Array
from pyutils.jax.differential import compute_forward_derivative, kron, speye

In [5]:
compute_forward_derivative(jnp.array([1.0, 2.0, 3.0, 5.0]))

BCOO(float32[4, 4], nse=7)

In [6]:
x = jnp.linspace(0, 10, 100)
y = jnp.linspace(0, 5, 101)
z = jnp.linspace(-2, 2, 40)

## 1D

## 2D

In [None]:
import jax
from pyutils.jax import vvmap
from pyutils.jax import build_differential_matrices

def f(x, y):
    return jnp.sin(x) * jnp.cos(y)

x = jnp.linspace(0, 10, 100)
y = jnp.linspace(0, 5, 101)


In [None]:
xx, yy = jnp.meshgrid(x, y, indexing='ij')
f_grid = f(xx, yy)
f_x_analytiacl = vvmap(jax.grad(f, argnums=0))(xx,yy)
f_y_analytical = vvmap(jax.grad(f, argnums=1))(xx,yy)
f_xx_analytical = vvmap(jax.grad(jax.grad(f, argnums=0), argnums=0))(xx,yy)
f_yy_analytical = vvmap(jax.grad(jax.grad(f, argnums=1), argnums=1))(xx,yy)
f_xy_analytical = vvmap(jax.grad(jax.grad(f, argnums=1), argnums=0))(xx,yy)

In [8]:
D_x_forward, D_x_backward, \
    D_y_forward, D_y_backward,\
        D_xx, D_yy, D_xi, D_eta = build_differential_matrices(x,y)

## A simple example

In [None]:
from pyutils.jax.differential.differential import compute_D_x



def f(x, y):
    return jnp.sin(x) * jnp.cos(y)

xx, yy = jnp.meshgrid(x, y, indexing='ij')
f_grid = f(xx, yy)

D_x, D_y = build_differential([x, y])

Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)
Dyf = (D_y @ f_grid.flatten()).reshape(f_grid.shape)
Dxf.shape

(100, 101)

In [6]:
from pyutils.jax.differential import compute_D_xi, compute_D_eta
D_xi = compute_D_xi(x, y)
D_eta = compute_D_eta(x, y)
hx = x[1] - x[0]
hy = y[1] - y[0]
D_xy = (hx**2 + hy**2) / (4 * hx * hy) * (D_xi - D_eta)
f_xy = (D_xy @ f_grid.flatten()).reshape(f_grid.shape)
f_xy_exact = -jnp.cos(xx)*jnp.sin(yy)


f_xy

Array([[ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        , -0.04961872, -0.09911537, ...,  0.975373  ,
         0.9649062 ,  0.        ],
       [ 0.        , -0.0488596 , -0.09759521, ...,  0.96043277,
         0.9501257 ,  0.        ],
       ...,
       [ 0.        ,  0.04644012,  0.0927639 , ..., -0.9128742 ,
        -0.9030776 ,  0.        ],
       [ 0.        ,  0.04437065,  0.08863068, ..., -0.8721881 ,
        -0.8628278 ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)

In [7]:
f_xy_exact

Array([[-0.        , -0.04997917, -0.09983341, ...,  0.9824527 ,
         0.9719031 ,  0.9589243 ],
       [-0.        , -0.04972441, -0.09932454, ...,  0.97744495,
         0.9669491 ,  0.9540365 ],
       [-0.        , -0.04896275, -0.09780313, ...,  0.9624728 ,
         0.9521377 ,  0.9394229 ],
       ...,
       [ 0.        ,  0.04653884,  0.09296137, ..., -0.9148255 ,
        -0.90500206, -0.8929167 ],
       [ 0.        ,  0.04446411,  0.08881709, ..., -0.874042  ,
        -0.86465645, -0.85310984],
       [ 0.        ,  0.04193609,  0.08376738, ..., -0.8243481 ,
        -0.8154962 , -0.8046061 ]], dtype=float32)

Array([[ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ],
       [ 0.        , -0.04961872, -0.09911537, ...,  0.975373  ,
         0.9649062 ,  0.        ],
       [ 0.        , -0.0488596 , -0.09759521, ...,  0.96043277,
         0.9501257 ,  0.        ],
       ...,
       [ 0.        ,  0.04644012,  0.0927639 , ..., -0.9128742 ,
        -0.9030776 ,  0.        ],
       [ 0.        ,  0.04437065,  0.08863068, ..., -0.8721881 ,
        -0.8628278 ,  0.        ],
       [ 0.        ,  0.        ,  0.        , ...,  0.        ,
         0.        ,  0.        ]], dtype=float32)

In [26]:
f_xy_exact - f_xy

Array([[-0.00000000e+00, -4.99791652e-02, -9.98334140e-02, ...,
         9.82452691e-01,  9.71903086e-01,  9.58924294e-01],
       [-0.00000000e+00, -1.05690211e-04, -2.09167600e-04, ...,
         2.07191706e-03,  2.04288960e-03,  9.54036474e-01],
       [-0.00000000e+00, -1.03157014e-04, -2.07915902e-04, ...,
         2.04002857e-03,  2.01201439e-03,  9.39422905e-01],
       ...,
       [ 0.00000000e+00,  9.87201929e-05,  1.97470188e-04, ...,
        -1.95127726e-03, -1.92445517e-03, -8.92916679e-01],
       [ 0.00000000e+00,  9.34600830e-05,  1.86413527e-04, ...,
        -1.85388327e-03, -1.82867050e-03, -8.53109837e-01],
       [ 0.00000000e+00,  4.19360921e-02,  8.37673768e-02, ...,
        -8.24348092e-01, -8.15496206e-01, -8.04606080e-01]],      dtype=float32)

In [None]:
from pyutils.jax.differential import compute_D_x, compute_D_xy
D_x = compute_D_x(x,y,"forward")
D_xy = compute_D_xy(x,y)

In [None]:
mask = D_xy.indices[:, 0] == 3
D_xy.indices[mask], D_xy.data[mask]

(Array([[  3, 105],
        [  3, 103],
        [  3,   4],
        [  3,   2]], dtype=int32),
 Array([ 99.00001, -99.00001, -99.00001,  99.00001], dtype=float32))

ImportError: cannot import name 'compute_D_xi' from 'pyutils.jax.differential' (/Users/marcdelabarrera/IESE Dropbox/Marc de la Barrera/repos/pyutils/src/pyutils/jax/differential/__init__.py)

In [None]:
Vxy_approx = (hx**2 + hy**2) * ((D_xi @ fv) - (D_eta @ fv)) / (4 * hx * hy)


In [36]:
D_xy

BCOO(float32[10100, 10100], nse=40400)

In [32]:
mask = D_x.indices[:, 0] == 3
D_x.indices[mask], D_x.data[mask]

(Array([[  3, 104],
        [  3,   3]], dtype=int32),
 Array([ 9.900001, -9.900001], dtype=float32))

In [16]:
compute_forward_derivative(x)

BCOO(float32[100, 100], nse=199)

In [17]:
speye(len(y))

BCOO(float32[101, 101], nse=101)

In [18]:
kron(compute_forward_derivative(x), speye(y.shape[0]))

BCOO(float32[10100, 10100], nse=20099)

In [27]:
import jax
D_x = build_differential([x])
Dxf_vmap = jax.vmap(lambda f: D_x@f, in_axes = 1, out_axes = 1)(f_grid)

In [30]:
%%timeit
Dxf_vmap = jax.vmap(lambda f: D_x@f, in_axes = 1, out_axes = 1)(f_grid)

559 μs ± 8.11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [33]:
%%timeit
Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)

224 μs ± 10.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## 3D

In [34]:
x = jnp.linspace(0, 10, 100)
y = jnp.linspace(0, 5, 101)
z = jnp.linspace(-2, 2, 40)

def f(x, y, z):
    return jnp.sin(x) * jnp.cos(y)+ z**2

xx, yy, zz = jnp.meshgrid(x, y, z, indexing='ij')
f_grid = f(xx, yy, zz)

D_x, D_y, D_z = build_differential([x, y, z])

Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)
Dyf = (D_y @ f_grid.flatten()).reshape(f_grid.shape)
Dzf = (D_z @ f_grid.flatten()).reshape(f_grid.shape)

In [59]:
%%timeit
Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)

664 μs ± 4.06 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [57]:
D_x, D_y, D_z = build_differential([x, y, z])
@jax.jit
def foo(f):
    return D_x @ f

In [58]:
%%timeit
Dxf = (D_x @ f_grid.flatten()).reshape(f_grid.shape)

668 μs ± 11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [40]:
D_x = build_differential([x])

In [50]:
@jax.jit
def foo(f):
    return D_x@f

In [56]:
%%timeit
Dxf_vmap = jax.vmap(jax.vmap(foo, in_axes = 1, out_axes = 1), in_axes = 1, out_axes = 1)(f_grid)

448 μs ± 13.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [53]:
%%timeit
Dxf_vmap = jax.vmap(jax.vmap(lambda f: D_x@f, in_axes = 1, out_axes = 1), in_axes = 1, out_axes = 1)(f_grid)

1.11 ms ± 11.9 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
