In [5]:
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap
from jax import random
from jax.tree_util import register_pytree_node
import jax
import jax.tools

### Checking available devices

In [6]:
jax.devices()



[CpuDevice(id=0)]

### Simple batched function computed in parallel

In [10]:
def pow2(x):
    return jnp.power(x, 2)

In [11]:
parallel_pow2 = pmap(pow2, in_axes=(0))

# we have to provide number of batches that fits the available devices

X = jnp.array([1,2,3,4], dtype=jnp.float32)

X_batched = jnp.array(jnp.split(X, 1))

parallel_pow2(X_batched)

ShardedDeviceArray([[ 1.,  4.,  9., 16.]], dtype=float32)

### Joining parallel computations (collective ops jax.lax.p*)

In [11]:
def normalized_convolution(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  output = jnp.array(output)
  return output / jax.lax.psum(output, axis_name='p')


In [14]:
n_devices = jax.local_device_count() 
xs = jnp.arange(5 * n_devices).reshape(-1, 5)
ws = jnp.stack([w] * n_devices)

print(xs)
print(ws)

[[0 1 2 3 4]]
[[2. 3. 4.]]


In [15]:
vmap(normalized_convolution, axis_name='p')(xs, ws)

DeviceArray([[1., 1., 1.]], dtype=float32)