Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gp/continuous geometry #1583

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Examples/Continuous/HO.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import netket as nk

import netket.experimental as nkexp
import jax.numpy as jnp


def v(x):
return jnp.linalg.norm(x) ** 2


hilb = nk.hilbert.Particle(N=10, L=(jnp.inf, jnp.inf, jnp.inf), pbc=False)

geometry = nkexp.geometry.Free(dim=3)
hilb = nk.hilbert.Particle(N=10, geometry=geometry)
sab = nk.sampler.MetropolisGaussian(hilb, sigma=0.1, n_chains=16, n_sweeps=32)

ekin = nk.operator.KineticEnergy(hilb, mass=1.0)
Expand Down
22 changes: 15 additions & 7 deletions Examples/Continuous/Helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import netket as nk
import netket.experimental as nkexp
import jax.numpy as jnp

from optax._src import linear_algebra


def mycb(step, logged_data, driver):
logged_data["acceptance"] = float(driver.state.sampler_state.acceptance)
logged_data["globalnorm"] = float(linear_algebra.global_norm(driver._loss_grad))
return True


def minimum_distance(x, sdim):
"""Computes distances between particles using minimum image convention"""
Expand Down Expand Up @@ -53,14 +62,13 @@ def potential(x, sdim):
d = 0.3 # 1/Angstrom
rm = 2.9673 # Angstrom
L = N / (0.3 * rm)
hilb = nk.hilbert.Particle(N=N, L=(L,), pbc=True)
sab = nk.sampler.MetropolisGaussian(hilb, sigma=0.05, n_chains=16, n_sweeps=32)

geometry = nkexp.geometry.Cell(basis=L * jnp.eye(1))
hilb = nk.hilbert.Particle(N=N, geometry=geometry)
sab = nk.sampler.MetropolisGaussian(hilb, sigma=0.008, n_chains=16, n_sweeps=32)

ekin = nk.operator.KineticEnergy(hilb, mass=1.0)
pot = nk.operator.PotentialEnergy(hilb, lambda x: potential(x, 1))
ha = ekin + pot

model = nk.models.DeepSetRelDistance(
hilbert=hilb,
cusp_exponent=5,
Expand All @@ -71,8 +79,8 @@ def potential(x, sdim):
)
vs = nk.vqs.MCState(sab, model, n_samples=4096, n_discard_per_chain=128)

op = nk.optimizer.Sgd(0.01)
sr = nk.optimizer.SR(diag_shift=0.01)
op = nk.optimizer.Sgd(0.001)
sr = nk.optimizer.SR(diag_shift=0.001)

gs = nk.VMC(ha, op, sab, variational_state=vs, preconditioner=sr)
gs.run(n_iter=1000, out="Helium_10_1d")
gs.run(n_iter=1000, callback=mycb, out="Helium_10_1d")
2 changes: 2 additions & 0 deletions netket/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"hilbert",
"operator",
"logging",
"geometry",
]

from . import hilbert
Expand All @@ -31,6 +32,7 @@
from . import vqs
from . import logging
from . import qsr
from . import geometry

from .driver import TDVP
from .qsr import QSR
Expand Down
159 changes: 159 additions & 0 deletions netket/experimental/geometry/Cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
import numpy as np
from . import AbstractGeometry
from typing import Optional
from netket.utils import HashableArray


def take_sub(key, x, n):
key, subkey = jax.random.split(key)
ind = jax.random.choice(
subkey, jnp.arange(0, x.shape[0], 1), replace=False, shape=(n,)
)
return x[ind, :]


class Cell(AbstractGeometry):
def __init__(
self,
dim: Optional[int] = None,
basis: Optional = None,
):
"""
Construct a periodic geometry in continuous space, given the dimension or a basis. If only the dimension is
given the standard basis is assumed, e.g. (1,0), (0,1) for 2D space.

Args:
dim: (Optional) The number of spatial dimensions of the physical space. If None, a basis has to be
specified. If int and basis is None, the standard basis is assumed.
basis: (Optional) A basis for the physical space. If basis is None, the standard basis is assumed.
"""
if dim is None and basis is None:
raise ValueError(

Check warning on line 46 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L46

Added line #L46 was not covered by tests
"""Specify either the dimension of the geometry or provide a basis for it."""
)

if basis is None:
basis = jnp.eye(dim)

Check warning on line 51 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L51

Added line #L51 was not covered by tests

super().__init__(basis=basis)

@property
def pbc(self):
return True

@property
def volume(self) -> int:
r"""Returns the volume of the given physical space (defined by the basis)."""
return np.abs(np.linalg.det(self.basis))

@property
def extent(self):
r"""Returns an array of the maximum extension in each spatial direction."""
temp = self.from_basis_to_standard(jnp.ones((1, self.dim)))
return temp

def distance(self, x, y=None, norm=False, tri=False, mode=None):
assert (

Check warning on line 71 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L71

Added line #L71 was not covered by tests
self.dim == x.shape[-1]
), "The dimension of the geometry does not match the dimension of the positions."
if y is None:
dis = x[..., None, :, :] - x[..., None, :]

Check warning on line 75 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L75

Added line #L75 was not covered by tests
else:
assert x.shape[-1] == y.shape[-1]
dis = x[..., None, :, :] - y[..., None, :]

Check warning on line 78 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L77-L78

Added lines #L77 - L78 were not covered by tests

if mode == "MIC":
dis = self.from_standard_to_basis(dis)
dis = jnp.remainder(dis + 1.0 / 2.0, 1.0) - 1.0 / 2.0
dis = self.from_basis_to_standard(dis)

Check warning on line 83 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L81-L83

Added lines #L81 - L83 were not covered by tests
if tri is True:
idx = jnp.triu_indices(dis.shape[1], 1)
dis = dis[..., idx[0], idx[1], :]

Check warning on line 86 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L85-L86

Added lines #L85 - L86 were not covered by tests
if norm is True:
return dis, jnp.linalg.norm(dis)

Check warning on line 88 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L88

Added line #L88 was not covered by tests
else:
return dis

Check warning on line 90 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L90

Added line #L90 was not covered by tests

elif mode == "Periodic":
pdis = self.make_periodic(dis)

Check warning on line 93 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L93

Added line #L93 was not covered by tests

if norm is True:
frac = self.from_standard_to_basis(dis)
sij = jnp.einsum("ik,kj->ij", self.basis, self.basis) / jnp.linalg.norm(

Check warning on line 97 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L96-L97

Added lines #L96 - L97 were not covered by tests
self.basis, axis=-1, keepdims=True
)
t1 = jnp.einsum(

Check warning on line 100 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L100

Added line #L100 was not covered by tests
"...i,ij,...j->...",
1 - jnp.cos(2 * jnp.pi * frac),
sij,
1 - jnp.cos(2 * jnp.pi * frac),
)
t2 = jnp.einsum(

Check warning on line 106 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L106

Added line #L106 was not covered by tests
"...i,ij,...j->...",
jnp.sin(2 * jnp.pi * frac),
sij,
jnp.sin(2 * jnp.pi * frac),
)
if tri is True:
idx = jnp.triu_indices(dis.shape[1], 1)
pdis = pdis[..., idx[0], idx[1], :]
pdisnorm = t1[..., idx[0], idx[1], :] + t2[..., idx[0], idx[1], :]
return pdis, pdisnorm

Check warning on line 116 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L113-L116

Added lines #L113 - L116 were not covered by tests
else:
return pdis, t1 + t2

Check warning on line 118 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L118

Added line #L118 was not covered by tests
else:
return pdis

Check warning on line 120 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L120

Added line #L120 was not covered by tests

raise NotImplementedError

Check warning on line 122 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L122

Added line #L122 was not covered by tests

def add(self, x, y):
frac = self.from_standard_to_basis(x)
frac = (frac + y) % 1.0
return self.from_basis_to_standard(frac)

def random_init(self, shape):
batches, N, _ = shape
key = jax.random.PRNGKey(42)
key = jax.random.split(key, num=batches)

n = int(np.ceil(N ** (1 / self.dim)))
xs = jnp.linspace(0, 1, n)
uniform = jnp.array(jnp.meshgrid(*(self.dim * [xs]))).T.reshape(-1, self.dim)
uniform = jnp.tile(uniform, (batches, 1, 1))
uniform = jax.vmap(take_sub, in_axes=(0, 0, None))(key, uniform, N)
uniform = self.from_basis_to_standard(uniform).reshape(batches, -1)

return uniform

def make_periodic(self, x):
r"""Given a batch of position vectors in the geometry, this function returns a periodic decomposition of these
vectors.
"""
frac = self.from_standard_to_basis(x)
return jnp.concatenate(

Check warning on line 148 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L147-L148

Added lines #L147 - L148 were not covered by tests
(jnp.sin(2 * jnp.pi * frac), jnp.cos(2 * jnp.pi * frac)), axis=-1
)

@property
def _attrs(self):
return (HashableArray(self.basis), HashableArray(self.volume), self.dim)

Check warning on line 154 in netket/experimental/geometry/Cell.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Cell.py#L154

Added line #L154 was not covered by tests

def __repr__(self):
return "PeriodicCell(lattice={}, volume={}, dim={})".format(
HashableArray(self.basis), HashableArray(self.volume), self.dim
)
91 changes: 91 additions & 0 deletions netket/experimental/geometry/Free.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.numpy as jnp

from typing import Optional
from . import AbstractGeometry


class Free(AbstractGeometry):
def __init__(
self,
dim: Optional[int] = None,
basis: Optional = None,
):
"""
Construct a periodic geometry in continuous space, given the dimension or a basis. If only the dimension is
given the standard basis is assumed, e.g. (1,0), (0,1) for 2D space.

Args:
dim: (Optional) The number of spatial dimensions of the physical space. If None, a basis has to be
specified. If int and basis is None, the standard basis is assumed.
basis: (Optional) A basis for the physical space. If basis is None, the standard basis is assumed.
"""
if dim is None and basis is None:
raise ValueError(

Check warning on line 36 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L36

Added line #L36 was not covered by tests
"""Specify either the dimension of the geometry or provide a basis for it."""
)

if basis is None:
basis = jnp.eye(dim)

Check warning on line 41 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L41

Added line #L41 was not covered by tests

super().__init__(basis=basis)

@property
def pbc(self):
return False

Check warning on line 47 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L47

Added line #L47 was not covered by tests

def distance(self, x, y=None, norm=False, tri=False, mode="Euclidean"):
assert (

Check warning on line 50 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L50

Added line #L50 was not covered by tests
self.dim == x.shape[-1]
), "The dimension of the geometry does not match the dimension of the positions."
if mode != "Euclidean":
raise ValueError(

Check warning on line 54 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L54

Added line #L54 was not covered by tests
"""There is only the Euclidean mode for free space distance computation."""
)
if mode == "Euclidean":
if y is None:
dis = x[..., None, :, :] - x[..., None, :]

Check warning on line 59 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L59

Added line #L59 was not covered by tests
else:
assert x.shape[-1] == y.shape[-1]
dis = x[..., None, :, :] - y[..., None, :]

Check warning on line 62 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L61-L62

Added lines #L61 - L62 were not covered by tests

if tri:
idx = jnp.triu_indices(dis.shape[1], 1)
dis = dis[..., idx[0], idx[1], :]

Check warning on line 66 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L65-L66

Added lines #L65 - L66 were not covered by tests

if norm and y is None:
return dis, jnp.linalg.norm(dis + jnp.eye(dis.shape[1]), axis=-1) * (

Check warning on line 69 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L69

Added line #L69 was not covered by tests
1 - jnp.eye(dis.shape[1])
)
if norm and y is not None:
dis, jnp.linalg.norm(dis, axis=-1)

Check warning on line 73 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L73

Added line #L73 was not covered by tests

return dis

Check warning on line 75 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L75

Added line #L75 was not covered by tests

raise NotImplementedError

Check warning on line 77 in netket/experimental/geometry/Free.py

View check run for this annotation

Codecov / codecov/patch

netket/experimental/geometry/Free.py#L77

Added line #L77 was not covered by tests

def add(self, x, y):
return x + y

def random_init(self, shape):
batches, N, _ = shape
return jnp.zeros((shape[0], N * self.dim))

@property
def _attrs(self):
return (self.dim,)

def __repr__(self):
return "FreeSpace(dim={})".format(self.dim)
21 changes: 21 additions & 0 deletions netket/experimental/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = [
"AbstractGeometry",
"Cell",
"Free",
]
from .continuous_geometry import AbstractGeometry
from .Cell import Cell
from .Free import Free
Loading
Loading