Skip to content

Commit

Permalink
throw error if real-qgtonthefly times complex vector
Browse files Browse the repository at this point in the history
fixup
  • Loading branch information
PhilipVinc committed Aug 24, 2021
1 parent b3966e4 commit 943cec1
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions netket/optimizer/qgt/qgt_onthefly.py
Expand Up @@ -14,6 +14,7 @@

from typing import Callable, Optional, Union
from functools import partial
from textwrap import dedent

import jax
from jax import numpy as jnp
Expand All @@ -28,6 +29,29 @@
from ..linear_operator import LinearOperator, Uninitialized


def check_valid_vector_type(x, target):
"""
Raises a TypeError if x is complex where target is real, because it is not
supported by QGTOnTheFly and the imaginary part would be dicscarded after
anyhow.
"""

def check(x, target):
if jnp.iscomplexobj(target) and not jnp.iscomplexobj(x):
raise TypeError(
dedent(
"""
Cannot multiply the (real part of the) QGT by a complex vector.
You should either take the real part of the vector, or perform
the multiplication against the real and imaginary part of the
vector separately and then recomposing the two.
"""
)
)

jax.tree_multimap(check,x,target)


def QGTOnTheFly(vstate=None, **kwargs) -> "QGTOnTheFlyT":
"""
Lazy representation of an S Matrix computed by performing 2 jvp
Expand Down Expand Up @@ -137,6 +161,8 @@ def onthefly_mat_treevec(
else:
ravel_result = False

check_valid_vector_type(vec, S._params)

vec = nkjax.tree_cast(vec, S._params)

res = S._mat_vec(vec, S.diag_shift)
Expand All @@ -151,6 +177,8 @@ def onthefly_mat_treevec(
def _solve(
self: QGTOnTheFlyT, solve_fun, y: PyTree, *, x0: Optional[PyTree], **kwargs
) -> PyTree:

check_valid_vector_type(y, self._params)

y = nkjax.tree_cast(y, self._params)

Expand Down

0 comments on commit 943cec1

Please sign in to comment.