Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function with static_argnums as jited function argument. #5609

Open
jackd opened this issue Feb 3, 2021 · 7 comments
Open

Function with static_argnums as jited function argument. #5609

jackd opened this issue Feb 3, 2021 · 7 comments
Labels
enhancement New feature or request open Issues intentionally left open, with no schedule for next steps.

Comments

@jackd
Copy link
Contributor

jackd commented Feb 3, 2021

I'm looking to pass a function to a jited function. The argument is itself a jited function with a static argument set using jax.tree_util.Partial. This is similar to 1443. Am I missing something obvious? Is this intended behaviour? If not, is there a work-around?

from functools import partial
import jax
import jax.numpy as jnp

@jax.jit
def f(x, fn):
    return fn(fn(x))

@partial(jax.jit, static_argnums=(0,))
def fn_with_static_arg(p, x):
    return jnp.tile(x, (p,))

@jax.jit
def fn_simple(p, x):
    return x ** p

x = jnp.arange(3)
p = 2
print(f(x, jax.tree_util.Partial(fn_simple, p)))  # works fine
print(f(x, jax.tree_util.Partial(fn_with_static_arg, p)))  # ValueError

Error:

  File ".../jax/api_util.py", line 101, in argnums_partial_except
    "Non-hashable static arguments are not supported, as this can lead "
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses.
Static argument (index 0) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function
fn_with_static_arg is non-hashable.
@jakevdp
Copy link
Collaborator

jakevdp commented Feb 3, 2021

Agreed this is strange... I'd expect to see an error in both cases! I'm not sure why the first one works, but you can fix both by marking the function argument as static:

@partial(jax.jit, static_argnums=(1,))
def f(x, fn):
    return fn(fn(x))

@jackd
Copy link
Contributor Author

jackd commented Feb 3, 2021

@jakevdp thanks for the quick response. What about if the function contains both static and non-static args? My actual use case is sparse implementations - simplified version below. I suppose this might be easier once your sparse support PR goes through with primitives, but in the meantime is there any way to achieve this?

from typing import Callable
from functools import partial
import jax
import jax.numpy as jnp

@jax.jit
def power_iteration(A_fun: Callable, x0, iterations):
    def cond_fun(state):
        return state[2] < iterations

    def body_fun(state):
        value, vec, it = state
        vec = A_fun(vec)
        value = jnp.linalg.norm(vec)
        vec /= value
        return value, vec, it + 1

    value = jnp.linalg.norm(x0)
    vec = x0 / value
    value, vec, _ = jax.lax.while_loop(cond_fun, body_fun, (value, vec, 0))
    return value, vec

@jax.jit
def dense_matvec(A, x):
    return A @ x

@partial(jax.jit, static_argnums=(3,))
def coo_matvec(data, row, col, nrows, v):
    assert v.ndim == 1
    dv = data * v[col]
    return jnp.zeros(nrows, dtype=dv.dtype).at[row].add(dv)

n = 16
iters = 100
sparsity = 0.1
dtype = jnp.float32
key = jax.random.PRNGKey(0)
vals_key, mask_key, x0_key = jax.random.split(key, 3)
x0 = jax.random.normal(x0_key, shape=(n,), dtype=dtype)
a = jax.random.normal(vals_key, shape=(n, n), dtype=dtype)
mask = jax.random.uniform(mask_key, shape=(n, n), dtype=dtype) < sparsity
# strengthen diagonal so eigvals are more real
a = a + n * jnp.eye(n, dtype=dtype)
mask = jax.ops.index_update(
    mask, jax.ops.index[jnp.arange(n), jnp.arange(n)], jnp.ones((n,), dtype=bool)
)
# get coo data
row, col = jnp.where(mask)
data = a[row, col]

# create masked a
a = jax.ops.index_update(jnp.zeros((n, n), dtype=dtype), jax.ops.index[row, col], data)

w, v = jax.jit(jnp.linalg.eig, backend="cpu")(a)
wi = jnp.argmax(jnp.abs(w))
true_value = w[wi]
true_vec = v[:, wi]
print("True:")
print(true_value)
print(true_vec)

# our dense implemenetation
dense_fun = jax.tree_util.Partial(dense_matvec, a)
dense_value, dense_vec = power_iteration(dense_fun, x0, iters)
print("Dense:")
print(dense_value)
print(dense_vec)

# our coo implementation
coo_fun = jax.tree_util.Partial(coo_matvec, data, row, col, n)
coo_value, coo_vec = power_iteration(coo_fun, x0, iters)
print(coo_value, coo_vec)
print("COO:")
print(coo_value)
print(coo_vec)

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 3, 2021

I think my fix works there as well. Change the first function definition to this:

@partial(jax.jit, static_argnums=(0,))
def power_iteration(A_fun: Callable, x0, iterations):

In other words, a callable passed to a jitted function should always be marked static in that jitted function. I'm surprised it would ever work otherwise.

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 3, 2021

Digging a bit, I think I see what's going on here. With jax.tree_util.Partial, bound arguments become part of the pytree, and so they are traced in a jitted context:

jax/jax/tree_util.py

Lines 301 to 315 in cd4138b

class Partial(functools.partial):
"""A version of functools.partial that works in pytrees.
Use it for partial function evaluation in a way that is compatible with JAX's
transformations, e.g., ``Partial(func, *args, **kwargs)``.
(You need to explicitly opt-in to this behavior because we didn't want to give
functools.partial different semantics than normal function closures.)
"""
register_pytree_node(
Partial,
lambda partial_: ((partial_.args, partial_.keywords), partial_.func),
lambda func, xs: Partial(func, *xs[0], **xs[1]),
)

We could fix the issue you're seeing by leaving any static arguments out of the pytree produced by Partial; that would make things like this more consistent.

@jackd
Copy link
Contributor Author

jackd commented Feb 3, 2021

Hmm... I would have thought you could set things up such that the outer function doesn't need to be recompiled so long as the passed function doesn't need to be - e.g. changing the data in a shouldn't require a recompile in either the sparse or dense case, though it would if the shape / dtype / nnz (for the sparse case) changed.

@jakevdp
Copy link
Collaborator

jakevdp commented Feb 3, 2021

Yeah - I spent some time looking into this. JAX's current method of tracking static arguments makes it difficult to do things at this level of granularity. If you pass a jax.tree_util.Partial function as a non-static argument, all arguments to that function will be traced, and I don't think there's any way around this currently.

@jackd
Copy link
Contributor Author

jackd commented Feb 4, 2021

@jakevdp thanks for your investigation. I've got a work-around based on passing in a vector of size nrows with arbitrary values (sized in the code below). It's dirty, but it works for the moment...

@jax.jit
def coo_matvec(data, row, col, sized, v):
    assert v.ndim == 1
    dv = data * v[col]
    return jnp.zeros(sized.size, dtype=dv.dtype).at[row].add(dv)

@froystig froystig added enhancement New feature or request open Issues intentionally left open, with no schedule for next steps. labels Feb 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request open Issues intentionally left open, with no schedule for next steps.
Projects
None yet
Development

No branches or pull requests

3 participants