contraction error when running mps-rnn repo #1595
Replies: 4 comments
-
Netket had a few changes from that time, notably the ordering of the dimensions of the samples. You should probably install netket's old version from Dian's branch (the plum error is not important, the other error is because you did not load netket first and you'll be running simulations in single precision and not double) |
Beta Was this translation helpful? Give feedback.
-
Hi @mschuylermoss , there're indeed some changes of package versions over the past few months. I've updated the pinned versions in that repo, now you can try to pull the repo and install the dependencies again using If there're still problems, you can open an issue in that repo. @PhilipVinc The problem is that there is an unexpected dimension in a very deep vmap + scan, and I guess it's related to the transpose of samples. I think it's time to finish my PR of RNN and do some thorough testing. Last year we were waiting for Flax's RNN API to be stable, now it's already stable and I can refactor my code on that. Next week I'll spend some time on this. |
Beta Was this translation helpful? Give feedback.
-
Thank you both for the quick responses-- everything seems to be working for me now! If anything else comes up, I will address it in the original repo. |
Beta Was this translation helpful? Give feedback.
-
@wdphy16 yes, I had thought about that a while back but forgot to ping you about it :) |
Beta Was this translation helpful? Give feedback.
-
I downloaded the mps-rnn repo) and am trying to run it using the default settings. The only things I specify are:
python vmc.py --net_dim 1 --dtype "float32" --show_progress --cuda "" --run_name "test" --out_dir "./out"
I am using a conda environment with python=3.9.17, but I pip installed everything instead of conda install (I have read that Conda can raise weird errors with NetKet). Other relevant dependencies are: jax==0.4.14, jaxlib=0.4.14, numpy=1.24.4, netket (custom branch), plum-dispatch=2.2.0
When I run the code, everything gets traced and the VMC begins to run and immediately breaks (i.e. the progress bar appears and shows 0%, so the first step of the VMC breaks) the final error reads:
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Einstein sum subscript 'a' does not contain the correct number of indices for operand 0.
The full stack trace is here:
I looked at the dimensions of the objects being contracted and made some very naive adjustments to see if it was actually a bug having to do with the contractions, but that seemed to create a bread crumb trail where one fix lead to some contraction error somewhere else and so on. This is the very initial step of this whole repo, so I suspect there is something deeper going on here (and that the issue is not actually with these contractions).
Something else that might be noteworthy is that before the VMC begins running, I get two user warnings that read:
UserWarning: `plum.Val` is deprecated and will be removed in a future version. Please use `typing.Literal` instead.
and
UserWarning: Explicitly requested dtype <class 'numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable.
Beta Was this translation helpful? Give feedback.
All reactions