Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: iteration over a 0-d array #23

Closed
jondeaton opened this issue Feb 27, 2020 · 4 comments
Closed

TypeError: iteration over a 0-d array #23

jondeaton opened this issue Feb 27, 2020 · 4 comments

Comments

@jondeaton
Copy link

jondeaton commented Feb 27, 2020

When I run the following code, I get a TypeError: iteration over a 0-d array error from Jax. The code looks correct to me, and I don't understand where this error is coming from or how to fix the issue.

import jax
import jax.numpy as jnp
from flax import nn, optim

class CNN(nn.Module):
    def apply(self, x):
        x = nn.Conv(x, features=32, kernel_size=(3, 3))
        x = x.reshape((x.shape[0], -1))
        v = nn.Dense(x, features=1)
        return v

@jax.jit
def train_step(optimizer, observations, returns):
    def loss_fn(model):
        values_pred = model(observations)
        return jnp.square(values_pred - returns).mean()
    optimizer, _ = optimizer.optimize(loss_fn)
    return optimizer

batch_size = 7
input_shape = (batch_size, 32, 32, 5)

key = jax.random.PRNGKey(0)

_, model = CNN.create_by_shape(key, [(input_shape, jnp.float32)])
optimizer = optim.Adam(learning_rate=0.01).create(model)

observations = jax.random.normal(key, input_shape)
returns = jax.random.normal(key, (batch_size, ))

optimizer = train_step(optimizer, observations, returns)

produces the following stack trace

Traceback (most recent call last):
  File "test.py", line 33, in <module>
    optimizer = train_step(optimizer, observations, returns)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 150, in f_jitted
    name=flat_fun.__name__)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py", line 605, in call_bind
    outs = primitive.impl(f, *args, **params)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 449, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, *map(arg_spec, args))
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 223, in memoized_fun
    ans = call(fun, *args)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py", line 466, in _xla_callable
    jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 152, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "test.py", line 18, in train_step
    optimizer, l, logits = optimizer.optimize(loss_fn)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/optim.py", line 266, in optimize
    loss, aux, grad = self.compute_gradients(loss_fn)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/optim.py", line 250, in compute_gradients
    (loss, aux), grad = grad_fn(self.target)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 413, in value_and_grad_f
    ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 1293, in vjp
    out_primal, out_vjp, aux = ad.vjp(flat_fun, primals_flat, has_aux=True)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/ad.py", line 111, in vjp
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/ad.py", line 98, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 337, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 156, in call_wrapped
    ans = gen.send(ans)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api_util.py", line 72, in flatten_fun_nokwargs2
    ans, aux = yield py_args, {}
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py", line 300, in __iter__
    return iter(self.aval._iter(self))
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/lax/lax.py", line 1497, in _iter
    raise TypeError("iteration over a 0-d array")  # same as numpy error
TypeError: iteration over a 0-d array

Any help is greatly appreciated!

@jondeaton jondeaton changed the title TypeError: iteration over a 0-d array Strange error, TypeError: iteration over a 0-d array Feb 27, 2020
@jondeaton jondeaton changed the title Strange error, TypeError: iteration over a 0-d array TypeError: iteration over a 0-d array Feb 27, 2020
@avital
Copy link
Contributor

avital commented Mar 6, 2020

Hi @jondeaton, I believe your returns should be target vectors so maybe try changing to returns = jax.random.normal(key, (batch_size, 1))

(though the fact that this error is what appears should be considered a bug!)

@jondeaton
Copy link
Author

jondeaton commented Mar 7, 2020

Thanks for the suggestion! Unfortunately, I'm still encountering the same problem after changing returns as you recommend. Although, I tried reinstalling flax at head and now I am getting a different error / stack trace, (same exact code though)

Traceback (most recent call last):
  File "test.py", line 25, in <module>
    _, model = CNN.create_by_shape(key, [(input_shape, jnp.float32)])
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/nn/base.py", line 261, in wrapper
    return super_fn(*args, **kwargs)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/nn/base.py", line 381, in create_by_shape
    return jax_utils.partial_eval_by_shape(lazy_create, input_specs)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/jax_utils.py", line 86, in partial_eval_by_shape
    output_shapes = jax.eval_shape(lazy_fn, *input_structs)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py", line 2042, in eval_shape
    out = pe.abstract_eval_fun(fun.call_wrapped, *map(abstractify, args_flat))
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 273, in abstract_eval_fun
    instantiate=True)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 354, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py", line 149, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/jonpdeaton/miniconda3/envs/jax/lib/python3.7/site-packages/flax/jax_utils.py", line 79, in lazy_fn
    master = leaves[0].trace.master
AttributeError: 'function' object has no attribute 'master'

Any ideas why this is occurring or how I could contribute a patch that would fix this bug? I'm not able to continue with my RL project that I'm trying to use flax for, so I'd be happy to send a PR but I'm not really sure where to start looking to fix this.

@avital
Copy link
Contributor

avital commented Mar 10, 2020

Hi Jon, I was able to reproduce this, but once I upgraded to latest flax, jax (which required a new version of jaxlib as well) this was resolved. Can you give that a shot?

@marcvanzee
Copy link
Collaborator

I managed to run it without errors as well using the latest version of flax, so I am closing this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants