Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Non-hashable static arguments are not supported #33

Closed
Spark001 opened this issue Aug 1, 2022 · 1 comment
Closed

ValueError: Non-hashable static arguments are not supported #33

Spark001 opened this issue Aug 1, 2022 · 1 comment

Comments

@Spark001
Copy link

Spark001 commented Aug 1, 2022

When I run the training script, I encounter the problems as follow.
By the way, it's ok when I run training script from nerfies with similar config.

train.py:302] Starting training
Traceback (most recent call last):
  File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 146, in argnums_partial_except
    hash(static_arg)
TypeError: unhashable type: 'NerfModel'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
    app.run(main)
  File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
    state, stats, keys, model_out = ptrain_step(
  File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 1669, in f_pmapped
    out = pxla.xla_pmap(
  File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1620, in bind
    return call_bind(self, fun, *args, **params)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1551, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1623, in process
    return trace.process_map(self, fun, tracers, params)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 606, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl
    compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size,
  File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 262, in memoized_fun
    ans = call(fun, *args)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 712, in parallel_callable
    jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
  File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 413, in cache_miss
    f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
  File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
    raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
    app.run(main)
  File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
    state, stats, keys, model_out = ptrain_step(
  File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
    raise ValueError(
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.

the jax version is as below:

jax                                0.2.17
jaxlib                             0.1.66+cuda110
@Spark001
Copy link
Author

Spark001 commented Aug 1, 2022

Solved by updating package version according to requirements.txt

@Spark001 Spark001 closed this as completed Aug 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant