Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formalize operator dtypes #1697

Merged
merged 25 commits into from Jan 24, 2024
Merged

Formalize operator dtypes #1697

merged 25 commits into from Jan 24, 2024

Conversation

wdphy16
Copy link
Collaborator

@wdphy16 wdphy16 commented Jan 19, 2024

The discussion started in #1543 , and now I finally have some time to thoroughly check it.

I think the most important intuition is that, if dtype is specified in __init__, the property operator.dtype should return the same value. (Except when the specified dtype is 64-bit, and NETKET_ENABLE_X64 is disabled, it will become 32-bit.) This behavior is consistent with jnp.array.

If not specified in __init__, it will be inferred from all other arguments, and whether the operator is required to be complex, using JAX's promotion rules. If no other argument can be used to infer it, it defaults to float64 if NETKET_ENABLE_X64 is enabled, and float32 otherwise.

If the inferred dtype is different from the specified one, __init__ may cast it to the specified dtype, or raise a TypeError if it's unphysical or hard to implement.

  • If the inferred dtype is lower than the specified one, we silently upcast it as in numpy.
  • If the inferred dtype is higher than the specified one, and both are real or both are complex, we silently truncate it as in numpy. (I guess people don't like to see a lot of warnings when they already decide to specify 32-bit...)
  • If the inferred dtype is float64 but the specified one is complex64, we also silently truncate it as in numpy.
  • If the inferred dtype is complex but the specified one is real, we may give a warning and discard the imaginary part, which may be already done by numpy. Or we may raise the error.

If the dtype is not specified and it's inferred to be int, we promote it to float, as suggested previously by this test. (I think it's intuitive for physicists who are too lazy to type .0, rather than dtype lawers)

It's also possible to specify an int dtype and it just works in many cases. In future we can make it work with PauliStrings, and raise an error for uint and other dtypes, if someone really needs that.

Note that when doing in-place arithmetic, some operators actually modify the underlying arrays, so their dtypes never change, and they raise a TypeError if casting complex to real. Other operators just call the out-of-place methods, so their dtypes may change. This PR only cleans up __init__ and does not touch those methods.

Discrete operators

Their dtype is the dtype of matrix elements, and we already explicitly ensured that in most cases. The dtype of expectation values will be inferred from both the matrix elements and the wave function.

For Ising and IsingJax, previously the dtype was inferred to be int if J and h are ints, now it's promoted to float.

For Heisenberg, previously there was no argument dtype in __init__, now we've added it and it's handled by LocalOperator.

For BoseHubbard, previously we didn't cast it to x32 when x64 is disabled, now we cast it. Although this operator doesn't have a JAX version yet, we still do the cast to make things more consistent.

For LocalLiouvillian, we also cast it to x32 when x64 is disabled. Previously we cast the specified real dtype to complex, now we raise a TypeError because it's unphysical.

Pauli strings

Now we infer that the dtype is complex if any string has an odd number of Y (previously if it has any Y), or if any weight is complex. Previously we cast the specified real dtype to complex when needed, now we raise a TypeError because it's unphysical.

I've added a _reduce_pauli_string in __init__. After that, those strings with Y cannot cancel out.

It's still possible that a term has an odd number of Y and a purely imaginary weight, so that the whole Hamiltonian is real and non-Hermitian, and we have to work with a complex array of weights. We don't specially handle that for now.

Also, there is a (maybe subtle) change: When dtype is not specified and cannot be inferred from weights, previously it defaulted to x32 because of this line, now it defaults to x64 when x64 is enabled.

Continuous operators

They don't compute matrix elements, and we never explicitly cast the output dtype, so I think their dtype should behave like param_dtype in Flax.

For PotentialEnergy and KineticEnergy, the dtype is only used to cast coefficient and mass, so nothing is changed.

For SumOperator, there is a subtle change: Previously we only inferred the dtype from the operators, now we also infer it from the coefficients. Note that the dtype is only the dtype of coefficients, and the dtypes of operators are unaffected. The reason is like how we define Flax modules: If we write

small_module_1 = SmallModule(param_dtype=dtype1)
small_module_2 = SmallModule(param_dtype=dtype2)
big_module = BigModule(small_module_1, small_module_2, param_dtype=dtype3)

then we usually expect that BigModule will not change the dtypes of parameters in small_module_1 and small_module_2, and only the parameters newly defined in BigModule have dtype3. On the contrary, if we write

big_module = BigModule(module_type="SmallModule", param_dtype=dtype3)

then we expect that it constructs some SmallModule using dtype3.

Copy link

codecov bot commented Jan 19, 2024

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (bb86b02) 82.53% compared to head (e86e625) 82.75%.
Report is 3 commits behind head on master.

❗ Current head e86e625 differs from pull request most recent head 124d2a2. Consider uploading reports for the commit 124d2a2 to get more accurate results

Files Patch % Lines
netket/utils/numbers.py 60.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1697      +/-   ##
==========================================
+ Coverage   82.53%   82.75%   +0.21%     
==========================================
  Files         298      298              
  Lines       18304    18224      -80     
  Branches     2763     3504     +741     
==========================================
- Hits        15107    15081      -26     
+ Misses       2512     2468      -44     
+ Partials      685      675      -10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@PhilipVinc PhilipVinc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @wdphy16 I think this is very good, and was long needed.

My main thoughts are the following:

  • I don't want netket to throw tons of warnings if using numba operators with x64 disabled.
  • I'm not sure it makes sense to talk about weak types for the dtype of numba operators, as those will be arrays and jax will not tag them as weak (I think).

Therefore, I think that we should somehow default to having single-precision operators (where relevant) if x64 is disabled, and double-precision operators if x64 is enabled.

What do you think?

netket/operator/_continuous_operator.py Outdated Show resolved Hide resolved
netket/operator/_local_operator/convert.py Outdated Show resolved Hide resolved
netket/operator/_local_operator/helpers.py Outdated Show resolved Hide resolved
test/operator/test_continuous_operator.py Show resolved Hide resolved
Copy link
Member

@PhilipVinc PhilipVinc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @wdphy16 , I think this is good to be merged.

As tests pass, please feel free to add one line to the CHANGELOG and then merge

netket/operator/_ising/base.py Outdated Show resolved Hide resolved
@PhilipVinc
Copy link
Member

@wdphy16 Ah sorry, before farming, I forgot, can you double check that operators in experimental (fermions) are addressed as well ?

@PhilipVinc
Copy link
Member

@wdphy16 I added the changelog and commit a3d1c54 should fix fermions as well. Can you confirm?

Copy link
Member

@PhilipVinc PhilipVinc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's good for me. Can I merge?

@PhilipVinc PhilipVinc merged commit b2d0f88 into netket:master Jan 24, 2024
9 checks passed
@wdphy16 wdphy16 deleted the operator_dtype branch January 24, 2024 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants