-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Function with static_argnums
as jit
ed function argument.
#5609
Comments
Agreed this is strange... I'd expect to see an error in both cases! I'm not sure why the first one works, but you can fix both by marking the function argument as static: @partial(jax.jit, static_argnums=(1,))
def f(x, fn):
return fn(fn(x)) |
@jakevdp thanks for the quick response. What about if the function contains both static and non-static args? My actual use case is sparse implementations - simplified version below. I suppose this might be easier once your sparse support PR goes through with primitives, but in the meantime is there any way to achieve this? from typing import Callable
from functools import partial
import jax
import jax.numpy as jnp
@jax.jit
def power_iteration(A_fun: Callable, x0, iterations):
def cond_fun(state):
return state[2] < iterations
def body_fun(state):
value, vec, it = state
vec = A_fun(vec)
value = jnp.linalg.norm(vec)
vec /= value
return value, vec, it + 1
value = jnp.linalg.norm(x0)
vec = x0 / value
value, vec, _ = jax.lax.while_loop(cond_fun, body_fun, (value, vec, 0))
return value, vec
@jax.jit
def dense_matvec(A, x):
return A @ x
@partial(jax.jit, static_argnums=(3,))
def coo_matvec(data, row, col, nrows, v):
assert v.ndim == 1
dv = data * v[col]
return jnp.zeros(nrows, dtype=dv.dtype).at[row].add(dv)
n = 16
iters = 100
sparsity = 0.1
dtype = jnp.float32
key = jax.random.PRNGKey(0)
vals_key, mask_key, x0_key = jax.random.split(key, 3)
x0 = jax.random.normal(x0_key, shape=(n,), dtype=dtype)
a = jax.random.normal(vals_key, shape=(n, n), dtype=dtype)
mask = jax.random.uniform(mask_key, shape=(n, n), dtype=dtype) < sparsity
# strengthen diagonal so eigvals are more real
a = a + n * jnp.eye(n, dtype=dtype)
mask = jax.ops.index_update(
mask, jax.ops.index[jnp.arange(n), jnp.arange(n)], jnp.ones((n,), dtype=bool)
)
# get coo data
row, col = jnp.where(mask)
data = a[row, col]
# create masked a
a = jax.ops.index_update(jnp.zeros((n, n), dtype=dtype), jax.ops.index[row, col], data)
w, v = jax.jit(jnp.linalg.eig, backend="cpu")(a)
wi = jnp.argmax(jnp.abs(w))
true_value = w[wi]
true_vec = v[:, wi]
print("True:")
print(true_value)
print(true_vec)
# our dense implemenetation
dense_fun = jax.tree_util.Partial(dense_matvec, a)
dense_value, dense_vec = power_iteration(dense_fun, x0, iters)
print("Dense:")
print(dense_value)
print(dense_vec)
# our coo implementation
coo_fun = jax.tree_util.Partial(coo_matvec, data, row, col, n)
coo_value, coo_vec = power_iteration(coo_fun, x0, iters)
print(coo_value, coo_vec)
print("COO:")
print(coo_value)
print(coo_vec) |
I think my fix works there as well. Change the first function definition to this: @partial(jax.jit, static_argnums=(0,))
def power_iteration(A_fun: Callable, x0, iterations): In other words, a callable passed to a jitted function should always be marked static in that jitted function. I'm surprised it would ever work otherwise. |
Digging a bit, I think I see what's going on here. With Lines 301 to 315 in cd4138b
We could fix the issue you're seeing by leaving any static arguments out of the pytree produced by |
Hmm... I would have thought you could set things up such that the outer function doesn't need to be recompiled so long as the passed function doesn't need to be - e.g. changing the data in |
Yeah - I spent some time looking into this. JAX's current method of tracking static arguments makes it difficult to do things at this level of granularity. If you pass a |
@jakevdp thanks for your investigation. I've got a work-around based on passing in a vector of size @jax.jit
def coo_matvec(data, row, col, sized, v):
assert v.ndim == 1
dv = data * v[col]
return jnp.zeros(sized.size, dtype=dv.dtype).at[row].add(dv) |
I'm looking to pass a function to a
jit
ed function. The argument is itself ajit
ed function with a static argument set usingjax.tree_util.Partial
. This is similar to 1443. Am I missing something obvious? Is this intended behaviour? If not, is there a work-around?Error:
The text was updated successfully, but these errors were encountered: