Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partial eval silently skips effects #21713

Open
Alan-Chen99 opened this issue Jun 6, 2024 · 3 comments
Open

partial eval silently skips effects #21713

Alan-Chen99 opened this issue Jun 6, 2024 · 3 comments
Assignees
Labels
question Questions for the JAX team

Comments

@Alan-Chen99
Copy link

Alan-Chen99 commented Jun 6, 2024

Description

import jax
from jax.experimental import checkify
from jax.interpreters import partial_eval


def fn(x):
    checkify.check(x < 10, "checkify")
    jax.debug.print("callback: {}", x)
    return x


jaxpr = jax.make_jaxpr(fn)(1)
jaxpr_known, jaxpr_unkown, out_unknowns, out_avals = (
    partial_eval.partial_eval_jaxpr_nounits(jaxpr, unknowns=[True], instantiate=False)
)
print(jaxpr_known, jaxpr_unkown, out_unknowns, out_avals)

prints

{ lambda ; . let _:i32[] = select_n False 1 -1 in () } { lambda ; a:i32[]. let  in (a,) } [True] []

I expected jaxpr_unkown to contain the effects.

possible workaround: wrap unknown args using a dynamic=False jaxpr trace. I also had to change how pjit_p is handled.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.12.2 (main, Feb 6 2024, 20:19:44) [GCC 13.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1

@Alan-Chen99 Alan-Chen99 added the bug Something isn't working label Jun 6, 2024
@mattjj mattjj self-assigned this Jun 6, 2024
@mattjj
Copy link
Member

mattjj commented Jun 6, 2024

Thanks for raising this!

Actually, I'm not sure if we should call this a 'bug', since this is an internal API. Is there an issue you're seeing in a public API?

This is one of the main reasons partial_eval_jaxpr_stateful exists. The two may be merged someday, but for the moment this is essentially expected behavior for trace_to_jaxpr_nounits.

Can you say more about what problem you're trying to solve?

@mattjj mattjj added question Questions for the JAX team and removed bug Something isn't working labels Jun 6, 2024
@mattjj
Copy link
Member

mattjj commented Jun 6, 2024

Here's how you would call it, but this is a very internal API so beware:

import jax
from jax.experimental import checkify
from jax._src.interpreters import partial_eval as pe


def fn(x):
    checkify.check(x < 10, "checkify")
    jax.debug.print("callback: {}", x)
    return x


jaxpr = jax.make_jaxpr(fn)(1)
jaxpr_known, jaxpr_unkown, out_unknowns, *_ = (
    pe.partial_eval_jaxpr_stateful(jaxpr.jaxpr, in_unknowns=[True],
                                   in_inst=[True], ensure_out_unknowns=False,
                                   ensure_out_inst=False, saveable=lambda *_,
                                   **__: True)
)
print(jaxpr_known, jaxpr_unkown, out_unknowns)

@Alan-Chen99
Copy link
Author

Thank you!

while im not using a public api, im during something similar to jax.linearize, which also have this problem

import jax
from jax import Array
from jax import numpy as jnp
from jax._src import checkify, effects

# effects.custom_derivatives_allowed_effects.add_type(checkify.ErrorEffect)


@jax.custom_jvp
def fn(x: Array):
    return x


@fn.defjvp
def testfn_jvp(primals: tuple[Array, ...], tangents: tuple[Array, ...]):

    (x,) = primals
    (tg,) = tangents

    jax.debug.print("callback: {}", tg)
    # checkify.check(tg < 0, "invalid tangent {}", tg)

    return x, tg


# jvp works
# jax.jvp(fn, [jnp.array(1.0)], [jnp.array(2.0)])

val, tang = jax.linearize(fn, jnp.array(1.0))
tang(jnp.array(2.0))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

2 participants