### [Why You Should (or Shouldn't) be Using Google's JAX in 2023](https://www.assemblyai.com/blog/why-you-should-or-shouldnt-be-using-jax-in-2023/)

https://github.com/google/jax

https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb#scrollTo=AvXl1WDPKjmV

https://theaisummer.com/jax/#auto-differentiation-with-grad-function

#### NumPy on Accelerators

In [5]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap

In [6]:
def fn(x):
  return x + x*x + x*x*x

x = np.random.randn(10000, 10000).astype(dtype='float32')

In [2]:
%timeit -n5 fn(x)


480 ms ± 65.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


In [8]:
x = jnp.array(x)
%timeit -n5 fn(x).block_until_ready()

256 ms ± 9.81 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)


#### XLA

https://www.tensorflow.org/xla - specific compiler for linear algebra.



TPU - "XLA devices"


#### just-in-time (JIT)

In [10]:
def fn(x):
  return x + x*x + x*x*x

x_np = np.random.randn(5000, 5000).astype(dtype='float32')
x_jnp = jnp.array(x_np)

%timeit fn(x_jnp).block_until_ready()
jitted = jit(fn)
jitted(x_jnp)
%timeit jitted(x_jnp).block_until_ready()

64.8 ms ± 1.8 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
9.97 ms ± 210 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

#### Auto-differentiation

In [2]:
def rectified_cube(x):
  r = 1
  
  if x < 0.:
    for i in range(3):
      r *= x
    r = -r
  else:
    for i in range(3):
        r *= x
  
  return r

gradient_function = grad(rectified_cube)

print(f"x = 2   f(x) = {rectified_cube(2.)}   f'(x) =  3*x^2 = {gradient_function(2.)}")
print(f"x = -3  f(x) = {rectified_cube(-3.)}  f'(x) = -3*x^2 = {gradient_function(-3.)}")



x = 2   f(x) = 8.0   f'(x) =  3*x^2 = 12.0
x = -3  f(x) = 27.0  f'(x) = -3*x^2 = -27.0


In [3]:
# for x >= 0: f(x)=x^3 => f'(x)=3*x^2 => f''(x)=3*2*x => f'''(x)=6
third_deriv = grad(grad(grad(rectified_cube)))
for i in range(5):
  print(third_deriv(float(i)))

6.0
6.0
6.0
6.0
6.0


#### Deep Learning

https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

https://github.com/google/flax

https://flax.readthedocs.io/en/latest/getting_started.html#install-flax

http://www.echonolan.net/posts/2021-09-06-JAX-vs-PyTorch-A-Transformer-Benchmark.html

https://github.com/pytorch/xla

https://www.kaggle.com/code/tanulsingh077/pytorch-xla-understanding-tpu-s-and-xla

https://kozodoi.me/python/deep%20learning/computer%20vision/tutorial/2020/10/30/pytorch-xla-tpu.html

https://colab.research.google.com/github/kozodoi/website/blob/master/_notebooks/2020-10-30-pytorch-xla-tpu.ipynb