In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

### Simple function

In [2]:
def linear(x, w, b):
    return jnp.dot(x,w) + b

In [3]:
x = jnp.array([1,2,3,4,5], dtype=jnp.float32)
w = jnp.array([0.5, 0.4, 0.6, 0.8, 0.9], dtype=jnp.float32)
b = jnp.array([-1,-3,-6,-19,-2], dtype=jnp.float32)



In [4]:
linear(x, w, b)

DeviceArray([ 9.8,  7.8,  4.8, -8.2,  8.8], dtype=float32)

### Just in time compilation

In [5]:
optimized_linear = jit(linear)

In [6]:
optimized_linear(x, w, b) 

DeviceArray([ 9.8,  7.8,  4.8, -8.2,  8.8], dtype=float32)

In [7]:
%timeit -n 100 linear(x, w, b).block_until_ready()

244 µs ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit -n 100 optimized_linear(x, w, b).block_until_ready()

25.1 µs ± 2.04 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Other jit forms

In [9]:
plus_one = jit(lambda x: x + 1)

@jit
def plus_two(x):
    return x + 2
    
plus_one(7), plus_two(7)

(DeviceArray(8, dtype=int32), DeviceArray(9, dtype=int32))