<center>
    <h1>JaxTon</h1>
    <i>💯 JAX exercises</i>
    <br>
    <br>
    <a href='https://github.com/vopani/jaxton/blob/master/LICENSE'>
        <img src='https://img.shields.io/badge/license-Apache%202.0-blue.svg?logo=apache'>
    </a>
    <a href='https://github.com/vopani/jaxton'>
        <img src='https://img.shields.io/github/stars/vopani/jaxton?color=yellowgreen&logo=github'>
    </a>
    <a href='https://twitter.com/vopani'>
        <img src='https://img.shields.io/twitter/follow/vopani'>
    </a>
</center>

<center>
    This is Set 4: Just-In-Time (JIT) Compilation (Exercises 31-40) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>

**Prerequisites**

* The configuration of jax should be set as shown in the code snippet below in order to use TPUs.
* A sample function `cube` will be used for the exercises.

In [None]:
!python3 -m pip install jax

In [1]:
import jax
import jax.numpy as jnp
import os
import requests

try:
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
except:
    pass

jax.devices()

[GpuDevice(id=0, process_index=0)]

In [2]:
def cube(x):
    return x**3

cube(2.1)

9.261000000000001

**Exercise 31: JIT-compile the `cube` function and assign it to `cube_jit`**

In [3]:
cube_jit = jax.jit(cube)
cube_jit

<CompiledFunction of <function cube at 0x7fa81eee2550>>

**Exercise 32: Display execution time of `cube_jit` for first run (with overhead) with input=10.24**

In [5]:
%%time
cube_jit(10.24)

CPU times: user 47.7 ms, sys: 55.7 ms, total: 103 ms
Wall time: 133 ms


DeviceArray(1073.7418, dtype=float32, weak_type=True)

**Exercise 33: Display execution time of `cube_jit` for second run (without overhead) with input=10.24**

In [6]:
%%time
cube_jit(10.24)

CPU times: user 1.6 ms, sys: 10 ms, total: 11.6 ms
Wall time: 11.1 ms


DeviceArray(1073.7418, dtype=float32, weak_type=True)

**Exercise 34: Run `cube_jit` with input=10.24 and assign it to `cube_value`**

In [8]:
cube_value = cube_jit(10.24)
cube_value

DeviceArray(1073.7418, dtype=float32, weak_type=True)

**Exercise 35: Run `cube_jit` with jit disabled and input=10.24 and assign it to `cube_value_nojit`**

In [9]:
with jax.disable_jit():
    cube_value_nojit = cube_jit(10.24)

cube_value_nojit

1073.7418240000002

**Exercise 36: Evaluate the shape of `cube_jit` with input=10.24 and assign it to `cube_shape`**

In [12]:
cube_shape = jax.eval_shape(cube_jit, 10.24)
cube_shape

ShapeDtypeStruct(shape=(), dtype=float32)

**Exercise 37: Create the jaxpr of `cube_jit` with input=10.24 and assign it to `cube_jaxpr`**

In [14]:
cube_jaxpr = jax.make_jaxpr(cube_jit)(10.24)
cube_jaxpr

{ lambda ; a:f32[]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[]. let d:f32[] = integer_pow[y=3] c in (d,) }
      name=cube
    ] a
  in (b,) }

**Exercise 38: Assign the XLA computation of `cube_jit` with input=10.24 to `cube_xla` and print it's XLA HLO text**

In [16]:
cube_xla = jax.xla_computation(cube_jit)(10.24)
print(cube_xla.as_hlo_text())

HloModule xla_computation_cube__1.10

jit_cube__1.3 {
  constant.5 = pred[] constant(false)
  parameter.4 = f32[] parameter(0)
  multiply.6 = f32[] multiply(parameter.4, parameter.4)
  ROOT multiply.7 = f32[] multiply(parameter.4, multiply.6)
}

ENTRY xla_computation_cube__1.10 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  call.8 = f32[] call(parameter.1), to_apply=jit_cube__1.3
  ROOT tuple.9 = (f32[]) tuple(call.8)
}




**Exercise 39: Use the name `jaxton_cube_fn` internally for the `cube_jit` function and assign the named function to `cube_named_jit`**

In [17]:
cube_named_jit = jax.named_call(cube_jit, name='jaxton_cube_fn')
cube_named_jit

<function __main__.cube(x)>

**Exercise 40: Assign the XLA computation of `cube_named_jit` with input=10.24 to `cube_named_xla` and print it's XLA HLO text**

In [18]:
cube_named_xla = jax.xla_computation(cube_named_jit)(10.24)
print(cube_named_xla.as_hlo_text())

HloModule xla_computation_cube__2.14

jit_cube__2.3 {
  constant.5 = pred[] constant(false)
  parameter.4 = f32[] parameter(0)
  multiply.6 = f32[] multiply(parameter.4, parameter.4)
  ROOT multiply.7 = f32[] multiply(parameter.4, multiply.6)
}

jaxton_cube_fn.8 {
  constant.10 = pred[] constant(false)
  parameter.9 = f32[] parameter(0)
  ROOT call.11 = f32[] call(parameter.9), to_apply=jit_cube__2.3
}

ENTRY xla_computation_cube__2.14 {
  constant.2 = pred[] constant(false)
  parameter.1 = f32[] parameter(0)
  call.12 = f32[] call(parameter.1), to_apply=jaxton_cube_fn.8
  ROOT tuple.13 = (f32[]) tuple(call.12)
}




<center>
    This completes Set 4: Just-In-Time (JIT) Compilation (Exercises 31-40) of <b>JaxTon</b>: <i>💯 JAX exercises</i>
    <br>
    You can find all the exercises and solutions on <a href="https://github.com/vopani/jaxton#exercises-">GitHub</a>
</center>