Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] some bug for jax gpu version #5

Closed
hdadong opened this issue Feb 19, 2023 · 4 comments
Closed

[BUG] some bug for jax gpu version #5

hdadong opened this issue Feb 19, 2023 · 4 comments

Comments

@hdadong
Copy link

hdadong commented Feb 19, 2023

My device is ubuntu20.04 NVIDIA-SMI 515.86.01 Driver Version: 515.86.01 CUDA Version: 11.7, CUDNN 870, Python3.8
pip list:

when I try to execute 'python example.py', I got the follwing bug:

(dreamerv3) weidong@user-NULL:~/dreamerv3$ python example.py
2023-02-19 16:52:05.740320: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 16:52:05.740423: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 16:52:05.740432: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/weidong/dreamerv3/example.py:48 in <module>                                                │
│                                                                                                  │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│ ❱ 48   main()                                                                                    │
│   49                                                                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/example.py:40 in main                                                    │
│                                                                                                  │
│   37   env = dreamerv3.wrap_env(env, config.wrapper)                                             │
│   38   env = embodied.BatchEnv([env], parallel=False)                                            │
│   39                                                                                             │
│ ❱ 40   agent = dreamerv3.Agent(env.obs_space, env.act_space, step, config)                       │
│   41   replay = embodied.replay.Uniform(                                                         │
│   42 │     config.batch_length, config.replay_size, logdir / 'replay')                           │
│   43   args = config.run.update(batch_steps=config.batch_size * config.batch_length)             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:20 in __init__                                     │
│                                                                                                  │
│    17 │   configs = agent_cls.configs                                                            │
│    18 │   inner = agent_cls                                                                      │
│    19 │   def __init__(self, obs_space, act_space, step, config):                                │
│ ❱  20 │     super().__init__(agent_cls, obs_space, act_space, step, config)                      │
│    21   return Agent                                                                             │
│    22                                                                                            │
│    23                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:28 in __init__                                     │
│                                                                                                  │
│    25                                                                                            │
│    26   def __init__(self, agent_cls, obs_space, act_space, step, config):                       │
│    27 │   self.config = config.jax                                                               │
│ ❱  28 │   self.setup()                                                                           │
│    29 │   self.agent = agent_cls(obs_space, act_space, step, config, name='agent')               │
│    30 │   self.rng = jaxutils.RNG(config.seed)                                                   │
│    31 │   self.varibs = {}                                                                       │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:73 in setup                                        │
│                                                                                                  │
│    70 │   if self.config.platform == 'cpu':                                                      │
│    71 │     jax.config.update('jax_disable_most_optimizations', self.config.debug)               │
│    72 │   jaxutils.COMPUTE_DTYPE = getattr(jnp, self.config.precision)                           │
│ ❱  73 │   print(f'JAX DEVICES ({jax.local_device_count()}):', jax.devices())                     │
│    74                                                                                            │
│    75   def train(self, data, state=None):                                                       │
│    76 │   data = self._convert_inps(data)                                                        │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:45 │
│ 7 in local_device_count                                                                          │
│                                                                                                  │
│   454                                                                                            │
│   455 def local_device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:           │
│   456   """Returns the number of devices addressable by this process."""                         │
│ ❱ 457   return int(get_backend(backend).local_device_count())                                    │
│   458                                                                                            │
│   459                                                                                            │
│   460 def devices(backend: Optional[Union[str, XlaBackend]] = None) -> List[xla_client.Device]   │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:42 │
│ 5 in get_backend                                                                                 │
│                                                                                                  │
│   422                                                                                            │
│   423 @lru_cache(maxsize=None)  # don't use util.memoize because there is no X64 dependence.     │
│   424 def get_backend(platform=None):                                                            │
│ ❱ 425   return _get_backend_uncached(platform)                                                   │
│   426                                                                                            │
│   427                                                                                            │
│   428 def get_device_backend(device=None):                                                       │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:41 │
│ 1 in _get_backend_uncached                                                                       │
│                                                                                                  │
│   408                                                                                            │
│   409   bs = backends()                                                                          │
│   410   if platform is not None:                                                                 │
│ ❱ 411 │   platform = canonicalize_platform(platform)                                             │
│   412 │   backend = bs.get(platform, None)                                                       │
│   413 │   if backend is None:                                                                    │
│   414 │     if platform in _backends_errors:                                                     │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/lib/xla_bridge.py:29 │
│ 4 in canonicalize_platform                                                                       │
│                                                                                                  │
│   291   for p in platforms:                                                                      │
│   292 │   if p in b.keys():                                                                      │
│   293 │     return p                                                                             │
│ ❱ 294   raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "                   │
│   295 │   │   │   │   │    f"platforms that are instances of {platform} are present. "           │
│   296 │   │   │   │   │    "Platforms present are: " + ",".join(b.keys()))                       │
│   297                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. 
Platforms present are: interpreter,cpu

After read the jax GPU gudide install the GPU jax and GPU tensorflow :

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install tensorflow

But I got an another bug follow:

(dreamerv3) weidong@user-NULL:~/dreamerv3$ python example.py
2023-02-19 17:07:51.659700: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 17:07:51.659805: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000:/home/weidong/.mujoco/mujoco210/bin:/usr/lib/nvidia:/home/weidong/.mujoco/mujoco200/bin:/usr/lib/nvidia-000
2023-02-19 17:07:51.659814: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
JAX DEVICES (8): [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=2, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=3, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=4, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=5, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=6, process_index=0, slice_index=0), StreamExecutorGpuDevice(id=7, process_index=0, slice_index=0)]
Encoder CNN shapes: {'image': (64, 64, 3)}
Encoder MLP shapes: {}
Decoder CNN shapes: {'image': (64, 64, 3)}
Decoder MLP shapes: {}
Logdir /home/weidong/logdir/run1
Observation space:
  image            Space(dtype=uint8, shape=(64, 64, 3), low=0, high=255)
  reward           Space(dtype=float32, shape=(), low=-inf, high=inf)
  is_first         Space(dtype=bool, shape=(), low=False, high=True)
  is_last          Space(dtype=bool, shape=(), low=False, high=True)
  is_terminal      Space(dtype=bool, shape=(), low=False, high=True)
Action space:
  action           Space(dtype=float32, shape=(17,), low=0, high=1)
  reset            Space(dtype=bool, shape=(), low=False, high=True)
Fill train dataset (1024 steps).
Episode has 147 steps and return 2.1.
Episode has 305 steps and return 2.1.
Episode has 110 steps and return 1.1.
Episode has 176 steps and return 0.1.
Episode has 140 steps and return 0.1.
───────────────────────────────────────────────── Step 1024 ─────────────────────────────────────────────────
episode/length 140 / episode/score 0.1 / episode/sum_abs_reward 2.1 / episode/reward_rate 0.01

Creating new TensorBoard event file writer.
Saved chunk: 20230219T170756F065708-18SVIAerO9mVKu8c3SI3e2-1Oa8uUlL3lQQnPBO7aaVU7-1024.npz
Tracing train function.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/weidong/dreamerv3/example.py:48 in <module>                                                │
│                                                                                                  │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│ ❱ 48   main()                                                                                    │
│   49                                                                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/example.py:44 in main                                                    │
│                                                                                                  │
│   41   replay = embodied.replay.Uniform(                                                         │
│   42 │     config.batch_length, config.replay_size, logdir / 'replay')                           │
│   43   args = config.run.update(batch_steps=config.batch_size * config.batch_length)             │
│ ❱ 44   embodied.run.train(agent, env, replay, logger, args)                                      │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/embodied/run/train.py:79 in train                              │
│                                                                                                  │
│    76   for _ in range(args.pretrain):                                                           │
│    77 │   with timer.scope('dataset'):                                                           │
│    78 │     batch = next(dataset)                                                                │
│ ❱  79 │   _, state[0], _ = agent.train(batch, state[0])                                          │
│    80                                                                                            │
│    81   batch = [None]                                                                           │
│    82   def train_step(tran, worker):                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:80 in train                                        │
│                                                                                                  │
│    77 │   rng = self._next_rngs(mirror=not self.varibs)                                          │
│    78 │   if state is None:                                                                      │
│    79 │     state, self.varibs = self._init_train(self.varibs, rng, data['is_first'])            │
│ ❱  80 │   (outs, state, mets), self.varibs = self._train(                                        │
│    81 │   │   self.varibs, rng, data, state)                                                     │
│    82 │   outs = self._convert_outs(outs)                                                        │
│    83 │   mets = self._convert_mets(mets)                                                        │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:181 in wrapper                                       │
│                                                                                                  │
│   178 │   statics = tuple(sorted([(k, v) for k, v in kwargs.items() if k in static]))            │
│   179 │   kwargs = {k: v for k, v in kwargs.items() if k not in static}                          │
│   180 │   if not hasattr(wrapper, 'keys'):                                                       │
│ ❱ 181 │     created = init(statics, rng, *args, **kwargs)                                        │
│   182 │     wrapper.keys = set(created.keys())                                                   │
│   183 │     for key, value in created.items():                                                   │
│   184 │   │   if key not in state:                                                               │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/traceback_util.py:16 │
│ 3 in reraise_with_filtered_traceback                                                             │
│                                                                                                  │
│   160   def reraise_with_filtered_traceback(*args, **kwargs):                                    │
│   161 │   __tracebackhide__ = True                                                               │
│   162 │   try:                                                                                   │
│ ❱ 163 │     return fun(*args, **kwargs)                                                          │
│   164 │   except Exception as e:                                                                 │
│   165 │     mode = filtering_mode()                                                              │
│   166 │     if is_under_reraiser(e) or mode == "off":                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:237 in       │
│ cache_miss                                                                                       │
│                                                                                                  │
│    234                                                                                           │
│    235   @api_boundary                                                                           │
│    236   def cache_miss(*args, **kwargs):                                                        │
│ ❱  237 │   outs, out_flat, out_tree, args_flat = _python_pjit_helper(                            │
│    238 │   │   fun, infer_params_fn, *args, **kwargs)                                            │
│    239 │                                                                                         │
│    240 │   executable = _read_most_recent_pjit_call_executable()                                 │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:180 in       │
│ _python_pjit_helper                                                                              │
│                                                                                                  │
│    177                                                                                           │
│    178                                                                                           │
│    179 def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):                           │
│ ❱  180   args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(                           │
│    181 │     *args, **kwargs)                                                                    │
│    182   for arg in args_flat:                                                                   │
│    183 │   dispatch.check_arg(arg)                                                               │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/api.py:443 in        │
│ infer_params                                                                                     │
│                                                                                                  │
│    440 │   │     static_argnames=static_argnames, donate_argnums=donate_argnums,                 │
│    441 │   │     device=device, backend=backend, keep_unused=keep_unused,                        │
│    442 │   │     inline=inline, resource_env=None)                                               │
│ ❱  443 │     return pjit.common_infer_params(pjit_info_args, *args, **kwargs)                    │
│    444 │                                                                                         │
│    445 │   has_explicit_sharding = pjit._pjit_explicit_sharding(                                 │
│    446 │   │   in_shardings, out_shardings, device, backend)                                     │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:520 in       │
│ common_infer_params                                                                              │
│                                                                                                  │
│    517 │     hashable_pytree(in_shardings), local_in_avals, in_tree, in_positional_semantics,    │
│    518 │     tuple(isinstance(a, GDA) for a in args_flat), resource_env)                         │
│    519                                                                                           │
│ ❱  520   jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(                          │
│    521 │     flat_fun, hashable_pytree(out_shardings), global_in_avals,                          │
│    522 │     HashableFunction(out_tree, closure=()),                                             │
│    523 │     ('jit' if resource_env is None else 'pjit'))                                        │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:301   │
│ in memoized_fun                                                                                  │
│                                                                                                  │
│   298 │     ans, stores = result                                                                 │
│   299 │     fun.populate_stores(stores)                                                          │
│   300 │   else:                                                                                  │
│ ❱ 301 │     ans = call(fun, *args)                                                               │
│   302 │     cache[key] = (ans, fun.stores)                                                       │
│   303 │                                                                                          │
│   304 │   return ans                                                                             │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/pjit.py:932 in       │
│ _pjit_jaxpr                                                                                      │
│                                                                                                  │
│    929 │   with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "     │
│    930 │   │   │   │   │   │   │   │      "for pjit in {elapsed_time} sec",                      │
│    931 │   │   │   │   │   │   │   │   │   event=dispatch.JAXPR_TRACE_EVENT):                    │
│ ❱  932 │     jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(                        │
│    933 │   │     fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))            │
│    934   finally:                                                                                │
│    935 │   pxla.positional_semantics.val = prev_positional_val                                   │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/profiler.py:314 in   │
│ wrapper                                                                                          │
│                                                                                                  │
│   311   @wraps(func)                                                                             │
│   312   def wrapper(*args, **kwargs):                                                            │
│   313 │   with TraceAnnotation(name, **decorator_kwargs):                                        │
│ ❱ 314 │     return func(*args, **kwargs)                                                         │
│   315 │   return wrapper                                                                         │
│   316   return wrapper                                                                           │
│   317                                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:1985 in trace_to_jaxpr_dynamic                                                               │
│                                                                                                  │
│   1982 │   │   │   │   │   │      keep_inputs: Optional[List[bool]] = None):                     │
│   1983   with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore            │
│   1984 │   main.jaxpr_stack = ()  # type: ignore                                                 │
│ ❱ 1985 │   jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(                                 │
│   1986 │     fun, main, in_avals, keep_inputs=keep_inputs, debug_info=debug_info)                │
│   1987 │   del main, fun                                                                         │
│   1988   return jaxpr, out_avals, consts                                                         │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/interpreters/partial_eval │
│ .py:2002 in trace_to_subjaxpr_dynamic                                                            │
│                                                                                                  │
│   1999 │   trace = DynamicJaxprTrace(main, core.cur_sublevel())                                  │
│   2000 │   in_tracers = _input_type_to_tracers(trace.new_arg, in_avals)                          │
│   2001 │   in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]                 │
│ ❱ 2002 │   ans = fun.call_wrapped(*in_tracers_)                                                  │
│   2003 │   out_tracers = map(trace.full_raise, ans)                                              │
│   2004 │   jaxpr, consts = frame.to_jaxpr(out_tracers)                                           │
│   2005 │   del fun, main, trace, frame, in_tracers, out_tracers, ans                             │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/linear_util.py:165   │
│ in call_wrapped                                                                                  │
│                                                                                                  │
│   162 │   gen = gen_static_args = out_store = None                                               │
│   163 │                                                                                          │
│   164 │   try:                                                                                   │
│ ❱ 165 │     ans = self.f(*args, **dict(self.params, **kwargs))                                   │
│   166 │   except:                                                                                │
│   167 │     # Some transformations yield from inside context managers, so we have to             │
│   168 │     # interrupt them before reraising the exception. Otherwise they will only            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:166 in init                                          │
│                                                                                                  │
│   163   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│   164   def init(statics, rng, *args, **kwargs):                                                 │
│   165 │   # Return only state so JIT can remove dead code for fast initialization.               │
│ ❱ 166 │   s = fun({}, rng, *args, ignore=True, **dict(statics), **kwargs)[1]                     │
│   167 │   return s                                                                               │
│   168                                                                                            │
│   169   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified                                       │
│                                                                                                  │
│    74 │   before = CONTEXT                                                                       │
│    75 │   try:                                                                                   │
│    76 │     CONTEXT = Context(state.copy(), rng, create, modify, ignore, [])                     │
│ ❱  77 │     out = fun(*args, **kwargs)                                                           │
│    78 │     state = dict(CONTEXT)                                                                │
│    79 │     return out, state                                                                    │
│    80 │   finally:                                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:80 in train                                           │
│                                                                                                  │
│    77 │   self.config.jax.jit and print('Tracing train function.')                               │
│    78 │   metrics = {}                                                                           │
│    79 │   data = self.preprocess(data)                                                           │
│ ❱  80 │   state, wm_outs, mets = self.wm.train(data, state)                                      │
│    81 │   metrics.update(mets)                                                                   │
│    82 │   context = {**data, **wm_outs['post']}                                                  │
│    83 │   start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context)               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:151 in train                                          │
│                                                                                                  │
│   148                                                                                            │
│   149   def train(self, data, state):                                                            │
│   150 │   modules = [self.encoder, self.rssm, *self.heads.values()]                              │
│ ❱ 151 │   mets, (state, outs, metrics) = self.opt(                                               │
│   152 │   │   modules, self.loss, data, state, has_aux=True)                                     │
│   153 │   metrics.update(mets)                                                                   │
│   154 │   return state, outs, metrics                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:410 in __call__                                    │
│                                                                                                  │
│   407 │   │   loss *= sg(self.grad_scale.read())                                                 │
│   408 │     return loss, aux                                                                     │
│   409 │   metrics = {}                                                                           │
│ ❱ 410 │   loss, params, grads, aux = nj.grad(                                                    │
│   411 │   │   wrapped, modules, has_aux=True)(*args, **kwargs)                                   │
│   412 │   if not self.PARAM_COUNTS[self.path]:                                                   │
│   413 │     count = sum([np.prod(x.shape) for x in params.values()])                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:142 in wrapper                                       │
│                                                                                                  │
│   139   backward = jax.value_and_grad(forward, has_aux=True)                                     │
│   140   @functools.wraps(backward)                                                               │
│   141   def wrapper(*args, **kwargs):                                                            │
│ ❱ 142 │   _prerun(fun, *args, **kwargs)                                                          │
│   143 │   assert all(isinstance(x, (str, Module)) for x in keys)                                 │
│   144 │   strs = [x for x in keys if isinstance(x, str)]                                         │
│   145 │   mods = [x for x in keys if isinstance(x, Module)]                                      │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:271 in _prerun                                       │
│                                                                                                  │
│   268 def _prerun(fun, *args, **kwargs):                                                         │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│ ❱ 271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified                                       │
│                                                                                                  │
│    74 │   before = CONTEXT                                                                       │
│    75 │   try:                                                                                   │
│    76 │     CONTEXT = Context(state.copy(), rng, create, modify, ignore, [])                     │
│ ❱  77 │     out = fun(*args, **kwargs)                                                           │
│    78 │     state = dict(CONTEXT)                                                                │
│    79 │     return out, state                                                                    │
│    80 │   finally:                                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:402 in wrapped                                     │
│                                                                                                  │
│   399                                                                                            │
│   400   def __call__(self, modules, lossfn, *args, has_aux=False, **kwargs):                     │
│   401 │   def wrapped(*args, **kwargs):                                                          │
│ ❱ 402 │     outs = lossfn(*args, **kwargs)                                                       │
│   403 │     loss, aux = outs if has_aux else (outs, None)                                        │
│   404 │     assert loss.dtype == jnp.float32, (self.name, loss.dtype)                            │
│   405 │     assert loss.shape == (), (self.name, loss.shape)                                     │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:161 in loss                                           │
│                                                                                                  │
│   158 │   prev_latent, prev_action = state                                                       │
│   159 │   prev_actions = jnp.concatenate([                                                       │
│   160 │   │   prev_action[:, None], data['action'][:, :-1]], 1)                                  │
│ ❱ 161 │   post, prior = self.rssm.observe(                                                       │
│   162 │   │   embed, prev_actions, data['is_first'], prev_latent)                                │
│   163 │   dists = {}                                                                             │
│   164 │   feats = {**post, 'embed': embed}                                                       │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/nets.py:60 in observe                                          │
│                                                                                                  │
│    57 │   step = lambda prev, inputs: self.obs_step(prev[0], *inputs)                            │
│    58 │   inputs = swap(action), swap(embed), swap(is_first)                                     │
│    59 │   start = state, state                                                                   │
│ ❱  60 │   post, prior = jaxutils.scan(step, inputs, start, self._unroll)                         │
│    61 │   post = {k: swap(v) for k, v in post.items()}                                           │
│    62 │   prior = {k: swap(v) for k, v in prior.items()}                                         │
│    63 │   return post, prior                                                                     │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:73 in scan                                         │
│                                                                                                  │
│    70 def scan(fn, inputs, start, unroll=True, modify=False):                                    │
│    71   fn2 = lambda carry, inp: (fn(carry, inp),) * 2                                           │
│    72   if not unroll:                                                                           │
│ ❱  73 │   return nj.scan(fn2, start, inputs, modify=modify)[1]                                   │
│    74   length = len(jax.tree_util.tree_leaves(inputs)[0])                                       │
│    75   carrydef = jax.tree_util.tree_structure(start)                                           │
│    76   carry = start                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:245 in scan                                          │
│                                                                                                  │
│   242 @jax.named_scope('scan')                                                                   │
│   243 def scan(fun, carry, xs, reverse=False, unroll=1, modify=False):                           │
│   244   fun = pure(fun, nested=True)                                                             │
│ ❱ 245   _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs))                          │
│   246   length = len(jax.tree_util.tree_leaves(xs)[0])                                           │
│   247   rngs = rng(length)                                                                       │
│   248   if modify:                                                                               │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:272 in _prerun                                       │
│                                                                                                  │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│ ❱ 272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/tree_util.py:207 in  │
│ tree_map                                                                                         │
│                                                                                                  │
│   204   """                                                                                      │
│   205   leaves, treedef = tree_flatten(tree, is_leaf)                                            │
│   206   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]                         │
│ ❱ 207   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))                              │
│   208                                                                                            │
│   209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any:                                        │
│   210   return treedef.from_iterable_tree(xs)                                                    │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/tree_util.py:207 in  │
│ <genexpr>                                                                                        │
│                                                                                                  │
│   204   """                                                                                      │
│   205   leaves, treedef = tree_flatten(tree, is_leaf)                                            │
│   206   all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]                         │
│ ❱ 207   return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))                              │
│   208                                                                                            │
│   209 def build_tree(treedef: PyTreeDef, xs: Any) -> Any:                                        │
│   210   return treedef.from_iterable_tree(xs)                                                    │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:273 in <lambda>                                      │
│                                                                                                  │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│ ❱ 273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│   276                                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/site-packages/jax/_src/core.py:734 in       │
│ delete                                                                                           │
│                                                                                                  │
│    731 │     f"The 'copy_to_host_async' method is not available on the JAX Tracer object {self}  │
│    732                                                                                           │
│    733   def delete(self):                                                                       │
│ ❱  734 │   raise ConcretizationTypeError(self,                                                   │
│    735 │     f"The delete() method was called on the JAX Tracer object {self}")                  │
│    736                                                                                           │
│    737   def device(self):                                                                       │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where 
concrete value is expected: Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The delete() method was called on the JAX Tracer object 
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function init at /home/weidong/dreamerv3/dreamerv3/ninjax.py:163 for 
jit. This concrete value was not available in Python because it depends on the values of the arguments 
'statics', 'rng', and 'args'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/weidong/dreamerv3/example.py:48 in <module>                                                │
│                                                                                                  │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│ ❱ 48   main()                                                                                    │
│   49                                                                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/example.py:44 in main                                                    │
│                                                                                                  │
│   41   replay = embodied.replay.Uniform(                                                         │
│   42 │     config.batch_length, config.replay_size, logdir / 'replay')                           │
│   43   args = config.run.update(batch_steps=config.batch_size * config.batch_length)             │
│ ❱ 44   embodied.run.train(agent, env, replay, logger, args)                                      │
│   45                                                                                             │
│   46                                                                                             │
│   47 if __name__ == '__main__':                                                                  │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/embodied/run/train.py:79 in train                              │
│                                                                                                  │
│    76   for _ in range(args.pretrain):                                                           │
│    77 │   with timer.scope('dataset'):                                                           │
│    78 │     batch = next(dataset)                                                                │
│ ❱  79 │   _, state[0], _ = agent.train(batch, state[0])                                          │
│    80                                                                                            │
│    81   batch = [None]                                                                           │
│    82   def train_step(tran, worker):                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxagent.py:80 in train                                        │
│                                                                                                  │
│    77 │   rng = self._next_rngs(mirror=not self.varibs)                                          │
│    78 │   if state is None:                                                                      │
│    79 │     state, self.varibs = self._init_train(self.varibs, rng, data['is_first'])            │
│ ❱  80 │   (outs, state, mets), self.varibs = self._train(                                        │
│    81 │   │   self.varibs, rng, data, state)                                                     │
│    82 │   outs = self._convert_outs(outs)                                                        │
│    83 │   mets = self._convert_mets(mets)                                                        │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:181 in wrapper                                       │
│                                                                                                  │
│   178 │   statics = tuple(sorted([(k, v) for k, v in kwargs.items() if k in static]))            │
│   179 │   kwargs = {k: v for k, v in kwargs.items() if k not in static}                          │
│   180 │   if not hasattr(wrapper, 'keys'):                                                       │
│ ❱ 181 │     created = init(statics, rng, *args, **kwargs)                                        │
│   182 │     wrapper.keys = set(created.keys())                                                   │
│   183 │     for key, value in created.items():                                                   │
│   184 │   │   if key not in state:                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:166 in init                                          │
│                                                                                                  │
│   163   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│   164   def init(statics, rng, *args, **kwargs):                                                 │
│   165 │   # Return only state so JIT can remove dead code for fast initialization.               │
│ ❱ 166 │   s = fun({}, rng, *args, ignore=True, **dict(statics), **kwargs)[1]                     │
│   167 │   return s                                                                               │
│   168                                                                                            │
│   169   @bind(jax.jit, static_argnums=[0], **kwargs)                                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified                                       │
│                                                                                                  │
│    74 │   before = CONTEXT                                                                       │
│    75 │   try:                                                                                   │
│    76 │     CONTEXT = Context(state.copy(), rng, create, modify, ignore, [])                     │
│ ❱  77 │     out = fun(*args, **kwargs)                                                           │
│    78 │     state = dict(CONTEXT)                                                                │
│    79 │     return out, state                                                                    │
│    80 │   finally:                                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:80 in train                                           │
│                                                                                                  │
│    77 │   self.config.jax.jit and print('Tracing train function.')                               │
│    78 │   metrics = {}                                                                           │
│    79 │   data = self.preprocess(data)                                                           │
│ ❱  80 │   state, wm_outs, mets = self.wm.train(data, state)                                      │
│    81 │   metrics.update(mets)                                                                   │
│    82 │   context = {**data, **wm_outs['post']}                                                  │
│    83 │   start = tree_map(lambda x: x.reshape([-1] + list(x.shape[2:])), context)               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:151 in train                                          │
│                                                                                                  │
│   148                                                                                            │
│   149   def train(self, data, state):                                                            │
│   150 │   modules = [self.encoder, self.rssm, *self.heads.values()]                              │
│ ❱ 151 │   mets, (state, outs, metrics) = self.opt(                                               │
│   152 │   │   modules, self.loss, data, state, has_aux=True)                                     │
│   153 │   metrics.update(mets)                                                                   │
│   154 │   return state, outs, metrics                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:410 in __call__                                    │
│                                                                                                  │
│   407 │   │   loss *= sg(self.grad_scale.read())                                                 │
│   408 │     return loss, aux                                                                     │
│   409 │   metrics = {}                                                                           │
│ ❱ 410 │   loss, params, grads, aux = nj.grad(                                                    │
│   411 │   │   wrapped, modules, has_aux=True)(*args, **kwargs)                                   │
│   412 │   if not self.PARAM_COUNTS[self.path]:                                                   │
│   413 │     count = sum([np.prod(x.shape) for x in params.values()])                             │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:142 in wrapper                                       │
│                                                                                                  │
│   139   backward = jax.value_and_grad(forward, has_aux=True)                                     │
│   140   @functools.wraps(backward)                                                               │
│   141   def wrapper(*args, **kwargs):                                                            │
│ ❱ 142 │   _prerun(fun, *args, **kwargs)                                                          │
│   143 │   assert all(isinstance(x, (str, Module)) for x in keys)                                 │
│   144 │   strs = [x for x in keys if isinstance(x, str)]                                         │
│   145 │   mods = [x for x in keys if isinstance(x, Module)]                                      │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:271 in _prerun                                       │
│                                                                                                  │
│   268 def _prerun(fun, *args, **kwargs):                                                         │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│ ❱ 271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:77 in purified                                       │
│                                                                                                  │
│    74 │   before = CONTEXT                                                                       │
│    75 │   try:                                                                                   │
│    76 │     CONTEXT = Context(state.copy(), rng, create, modify, ignore, [])                     │
│ ❱  77 │     out = fun(*args, **kwargs)                                                           │
│    78 │     state = dict(CONTEXT)                                                                │
│    79 │     return out, state                                                                    │
│    80 │   finally:                                                                               │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:402 in wrapped                                     │
│                                                                                                  │
│   399                                                                                            │
│   400   def __call__(self, modules, lossfn, *args, has_aux=False, **kwargs):                     │
│   401 │   def wrapped(*args, **kwargs):                                                          │
│ ❱ 402 │     outs = lossfn(*args, **kwargs)                                                       │
│   403 │     loss, aux = outs if has_aux else (outs, None)                                        │
│   404 │     assert loss.dtype == jnp.float32, (self.name, loss.dtype)                            │
│   405 │     assert loss.shape == (), (self.name, loss.shape)                                     │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/agent.py:161 in loss                                           │
│                                                                                                  │
│   158 │   prev_latent, prev_action = state                                                       │
│   159 │   prev_actions = jnp.concatenate([                                                       │
│   160 │   │   prev_action[:, None], data['action'][:, :-1]], 1)                                  │
│ ❱ 161 │   post, prior = self.rssm.observe(                                                       │
│   162 │   │   embed, prev_actions, data['is_first'], prev_latent)                                │
│   163 │   dists = {}                                                                             │
│   164 │   feats = {**post, 'embed': embed}                                                       │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:353 in wrapper                                       │
│                                                                                                  │
│   350   def wrapper(self, *args, **kwargs):                                                      │
│   351 │   with scope(self._path, absolute=True):                                                 │
│   352 │     with jax.named_scope(self._path.split('/')[-1]):                                     │
│ ❱ 353 │   │   return method(self, *args, **kwargs)                                               │
│   354   return wrapper                                                                           │
│   355                                                                                            │
│   356                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/nets.py:60 in observe                                          │
│                                                                                                  │
│    57 │   step = lambda prev, inputs: self.obs_step(prev[0], *inputs)                            │
│    58 │   inputs = swap(action), swap(embed), swap(is_first)                                     │
│    59 │   start = state, state                                                                   │
│ ❱  60 │   post, prior = jaxutils.scan(step, inputs, start, self._unroll)                         │
│    61 │   post = {k: swap(v) for k, v in post.items()}                                           │
│    62 │   prior = {k: swap(v) for k, v in prior.items()}                                         │
│    63 │   return post, prior                                                                     │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/jaxutils.py:73 in scan                                         │
│                                                                                                  │
│    70 def scan(fn, inputs, start, unroll=True, modify=False):                                    │
│    71   fn2 = lambda carry, inp: (fn(carry, inp),) * 2                                           │
│    72   if not unroll:                                                                           │
│ ❱  73 │   return nj.scan(fn2, start, inputs, modify=modify)[1]                                   │
│    74   length = len(jax.tree_util.tree_leaves(inputs)[0])                                       │
│    75   carrydef = jax.tree_util.tree_structure(start)                                           │
│    76   carry = start                                                                            │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:245 in scan                                          │
│                                                                                                  │
│   242 @jax.named_scope('scan')                                                                   │
│   243 def scan(fun, carry, xs, reverse=False, unroll=1, modify=False):                           │
│   244   fun = pure(fun, nested=True)                                                             │
│ ❱ 245   _prerun(fun, carry, jax.tree_util.tree_map(lambda x: x[0], xs))                          │
│   246   length = len(jax.tree_util.tree_leaves(xs)[0])                                           │
│   247   rngs = rng(length)                                                                       │
│   248   if modify:                                                                               │
│                                                                                                  │
│ /home/weidong/anaconda3/envs/dreamerv3/lib/python3.8/contextlib.py:75 in inner                   │
│                                                                                                  │
│    72 │   │   @wraps(func)                                                                       │
│    73 │   │   def inner(*args, **kwds):                                                          │
│    74 │   │   │   with self._recreate_cm():                                                      │
│ ❱  75 │   │   │   │   return func(*args, **kwds)                                                 │
│    76 │   │   return inner                                                                       │
│    77                                                                                            │
│    78                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:272 in _prerun                                       │
│                                                                                                  │
│   269   if not context().create:                                                                 │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│ ❱ 272   jax.tree_util.tree_map(                                                                  │
│   273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│                                                                                                  │
│ /home/weidong/dreamerv3/dreamerv3/ninjax.py:273 in <lambda>                                      │
│                                                                                                  │
│   270 │   return                                                                                 │
│   271   discarded, state = fun(dict(context()), rng(), *args, ignore=True, **kwargs)             │
│   272   jax.tree_util.tree_map(                                                                  │
│ ❱ 273 │     lambda x: hasattr(x, 'delete') and x.delete(), discarded)                            │
│   274   context().update(state)                                                                  │
│   275                                                                                            │
│   276                                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: 
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The delete() method was called on the JAX Tracer object 
Traced<ShapedArray(float16[16,1024])>with<DynamicJaxprTrace(level=1/0)>
The error occurred while tracing the function init at /home/weidong/dreamerv3/dreamerv3/ninjax.py:163 for 
jit. This concrete value was not available in Python because it depends on the values of the arguments 
'statics', 'rng', and 'args'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

I think some jax support guide should be added in readme.

My pip list now is as follow:

(dreamerv3) weidong@user-NULL:~/dreamerv3$ pip list
Package                      Version
---------------------------- --------------------
absl-py                      1.4.0
astunparse                   1.6.3
cachetools                   5.3.0
certifi                      2022.12.7
charset-normalizer           3.0.1
chex                         0.1.6
cloudpickle                  1.6.0
crafter                      1.8.0
decorator                    5.1.1
dm-tree                      0.1.8
flatbuffers                  23.1.21
gast                         0.4.0
google-auth                  2.16.1
google-auth-oauthlib         0.4.6
google-pasta                 0.2.0
grpcio                       1.51.1
gym                          0.19.0
h5py                         3.8.0
idna                         3.4
imageio                      2.25.1
importlib-metadata           6.0.0
jax                          0.4.4
jaxlib                       0.4.4+cuda11.cudnn86
keras                        2.11.0
libclang                     15.0.6.1
llvmlite                     0.39.1
Markdown                     3.4.1
markdown-it-py               2.1.0
MarkupSafe                   2.1.2
mdurl                        0.1.2
numba                        0.56.4
numpy                        1.23.5
oauthlib                     3.2.2
opensimplex                  0.4.4
opt-einsum                   3.3.0
optax                        0.1.4
packaging                    23.0
Pillow                       9.4.0
pip                          23.0.1
protobuf                     3.19.6
pyasn1                       0.4.8
pyasn1-modules               0.2.8
Pygments                     2.14.0
python-version               0.0.2
requests                     2.28.2
requests-oauthlib            1.3.1
rich                         13.3.1
rsa                          4.9
ruamel.yaml                  0.17.21
ruamel.yaml.clib             0.2.7
scipy                        1.10.0
setuptools                   65.6.3
six                          1.16.0
tensorboard                  2.11.2
tensorboard-data-server      0.6.1
tensorboard-plugin-wit       1.8.1
tensorflow                   2.11.0
tensorflow-cpu               2.11.0
tensorflow-estimator         2.11.0
tensorflow-io-gcs-filesystem 0.30.0
tensorflow-probability       0.19.0
termcolor                    2.2.0
toolz                        0.12.0
typing_extensions            4.5.0
urllib3                      1.26.14
Werkzeug                     2.2.3
wheel                        0.38.4
wrapt                        1.14.1
zipp                         3.14.0
@hdadong hdadong changed the title [BUG] RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: interpreter,cpu` [BUG] some bug for jax gpu version Feb 19, 2023
@hdadong
Copy link
Author

hdadong commented Feb 19, 2023

I fixed this bug by:

pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install jax==0.3.25

Thanks!
reference: #1

@danijar
Copy link
Owner

danijar commented Feb 19, 2023

Thanks for reporting and great that you figured it out!

@danijar
Copy link
Owner

danijar commented Feb 20, 2023

@hdadong I think I've fixed the issue with the newest JAX version. Could you try again, please?

@swsychen
Copy link

swsychen commented Apr 4, 2024

This bug disappears when using the following commands in a Python 3.9 env:

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt **(Here the trick is commenting out the two lines: jax and jaxlib)**

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants