Well, I guess you use jax not just for fun but to actually make your code faster. 
The way to do it is to jit around your functions, but obviously it is not always as easy as one would like.

In [5]:
import jax.numpy as jnp
from jax import random, jit
from jax.ops import index_update, index

import numpy as onp

#### Slicing and jitting

In [18]:
# This is the same example where jit didnt like inplace assignment, so we had a work around
def neg_part_of_the_array(n):
    an_array = jnp.arange(0, 10, 1)
    # an_array[:n] *= -1 - jax does not like item assignment
    an_array = jnp.append(an_array[:n], -1 * an_array[n:])
    return an_array

In [25]:
my_array = neg_part_of_the_array(3)
print(my_array)

[ 0  1  2 -3 -4 -5 -6 -7 -8 -9]


One would like to jit it but, nope.

neg_part_of_the_array_jit = jit(neg_part_of_the_array)

my_array_jit = neg_part_of_the_array_jit(3)

In [22]:
def a_work_around(n):
    an_array = jnp.arange(0, 10, 1)
    m = an_array.shape[0] - n
    my_ones = jnp.append(jnp.ones([n]), - 1 * jnp.ones([m]), axis = 0)
    
    return my_ones * an_array

In [24]:
my_array_work_around = a_work_around(3)
print(my_array_work_around)

[ 0.  1.  2. -3. -4. -5. -6. -7. -8. -9.]


can we jit it now?

a_work_around_jit = jit(a_work_around)

my_array_work_around_jit = a_work_around_jit(3)

In [33]:
# Not so fast... We need to jit it with static_argnums
a_work_around_jit = jit(a_work_around, static_argnums=0)

In [35]:
my_array_work_around = a_work_around_jit(3)
print(my_array_work_around)

[ 0.  1.  2. -3. -4. -5. -6. -7. -8. -9.]


#### Index conditioning and jitting

In [38]:
def bigger_than_zero(arr):
    return arr[jnp.where(arr > 0)]

In [39]:
arr = jnp.array([1,-1,2,-2])
arr_positive = bigger_than_zero(arr)
print(arr)

[ 1 -1  2 -2]


Can we jit it?

bigger_than_zero_jit = jit(bigger_than_zero)

arr_positive_jit = bigger_than_zero_jit(arr)

In [42]:
# of course not! jax really doesnt like boolean indexing when jiting.

def bigger_than_zero_that_will_jit(arr):
    my_nans = jnp.nan * jnp.ones(arr.shape)
    return jnp.where(arr > 0, arr, my_nans)

In [43]:
bigger_than_zero_that_will_jit_jit = jit(bigger_than_zero_that_will_jit)

In [45]:
arr_positive_jit = bigger_than_zero_that_will_jit_jit(arr)
print(arr_positive_jit)

[ 1. nan  2. nan]


As you can see, if you wanna jit you must make sure your array preserves its shape. You can put inf/zeros/nans to mark the indices you dont want or any other value that makes sense for the computation you wanna do later.
For example, if you wanna sum the array putting 0s instead of the indices you wanna filter out makes sense.