# Google JAX on Windows
#### Celso Axelrud
#### Revision 1.0 - 4/24/2021

This document describes the efforts to execute Google JAX (https://github.com/google/jax) on Windows OS.

JAX is available for Linux including on Google Collab environment.
But most of the client's solutions are implemented on Windows OS.
I have been collaborating by trying to compile and test JAX for Windows.

Finally, I got it running correctly for the CPU version.
I am still working on tests for GPU version.


JAX is much more than just a GPU-backed NumPy.
JAX is Autograd and XLA, brought together for high-performance machine learning research.
JAX can automatically differentiate native Python and NumPy functions. 
It can differentiate through loops, branches, recursion, and closures, and it can take 
derivatives of derivatives of derivatives. 
It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as
forward-mode differentiation, and the two can be composed arbitrarily to any order.

#### Libraries on top of JAX
My personal interest is on NumPyro (UBER).

Other libraries are Flax (Google Brain), Haiku (Deepmind), Trax (Google Brain), Objax (Google), Stax (Google), Elegy (PoetsAI), RLax (Deepmind), Optax (Deepmind), Jraph (Deepmind), JAX-M.D. (Google), Oryx (Google).

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads


#from jax import grad

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
# prints 0.4199743

print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673



0.4199743
0.6216266


In [2]:
from jax import jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)

import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)
#24.5 ms ± 17.8 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
#159 ms ± 2.32 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

1.0
-1.0
The slowest run took 5.49 times longer than the fastest. This could mean that an intermediate result is being cached.
27.8 ms ± 18.4 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
180 ms ± 12.6 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [3]:
def predict(params, input_vec):
  assert input_vec.ndim == 1
  activations = input_vec
  for W, b in params:
    outputs = jnp.dot(W, activations) + b  # `activations` on the right-hand side!
    activations = jnp.tanh(outputs)
  return outputs

from functools import partial
#predictions = jnp.stack(list(map(partial(predict, params), input_batch)))


from jax import vmap
#predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
#predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

#per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)


#https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
#[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
# -0.6713536  -0.59086424  0.73168874  0.56730247]

[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.59086424  0.73168874  0.56730247]


In [4]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU
#267 ms ± 3.53 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

274 ms ± 5.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
#299 ms ± 4.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

436 ms ± 9.92 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
#293 ms ± 9.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

310 ms ± 11.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
#224 ms ± 7.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

229 ms ± 4.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
#7.77 ms ± 180 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

6.93 ms ± 352 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
#1.57 ms ± 26.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

1.3 ms ± 40.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [10]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
#[0.25       0.19661197 0.10499357]

[0.25       0.19661197 0.10499357]


In [11]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])

In [12]:
print(first_finite_differences(sum_logistic, x_small))
#[0.24998187 0.1964569  0.10502338]

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
#-0.035325594

[0.24998187 0.1964569  0.10502338]
-0.035325594


In [13]:
#from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
#Naively batched
#5.28 ms ± 1.59 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


Naively batched
4.09 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
#Manually batched
#76.3 µs ± 872 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Manually batched
57.2 µs ± 556 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [15]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
#Auto-vectorized with vmap
#94.4 µs ± 612 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


Auto-vectorized with vmap
75.5 µs ± 2.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
