Skip to content

Commit

Permalink
Throw error real-QGTOnTheFly@complexvector (#885)
Browse files Browse the repository at this point in the history
Check that if the qgt is real, the vector must be real too (currently this throws an untelligible error for QGTJacobianPyTree and returns the wrong result for OnTheFly).
  • Loading branch information
PhilipVinc committed Jan 11, 2022
1 parent 30d938e commit da6171c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 1 deletion.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Expand Up @@ -16,7 +16,9 @@

### Bug Fixes
* Initialisation of all implementations of `DenseSymm`, `DenseEquivariant`, `GCNN` now defaults to truncated normals with Lecun variance scaling. For layers without masking, there should be no noticeable change in behaviour. For masked layers, the same variance scaling now works correctly. [#1045](https://github.com/netket/netket/pull/1045)
* Fix bug that prevented gradients of non-hermitian operators to be computed. The feature is still marked as experimental but will now run (we do not guarantee that results are correct). [#1045](https://github.com/netket/netket/pull/1045)
* Fix bug that prevented gradients of non-hermitian operators to be computed. The feature is still marked as experimental but will now run (we do not guarantee that results are correct). [#1053](https://github.com/netket/netket/pull/1053)
* Common lattice constructors such as `Honeycomb` now accepts the same keyword arguments as `Lattice`. [#1046](https://github.com/netket/netket/pull/1046)
* Multiplying a `QGTOnTheFly` representing the real part of the QGT (showing up when the ansatz has real parameters) with a complex vector now throws an error. Previously the result would be wrong, as the imaginary part [was casted away](https://github.com/netket/netket/issues/789#issuecomment-871145119). [#885](https://github.com/netket/netket/pull/885)


## NetKet 3.3 (🎁 20 December 2021)
Expand Down
52 changes: 52 additions & 0 deletions netket/optimizer/qgt/common.py
@@ -0,0 +1,52 @@
# Copyright 2022 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from textwrap import dedent

import jax
from jax import numpy as jnp

from netket.utils.types import PyTree


def check_valid_vector_type(x: PyTree, target: PyTree):
"""
Raises a TypeError if x is complex where target is real, because it is not
supported by QGTOnTheFly and the imaginary part would be dicscarded after
anyhow.
"""

def check(x, target):
print(f"x={x} target={target}")
par_iscomplex = jnp.iscomplexobj(x)

# Account for split real-imaginary part in Jacobian*** methods
if isinstance(target, tuple):
vec_iscomplex = True if len(target) == 2 else False
else:
vec_iscomplex = jnp.iscomplexobj(target)

if not par_iscomplex and vec_iscomplex:
raise TypeError(
dedent(
"""
Cannot multiply the (real part of the) QGT by a complex vector.
You should either take the real part of the vector, or perform
the multiplication against the real and imaginary part of the
vector separately and then recomposing the two.
"""
)
)

jax.tree_multimap(check, x, target)
6 changes: 6 additions & 0 deletions netket/optimizer/qgt/qgt_jacobian_pytree.py
Expand Up @@ -25,6 +25,7 @@

from ..linear_operator import LinearOperator, Uninitialized

from .common import check_valid_vector_type
from .qgt_jacobian_pytree_logic import mat_vec, prepare_centered_oks
from .qgt_jacobian_common import choose_jacobian_mode

Expand Down Expand Up @@ -204,6 +205,8 @@ def _matmul(
if self.mode != "holomorphic" and not self._in_solve:
vec, reassemble = nkjax.tree_to_real(vec)

check_valid_vector_type(self.params, vec)

if self.scale is not None:
vec = jax.tree_multimap(jnp.multiply, vec, self.scale)

Expand All @@ -227,10 +230,13 @@ def _matmul(
def _solve(
self: QGTJacobianPyTreeT, solve_fun, y: PyTree, *, x0: Optional[PyTree] = None
) -> PyTree:

# Real-imaginary split RHS in R→R and R→C modes
if self.mode != "holomorphic":
y, reassemble = nkjax.tree_to_real(y)

check_valid_vector_type(self.params, y)

if self.scale is not None:
y = jax.tree_multimap(jnp.divide, y, self.scale)
if x0 is not None:
Expand Down
5 changes: 5 additions & 0 deletions netket/optimizer/qgt/qgt_onthefly.py
Expand Up @@ -23,6 +23,7 @@
from netket.utils.types import PyTree
from netket.utils import warn_deprecation

from .common import check_valid_vector_type
from .qgt_onthefly_logic import mat_vec_factory, mat_vec_chunked_factory

from ..linear_operator import LinearOperator, Uninitialized
Expand Down Expand Up @@ -157,6 +158,8 @@ def onthefly_mat_treevec(
else:
ravel_result = False

check_valid_vector_type(S._params, vec)

vec = nkjax.tree_cast(vec, S._params)

res = S._mat_vec(vec, S.diag_shift)
Expand All @@ -172,6 +175,8 @@ def _solve(
self: QGTOnTheFlyT, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs
) -> PyTree:

check_valid_vector_type(self._params, y)

y = nkjax.tree_cast(y, self._params)

# we could cache this...
Expand Down
24 changes: 24 additions & 0 deletions test/optimizer/test_qgt_solvers.py
Expand Up @@ -25,6 +25,11 @@

from .. import common # noqa: F401

QGT_types = {}
QGT_types["QGTOnTheFly"] = nk.optimizer.qgt.QGTOnTheFly
# QGT_types["QGTJacobianDense"] = nk.optimizer.qgt.QGTJacobianDense
QGT_types["QGTJacobianPyTree"] = nk.optimizer.qgt.QGTJacobianPyTree

QGT_objects = {}

QGT_objects["JacobianPyTree"] = partial(qgt.QGTJacobianPyTree, diag_shift=0.00)
Expand Down Expand Up @@ -75,3 +80,22 @@ def test_qgt_solve(qgt, vstate, solver, _mpi_size, _mpi_rank):
S = qgt(vstate)

x, _ = S.solve(solver, vstate.parameters)


# Issue #789 https://github.com/netket/netket/issues/789
# cannot multiply real qgt by complex vector
@common.skipif_mpi
@pytest.mark.parametrize(
"SType", [pytest.param(T, id=name) for name, T in QGT_types.items()]
)
def test_qgt_throws(SType):
hi = nk.hilbert.Spin(s=1 / 2, N=5)
ma = nk.models.RBMModPhase(alpha=1, dtype=float)
sa = nk.sampler.MetropolisLocal(hi, n_chains=16, reset_chains=False)
vs = nk.vqs.MCState(sa, ma, n_samples=100, n_discard_per_chain=100)

S = vs.quantum_geometric_tensor(SType)
g_cmplx = jax.tree_map(lambda x: x + x * 0.1j, vs.parameters)

with pytest.raises(TypeError, match="Cannot multiply the"):
S @ g_cmplx

0 comments on commit da6171c

Please sign in to comment.