Skip to content

Commit

Permalink
solve plum deprecation warnings (#1572)
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc committed Sep 25, 2023
1 parent 83da3c8 commit a88e5c8
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 42 deletions.
25 changes: 17 additions & 8 deletions netket/utils/_dependencies_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from .version_check import module_version, version_string


def create_msg(pkg_name, cur_version, desired_version, extra_msg=""):
def create_msg(pkg_name, cur_version, desired_version, extra_msg="", pip_pkg_name=None):
if pip_pkg_name is None: # pragma: no cover
pip_pkg_name = pkg_name
return dedent(
f"""
Expand All @@ -33,7 +35,7 @@ def create_msg(pkg_name, cur_version, desired_version, extra_msg=""):
Please update `{pkg_name}` by running the command:
pip install --upgrade pip
pip install --upgrade netket {pkg_name}
pip install --upgrade netket {pip_pkg_name}
(assuming you are using pip. Similar commands can be used on conda).
Expand All @@ -48,22 +50,29 @@ def create_msg(pkg_name, cur_version, desired_version, extra_msg=""):
)


if not module_version("jax") >= (0, 4, 0):
if not module_version("jax") >= (0, 4, 3): # pragma: no cover
cur_version = version_string("optax")
raise ImportError(create_msg("jax", cur_version, "0.4"))
raise ImportError(create_msg("jax", cur_version, "0.4.3"))

if not module_version("optax") >= (0, 1, 1):
if not module_version("optax") >= (0, 1, 3): # pragma: no cover
cur_version = version_string("optax")
extra = """Reason: Optax is NetKet's provider of optimisers. Versions before 0.1.1 did not
support complex numbers and silently returned wrong values, especially when
using optimisers involving the norm of the gradient such as `Adam`.
As recent versions of optax correctly work with complex numbers, please upgrade.
"""
raise ImportError(create_msg("optax", cur_version, "0.1.1", extra))
raise ImportError(create_msg("optax", cur_version, "0.1.3", extra))

if not module_version("flax") >= (0, 5, 0):
if not module_version("flax") >= (0, 6, 5): # pragma: no cover
cur_version = version_string("flax")
extra = """Reason: Flax is NetKet's default neural-network library. Versions before 0.5 had
a bug and did not properly support complex numbers.
"""
raise ImportError(create_msg("flax", cur_version, "0.5", extra))
raise ImportError(create_msg("flax", cur_version, "0.6.5", extra))

if not module_version("plum") >= (2, 2, 2): # pragma: no cover
raise ImportError(
create_msg(
"plum", version_string("plum"), "2.2.2", pip_pkg_name="plum-dispatch"
)
)
36 changes: 30 additions & 6 deletions netket/utils/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import Literal, Union

from plum import dispatch, parametric, convert, Val # noqa: F401
from plum import dispatch, parametric, convert # noqa: F401

# Signature-types for True and False
TrueT = Val[True]
FalseT = Val[False]
Bool = Union[TrueT, FalseT]

# Todo: deprecated in netket 3.10/august 2023 . To eventually remove.
def __getattr__(name):
if name in ["TrueT", "FalseT", "Bool"]:
from netket.utils import warn_deprecation as _warn_deprecation

_warn_deprecation(
"""
The variables `nk.utils.dispatch.{TrueT|FalseT|Bool}` are deprecated. Their usages
should instead be replaced by the following objects:
`TrueT` should be replaced by `typing.Literal[True]`
`FalseT` should be replaced by `typing.Literal[False]`
`Bool` should be replaced by `bool`
"""
)
# Deprecated signature-types for True and False
# TrueT = Literal[True]
# FalseT = Literal[False]
# Bool = Union[TrueT, FalseT]
if name == "TrueT":
return Literal[True]
elif name == "FalseT":
return Literal[False]
elif name == "Bool":
return Union[Literal[True], Literal[False]]

raise AttributeError(f"module {__name__} has no attribute {name}")
15 changes: 5 additions & 10 deletions netket/vqs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from netket.operator import AbstractOperator, Squared
from netket.hilbert import AbstractHilbert
from netket.utils.types import PyTree, PRNGKeyT, NNInitFunc
from netket.utils.dispatch import dispatch, TrueT, FalseT
from netket.utils.dispatch import dispatch
from netket.stats import Stats


Expand Down Expand Up @@ -378,19 +378,14 @@ def expect_and_grad(
@nk.vqs.expect_and_grad.register
expect_and_grad(vstate : VStateType, operator: OperatorType,
use_covariance : bool/TrueT/FalseT, * mutable)
use_covariance : bool/Literal[True]/Literal[False], * mutable)
return ...
"""

# convert to type-static True/False
if isinstance(use_covariance, bool):
use_covariance = TrueT() if use_covariance else FalseT()

if use_covariance is None:
if isinstance(operator, Squared):
use_covariance = FalseT()
use_covariance = False
else:
use_covariance = TrueT() if operator.is_hermitian else FalseT()
use_covariance = True if operator.is_hermitian else False

return expect_and_grad(
vstate, operator, use_covariance, *args, mutable=mutable, **kwargs
Expand All @@ -417,6 +412,6 @@ def expect_and_forces(
@nk.vqs.expect_and_forces.register
expect_and_forces(vstate : VStateType, operator: OperatorType,
use_covariance : bool/TrueT/FalseT, * mutable)
use_covariance : bool/Literal[True]/Literal[False], * mutable)
return ...
"""
6 changes: 3 additions & 3 deletions netket/vqs/full_summ/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import partial, lru_cache
from typing import Callable
from typing import Callable, Literal

import jax
from jax import numpy as jnp
Expand All @@ -23,7 +23,7 @@
from netket.operator import Squared
from netket.stats import Stats
from netket.utils.types import PyTree
from netket.utils.dispatch import dispatch, TrueT
from netket.utils.dispatch import dispatch

from netket.operator import DiscreteOperator

Expand Down Expand Up @@ -136,7 +136,7 @@ def _exp_forces(
def expect_and_grad(
vstate: FullSumState,
: DiscreteOperator,
use_covariance: TrueT,
use_covariance: Literal[True],
*,
mutable: CollectionFilter,
) -> tuple[Stats, PyTree]:
Expand Down
5 changes: 3 additions & 2 deletions netket/vqs/mc/mc_mixed_state/expect_grad_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal, Union

from netket.operator import AbstractOperator
from netket.utils.dispatch import Bool

from netket.vqs import expect_and_grad, expect_and_forces

Expand All @@ -25,7 +26,7 @@
def expect_and_grad_nochunking(
vstate: MCMixedState,
operator: AbstractOperator,
use_covariance: Bool,
use_covariance: Union[Literal[True], Literal[False]],
chunk_size: None,
*args,
**kwargs,
Expand Down
13 changes: 6 additions & 7 deletions netket/vqs/mc/mc_state/expect_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import partial
from typing import Callable
from typing import Callable, Literal

import jax
from jax import numpy as jnp
Expand All @@ -24,7 +24,6 @@
from netket.stats import Stats
from netket.utils import mpi
from netket.utils.types import PyTree
from netket.utils.dispatch import TrueT, FalseT

from netket.operator import (
AbstractOperator,
Expand All @@ -42,14 +41,14 @@
from .state import MCState


# Implementation of expect_and_grad for `use_covariance == True` (due to the TrueT
# Implementation of expect_and_grad for `use_covariance == True` (due to the Literal[True]
# type in the signature).` This case is equivalent to the composition of the
# `expect_and_forces` and `_force_to_grad` functions.
@expect_and_grad.dispatch
def expect_and_grad_covariance(
vstate: MCState,
: AbstractOperator,
use_covariance: TrueT,
use_covariance: Literal[True],
*,
mutable: CollectionFilter,
) -> tuple[Stats, PyTree]:
Expand Down Expand Up @@ -79,9 +78,9 @@ def _force_to_grad(Ō_grad, parameters):
# Specialized dispatch rule for pure states with squared operators as well as general operators
# with use_covariance == False (experimental).
@expect_and_grad.dispatch_multi(
(MCState, Squared[DiscreteOperator], FalseT),
(MCState, Squared[AbstractOperator], FalseT),
(MCState, AbstractOperator, FalseT),
(MCState, Squared[DiscreteOperator], Literal[False]),
(MCState, Squared[AbstractOperator], Literal[False]),
(MCState, AbstractOperator, Literal[False]),
)
def expect_and_grad_nonherm(
vstate,
Expand Down
9 changes: 4 additions & 5 deletions netket/vqs/mc/mc_state/expect_grad_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from typing import Any, Union, Literal
import warnings

import jax
Expand All @@ -22,7 +22,6 @@
from netket.operator import AbstractOperator
from netket.stats import Stats
from netket.utils.types import PyTree
from netket.utils.dispatch import TrueT, Bool

from netket.vqs import expect_and_grad, expect_and_forces

Expand All @@ -34,7 +33,7 @@
def expect_and_grad_nochunking( # noqa: F811
vstate: MCState,
operator: AbstractOperator,
use_covariance: Bool,
use_covariance: Union[Literal[True], Literal[False]],
chunk_size: None,
*args,
**kwargs,
Expand All @@ -47,7 +46,7 @@ def expect_and_grad_nochunking( # noqa: F811
def expect_and_grad_fallback( # noqa: F811
vstate: MCState,
operator: AbstractOperator,
use_covariance: Bool,
use_covariance: Union[Literal[True], Literal[False]],
chunk_size: Any,
*args,
**kwargs,
Expand All @@ -66,7 +65,7 @@ def expect_and_grad_fallback( # noqa: F811
def expect_and_grad_covariance_chunked( # noqa: F811
vstate: MCState,
: AbstractOperator,
use_covariance: TrueT,
use_covariance: Literal[True],
chunk_size: int,
*,
mutable: CollectionFilter,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"numpy~=1.20",
"scipy>=1.5.3, <2",
"tqdm>=4.60, <5",
"plum-dispatch>=1.5.1, <3",
"plum-dispatch>=2.2.2, <3",
"numba>=0.52, <0.59",
"igraph>=0.10.0, <0.11.0",
"jax>=0.4.3, <0.5",
Expand Down
30 changes: 30 additions & 0 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,33 @@ def applyfun(pars, x, mutable=False):
res = afun(None, xb, mutable=True)[0]
assert res.shape == (1,)
assert res == jnp.sum(x, axis=-1)


def test_deprecated_dispatch_bool():
from netket.utils import dispatch

@dispatch.dispatch
def test(a):
return 1

with pytest.warns():

@dispatch.dispatch
def test(a: dispatch.TrueT): # noqa: F811
return True

with pytest.warns():

@dispatch.dispatch
def test(b: dispatch.FalseT): # noqa: F811
return False

assert test(1) == 1
assert test(True) is True
assert test(False) is False

with pytest.warns():

@dispatch.dispatch
def test(b: dispatch.Bool): # noqa: F811
return False

0 comments on commit a88e5c8

Please sign in to comment.