In [26]:
import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp  # noqa: E402
import numpy as np  # noqa: E402
import scipy.optimize  # noqa: E402

DZ = 0.01
ZVALS = np.linspace(0, 3, 301)[:-1] + DZ 
ZBINS = np.concatenate([np.zeros(1), ZVALS])
ZLOW = ZBINS[:-1]
ZHIGH = ZBINS[1:]

In [294]:
@jax.jit
def sompz_integral(y, x, low, high):
    """Integrate a linearly interpolated set of values 
    on a grid in a range (low, high)."""
    low = jnp.minimum(x[-1], jnp.maximum(low, x[0]))
    high = jnp.minimum(x[-1], jnp.maximum(high, x[0]))
    low_ind = jnp.digitize(low, x)
    high_ind = jnp.digitize(high, x, right=True)
    dx = x[1:] - x[:-1]

    # high point not in same bin as low point
    not_in_single_bin = high_ind > low_ind  

    # at least one bin between high point and low point
    has_cen_contribution = high_ind - 1 > low_ind
    
    # fractional bit on the left
    ileft = jax.lax.select(
        not_in_single_bin,
        (y[low_ind-1] + y[low_ind]) / 2.0 * (1.0 - (low - x[low_ind-1]) / dx[low_ind-1]) * dx[low_ind-1],
        (y[low_ind-1] + y[low_ind]) / 2.0 * (high - low),
    )

    # fractional bit on the right
    iright = jax.lax.select(
        not_in_single_bin,
        (y[high_ind-1] + y[high_ind]) / 2.0 * (high - x[high_ind-1]),
        0.0
    )

    # central bits
    yint = (y[1:] + y[:-1]) / 2.0 * dx
    yind = jnp.arange(yint.shape[0])
    msk = (yind >= low_ind) & (yind < high_ind - 1)
    icen = jax.lax.select(
        jnp.any(msk),
        jnp.sum(jnp.where(
            msk,
            yint,
            jnp.zeros_like(yint),
        )),
        0.0,
    )

    return (ileft + icen + iright)

In [297]:
y = np.arange(10, dtype=float) * 0 + 1
dx = 1
x = np.arange(10) * dx

sompz_integral(y, x, 3, 5.7)

Array(2.7, dtype=float64)

In [15]:
jnp.arange(10)

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)

In [53]:
jnp.digitize(0.5, jnp.arange(10))

Array(1, dtype=int32)

In [45]:
jnp.searchsorted(jnp.arange(10), 3, side="right")-1, jnp.searchsorted(jnp.arange(10), 3.5, side="right")-1

(Array(3, dtype=int32), Array(3, dtype=int32))

In [46]:
jnp.searchsorted(jnp.arange(10), 3.5, side="left"), jnp.searchsorted(jnp.arange(10), 4, side="left")

(Array(4, dtype=int32), Array(4, dtype=int32))

In [39]:
-1

Array(3, dtype=int32)

In [27]:
ZBINS

array([0.  , 0.  , 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09,
       0.1 , 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2 ,
       0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3 , 0.31,
       0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4 , 0.41, 0.42,
       0.43, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5 , 0.51, 0.52, 0.53,
       0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6 , 0.61, 0.62, 0.63, 0.64,
       0.65, 0.66, 0.67, 0.68, 0.69, 0.7 , 0.71, 0.72, 0.73, 0.74, 0.75,
       0.76, 0.77, 0.78, 0.79, 0.8 , 0.81, 0.82, 0.83, 0.84, 0.85, 0.86,
       0.87, 0.88, 0.89, 0.9 , 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97,
       0.98, 0.99, 1.  , 1.01, 1.02, 1.03, 1.04, 1.05, 1.06, 1.07, 1.08,
       1.09, 1.1 , 1.11, 1.12, 1.13, 1.14, 1.15, 1.16, 1.17, 1.18, 1.19,
       1.2 , 1.21, 1.22, 1.23, 1.24, 1.25, 1.26, 1.27, 1.28, 1.29, 1.3 ,
       1.31, 1.32, 1.33, 1.34, 1.35, 1.36, 1.37, 1.38, 1.39, 1.4 , 1.41,
       1.42, 1.43, 1.44, 1.45, 1.46, 1.47, 1.48, 1.

In [197]:
jnp.minimum?

[0;31mSignature:[0m      [0mjnp[0m[0;34m.[0m[0mminimum[0m[0;34m([0m[0mx1[0m[0;34m,[0m [0mx2[0m[0;34m,[0m [0;34m/[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mCall signature:[0m [0mjnp[0m[0;34m.[0m[0mminimum[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           PjitFunction
[0;31mString form:[0m    <PjitFunction of <function jax.numpy.minimum at 0x11e5be170>>
[0;31mFile:[0m           ~/mambaforge/envs/work/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py
[0;31mDocstring:[0m     
Element-wise minimum of array elements.

LAX-backend implementation of :func:`numpy.minimum`.

*Original docstring below.*

Compare two arrays and return a new array containing the element-wise
minima. If one of the elements being compared is a NaN, then that
element is returned. If both elements are NaNs then the first is
returned. The latter distinction is important for complex NaN