In [13]:
import jax.numpy as jnp

# construct primitive states as a liner profile with jump
cell_centers = jnp.arange(0.0, 1.0, 0.05) ** 2

cell_distances_left = cell_centers[1:-1] - cell_centers[:-2]
cell_distances_right = cell_centers[2:] - cell_centers[1:-1]

primitive_states = jnp.zeros((3, len(cell_centers)))
primitive_states = primitive_states.at[:, :].set(jnp.where(cell_centers < 0.5, 1.0, 0.0))
primitive_states = primitive_states.at[:, :].add(cell_centers)

# formulation 1:
a = (primitive_states[:, 1:-1] - primitive_states[:, :-2]) / cell_distances_left
b = (primitive_states[:, 2:] - primitive_states[:, 1:-1]) / cell_distances_right
g = jnp.where(a != 0, jnp.divide(b, a), jnp.zeros_like(a))
slope_limited = jnp.maximum(0, jnp.minimum(1, g)) # minmod
limited_gradients_1 = slope_limited * a

# formulation 2:
def _minmod(a, b):
    return 0.5 * (jnp.sign(a) + jnp.sign(b)) * jnp.minimum(jnp.abs(a), jnp.abs(b))
limited_gradients_2 = _minmod(
    (primitive_states[:, 1:-1] - primitive_states[:, :-2]) / cell_distances_left,
    (primitive_states[:, 2:] - primitive_states[:, 1:-1]) / cell_distances_right
)

print(limited_gradients_1)
print(limited_gradients_2)

# check that the two formulations are equivalent
assert jnp.allclose(limited_gradients_1, limited_gradients_2)

[[ 0.99999106  0.99999106  0.9999955   0.9999955   1.0000011   0.9999975
   0.9999975   1.0000002   1.0000004   1.          1.          0.9999995
   0.9999995   0.         -0.          1.          1.          1.        ]
 [ 0.99999106  0.99999106  0.9999955   0.9999955   1.0000011   0.9999975
   0.9999975   1.0000002   1.0000004   1.          1.          0.9999995
   0.9999995   0.         -0.          1.          1.          1.        ]
 [ 0.99999106  0.99999106  0.9999955   0.9999955   1.0000011   0.9999975
   0.9999975   1.0000002   1.0000004   1.          1.          0.9999995
   0.9999995   0.         -0.          1.          1.          1.        ]]
[[0.99999106 0.99999106 0.9999955  0.9999955  1.0000011  0.9999975
  0.9999975  1.0000002  1.0000004  1.         1.         0.9999995
  0.9999995  0.         0.         1.         1.         1.        ]
 [0.99999106 0.99999106 0.9999955  0.9999955  1.0000011  0.9999975
  0.9999975  1.0000002  1.0000004  1.         1.         0.9999995