/ jax Public

# time to jit a function grows superlinear with memory accessed by function#1776

Closed
opened this issue Nov 27, 2019 · 3 comments
Closed

# time to jit a function grows superlinear with memory accessed by function #1776

opened this issue Nov 27, 2019 · 3 comments
Assignees
Labels
question Questions for the JAX team

### lhk commented Nov 27, 2019

Here is a simple example, which numerically integrates the product of two Gaussian pdfs. One of the Gaussians is fixed, with mean always at 0. The other Gaussian varies in its mean:

```import time

import jax.numpy as np
from jax import jit
from jax.scipy.stats.norm import pdf

# set up evaluation points for numerical integration
integr_resolution = 6400
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution

# integrate with new mean
def integrate(mu_new):
x_new = integr_grid - mu_new

proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)

print('starting jit')
start = time.perf_counter()
integrate = jit(integrate)
integrate(1)
stop = time.perf_counter()
print('took: ', stop - start)```

The function looks seemingly simple, but it doesn't scale at all:

integr_resolution seconds to execute
100 0.107
200 0.23
400 0.537
800 1.52
1600 5.2
3200 19
6400 134

For reference, the unjitted function, applied to `integr_resolution=6400` takes 0.02s.

I thought that this might be related to the fact that the function is accessing a global variable. But moving the code to set up the integration points inside of the function has no notable influence on the timing. The following code takes 5.36s to run. It corresponds to the table entry with 1600 which previously took 5.2s:

```# integrate with new mean
def integrate(mu_new):
# set up evaluation points for numerical integration
integr_resolution = 1600
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution

x_new = integr_grid - mu_new

proba_new = pdf(x_new)
total_proba = sum(proba * proba_new * integration_weight)

What is happening here?

### hawkinsp commented Nov 27, 2019

 It's because the code says `sum` where it should say `np.sum`. `sum` is a Python built-in that extracts each element of a sequence and sums them one by one using the `+` operator. This has the effect of building a large, unrolled chain of adds which XLA takes a long time to compile. (To be honest, I'm kind of amazed this worked at all!) If you use `np.sum`, then JAX builds a single XLA reduction operator, which is much faster to compile. Does this resolve the question? I'm not sure what we could do better here, although I admit it's a bit surprising!

self-assigned this Nov 27, 2019
added the question Questions for the JAX team label Nov 27, 2019

### hawkinsp commented Nov 27, 2019

 And just to show how I figured this out: I used `jax.make_jaxpr`, which dumps JAX's internal trace representation of a function. Here, it shows: ``````In [3]: import jax In [4]: jax.make_jaxpr(integrate)(1) Out[4]: { lambda b c ; ; a. let d = convert_element_type[ new_dtype=float32 old_dtype=int32 ] a e = sub c d f = sub e 0.0 g = pow f 2.0 h = div g 1.0 i = add 1.8378770351409912 h j = neg i k = div j 2.0 l = exp k m = mul b l n = mul m 2.0 o = slice[ start_indices=(0,) limit_indices=(1,) strides=(1,) operand_shape=(100,) ] n p = reshape[ new_sizes=() dimensions=None old_sizes=(1,) ] o q = add p 0.0 r = slice[ start_indices=(1,) limit_indices=(2,) strides=(1,) operand_shape=(100,) ] n s = reshape[ new_sizes=() dimensions=None old_sizes=(1,) ] r t = add q s u = slice[ start_indices=(2,) limit_indices=(3,) strides=(1,) operand_shape=(100,) ] n v = reshape[ new_sizes=() dimensions=None old_sizes=(1,) ] u w = add t v x = slice[ start_indices=(3,) limit_indices=(4,) strides=(1,) operand_shape=(100,) ] n y = reshape[ new_sizes=() dimensions=None old_sizes=(1,) ] x z = add w y ... similarly ... `````` and it's then obvious why this is slow: the program is very big. Contrast the `np.sum` version: ``````In [5]: def integrate(mu_new): ...: x_new = integr_grid - mu_new ...: ...: proba_new = pdf(x_new) ...: total_proba = np.sum(proba * proba_new * integration_weight) ...: ...: return total_proba ...: In [6]: jax.make_jaxpr(integrate)(1) Out[6]: { lambda b c ; ; a. let d = convert_element_type[ new_dtype=float32 old_dtype=int32 ] a e = sub c d f = sub e 0.0 g = pow f 2.0 h = div g 1.0 i = add 1.8378770351409912 h j = neg i k = div j 2.0 l = exp k m = mul b l n = mul m 2.0 o = reduce_sum[ axes=(0,) input_shape=(100,) ] n in [o] } ``````

### lhk commented Nov 27, 2019 • edited

 Thank you very much for this prompt and insightful help. I posted the same question on SO, if you want to get the rep for this great support :D https://stackoverflow.com/questions/59068666/jax-time-to-jit-a-function-grows-superlinear-with-memory-accessed-by-function

closed this as completed Nov 27, 2019