import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3'


# The following is run in parallel on each host on a GPU cluster or TPU pod slice.
import jax
# jax.distributed.initialize()  # On GPU, see above for the necessary arguments.
jax.device_count()  # total number of accelerator devices in the cluster

In [3]:
jax.local_device_count()  # number of accelerator devices attached to this host

3

In [4]:
# The psum is performed over all mapped devices across the pod slice
xs = jax.numpy.ones(jax.local_device_count())
jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# ShardedDeviceArray([32., 32., 32., 32., 32., 32., 32., 32.], dtype=float32)

Array([3., 3., 3.], dtype=float32)

In [1]:
import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)

Array([11., 20., 29.], dtype=float32)

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3'

import jax

import numpy as np
import jax.numpy as jnp

import time

In [2]:

size = 5
x = jnp.arange(size)
w = jnp.array([2., 3., 4.])


def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

iter = 2

In [3]:
for _ in range(iter):
  t0 = time.time()
  a=jax.jit(convolve)(x, w)
  print(time.time() - t0)


1.5886805057525635
[  11.   20.   29.   38.   47.   56.   65.   74.   83.   92.  101.  110.
  119.  128.  137.  146.  155.  164.  173.  182.  191.  200.  209.  218.
  227.  236.  245.  254.  263.  272.  281.  290.  299.  308.  317.  326.
  335.  344.  353.  362.  371.  380.  389.  398.  407.  416.  425.  434.
  443.  452.  461.  470.  479.  488.  497.  506.  515.  524.  533.  542.
  551.  560.  569.  578.  587.  596.  605.  614.  623.  632.  641.  650.
  659.  668.  677.  686.  695.  704.  713.  722.  731.  740.  749.  758.
  767.  776.  785.  794.  803.  812.  821.  830.  839.  848.  857.  866.
  875.  884.  893.  902.  911.  920.  929.  938.  947.  956.  965.  974.
  983.  992. 1001. 1010. 1019. 1028. 1037. 1046. 1055. 1064. 1073. 1082.
 1091. 1100. 1109. 1118. 1127. 1136. 1145. 1154. 1163. 1172. 1181. 1190.
 1199. 1208. 1217. 1226. 1235. 1244. 1253. 1262. 1271. 1280. 1289. 1298.
 1307. 1316. 1325. 1334. 1343. 1352. 1361. 1370. 1379. 1388. 1397. 1406.
 1415. 1424. 1433. 1442. 1451. 1

In [None]:
n_devices = jax.local_device_count()
xs = np.arange(size * n_devices).reshape(-1, size)
ws = np.stack([w] * n_devices)


In [None]:
for _ in range(iter):
  t0 = time.time()
  jax.vmap(convolve)(xs, ws)
  print(time.time() - t0)


In [None]:
for _ in range(iter):
  t0 = time.time()
  a = jax.pmap(convolve)(xs, ws).block_until_ready()
  print(time.time() - t0)


In [1]:
import jax
import jax.numpy as jnp
import scipy.special

def jv(v, z):
  v, z = jnp.asarray(v), jnp.asarray(z)

  # Require the order v to be integer type: this simplifies
  # the JVP rule below.
  assert jnp.issubdtype(v.dtype, jnp.integer)

  # Promote the input to inexact (float/complex).
  # Note that jnp.result_type() accounts for the enable_x64 flag.
  z = z.astype(jnp.result_type(float, z.dtype))

  # Wrap scipy function to return the expected dtype.
  _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)

  # Define the expected shape & dtype of output.
  result_shape_dtype = jax.ShapeDtypeStruct(
      shape=jnp.broadcast_shapes(v.shape, z.shape),
      dtype=z.dtype)

  # We use vectorize=True because scipy.special.jv handles broadcasted inputs.
  return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)

In [2]:
jv

<function __main__.jv(v, z)>

In [3]:
jnp.linalg.eig

<function jax._src.numpy.linalg.eig(a: Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, bool, int, float, complex]) -> Tuple[jax.Array, jax.Array]>

In [6]:
import jax
import numpy as np
import jax.numpy as jnp


X = jnp.arange(1000**2).reshape((1000,1000))

def eig(X):
    jax.config.update('jax_enable_x64', True)

    type_complex=jnp.complex128
    _eig = lambda x: np.linalg.eig(x)

    eigenvalues_shape = jax.ShapeDtypeStruct(X.shape[:-1], type_complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(X.shape, type_complex)

    result_shape_dtype = (eigenvalues_shape, eigenvectors_shape)

    return jax.pure_callback(_eig, result_shape_dtype, X)

print(eig(X))

(Array([2.12199579e-314+6.36598737e-314j, 1.06099790e-313+1.48539705e-313j,
       1.90979621e-313+2.33419537e-313j, 2.75859453e-313+3.18299369e-313j,
       3.60739285e-313+4.03179200e-313j, 4.45619116e-313+4.88059032e-313j,
       5.30498948e-313+5.72938864e-313j, 6.15378780e-313+6.57818695e-313j,
       7.00258611e-313+7.42698527e-313j, 7.85138443e-313+8.27578359e-313j,
       8.70018274e-313+9.12458190e-313j, 9.54898106e-313+9.97338022e-313j,
       1.03977794e-312+1.08221785e-312j, 1.12465777e-312+1.16709769e-312j,
       1.20953760e-312+1.25197752e-312j, 1.29441743e-312+1.33685735e-312j,
       1.37929726e-312+1.42173718e-312j, 1.46417710e-312+1.50661701e-312j,
       1.54905693e-312+1.59149684e-312j, 1.63393676e-312+1.67637668e-312j,
       1.71881659e-312+1.76125651e-312j, 1.80369642e-312+1.84613634e-312j,
       1.88857625e-312+1.93101617e-312j, 1.97345609e-312+2.01589600e-312j,
       2.05833592e-312+2.10077583e-312j, 2.14321575e-312+2.18565567e-312j,
       2.22809558e-312+2