Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement calculations in periodic boundary conditions #49

Merged
merged 6 commits into from Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Expand Up @@ -44,7 +44,7 @@ jobs:
flake8 .
- name: Lint with pylint
run: |
pylint ferminet
pylint --fail-under 9.75 ferminet
- name: Type check with pytype
run: |
pytype ferminet
Expand Down
14 changes: 13 additions & 1 deletion ferminet/base_config.py
Expand Up @@ -171,13 +171,25 @@ def default() -> ml_collections.ConfigDict:
'determinants': 16,
'after_determinants': (1,),
},
'envelope_type': 'full', # Where does the envelope go?
'bias_orbitals': False, # include bias in last layer to orbitals
# Whether to use the last layer of the two-electron stream of the
# DetNet
'use_last_layer': False,
# If true, determinants are dense rather than block-sparse
'full_det': True,
# String set to module.make_feature_layer, where make_feature_layer is
# callable (type: MakeFeatureLayer) which creates an object with
# member functions init() and apply() that initialize parameters
# for custom input features and modify raw input features,
# respectively. Module is the absolute module containing
# make_feature_layer.
# If not set, networks.make_ferminet_features is used.
'make_feature_layer_fn': '',
# Additional kwargs to pass into make_local_energy_fn.
'make_feature_layer_kwargs': {},
# Same structure as make_feature_layer
'make_envelope_fn': '',
'make_envelope_kwargs': {}
},
'debug': {
# Check optimizer state, parameters and loss and raise an exception if
Expand Down
43 changes: 43 additions & 0 deletions ferminet/configs/heg.py
@@ -0,0 +1,43 @@
"""Unpolarised 14 electron simple cubic homogeneous electron gas."""
gcassella marked this conversation as resolved.
Show resolved Hide resolved

from ferminet import base_config
from ferminet.utils import system
from ferminet.pbc import envelopes

import numpy as np


def _sc_lattice_vecs(rs: float, nelec: int) -> np.ndarray:
"""Returns simple cubic lattice vectors with Wigner-Seitz radius rs."""
volume = (4 / 3) * np.pi * (rs**3) * nelec
length = volume**(1 / 3)
return length * np.eye(3)


def get_config():
"""Returns config for running unpolarised 14 electron gas with FermiNet."""
# Get default options.
cfg = base_config.default()
cfg.system.electrons = (7, 7)
# A ghost atom at the origin defines one-electron coordinate system.
# Element 'X' is a dummy nucleus with zero charge
cfg.system.molecule = [system.Atom("X", (0., 0., 0.))]
# Pretraining is not currently implemented for systems in PBC
cfg.pretrain.method = None

lattice = _sc_lattice_vecs(1.0, sum(cfg.system.electrons))
kpoints = envelopes.make_kpoints(lattice, cfg.system.electrons)

cfg.system.make_local_energy_fn = "ferminet.pbc.hamiltonian.local_energy"
cfg.system.make_local_energy_kwargs = {"lattice": lattice, "heg": True}
cfg.network.make_feature_layer_fn = \
gcassella marked this conversation as resolved.
Show resolved Hide resolved
"ferminet.pbc.feature_layer.make_pbc_feature_layer"
gcassella marked this conversation as resolved.
Show resolved Hide resolved
cfg.network.make_feature_layer_kwargs = {
"lattice": lattice,
"include_r_ae": False
}
cfg.network.make_envelope_fn = \
gcassella marked this conversation as resolved.
Show resolved Hide resolved
"ferminet.pbc.envelopes.make_multiwave_envelope"
cfg.network.make_envelope_kwargs = {"kpoints": kpoints}
cfg.network.full_det = True
return cfg
56 changes: 28 additions & 28 deletions ferminet/networks.py
Expand Up @@ -116,6 +116,23 @@ class FeatureLayerType(enum.Enum):
STANDARD = enum.auto()


class MakeFeatureLayer(Protocol):

def __call__(self,
charges: jnp.ndarray,
nspins: Sequence[int],
ndim: int,
**kwargs: Any) -> FeatureLayer:
"""Builds the FeatureLayer object.

Args:
charges: (natom) array of atom nuclear charges.
nspins: tuple of the number of spin-up and spin-down electrons.
ndim: dimension of the system.
**kwargs: additional kwargs to use for creating the specific FeatureLayer.
"""


## Network settings ##


Expand All @@ -138,8 +155,6 @@ class FermiNetOptions:
block-diagonalise determinants into spin channels.
bias_orbitals: If true, include a bias in the final linear layer to shape
the outputs into orbitals.
envelope_label: Envelope to use to impose orbitals go to zero at infinity.
See envelopes module.
envelope: Envelope object to create and apply the multiplicative envelope.
feature_layer: Feature object to create and apply the input features for the
one- and two-electron layers.
Expand All @@ -150,11 +165,10 @@ class FermiNetOptions:
determinants: int = 16
full_det: bool = True
bias_orbitals: bool = False
envelope_label: envelopes.EnvelopeLabel = envelopes.EnvelopeLabel.ISOTROPIC
envelope: envelopes.Envelope = attr.ib(
default=attr.Factory(
lambda self: envelopes.get_envelope(self.envelope_label),
takes_self=True))
lambda: envelopes.make_isotropic_envelope(),
takes_self=False))
feature_layer: FeatureLayer = attr.ib(
default=attr.Factory(
lambda self: make_ferminet_features(ndim=self.ndim), takes_self=True))
Expand Down Expand Up @@ -344,8 +358,7 @@ def init_fermi_net_params(
PyTree of network parameters. Spin-dependent parameters are only created for
spin channels containing at least one particle.
"""
if options.envelope_label in (envelopes.EnvelopeLabel.STO,
envelopes.EnvelopeLabel.STO_POLY):
if options.envelope.apply_type == envelopes.EnvelopeType.PRE_ORBITAL:
if options.bias_orbitals:
raise ValueError('Cannot bias orbitals w/STO envelope.')
if hf_solution is not None:
Expand Down Expand Up @@ -700,8 +713,8 @@ def make_fermi_net(
nspins: Tuple[int, int],
charges: jnp.ndarray,
*,
envelope: Union[str, envelopes.EnvelopeLabel] = 'isotropic',
feature_layer: Union[str, FeatureLayerType] = FeatureLayerType.STANDARD,
envelope: Optional[envelopes.Envelope] = None,
feature_layer: Optional[FeatureLayer] = None,
bias_orbitals: bool = False,
use_last_layer: bool = False,
hf_solution: Optional[scf.Scf] = None,
Expand Down Expand Up @@ -742,33 +755,20 @@ def make_fermi_net(
"""
del after_determinants

if isinstance(envelope, str):
envelope = envelope.upper().replace('-', '_')
envelope_label = envelopes.EnvelopeLabel[envelope]
else:
# support naming scheme used in config files.
envelope_label = envelope
if envelope_label == envelopes.EnvelopeLabel.EXACT_CUSP:
envelope_kwargs = {'nspins': nspins, 'charges': charges}
else:
envelope_kwargs = {}
if not envelope:
envelope = envelopes.make_isotropic_envelope()

if isinstance(feature_layer, str):
feature_layer = FeatureLayerType[feature_layer.upper()]
if feature_layer == FeatureLayerType.STANDARD:
feature_layer_fns = make_ferminet_features(charges, nspins)
else:
raise ValueError(f'Unsupported feature layer type: {feature_layer}')
if not feature_layer:
feature_layer = make_ferminet_features(charges, nspins)

options = FermiNetOptions(
hidden_dims=hidden_dims,
use_last_layer=use_last_layer,
determinants=determinants,
full_det=full_det,
bias_orbitals=bias_orbitals,
envelope_label=envelope_label,
envelope=envelopes.get_envelope(envelope_label, **envelope_kwargs),
feature_layer=feature_layer_fns,
envelope=envelope,
feature_layer=feature_layer,
)

init = functools.partial(
Expand Down
13 changes: 13 additions & 0 deletions ferminet/pbc/__init__.py
@@ -0,0 +1,13 @@
# Copyright 2022 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
119 changes: 119 additions & 0 deletions ferminet/pbc/envelopes.py
@@ -0,0 +1,119 @@
# Copyright 2022 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

"""Multiplicative envelopes appropriate for periodic boundary conditions.

See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D.,
Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions
with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183."""

from itertools import product
gcassella marked this conversation as resolved.
Show resolved Hide resolved
from ferminet import envelopes
from ferminet.utils import scf
from typing import Mapping, Optional, Sequence, Tuple
gcassella marked this conversation as resolved.
Show resolved Hide resolved

import jax.numpy as jnp


def make_multiwave_envelope(kpoints: jnp.ndarray) -> envelopes.Envelope:
"""Returns an oscillatory envelope.

Envelope consists of a sum of truncated 3D Fourier series, one centered on
each atom, with Fourier frequencies given by kpoints:

sigma_{2i}*cos(kpoints_i.r_{ae}) + sigma_{2i+1}*sin(kpoints_i.r_{ae})

Initialization sets the coefficient of the first term in each
series to 1, and all other coefficients to 0. This corresponds to the
cosine of the first entry in kpoints. If this is [0, 0, 0], the envelope
will evaluate to unity at the beginning of training.

Args:
kpoints: Reciprocal lattice vectors of terms included in the Fourier
series. Shape (nkpoints, ndim) (Note that ndim=3 is currently
a hard-coded default).

Returns:
An instance of ferminet.envelopes.Envelope with apply_type
envelopes.EnvelopeType.PRE_DETERMINANT
"""

def init(natom: int,
output_dims: Sequence[int],
hf: Optional[scf.Scf] = None,
ndim: int = 3) -> Sequence[Mapping[str, jnp.ndarray]]:
"""See ferminet.envelopes.EnvelopeInit."""
del hf, natom, ndim # unused
params = []
nk = kpoints.shape[0]
for output_dim in output_dims:
params.append({'sigma': jnp.zeros((2 * nk, output_dim))})
params[-1]['sigma'] = params[-1]['sigma'].at[0, :].set(1.0)
return params

def apply(*, ae: jnp.ndarray, r_ae: jnp.ndarray, r_ee: jnp.ndarray,
sigma: jnp.ndarray) -> jnp.ndarray:
"""See ferminet.envelopes.EnvelopeApply."""
del r_ae, r_ee # unused
phase_coords = ae @ kpoints.T
waves = jnp.concatenate((jnp.cos(phase_coords), jnp.sin(phase_coords)),
axis=2)
env = waves @ (sigma**2.0)
return jnp.sum(env, axis=1)

return envelopes.Envelope(envelopes.EnvelopeType.PRE_DETERMINANT, init, apply)


def make_kpoints(lattice: jnp.ndarray,
spins: Tuple[int, int],
min_kpoints: Optional[int] = None) -> jnp.ndarray:
"""Generates an array of reciprocal lattice vectors.

Args:
lattice: Matrix whose columns are the primitive lattice vectors of the
system, shape (ndim, ndim). (Note that ndim=3 is currently
a hard-coded default).
spins: Tuple of the number of spin-up and spin-down electrons.
min_kpoints: If specified, the number of kpoints which must be included in
the output. The number of kpoints returned will be the
first filled shell which is larger than this value. Defaults to None,
which results in min_kpoints == sum(spins).

Raises:
ValueError: Fewer kpoints requested by min_kpoints than number of
electrons in the system.

Returns:
jnp.ndarray, shape (nkpoints, ndim), an array of reciprocal lattice
vectors sorted in ascending order according to length.
"""
rec_lattice = 2 * jnp.pi * jnp.linalg.inv(lattice)
# Calculate required no. of k points
if min_kpoints is None:
min_kpoints = sum(spins)
elif min_kpoints < sum(spins):
raise ValueError(
'Number of kpoints must be equal or greater than number of electrons')

dk = 1 + 1e-5
# Generate ordinals of the lowest min_kpoints kpoints
max_k = int(jnp.ceil(min_kpoints * dk)**(1 / 3.))
ordinals = sorted(range(-max_k, max_k+1), key=abs)
ordinals = jnp.asarray(list(product(ordinals, repeat=3)))

kpoints = ordinals @ rec_lattice.T
kpoints = jnp.asarray(sorted(kpoints, key=jnp.linalg.norm))
k_norms = jnp.linalg.norm(kpoints, axis=1)

return kpoints[k_norms <= k_norms[min_kpoints - 1] * dk]