New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Formalize operator dtypes #1697
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #1697 +/- ##
==========================================
+ Coverage 82.53% 82.75% +0.21%
==========================================
Files 298 298
Lines 18304 18224 -80
Branches 2763 3504 +741
==========================================
- Hits 15107 15081 -26
+ Misses 2512 2468 -44
+ Partials 685 675 -10 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @wdphy16 I think this is very good, and was long needed.
My main thoughts are the following:
- I don't want netket to throw tons of warnings if using numba operators with x64 disabled.
- I'm not sure it makes sense to talk about weak types for the dtype of numba operators, as those will be arrays and jax will not tag them as weak (I think).
Therefore, I think that we should somehow default to having single-precision operators (where relevant) if x64 is disabled, and double-precision operators if x64 is enabled.
What do you think?
…ays promote from float when dtype is None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @wdphy16 , I think this is good to be merged.
As tests pass, please feel free to add one line to the CHANGELOG and then merge
@wdphy16 Ah sorry, before farming, I forgot, can you double check that operators in experimental (fermions) are addressed as well ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's good for me. Can I merge?
The discussion started in #1543 , and now I finally have some time to thoroughly check it.
I think the most important intuition is that, if
dtype
is specified in__init__
, the propertyoperator.dtype
should return the same value. (Except when the specified dtype is 64-bit, andNETKET_ENABLE_X64
is disabled, it will become 32-bit.) This behavior is consistent withjnp.array
.If not specified in
__init__
, it will be inferred from all other arguments, and whether the operator is required to be complex, using JAX's promotion rules. If no other argument can be used to infer it, it defaults to float64 ifNETKET_ENABLE_X64
is enabled, and float32 otherwise.If the inferred dtype is different from the specified one,
__init__
may cast it to the specified dtype, or raise aTypeError
if it's unphysical or hard to implement.If the dtype is not specified and it's inferred to be int, we promote it to float, as suggested previously by this test. (I think it's intuitive for physicists who are too lazy to type .0, rather than dtype lawers)
It's also possible to specify an int dtype and it just works in many cases. In future we can make it work with
PauliStrings
, and raise an error for uint and other dtypes, if someone really needs that.Note that when doing in-place arithmetic, some operators actually modify the underlying arrays, so their dtypes never change, and they raise a
TypeError
if casting complex to real. Other operators just call the out-of-place methods, so their dtypes may change. This PR only cleans up__init__
and does not touch those methods.Discrete operators
Their
dtype
is the dtype of matrix elements, and we already explicitly ensured that in most cases. The dtype of expectation values will be inferred from both the matrix elements and the wave function.For
Ising
andIsingJax
, previously the dtype was inferred to be int ifJ
andh
are ints, now it's promoted to float.For
Heisenberg
, previously there was no argumentdtype
in__init__
, now we've added it and it's handled byLocalOperator
.For
BoseHubbard
, previously we didn't cast it to x32 when x64 is disabled, now we cast it. Although this operator doesn't have a JAX version yet, we still do the cast to make things more consistent.For
LocalLiouvillian
, we also cast it to x32 when x64 is disabled. Previously we cast the specified real dtype to complex, now we raise aTypeError
because it's unphysical.Pauli strings
Now we infer that the dtype is complex if any string has an odd number of Y (previously if it has any Y), or if any weight is complex. Previously we cast the specified real dtype to complex when needed, now we raise a
TypeError
because it's unphysical.I've added a
_reduce_pauli_string
in__init__
. After that, those strings with Y cannot cancel out.It's still possible that a term has an odd number of Y and a purely imaginary weight, so that the whole Hamiltonian is real and non-Hermitian, and we have to work with a complex array of
weights
. We don't specially handle that for now.Also, there is a (maybe subtle) change: When
dtype
is not specified and cannot be inferred fromweights
, previously it defaulted to x32 because of this line, now it defaults to x64 when x64 is enabled.Continuous operators
They don't compute matrix elements, and we never explicitly cast the output dtype, so I think their
dtype
should behave likeparam_dtype
in Flax.For
PotentialEnergy
andKineticEnergy
, the dtype is only used to castcoefficient
andmass
, so nothing is changed.For
SumOperator
, there is a subtle change: Previously we only inferred the dtype from the operators, now we also infer it from the coefficients. Note that thedtype
is only the dtype ofcoefficients
, and the dtypes ofoperators
are unaffected. The reason is like how we define Flax modules: If we writethen we usually expect that
BigModule
will not change the dtypes of parameters insmall_module_1
andsmall_module_2
, and only the parameters newly defined inBigModule
have dtype3. On the contrary, if we writethen we expect that it constructs some
SmallModule
using dtype3.