From da6171c7ef98632f764c0e51ecd51000edf001fe Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 11 Jan 2022 12:06:48 +0100 Subject: [PATCH] Throw error real-QGTOnTheFly@complexvector (#885) Check that if the qgt is real, the vector must be real too (currently this throws an untelligible error for QGTJacobianPyTree and returns the wrong result for OnTheFly). --- CHANGELOG.md | 4 +- netket/optimizer/qgt/common.py | 52 +++++++++++++++++++++ netket/optimizer/qgt/qgt_jacobian_pytree.py | 6 +++ netket/optimizer/qgt/qgt_onthefly.py | 5 ++ test/optimizer/test_qgt_solvers.py | 24 ++++++++++ 5 files changed, 90 insertions(+), 1 deletion(-) create mode 100644 netket/optimizer/qgt/common.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d4cdf1b07e..5f6896ed2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,9 @@ ### Bug Fixes * Initialisation of all implementations of `DenseSymm`, `DenseEquivariant`, `GCNN` now defaults to truncated normals with Lecun variance scaling. For layers without masking, there should be no noticeable change in behaviour. For masked layers, the same variance scaling now works correctly. [#1045](https://github.com/netket/netket/pull/1045) -* Fix bug that prevented gradients of non-hermitian operators to be computed. The feature is still marked as experimental but will now run (we do not guarantee that results are correct). [#1045](https://github.com/netket/netket/pull/1045) +* Fix bug that prevented gradients of non-hermitian operators to be computed. The feature is still marked as experimental but will now run (we do not guarantee that results are correct). [#1053](https://github.com/netket/netket/pull/1053) +* Common lattice constructors such as `Honeycomb` now accepts the same keyword arguments as `Lattice`. [#1046](https://github.com/netket/netket/pull/1046) +* Multiplying a `QGTOnTheFly` representing the real part of the QGT (showing up when the ansatz has real parameters) with a complex vector now throws an error. Previously the result would be wrong, as the imaginary part [was casted away](https://github.com/netket/netket/issues/789#issuecomment-871145119). [#885](https://github.com/netket/netket/pull/885) ## NetKet 3.3 (🎁 20 December 2021) diff --git a/netket/optimizer/qgt/common.py b/netket/optimizer/qgt/common.py new file mode 100644 index 0000000000..893de8e0e5 --- /dev/null +++ b/netket/optimizer/qgt/common.py @@ -0,0 +1,52 @@ +# Copyright 2022 The NetKet Authors - All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from textwrap import dedent + +import jax +from jax import numpy as jnp + +from netket.utils.types import PyTree + + +def check_valid_vector_type(x: PyTree, target: PyTree): + """ + Raises a TypeError if x is complex where target is real, because it is not + supported by QGTOnTheFly and the imaginary part would be dicscarded after + anyhow. + """ + + def check(x, target): + print(f"x={x} target={target}") + par_iscomplex = jnp.iscomplexobj(x) + + # Account for split real-imaginary part in Jacobian*** methods + if isinstance(target, tuple): + vec_iscomplex = True if len(target) == 2 else False + else: + vec_iscomplex = jnp.iscomplexobj(target) + + if not par_iscomplex and vec_iscomplex: + raise TypeError( + dedent( + """ + Cannot multiply the (real part of the) QGT by a complex vector. + You should either take the real part of the vector, or perform + the multiplication against the real and imaginary part of the + vector separately and then recomposing the two. + """ + ) + ) + + jax.tree_multimap(check, x, target) diff --git a/netket/optimizer/qgt/qgt_jacobian_pytree.py b/netket/optimizer/qgt/qgt_jacobian_pytree.py index 1b24328f08..8766f34cf2 100644 --- a/netket/optimizer/qgt/qgt_jacobian_pytree.py +++ b/netket/optimizer/qgt/qgt_jacobian_pytree.py @@ -25,6 +25,7 @@ from ..linear_operator import LinearOperator, Uninitialized +from .common import check_valid_vector_type from .qgt_jacobian_pytree_logic import mat_vec, prepare_centered_oks from .qgt_jacobian_common import choose_jacobian_mode @@ -204,6 +205,8 @@ def _matmul( if self.mode != "holomorphic" and not self._in_solve: vec, reassemble = nkjax.tree_to_real(vec) + check_valid_vector_type(self.params, vec) + if self.scale is not None: vec = jax.tree_multimap(jnp.multiply, vec, self.scale) @@ -227,10 +230,13 @@ def _matmul( def _solve( self: QGTJacobianPyTreeT, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None ) -> PyTree: + # Real-imaginary split RHS in R→R and R→C modes if self.mode != "holomorphic": y, reassemble = nkjax.tree_to_real(y) + check_valid_vector_type(self.params, y) + if self.scale is not None: y = jax.tree_multimap(jnp.divide, y, self.scale) if x0 is not None: diff --git a/netket/optimizer/qgt/qgt_onthefly.py b/netket/optimizer/qgt/qgt_onthefly.py index 0798a61f53..dc9fc1f3c5 100644 --- a/netket/optimizer/qgt/qgt_onthefly.py +++ b/netket/optimizer/qgt/qgt_onthefly.py @@ -23,6 +23,7 @@ from netket.utils.types import PyTree from netket.utils import warn_deprecation +from .common import check_valid_vector_type from .qgt_onthefly_logic import mat_vec_factory, mat_vec_chunked_factory from ..linear_operator import LinearOperator, Uninitialized @@ -157,6 +158,8 @@ def onthefly_mat_treevec( else: ravel_result = False + check_valid_vector_type(S._params, vec) + vec = nkjax.tree_cast(vec, S._params) res = S._mat_vec(vec, S.diag_shift) @@ -172,6 +175,8 @@ def _solve( self: QGTOnTheFlyT, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs ) -> PyTree: + check_valid_vector_type(self._params, y) + y = nkjax.tree_cast(y, self._params) # we could cache this... diff --git a/test/optimizer/test_qgt_solvers.py b/test/optimizer/test_qgt_solvers.py index 711ddb7183..ca2b640173 100644 --- a/test/optimizer/test_qgt_solvers.py +++ b/test/optimizer/test_qgt_solvers.py @@ -25,6 +25,11 @@ from .. import common # noqa: F401 +QGT_types = {} +QGT_types["QGTOnTheFly"] = nk.optimizer.qgt.QGTOnTheFly +# QGT_types["QGTJacobianDense"] = nk.optimizer.qgt.QGTJacobianDense +QGT_types["QGTJacobianPyTree"] = nk.optimizer.qgt.QGTJacobianPyTree + QGT_objects = {} QGT_objects["JacobianPyTree"] = partial(qgt.QGTJacobianPyTree, diag_shift=0.00) @@ -75,3 +80,22 @@ def test_qgt_solve(qgt, vstate, solver, _mpi_size, _mpi_rank): S = qgt(vstate) x, _ = S.solve(solver, vstate.parameters) + + +# Issue #789 https://github.com/netket/netket/issues/789 +# cannot multiply real qgt by complex vector +@common.skipif_mpi +@pytest.mark.parametrize( + "SType", [pytest.param(T, id=name) for name, T in QGT_types.items()] +) +def test_qgt_throws(SType): + hi = nk.hilbert.Spin(s=1 / 2, N=5) + ma = nk.models.RBMModPhase(alpha=1, dtype=float) + sa = nk.sampler.MetropolisLocal(hi, n_chains=16, reset_chains=False) + vs = nk.vqs.MCState(sa, ma, n_samples=100, n_discard_per_chain=100) + + S = vs.quantum_geometric_tensor(SType) + g_cmplx = jax.tree_map(lambda x: x + x * 0.1j, vs.parameters) + + with pytest.raises(TypeError, match="Cannot multiply the"): + S @ g_cmplx