You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I run the training script, I encounter the problems as follow.
By the way, it's ok when I run training script from nerfies with similar config.
train.py:302] Starting training
Traceback (most recent call last):
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 146, in argnums_partial_except
hash(static_arg)
TypeError: unhashable type: 'NerfModel'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
app.run(main)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
state, stats, keys, model_out = ptrain_step(
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 1669, in f_pmapped
out = pxla.xla_pmap(
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1620, in bind
return call_bind(self, fun, *args, **params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1551, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 1623, in process
return trace.process_map(self, fun, tracers, params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/core.py", line 606, in process_call
return primitive.impl(f, *tracers, **params)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 624, in xla_pmap_impl
compiled_fun, fingerprint = parallel_callable(fun, backend, axis_name, axis_size,
File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, *args)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 712, in parallel_callable
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/root/anaconda3/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/root/anaconda3/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/root/anaconda3/lib/python3.9/site-packages/jax/_src/api.py", line 413, in cache_miss
f, args = argnums_partial_except(f, static_argnums, args, allow_invalid=True)
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/bfs/sz/Research/hypernerf/train.py", line 370, in <module>
app.run(main)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/root/anaconda3/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/bfs/sz/Research/hypernerf/train.py", line 330, in main
state, stats, keys, model_out = ptrain_step(
File "/root/anaconda3/lib/python3.9/site-packages/jax/api_util.py", line 148, in argnums_partial_except
raise ValueError(
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'hypernerf.models.NerfModel'> for function train_step is non-hashable.
the jax version is as below:
jax 0.2.17
jaxlib 0.1.66+cuda110
The text was updated successfully, but these errors were encountered:
When I run the training script, I encounter the problems as follow.
By the way, it's ok when I run training script from
nerfies
with similar config.the jax version is as below:
The text was updated successfully, but these errors were encountered: