Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange behavior for jax.scipy.optimize.minimize when using vmap #5732

Closed
quattro opened this issue Feb 15, 2021 · 9 comments · Fixed by #5741
Closed

Strange behavior for jax.scipy.optimize.minimize when using vmap #5732

quattro opened this issue Feb 15, 2021 · 9 comments · Fixed by #5741
Assignees
Labels
bug Something isn't working

Comments

@quattro
Copy link

quattro commented Feb 15, 2021

Hi all, thanks for developing such a fantastic package. I've been excited to use jax in my day to day research, as well as trainees' research in my group.

I'm experimenting with applying vmap to a scalar function and passing it along to jax.scipy.optimize.minimize. I'm seeing a strange error regarding in_axes specification for vmap after the fact. Here is code to reproduce the error below:

import jax
import jax.numpy as jnp
import jax.scipy.optimize as sopt
import numpy as np

# set up keys
seed = 1234
rng_key = jax.random.PRNGKey(seed)

# split
rng_key, x_key, b_key, e_key = jax.random.split(rng_key, 4)

# gen random data and a binary outcome
X = jax.random.normal(x_key, shape=(10, 3))
B = jax.random.normal(b_key, shape=(3,2))
eps = jax.random.normal(e_key, shape=(10, 2))
tholds = jnp.array([0, 1])
Y = ((X @ B + eps) > tholds).astype(float)

# loglike
def loglike(b_hat, y):
    t1 = jnp.sum(-jnp.log1p(jnp.exp(X @ b_hat)), axis=0)
    t2 = jnp.sum(y * (X @ b_hat), axis=0)
    return t1 + t2

# nll for min
def nll(b_hat, y):
    return -loglike(b_hat, y)

# vmap to span dimensions
NLL = jax.vmap(nll, (1, 1), 0)

# test it out
print("testing nll = ", NLL(B, Y))

# run it
sopt.minimize(NLL, jnp.zeros((3, 2)), Y, method="BFGS")
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
testing nll =  [5.0628967 5.4723964]
Traceback (most recent call last):
  File "test.py", line 37, in <module>
    sopt.minimize(NLL, jnp.zeros((3, 2)), Y, method="BFGS")
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/minimize.py", line 96, in minimize
    results = minimize_bfgs(fun_with_args, x0, **options)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/bfgs.py", line 99, in minimize_bfgs
    f_0, g_0 = jax.value_and_grad(fun)(x0)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/minimize.py", line 93, in <lambda>
    fun_with_args = lambda x: fun(x, *args)
jax._src.traceback_util.FilteredStackTrace: ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (1, 1) for value tree PyTreeDef(tuple, [*,*,*,*,*,*,*,*,*,*,*]).

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "test.py", line 37, in <module>
    sopt.minimize(NLL, jnp.zeros((3, 2)), Y, method="BFGS")
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/minimize.py", line 96, in minimize
    results = minimize_bfgs(fun_with_args, x0, **options)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/bfgs.py", line 99, in minimize_bfgs
    f_0, g_0 = jax.value_and_grad(fun)(x0)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/api.py", line 808, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/api.py", line 1914, in _vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 114, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 506, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/minimize.py", line 93, in <lambda>
    fun_with_args = lambda x: fun(x, *args)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/api.py", line 1220, in batched_fun
    in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/api_util.py", line 188, in flatten_axes
    raise ValueError(f"{name} specification must be a tree prefix of the "
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (1, 1) for value tree PyTreeDef(tuple, [*,*,*,*,*,*,*,*,*,*,*]).

There is some similarly strange behavior when trying to vmap over minimize directly,

# try vmap across minimize directly
def my_min(b0, y):
    return sopt.minimize(nll, b0, y, method="BFGS")

vmin = jax.vmap(my_min, (1, 1), 0)
vmin(B, Y)

With output,

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
testing nll =  [5.0628967 5.4723964]
Traceback (most recent call last):
  File "test.py", line 44, in <module>
    vmin(B, Y)
  File "test.py", line 41, in my_min
    return sopt.minimize(nll, b0, y, method="BFGS")
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/minimize.py", line 96, in minimize
    results = minimize_bfgs(fun_with_args, x0, **options)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/bfgs.py", line 99, in minimize_bfgs
    f_0, g_0 = jax.value_and_grad(fun)(x0)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/minimize.py", line 93, in <lambda>
    fun_with_args = lambda x: fun(x, *args)
jax._src.traceback_util.FilteredStackTrace: TypeError: nll() takes 2 positional arguments but 11 were given

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "test.py", line 44, in <module>
    vmin(B, Y)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/api.py", line 1222, in batched_fun
    out_flat = batching.batch(
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "test.py", line 41, in my_min
    return sopt.minimize(nll, b0, y, method="BFGS")
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/minimize.py", line 96, in minimize
    results = minimize_bfgs(fun_with_args, x0, **options)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/bfgs.py", line 99, in minimize_bfgs
    f_0, g_0 = jax.value_and_grad(fun)(x0)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/api.py", line 808, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/api.py", line 1914, in _vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 114, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 506, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/scipy/optimize/minimize.py", line 93, in <lambda>
    fun_with_args = lambda x: fun(x, *args)
TypeError: nll() takes 2 positional arguments but 11 were given
@hawkinsp
Copy link
Member

hawkinsp commented Feb 16, 2021

I think the issue is that jax.scipy.optimize.minimize expects the args argument to be a tuple. By contrast, scipy casts the argument to a tuple if it is not already: https://github.com/scipy/scipy/blob/v1.6.0/scipy/optimize/_minimize.py#L45-L646

If you change your code to:

sopt.minimize(w, jnp.zeros((3, 2)), (Y,), method="BFGS")

you at least get a different error:

TypeError: Gradient only defined for scalar-output functions. Output had shape: (2,).

I think the right fix is either to similarly cast, or to check if the input is not a tuple and error.

@shoyer any opinions?

@hawkinsp hawkinsp added the bug Something isn't working label Feb 16, 2021
@hawkinsp hawkinsp self-assigned this Feb 16, 2021
@quattro
Copy link
Author

quattro commented Feb 16, 2021

Thanks for the speedy response! As far as the second error goes, would that be solved with a decorator or some partial eval in the original nll function masking y with argnums?

@shoyer
Copy link
Member

shoyer commented Feb 16, 2021

@hawkinsp I agree, either of those sounds fine to me.

@hawkinsp
Copy link
Member

I'll change it to error if passed something other than a tuple.

As to the second error, the contract of jax.scipy.minimize is: "the objective function to be minimized, fun(x, *args) -> float", i.e., it optimizes a scalar objective function. I don't know what you are trying to achieve mathematically here! vmap will turn a scalar-valued function into a vector-valued function; that's what it does. What did you intend to compute here?

@quattro
Copy link
Author

quattro commented Feb 16, 2021

@hawkinsp I'd like to independently minimize multiple functions that operate over the same variables, with different responses. In this toy example above, it would be fitting multiple logistic regression models over the same X but with different y values.

So rather than doing something like,

res_i = []
for i in range(Y.shape[1]):
   res_i.append(sopt.minimize(nll, jnp.zeros(3), args=(Y.T[i],), method='BFGS'))

I could make a single vmap'd call and retrieve results for each value in Y.

@hawkinsp
Copy link
Member

I suspect in that case you want to apply vmap to the minimize call itself? You don't want to perform a single minimization of a vector objective, you want to perform multiple scalar minimizations.

@quattro
Copy link
Author

quattro commented Feb 16, 2021

@hawkinsp , yes that makes sense. I realize now my earlier example where I tried vmap'ing minimize directly (my_min defined above) still contained the tuple-error which broke the internal call to minimize. Fixing the tuple-issue and re-running throws a new error related to the str messages in the output from minimize.

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "test.py", line 37, in <module>
    vmin(B, Y)
jax._src.traceback_util.FilteredStackTrace: TypeError: <class 'str'> is not a valid JAX type

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "test.py", line 37, in <module>
    vmin(B, Y)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/api.py", line 1222, in batched_fun
    out_flat = batching.batch(
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/linear_util.py", line 179, in call_wrapped
    ans = gen.send(ans)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/interpreters/batching.py", line 71, in _match_axes
    yield map(partial(matchaxis, axis_size), out_dims, out_dim_dests, out_vals)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/_src/util.py", line 41, in safe_map
    return list(map(f, *args))
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/interpreters/batching.py", line 394, in matchaxis
    if core.get_aval(x) is core.abstract_unit:
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/core.py", line 914, in get_aval
    return concrete_aval(x)
  File "/Users/nicholas/opt/miniconda3/lib/python3.8/site-packages/jax/core.py", line 907, in concrete_aval
    raise TypeError(f"{type(x)} is not a valid JAX type")
TypeError: <class 'str'> is not a valid JAX type

@hawkinsp
Copy link
Member

The issue is actually the .message field, which is a string, in the output of minimize. Try:

def my_min(b0, y):
    out = sopt.minimize(nll, b0, (y,), method="BFGS")
    return out._replace(message=None)

@shoyer
Copy link
Member

shoyer commented Feb 16, 2021

We should consider just removing the message field and moving it into the docstring, given that it never changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants