Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The error occurred while tracing the function init at ninjax.py #1

Closed
emigmo opened this issue Feb 17, 2023 · 8 comments
Closed

The error occurred while tracing the function init at ninjax.py #1

emigmo opened this issue Feb 17, 2023 · 8 comments

Comments

@emigmo
Copy link

emigmo commented Feb 17, 2023

run demo in docker:

sh dreamerv3/embodied/scripts/xvfb_run.sh python3 dreamerv3/train.py   --configs dmc_vision --task dmc_walker_walk 
│ /dreamerv3/dreamerv3/ninjax.py:245 in scan                                                       │
│                                                                                                  │
│   242 @jax.named_scope('scan')                                                                   │
│   243 def scan(fun, carry, xs, reverse=False, unroll=1, modify=False):                           │
│   244   fun = pure(fun, nested=True)                                                             │
│ ❱ 245   _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs))                          │
│   246   length = len(jax.tree_util.tree_leaves(xs)[0])                                           │
│   247   rngs = rng(length)                                                                       │
│   248   if modify:                                                                               │
│                                                                                                  │
│ /usr/lib/python3.8/contextlib.py:75 in inner                                                     │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /dreamerv3/dreamerv3/ninjax.py:272 in _prerun                                                    │
│                                                                                                  │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│ ❱ 272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│                                                                                                  │
│ /dreamerv3/dreamerv3/ninjax.py:273 in <lambda>                                                   │
│                                                                                                  │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│ ❱ 273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│   276                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float16[16,4096])>with<DynamicJaxprTrace(level=1/0)>
The delete() method was called on the JAX Tracer object Traced<ShapedArray(float16[16,4096])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function init at /dreamerv3/dreamerv3/ninjax.py:163 for jit. This concrete value was not available in Python because it depends on the values of the arguments 'statics', 
'rng', and 'args'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
@IcarusWizard
Copy link

Same error here when running the example.py. Could be a wrong jax version? I see jax releases rapidly recently.

@IcarusWizard
Copy link

problem solved by downgrading jax to 0.3.25.

@danijar
Copy link
Owner

danijar commented Feb 17, 2023

Thanks for reporting! It looks like an issue with the newest JAX release to me. I'll keep this open for now and will investigate further if it continues to be a problem.

@Alian3785
Copy link

Yes I ran it in codespaces and got this error too. Changing requirments.txt to this helped

image

@danijar
Copy link
Owner

danijar commented Feb 18, 2023

Thanks for confirming! I've pinned the JAX version in requirements.txt for now. If the issue remains with the next JAX release, I'll investigate further.

@danijar
Copy link
Owner

danijar commented Feb 20, 2023

I think I've fixed the issue with the newest JAX version, just by commenting out the delete() call that wasn't necessary. @IcarusWizard @Alian3785 would you mind giving it another try with the newest JAX version?

@IcarusWizard
Copy link

IcarusWizard commented Feb 20, 2023

I tried example.py and it is runnable after commenting out these two lines. Jax version 0.4.4.

@danijar
Copy link
Owner

danijar commented Feb 20, 2023

Great, thanks! That works for me, too.

@danijar danijar closed this as completed Feb 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants