<a href="https://colab.research.google.com/github/hsudhakaran/test_jax/blob/main/Jax_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax.numpy as jnp
import jax

In [None]:
def f(x):
    return 4*x**3 + 3*x**2 + 2*x + 1
jax.make_jaxpr(f)(2.0)

{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=3] a
    c:f32[] = mul 4.0 b
    d:f32[] = integer_pow[y=2] a
    e:f32[] = mul 3.0 d
    f:f32[] = add c e
    g:f32[] = mul 2.0 a
    h:f32[] = add f g
    i:f32[] = add h 1.0
  in (i,) }

In [None]:
grad_f = jax.grad(f)
jax.make_jaxpr(grad_f)(2.0)

{ lambda ; a:f32[]. let
    b:f32[] = integer_pow[y=3] a
    c:f32[] = integer_pow[y=2] a
    d:f32[] = mul 3.0 c
    e:f32[] = mul 4.0 b
    f:f32[] = integer_pow[y=2] a
    g:f32[] = integer_pow[y=1] a
    h:f32[] = mul 2.0 g
    i:f32[] = mul 3.0 f
    j:f32[] = add e i
    k:f32[] = mul 2.0 a
    l:f32[] = add j k
    _:f32[] = add l 1.0
    m:f32[] = mul 2.0 1.0
    n:f32[] = mul 3.0 1.0
    o:f32[] = mul n h
    p:f32[] = add_any m o
    q:f32[] = mul 4.0 1.0
    r:f32[] = mul q d
    s:f32[] = add_any p r
  in (s,) }

In [None]:
jax.grad(f)(2.0)

Array(62., dtype=float32, weak_type=True)

In [None]:
def matrix_mul(a, b):
    return jnp.matmul(a, b)
key = jax.random.PRNGKey(42)
a = jax.random.normal(key, shape=(1000, 5000))
b = jax.random.normal(key, shape=(5000, 1000))
jax.make_jaxpr(matrix_mul)(a, b)

{ lambda ; a:f32[1000,5000] b:f32[5000,1000]. let
    c:f32[1000,1000] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a b
  in (c,) }

In [None]:
# Normal computation
%timeit -n5 matrix_mul(a, b).block_until_ready()

The slowest run took 31.46 times longer than the fastest. This could mean that an intermediate result is being cached.
14.6 ms ± 27.5 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [None]:
jit_matrix_mul = jax.jit(matrix_mul)
jax.make_jaxpr(jit_matrix_mul)(a, b)

{ lambda ; a:f32[1000,5000] b:f32[5000,1000]. let
    c:f32[1000,1000] = pjit[
      jaxpr={ lambda ; d:f32[1000,5000] e:f32[5000,1000]. let
          f:f32[1000,1000] = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] d e
        in (f,) }
      name=matrix_mul
    ] a b
  in (c,) }

In [None]:
# warmup
warmup_results = jit_matrix_mul(a, b)
# ⚡️ speed em up!
%timeit -n5 jit_matrix_mul(a, b).block_until_ready()

2.3 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [None]:
def f_def(x):
    return x*x

g_def = jax.vmap(f_def)
x_test = jnp.array([2,4,6])
%timeit -n5 g_def(a).block_until_ready()
jitted_g = jax.jit(g_def)
jitted_g(b)
%timeit -n5 jitted_g(a).block_until_ready()

1.32 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)
The slowest run took 22.53 times longer than the fastest. This could mean that an intermediate result is being cached.
1.27 ms ± 2.34 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
import jax
jax.devices()

RuntimeError: ignored