-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to compute Jacobian of outputs w.r.t. inputs #224
Comments
You can reimplement object-oriented version of Another option is to use combination of objax/objax/privacy/dpsgd/gradient.py Line 76 in c4785ff
You can not use functional jax transformations with Objax (like All JAX primitives are stateless and pure functional (i.e. don't have and don't assume side-effects). Objax provides wrappers for JAX primitives to simplify state management and make is more natural for machine learning applications. So if you try to mix JAX functional tranformations with Objax primitives it will break the state management and either code won't work at all or will work incorrectly.
As I mentioned above, Objax provides wrappers which simplify state management. |
I see, thank you for your explanation! |
It is hard to implement object-oriented version of |
Hi,
I am new to JAX and Objax, and I would like to compute the "partial derivative" of outputs w.r.t. inputs, below is a piece of code
The doc suggests do not mix JAX and Objax's transformation, and my question is:
jacfwd
orjacrev
, so what is the standard way to calculate Jacobian?vmap
andobjax.Vectorize
?Thanks.
The text was updated successfully, but these errors were encountered: