Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

time to jit a function grows superlinear with memory accessed by function #1776

Closed
lhk opened this issue Nov 27, 2019 · 3 comments
Closed

time to jit a function grows superlinear with memory accessed by function #1776

lhk opened this issue Nov 27, 2019 · 3 comments
Assignees
Labels
question Questions for the JAX team

Comments

@lhk
Copy link

lhk commented Nov 27, 2019

Here is a simple example, which numerically integrates the product of two Gaussian pdfs. One of the Gaussians is fixed, with mean always at 0. The other Gaussian varies in its mean:

import time

import jax.numpy as np
from jax import jit
from jax.scipy.stats.norm import pdf

# set up evaluation points for numerical integration
integr_resolution = 6400
lower_bound = -100
upper_bound = 100
integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
proba = pdf(integr_grid)
integration_weight = (upper_bound - lower_bound) / integr_resolution


# integrate with new mean
def integrate(mu_new):
    x_new = integr_grid - mu_new

    proba_new = pdf(x_new)
    total_proba = sum(proba * proba_new * integration_weight)

    return total_proba


print('starting jit')
start = time.perf_counter()
integrate = jit(integrate)
integrate(1)
stop = time.perf_counter()
print('took: ', stop - start)

The function looks seemingly simple, but it doesn't scale at all:

integr_resolution seconds to execute
100 0.107
200 0.23
400 0.537
800 1.52
1600 5.2
3200 19
6400 134

For reference, the unjitted function, applied to integr_resolution=6400 takes 0.02s.

I thought that this might be related to the fact that the function is accessing a global variable. But moving the code to set up the integration points inside of the function has no notable influence on the timing. The following code takes 5.36s to run. It corresponds to the table entry with 1600 which previously took 5.2s:

# integrate with new mean
def integrate(mu_new):
    # set up evaluation points for numerical integration
    integr_resolution = 1600
    lower_bound = -100
    upper_bound = 100
    integr_grid = np.linspace(lower_bound, upper_bound, integr_resolution)
    proba = pdf(integr_grid)
    integration_weight = (upper_bound - lower_bound) / integr_resolution

    x_new = integr_grid - mu_new

    proba_new = pdf(x_new)
    total_proba = sum(proba * proba_new * integration_weight)

    return total_proba

What is happening here?

@hawkinsp
Copy link
Collaborator

It's because the code says sum where it should say np.sum.

sum is a Python built-in that extracts each element of a sequence and sums them one by one using the + operator. This has the effect of building a large, unrolled chain of adds which XLA takes a long time to compile. (To be honest, I'm kind of amazed this worked at all!)

If you use np.sum, then JAX builds a single XLA reduction operator, which is much faster to compile.

Does this resolve the question? I'm not sure what we could do better here, although I admit it's a bit surprising!

@hawkinsp hawkinsp self-assigned this Nov 27, 2019
@hawkinsp hawkinsp added the question Questions for the JAX team label Nov 27, 2019
@hawkinsp
Copy link
Collaborator

And just to show how I figured this out: I used jax.make_jaxpr, which dumps JAX's internal trace representation of a function. Here, it shows:

In [3]: import jax

In [4]: jax.make_jaxpr(integrate)(1)
Out[4]:
{ lambda b c ;  ; a.
  let d = convert_element_type[ new_dtype=float32
                                old_dtype=int32 ] a
      e = sub c d
      f = sub e 0.0
      g = pow f 2.0
      h = div g 1.0
      i = add 1.8378770351409912 h
      j = neg i
      k = div j 2.0
      l = exp k
      m = mul b l
      n = mul m 2.0
      o = slice[ start_indices=(0,)
                 limit_indices=(1,)
                 strides=(1,)
                 operand_shape=(100,) ] n
      p = reshape[ new_sizes=()
                   dimensions=None
                   old_sizes=(1,) ] o
      q = add p 0.0
      r = slice[ start_indices=(1,)
                 limit_indices=(2,)
                 strides=(1,)
                 operand_shape=(100,) ] n
      s = reshape[ new_sizes=()
                   dimensions=None
                   old_sizes=(1,) ] r
      t = add q s
      u = slice[ start_indices=(2,)
                 limit_indices=(3,)
                 strides=(1,)
                 operand_shape=(100,) ] n
      v = reshape[ new_sizes=()
                   dimensions=None
                   old_sizes=(1,) ] u
      w = add t v
      x = slice[ start_indices=(3,)
                 limit_indices=(4,)
                 strides=(1,)
                 operand_shape=(100,) ] n
      y = reshape[ new_sizes=()
                   dimensions=None
                   old_sizes=(1,) ] x
      z = add w y
... similarly ...

and it's then obvious why this is slow: the program is very big.

Contrast the np.sum version:

In [5]: def integrate(mu_new):
   ...:     x_new = integr_grid - mu_new
   ...:
   ...:     proba_new = pdf(x_new)
   ...:     total_proba = np.sum(proba * proba_new * integration_weight)
   ...:
   ...:     return total_proba
   ...:

In [6]: jax.make_jaxpr(integrate)(1)
Out[6]:
{ lambda b c ;  ; a.
  let d = convert_element_type[ new_dtype=float32
                                old_dtype=int32 ] a
      e = sub c d
      f = sub e 0.0
      g = pow f 2.0
      h = div g 1.0
      i = add 1.8378770351409912 h
      j = neg i
      k = div j 2.0
      l = exp k
      m = mul b l
      n = mul m 2.0
      o = reduce_sum[ axes=(0,)
                      input_shape=(100,) ] n
  in [o] }

@lhk
Copy link
Author

lhk commented Nov 27, 2019

Thank you very much for this prompt and insightful help.
I posted the same question on SO, if you want to get the rep for this great support :D https://stackoverflow.com/questions/59068666/jax-time-to-jit-a-function-grows-superlinear-with-memory-accessed-by-function

@lhk lhk closed this as completed Nov 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants