-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Strange behavior for jax.scipy.optimize.minimize when using vmap #5732
Comments
I think the issue is that If you change your code to:
you at least get a different error:
I think the right fix is either to similarly cast, or to check if the input is not a tuple and error. @shoyer any opinions? |
Thanks for the speedy response! As far as the second error goes, would that be solved with a decorator or some partial eval in the original |
@hawkinsp I agree, either of those sounds fine to me. |
I'll change it to error if passed something other than a tuple. As to the second error, the contract of |
@hawkinsp I'd like to independently minimize multiple functions that operate over the same variables, with different responses. In this toy example above, it would be fitting multiple logistic regression models over the same So rather than doing something like, res_i = []
for i in range(Y.shape[1]):
res_i.append(sopt.minimize(nll, jnp.zeros(3), args=(Y.T[i],), method='BFGS')) I could make a single |
I suspect in that case you want to apply |
@hawkinsp , yes that makes sense. I realize now my earlier example where I tried
|
The issue is actually the
|
We should consider just removing the |
Hi all, thanks for developing such a fantastic package. I've been excited to use jax in my day to day research, as well as trainees' research in my group.
I'm experimenting with applying
vmap
to a scalar function and passing it along tojax.scipy.optimize.minimize
. I'm seeing a strange error regardingin_axes
specification forvmap
after the fact. Here is code to reproduce the error below:There is some similarly strange behavior when trying to
vmap
overminimize
directly,With output,
The text was updated successfully, but these errors were encountered: