Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating random numbers with jax.random.split can be >200x slower than np.random.normal #968

Closed
lxuechen opened this issue Jul 3, 2019 · 7 comments
Assignees
Labels
performance make things lean and fast question Questions for the JAX team

Comments

@lxuechen
Copy link

lxuechen commented Jul 3, 2019

Currently, I'm relying on jax.random.split and jax.random.normal for random number generation. I expect generating random numbers with this combination to be slower, but it's still surprising given the following results (on CPU):

import time
from jax import random, grad
import jax.numpy as np
import numpy.random as npr

def sample_repeatedly_with_split(key):
    for _ in range(10000):
        key, subkey = random.split(key)
        random.normal(subkey, shape=(3,))


def sample_repeatedly():
    for _ in range(10000):
        npr.normal(size=(3,))


key = random.PRNGKey(0)
now = time.time()
sample_repeatedly_with_split(key=key)
print('sample with split takes {:.4f} secs'.format(time.time() - now))

now = time.time()
sample_repeatedly()
print('`npr.normal` takes {:.4f} secs'.format(time.time() - now))

Results:

sample with split takes 8.6022 secs
`npr.normal` takes 0.0296 secs

Some profiling results (with cProfile and pstats) show:

myscript.cprof% stats 20
Wed Jul  3 10:13:14 2019    myscript.cprof

         7879396 function calls (7264157 primitive calls) in 8.870 seconds

   Ordered by: cumulative time
   List reduced from 1909 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    293/1    0.002    0.000    8.877    8.877 {built-in method builtins.exec}
        1    0.000    0.000    8.877    8.877 split_prof.py:1(<module>)
        1    0.091    0.091    8.602    8.602 split_prof.py:12(sample_repeatedly_with_split)
20003/20000    0.197    0.000    4.365    0.000 api.py:109(f_jitted)
    30000    0.034    0.000    4.076    0.000 xla.py:518(<genexpr>)
    20001    0.102    0.000    4.042    0.000 lax_numpy.py:2161(_rewriting_take)
    20004    0.089    0.000    3.580    0.000 lax.py:1206(index_in_dim)
20007/20000    0.143    0.000    3.062    0.000 core.py:656(call_bind)
20003/20000    0.090    0.000    2.574    0.000 xla.py:604(xla_call_impl)
40274/40273    0.081    0.000    2.410    0.000 core.py:139(bind)
    10000    0.028    0.000    2.295    0.000 random.py:376(normal)
    10000    0.021    0.000    2.121    0.000 random.py:161(split)
    40006    0.134    0.000    2.114    0.000 xla.py:50(apply_primitive)
    20005    0.048    0.000    1.781    0.000 lax.py:1192(slice_in_dim)
    20009    0.271    0.000    1.733    0.000 lax.py:586(slice)
20003/20000    0.085    0.000    1.660    0.000 linear_util.py:199(memoized_fun)
    40006    0.983    0.000    1.318    0.000 xla.py:83(execute_compiled_primitive)
    20013    0.128    0.000    1.284    0.000 lax.py:549(reshape)
    20003    0.503    0.000    0.691    0.000 xla.py:629(execute_compiled)
    59994    0.116    0.000    0.598    0.000 linear_util.py:173(__eq__)

Note I named my script split_prof.py. It seems there's considerable overhead with XLA, even when I'm not actively jitting any functions.

@mattjj mattjj self-assigned this Jul 3, 2019
@mattjj
Copy link
Member

mattjj commented Jul 3, 2019

Thanks for the clear benchmark!

I think this is essentially timing dispatch overheads. NumPy's dispatch overheads are a lot lower than JAX's, which makes it much faster at doing lots of small operations. (Interestingly, lots of small operations is what Python+NumPy is already considered bad at compared to something like pure C. One way to think about JAX, at least in its current state, is that it pushes that contrast further, in that it's even better than NumPy at large array-oriented operations because of its jit compilation and use of accelerators, but it's even worse at doing lots of small operations.)

One way to make things faster is to use jit, e.g.:

from jax import jit

@jit
def split_and_sample(key):
  key, subkey = random.split(key)
  val = random.normal(subkey, shape=(3,))
  return key, val

def sample_repeatedly_with_split(key):
  for _ in range(10000):
    key, _ = split_and_sample(key)

That sped things up, but only by a factor of 2. (That's also including compilation time, though that's probably small.)

To measure something other than dispatch overheads, which isn't specific to PRNG stuff but would be measurable in pretty much any JAX vs NumPy micro benchmark like this, we can make the arrays bigger. Here are a few different sizes, with the largest being 30000 (which don't forget is smaller than a 200x200 array, which has size 40000, so these aren't very big sizes):

issue968

Here's the full script (check for bugs!):

import time
from jax import random, grad, jit
import jax.numpy as np
import numpy.random as npr

@jit
def split_and_sample(key):
  key, subkey = random.split(key)
  val = random.normal(subkey, shape=shape)
  return key, val

def sample_repeatedly_with_split(key):
    for _ in range(10000):
        key, _ = split_and_sample(key)
    return key


def sample_repeatedly():
    for _ in range(10000):
        npr.normal(size=shape)


jax_times, np_times = [], []
sizes = [3, 30, 300, 3000, 30000]
for size in sizes:
  shape = (size,)

  key = random.PRNGKey(0)
  now = time.time()
  sample_repeatedly_with_split(key=key).block_until_ready()  # async!
  jax_times.append(time.time() - now)

  now = time.time()
  sample_repeatedly()
  np_times.append(time.time() - now)

import matplotlib.pyplot as plt
plt.semilogy(sizes, jax_times, label="jax times")
plt.semilogy(sizes, np_times, label="np times")
plt.legend()
plt.savefig('issue968.png')

The block_until_ready function is to prevent JAX from getting an unfair advantage due to its async dispatch, which lets us hide dispatch overheads and device latencies behind the real numerical work going on (but in this case it doesn't make a difference because all the time is spent in Python overheads, so overlapping the compute with the Python doesn't buy us anything). (EDITED)

Still, if you want to generate lots of small arrays and can't stick everything under a jit, that's the kind of workload for which NumPy is better than JAX.

What do you think?

@mattjj mattjj added performance make things lean and fast question Questions for the JAX team labels Jul 3, 2019
@mattjj
Copy link
Member

mattjj commented Jul 3, 2019

By the way, JAX's PRNG code is implemented entirely in Python. Not bad performance compared to the hand-written C/Fortran underlying NumPy's PRNG!

@lxuechen
Copy link
Author

lxuechen commented Jul 3, 2019

Hi Matt,

Thanks for the quick update! I realized that I should've "jitted" the function in the first place to avoid repeating some of the small computations.

Thanks also for pointing out that jax uses async dispatch, as I was really curious about the motivation for a DeviceArray type.

Also, awesome work on the pure Python PRNG!

@lxuechen lxuechen closed this as completed Jul 3, 2019
@ssydasheng
Copy link

import numpy as original_numpy
import jax.numpy as np
import jax.random as random
import time

z = random.normal(random.PRNGKey(123123), shape=[100, 100])
# z = original_numpy.asarray(z) # Adding this reduces running time from 28.2714s to 0.0153s
points = []
start_time = time.time()
for i in range(100):
    if i in [0, 99]: continue
    for j in range(100):
        if j in [0, 99]: continue
        if z[j, i] <= z[j-1, i] and z[j, i] <= z[j+1, i] and z[j, i] <= z[j, i-1] and z[j, i] <= z[j, i+1]:
            points.append([i, j])
total_time = time.time() - start_time
print('Elapsed Time = %.4f' % total_time)

Here seems to be one related example on large amount of small computations. However, it seems not easy to jit it.

@jit
def compare(z, i, j):
    return z[j, i] <= z[j - 1, i] and z[j, i] <= z[j + 1, i] and z[j, i] <= z[j, i - 1] and z[j, i] <= z[j, i + 1]

This raises error message,

Traceback (most recent call last):
File "function/speed", line 19, in
if compare(z, i, j):
File "/h/ssy/codes/packages/jax/jax/api.py", line 126, in f_jitted
out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
File "/h/ssy/codes/packages/jax/jax/core.py", line 663, in call_bind
ans = primitive.impl(f, *args, **params)
File "/h/ssy/codes/packages/jax/jax/interpreters/xla.py", line 673, in _xla_call_impl
compiled_fun = _xla_callable(fun, device_values, *map(abstractify, args))
File "/h/ssy/codes/packages/jax/jax/linear_util.py", line 208, in memoized_fun
ans = call(f, *args)
File "/h/ssy/codes/packages/jax/jax/interpreters/xla.py", line 685, in _xla_callable
jaxpr, (pval, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
File "/h/ssy/codes/packages/jax/jax/linear_util.py", line 147, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "function/speed", line 9, in compare
return z[j, i] <= z[j - 1, i] and z[j, i] <= z[j + 1, i] and z[j, i] <= z[j, i - 1] and z[j, i] <= z[j, i + 1]
File "/h/ssy/codes/packages/jax/jax/core.py", line 342, in bool
def bool(self): return self.aval._bool(self)
File "/h/ssy/codes/packages/jax/jax/abstract_arrays.py", line 38, in error
raise TypeError(concretization_err_msg(fun))
TypeError: Abstract value passed to bool, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using jit, try using static_argnums or applying jit to smaller subfunctions instead.

@mattjj
Copy link
Member

mattjj commented Jul 15, 2019

(It might be best to open a separate issue thread.)

Indeed, not everything can be jit, and this is a good example of code that JAX can't jit and must execute in an "op-by-op" fashion.

@marcdelabarrera
Copy link

That's a helpful discussion, I'm having the same problem with jax.random.choice. To my knowledge, it does not offer a vectorized version so I created a function that given one state, samples using jax.random.choice, and then I vmap this function over several states.

For the problem here, is vmap better in performance than a for loop? Or they are essentially the same?

from jax import jit
import jax

@jit
def sample(key):
  val = random.normal(key, shape=(3,))
  return val

def sample_repeatedly_with_split(key):
     key, *subkey = jax.random.split(key, 10000)
     return jax.vmap(sample)(jnp.array(subkey))

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 22, 2022

Your approach of iterating over a large key array and then re-casing it to jnp.array is going to be very slow. Try this instead:

def sample_repeatedly_with_split(key):
    subkeys = jax.random.split(key, 10000)
    key = subkeys[0]
    return jax.vmap(sample)(subkeys[1:])

I suspect this will virtually always be faster than a Python-side for-loop over the sample function, because a Python for-loop will always be executed sequentially, where vmap allows the compiler to parallelize operations when possible. Also, if you run both outside JIT, the for loop approach will have 10000 times the Python overhead than the vmap approach, and so will be much slower when the content of each loop is small.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance make things lean and fast question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

5 participants