diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 2ddaf82..13e8fbd 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -44,7 +44,7 @@ jobs: flake8 . - name: Lint with pylint run: | - pylint ferminet + pylint --fail-under 9.75 ferminet - name: Type check with pytype run: | pytype ferminet diff --git a/ferminet/base_config.py b/ferminet/base_config.py index e12cfcc..a475ad0 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -171,13 +171,25 @@ def default() -> ml_collections.ConfigDict: 'determinants': 16, 'after_determinants': (1,), }, - 'envelope_type': 'full', # Where does the envelope go? 'bias_orbitals': False, # include bias in last layer to orbitals # Whether to use the last layer of the two-electron stream of the # DetNet 'use_last_layer': False, # If true, determinants are dense rather than block-sparse 'full_det': True, + # String set to module.make_feature_layer, where make_feature_layer is + # callable (type: MakeFeatureLayer) which creates an object with + # member functions init() and apply() that initialize parameters + # for custom input features and modify raw input features, + # respectively. Module is the absolute module containing + # make_feature_layer. + # If not set, networks.make_ferminet_features is used. + 'make_feature_layer_fn': '', + # Additional kwargs to pass into make_local_energy_fn. + 'make_feature_layer_kwargs': {}, + # Same structure as make_feature_layer + 'make_envelope_fn': '', + 'make_envelope_kwargs': {} }, 'debug': { # Check optimizer state, parameters and loss and raise an exception if diff --git a/ferminet/configs/heg.py b/ferminet/configs/heg.py new file mode 100644 index 0000000..b756919 --- /dev/null +++ b/ferminet/configs/heg.py @@ -0,0 +1,57 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Unpolarised 14 electron simple cubic homogeneous electron gas.""" + +from ferminet import base_config +from ferminet.utils import system +from ferminet.pbc import envelopes + +import numpy as np + + +def _sc_lattice_vecs(rs: float, nelec: int) -> np.ndarray: + """Returns simple cubic lattice vectors with Wigner-Seitz radius rs.""" + volume = (4 / 3) * np.pi * (rs**3) * nelec + length = volume**(1 / 3) + return length * np.eye(3) + + +def get_config(): + """Returns config for running unpolarised 14 electron gas with FermiNet.""" + # Get default options. + cfg = base_config.default() + cfg.system.electrons = (7, 7) + # A ghost atom at the origin defines one-electron coordinate system. + # Element 'X' is a dummy nucleus with zero charge + cfg.system.molecule = [system.Atom("X", (0., 0., 0.))] + # Pretraining is not currently implemented for systems in PBC + cfg.pretrain.method = None + + lattice = _sc_lattice_vecs(1.0, sum(cfg.system.electrons)) + kpoints = envelopes.make_kpoints(lattice, cfg.system.electrons) + + cfg.system.make_local_energy_fn = "ferminet.pbc.hamiltonian.local_energy" + cfg.system.make_local_energy_kwargs = {"lattice": lattice, "heg": True} + cfg.network.make_feature_layer_fn = ( + "ferminet.pbc.feature_layer.make_pbc_feature_layer") + cfg.network.make_feature_layer_kwargs = { + "lattice": lattice, + "include_r_ae": False + } + cfg.network.make_envelope_fn = ( + "ferminet.pbc.envelopes.make_multiwave_envelope") + cfg.network.make_envelope_kwargs = {"kpoints": kpoints} + cfg.network.full_det = True + return cfg diff --git a/ferminet/networks.py b/ferminet/networks.py index 6118c66..262cedf 100644 --- a/ferminet/networks.py +++ b/ferminet/networks.py @@ -116,6 +116,23 @@ class FeatureLayerType(enum.Enum): STANDARD = enum.auto() +class MakeFeatureLayer(Protocol): + + def __call__(self, + charges: jnp.ndarray, + nspins: Sequence[int], + ndim: int, + **kwargs: Any) -> FeatureLayer: + """Builds the FeatureLayer object. + + Args: + charges: (natom) array of atom nuclear charges. + nspins: tuple of the number of spin-up and spin-down electrons. + ndim: dimension of the system. + **kwargs: additional kwargs to use for creating the specific FeatureLayer. + """ + + ## Network settings ## @@ -138,8 +155,6 @@ class FermiNetOptions: block-diagonalise determinants into spin channels. bias_orbitals: If true, include a bias in the final linear layer to shape the outputs into orbitals. - envelope_label: Envelope to use to impose orbitals go to zero at infinity. - See envelopes module. envelope: Envelope object to create and apply the multiplicative envelope. feature_layer: Feature object to create and apply the input features for the one- and two-electron layers. @@ -150,11 +165,10 @@ class FermiNetOptions: determinants: int = 16 full_det: bool = True bias_orbitals: bool = False - envelope_label: envelopes.EnvelopeLabel = envelopes.EnvelopeLabel.ISOTROPIC envelope: envelopes.Envelope = attr.ib( default=attr.Factory( - lambda self: envelopes.get_envelope(self.envelope_label), - takes_self=True)) + lambda: envelopes.make_isotropic_envelope(), + takes_self=False)) feature_layer: FeatureLayer = attr.ib( default=attr.Factory( lambda self: make_ferminet_features(ndim=self.ndim), takes_self=True)) @@ -344,8 +358,7 @@ def init_fermi_net_params( PyTree of network parameters. Spin-dependent parameters are only created for spin channels containing at least one particle. """ - if options.envelope_label in (envelopes.EnvelopeLabel.STO, - envelopes.EnvelopeLabel.STO_POLY): + if options.envelope.apply_type == envelopes.EnvelopeType.PRE_ORBITAL: if options.bias_orbitals: raise ValueError('Cannot bias orbitals w/STO envelope.') if hf_solution is not None: @@ -700,8 +713,8 @@ def make_fermi_net( nspins: Tuple[int, int], charges: jnp.ndarray, *, - envelope: Union[str, envelopes.EnvelopeLabel] = 'isotropic', - feature_layer: Union[str, FeatureLayerType] = FeatureLayerType.STANDARD, + envelope: Optional[envelopes.Envelope] = None, + feature_layer: Optional[FeatureLayer] = None, bias_orbitals: bool = False, use_last_layer: bool = False, hf_solution: Optional[scf.Scf] = None, @@ -742,23 +755,11 @@ def make_fermi_net( """ del after_determinants - if isinstance(envelope, str): - envelope = envelope.upper().replace('-', '_') - envelope_label = envelopes.EnvelopeLabel[envelope] - else: - # support naming scheme used in config files. - envelope_label = envelope - if envelope_label == envelopes.EnvelopeLabel.EXACT_CUSP: - envelope_kwargs = {'nspins': nspins, 'charges': charges} - else: - envelope_kwargs = {} + if not envelope: + envelope = envelopes.make_isotropic_envelope() - if isinstance(feature_layer, str): - feature_layer = FeatureLayerType[feature_layer.upper()] - if feature_layer == FeatureLayerType.STANDARD: - feature_layer_fns = make_ferminet_features(charges, nspins) - else: - raise ValueError(f'Unsupported feature layer type: {feature_layer}') + if not feature_layer: + feature_layer = make_ferminet_features(charges, nspins) options = FermiNetOptions( hidden_dims=hidden_dims, @@ -766,9 +767,8 @@ def make_fermi_net( determinants=determinants, full_det=full_det, bias_orbitals=bias_orbitals, - envelope_label=envelope_label, - envelope=envelopes.get_envelope(envelope_label, **envelope_kwargs), - feature_layer=feature_layer_fns, + envelope=envelope, + feature_layer=feature_layer, ) init = functools.partial( diff --git a/ferminet/pbc/__init__.py b/ferminet/pbc/__init__.py new file mode 100644 index 0000000..b5730d3 --- /dev/null +++ b/ferminet/pbc/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/ferminet/pbc/envelopes.py b/ferminet/pbc/envelopes.py new file mode 100644 index 0000000..466e1a7 --- /dev/null +++ b/ferminet/pbc/envelopes.py @@ -0,0 +1,121 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Multiplicative envelopes appropriate for periodic boundary conditions. + +See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., +Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. +""" + +import itertools +from typing import Mapping, Optional, Sequence, Tuple + +from ferminet import envelopes +from ferminet.utils import scf + +import jax.numpy as jnp + + +def make_multiwave_envelope(kpoints: jnp.ndarray) -> envelopes.Envelope: + """Returns an oscillatory envelope. + + Envelope consists of a sum of truncated 3D Fourier series, one centered on + each atom, with Fourier frequencies given by kpoints: + + sigma_{2i}*cos(kpoints_i.r_{ae}) + sigma_{2i+1}*sin(kpoints_i.r_{ae}) + + Initialization sets the coefficient of the first term in each + series to 1, and all other coefficients to 0. This corresponds to the + cosine of the first entry in kpoints. If this is [0, 0, 0], the envelope + will evaluate to unity at the beginning of training. + + Args: + kpoints: Reciprocal lattice vectors of terms included in the Fourier + series. Shape (nkpoints, ndim) (Note that ndim=3 is currently + a hard-coded default). + + Returns: + An instance of ferminet.envelopes.Envelope with apply_type + envelopes.EnvelopeType.PRE_DETERMINANT + """ + + def init(natom: int, + output_dims: Sequence[int], + hf: Optional[scf.Scf] = None, + ndim: int = 3) -> Sequence[Mapping[str, jnp.ndarray]]: + """See ferminet.envelopes.EnvelopeInit.""" + del hf, natom, ndim # unused + params = [] + nk = kpoints.shape[0] + for output_dim in output_dims: + params.append({'sigma': jnp.zeros((2 * nk, output_dim))}) + params[-1]['sigma'] = params[-1]['sigma'].at[0, :].set(1.0) + return params + + def apply(*, ae: jnp.ndarray, r_ae: jnp.ndarray, r_ee: jnp.ndarray, + sigma: jnp.ndarray) -> jnp.ndarray: + """See ferminet.envelopes.EnvelopeApply.""" + del r_ae, r_ee # unused + phase_coords = ae @ kpoints.T + waves = jnp.concatenate((jnp.cos(phase_coords), jnp.sin(phase_coords)), + axis=2) + env = waves @ (sigma**2.0) + return jnp.sum(env, axis=1) + + return envelopes.Envelope(envelopes.EnvelopeType.PRE_DETERMINANT, init, apply) + + +def make_kpoints(lattice: jnp.ndarray, + spins: Tuple[int, int], + min_kpoints: Optional[int] = None) -> jnp.ndarray: + """Generates an array of reciprocal lattice vectors. + + Args: + lattice: Matrix whose columns are the primitive lattice vectors of the + system, shape (ndim, ndim). (Note that ndim=3 is currently + a hard-coded default). + spins: Tuple of the number of spin-up and spin-down electrons. + min_kpoints: If specified, the number of kpoints which must be included in + the output. The number of kpoints returned will be the + first filled shell which is larger than this value. Defaults to None, + which results in min_kpoints == sum(spins). + + Raises: + ValueError: Fewer kpoints requested by min_kpoints than number of + electrons in the system. + + Returns: + jnp.ndarray, shape (nkpoints, ndim), an array of reciprocal lattice + vectors sorted in ascending order according to length. + """ + rec_lattice = 2 * jnp.pi * jnp.linalg.inv(lattice) + # Calculate required no. of k points + if min_kpoints is None: + min_kpoints = sum(spins) + elif min_kpoints < sum(spins): + raise ValueError( + 'Number of kpoints must be equal or greater than number of electrons') + + dk = 1 + 1e-5 + # Generate ordinals of the lowest min_kpoints kpoints + max_k = int(jnp.ceil(min_kpoints * dk)**(1 / 3.)) + ordinals = sorted(range(-max_k, max_k+1), key=abs) + ordinals = jnp.asarray(list(itertools.product(ordinals, repeat=3))) + + kpoints = ordinals @ rec_lattice.T + kpoints = jnp.asarray(sorted(kpoints, key=jnp.linalg.norm)) + k_norms = jnp.linalg.norm(kpoints, axis=1) + + return kpoints[k_norms <= k_norms[min_kpoints - 1] * dk] diff --git a/ferminet/pbc/feature_layer.py b/ferminet/pbc/feature_layer.py new file mode 100644 index 0000000..6ca5ec9 --- /dev/null +++ b/ferminet/pbc/feature_layer.py @@ -0,0 +1,94 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Feature layer for periodic boundary conditions. + +See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., +Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. +""" + +from typing import Optional, Tuple + +import jax.numpy as jnp + +from ferminet import networks + + +def make_pbc_feature_layer(charges: Optional[jnp.ndarray] = None, + nspins: Optional[Tuple[int, ...]] = None, + ndim: int = 3, + lattice: Optional[jnp.ndarray] = None, + include_r_ae: bool = True) -> networks.FeatureLayer: + """Returns the init and apply functions for periodic features. + + Args: + charges: (natom) array of atom nuclear charges. + nspins: tuple of the number of spin-up and spin-down electrons. + ndim: dimension of the system. + lattice: Matrix whose columns are the primitive lattice vectors of the + system, shape (ndim, ndim). + include_r_ae: Flag to enable electron-atom distance features. Set to False + to avoid cusps with ghost atoms in, e.g., homogeneous electron gas. + """ + + del charges, nspins + + if lattice is None: + lattice = jnp.eye(ndim) + + # Calculate reciprocal vectors, factor 2pi omitted + reciprocal_vecs = jnp.linalg.inv(lattice) + lattice_metric = lattice.T @ lattice + + def periodic_norm(vec, metric): + a = (1 - jnp.cos(2 * jnp.pi * vec)) + b = jnp.sin(2 * jnp.pi * vec) + # i,j = nelectron, natom for ae + cos_term = jnp.einsum('ijm,mn,ijn->ij', a, metric, a) + sin_term = jnp.einsum('ijm,mn,ijn->ij', b, metric, b) + return (1 / (2 * jnp.pi)) * jnp.sqrt(cos_term + sin_term) + + def init() -> Tuple[Tuple[int, int], networks.Param]: + if include_r_ae: + return (2 * ndim, 2 * ndim + 1), {} + else: + return (2 * ndim + 1, 2 * ndim + 1), {} + + def apply(ae, r_ae, ee, r_ee) -> Tuple[jnp.ndarray, jnp.ndarray]: + # One e features in phase coordinates, (s_ae)_i = k_i . ae + s_ae = jnp.einsum('il,jkl->jki', reciprocal_vecs, ae) + # Two e features in phase coordinates + s_ee = jnp.einsum('il,jkl->jki', reciprocal_vecs, ee) + # Periodized features + ae = jnp.concatenate( + (jnp.sin(2 * jnp.pi * s_ae), jnp.cos(2 * jnp.pi * s_ae)), axis=-1) + ee = jnp.concatenate( + (jnp.sin(2 * jnp.pi * s_ee), jnp.cos(2 * jnp.pi * s_ee)), axis=-1) + # Distance features defined on orthonormal projections + r_ae = periodic_norm(s_ae, lattice_metric) + # Don't take gradients through |0| + n = ee.shape[0] + s_ee += jnp.eye(n)[..., None] + r_ee = periodic_norm(s_ee, lattice_metric) * (1.0 - jnp.eye(n)) + + if include_r_ae: + ae_features = ae + else: + ae_features = jnp.concatenate((r_ae[..., None], ae), axis=2) + ae_features = jnp.reshape(ae_features, [jnp.shape(ae_features)[0], -1]) + ee_features = jnp.concatenate((r_ee[..., None], ee), axis=2) + return ae_features, ee_features + + return networks.FeatureLayer(init=init, apply=apply) diff --git a/ferminet/pbc/hamiltonian.py b/ferminet/pbc/hamiltonian.py new file mode 100644 index 0000000..4aa9333 --- /dev/null +++ b/ferminet/pbc/hamiltonian.py @@ -0,0 +1,201 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Ewald summation of Coulomb Hamiltonian in periodic boundary conditions. + +See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., +Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. +""" + +import itertools +from typing import Callable, Sequence + +from ferminet import hamiltonian +from ferminet import networks + +import chex +import jax +import jax.numpy as jnp + + +def make_ewald_potential( + lattice: jnp.ndarray, + atoms: jnp.ndarray, + charges: jnp.ndarray, + truncation_limit: int = 5, + include_heg_background: bool = True +) -> Callable[[jnp.ndarray, jnp.ndarray], float]: + """Creates a function to evaluate infinite Coulomb sum for periodic lattice. + + Args: + lattice: Shape (3, 3). Matrix whose columns are the primitive lattice + vectors. + atoms: Shape (natoms, ndim). Positions of the atoms. + charges: Shape (natoms). Nuclear charges of the atoms. + truncation_limit: Integer. Half side length of cube of nearest neighbours + to primitive cell which are summed over in evaluation of Ewald sum. + Must be large enough to achieve convergence for the real and reciprocal + space sums. + include_heg_background: bool. When True, includes cell-neutralizing + background term for homogeneous electron gas. + + Returns: + Callable with signature f(ae, ee), where (ae, ee) are atom-electon and + electron-electron displacement vectors respectively, which evaluates the + Coulomb sum for the periodic lattice via the Ewald method. + """ + rec = 2 * jnp.pi * jnp.linalg.inv(lattice) + volume = jnp.abs(jnp.linalg.det(lattice)) + # the factor gamma tunes the width of the summands in real / reciprocal space + # and this value is chosen to optimize the convergence trade-off between the + # two sums. See CASINO QMC manual. + gamma = (2.8 / volume**(1 / 3))**2 + ordinals = sorted(range(-truncation_limit, truncation_limit + 1), key=abs) + ordinals = jnp.array(list(itertools.product(ordinals, repeat=3))) + lat_vectors = jnp.einsum('kj,ij->ik', lattice, ordinals) + rec_vectors = jnp.einsum('kj,ij->ik', rec, ordinals[1:]) + rec_vec_square = jnp.einsum('ij,ij->i', rec_vectors, rec_vectors) + lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1) + + def real_space_ewald(separation: jnp.ndarray): + """Real-space Ewald potential between charges seperated by separation.""" + displacements = jnp.linalg.norm( + separation - lat_vectors, axis=-1) # |r - R| + return jnp.sum( + jax.scipy.special.erfc(gamma**0.5 * displacements) / displacements) + + def recp_space_ewald(separation: jnp.ndarray): + """Returns reciprocal-space Ewald potential between charges.""" + return (4 * jnp.pi / volume) * jnp.sum( + jnp.exp(1.0j * jnp.dot(rec_vectors, separation)) * + jnp.exp(-rec_vec_square / (4 * gamma)) / rec_vec_square) + + def ewald_sum(separation: jnp.ndarray): + """Evaluates combined real and reciprocal space Ewald potential.""" + return (real_space_ewald(separation) + recp_space_ewald(separation) - + jnp.pi / (volume * gamma)) + + madelung_const = (jnp.sum( + jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / lat_vec_norm) - + 2 * gamma**0.5 / jnp.pi**0.5) + madelung_const += ((4*jnp.pi / volume) * + jnp.sum(jnp.exp(-rec_vec_square/(4*gamma))/rec_vec_square) - + jnp.pi / (volume*gamma)) + + batch_ewald_sum = jax.vmap(ewald_sum, in_axes=(0,)) + + def atom_electron_potential(ae: jnp.ndarray): + """Evaluates periodic atom-electron potential.""" + nelec = ae.shape[0] + ae = jnp.reshape(ae, [-1, 3]) # flatten electronxatom axis + # calculate potential for each ae pair + ewald = batch_ewald_sum(ae) - madelung_const + return jnp.sum(-jnp.tile(charges, nelec) * ewald) + + def electron_electron_potential(ee: jnp.ndarray): + """Evaluates periodic electron-electron potential.""" + nelec = ee.shape[0] + ee = jnp.reshape(ee, [-1, 3]) + if include_heg_background: + ewald = batch_ewald_sum(ee) + else: + ewald = batch_ewald_sum(ee) - madelung_const + ewald = jnp.reshape(ewald, [nelec, nelec]) + ewald = ewald.at[jnp.diag_indices(nelec)].set(0.0) + if include_heg_background: + return 0.5 * jnp.sum(ewald) + 0.5 * nelec * madelung_const + else: + return 0.5 * jnp.sum(ewald) + + # Atom-atom potential + natom = atoms.shape[0] + if natom > 1: + aa = jnp.reshape(atoms, [1, -1, 3]) - jnp.reshape(atoms, [-1, 1, 3]) + aa = jnp.reshape(aa, [-1, 3]) + chargeprods = (charges[..., None] @ charges[..., None].T).flatten() + ewald = batch_ewald_sum(aa) - madelung_const + ewald = jnp.reshape(ewald, [natom, natom]) + ewald = ewald.at[jnp.diag_indices(natom)].set(0.0) + ewald = ewald.flatten() + atom_atom_potential = 0.5 * jnp.sum(chargeprods * ewald) + else: + atom_atom_potential = 0.0 + + def potential(ae: jnp.ndarray, ee: jnp.ndarray): + """Accumulates atom-electron, atom-atom, and electron-electron potential.""" + # Reduce vectors into first unit cell - Ewald summation + # is only guaranteed to converge close to the origin + phase_ae = jnp.einsum('il,jkl->jki', rec / (2 * jnp.pi), ae) + phase_ee = jnp.einsum('il,jkl->jki', rec / (2 * jnp.pi), ee) + phase_prim_ae = phase_ae % 1 + phase_prim_ee = phase_ee % 1 + prim_ae = jnp.einsum('il,jkl->jki', lattice, phase_prim_ae) + prim_ee = jnp.einsum('il,jkl->jki', lattice, phase_prim_ee) + return jnp.real( + atom_electron_potential(prim_ae) + + electron_electron_potential(prim_ee) + atom_atom_potential) + + return potential + + +def local_energy(f: networks.FermiNetLike, + atoms: jnp.ndarray, + charges: jnp.ndarray, + nspins: Sequence[int], + use_scan: bool = False, + lattice: jnp.ndarray = jnp.eye(3), + heg: bool = True, + convergence_radius: int = 5) -> hamiltonian.LocalEnergy: + """Creates the local energy function in periodic boundary conditions. + + Args: + f: Callable which returns the sign and log of the magnitude of the + wavefunction given the network parameters and configurations data. + atoms: Shape (natoms, ndim). Positions of the atoms. + charges: Shape (natoms). Nuclear charges of the atoms. + nspins: Number of particles of each spin. + use_scan: Whether to use a `lax.scan` for computing the laplacian. + lattice: Shape (ndim, ndim). Matrix of lattice vectors. + heg: bool. Flag to enable features specific to the electron gas. + convergence_radius: int. Radius of cluster summed over by Ewald sums. + + + Returns: + Callable with signature e_l(params, key, data) which evaluates the local + energy of the wavefunction given the parameters params, RNG state key, + and a single MCMC configuration in data. + """ + del nspins + log_abs_f = lambda *args, **kwargs: f(*args, **kwargs)[1] + ke = hamiltonian.local_kinetic_energy(log_abs_f, use_scan=use_scan) + potential_energy = make_ewald_potential(lattice, atoms, charges, + convergence_radius, heg) + + def _e_l(params: networks.ParamTree, key: chex.PRNGKey, + data: jnp.ndarray) -> jnp.ndarray: + """Returns the total energy. + + Args: + params: network parameters. + key: RNG state. + data: MCMC configuration. + """ + del key # unused + ae, ee, _, _ = networks.construct_input_features(data, atoms) + potential = potential_energy(ae, ee) + kinetic = ke(params, data) + return potential + kinetic + + return _e_l diff --git a/ferminet/pbc/tests/__init__.py b/ferminet/pbc/tests/__init__.py new file mode 100644 index 0000000..b5730d3 --- /dev/null +++ b/ferminet/pbc/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/ferminet/pbc/tests/features_test.py b/ferminet/pbc/tests/features_test.py new file mode 100644 index 0000000..1f4db78 --- /dev/null +++ b/ferminet/pbc/tests/features_test.py @@ -0,0 +1,88 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Tests for ferminet.pbc.feature_layer""" + +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized +from ferminet import networks +from ferminet.pbc import feature_layer as pbc_feature_layer +import numpy as np + + +class FeatureLayerTest(parameterized.TestCase): + + @parameterized.parameters([True, False]) + def test_shape(self, heg): + """ Assert that output shape of apply() matches what is expected by init() + """ + nspins = (6, 5) + atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) + natom = atoms.shape[0] + charges = jnp.asarray([2, 5, 7]) + key = jax.random.PRNGKey(42) + key, subkey = jax.random.split(key) + xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) + + feature_layer = pbc_feature_layer.make_pbc_feature_layer( + charges, nspins, 3, lattice=jnp.eye(3), include_r_ae=heg) + + dims, params = feature_layer.init() + ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) + + ae_features, ee_features = feature_layer.apply( + ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) + + assert dims[0] * natom == ae_features.shape[-1] + assert dims[1] == ee_features.shape[-1] + + def test_periodicity(self): + nspins = (6, 5) + atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) + charges = jnp.asarray([2, 5, 7]) + key = jax.random.PRNGKey(42) + key, subkey = jax.random.split(key) + xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) + + feature_layer = pbc_feature_layer.make_pbc_feature_layer( + charges, nspins, 3, lattice=jnp.eye(3), include_r_ae=False) + + _, params = feature_layer.init() + ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) + + ae_features_1, ee_features_1 = feature_layer.apply( + ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) + + # Select random electron coordinate to displace by a random lattice vector + key, subkey = jax.random.split(key) + e_idx = jax.random.randint(subkey, (1,), 0, xs.shape[0]) + key, subkey = jax.random.split(key) + randvec = jax.random.randint(subkey, (3,), 0, 100).astype(jnp.float32) + xs = xs.at[e_idx].add(randvec) + + ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) + + ae_features_2, ee_features_2 = feature_layer.apply( + ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) + + atol, rtol = 4.e-3, 4.e-3 + np.testing.assert_allclose( + ae_features_1, ae_features_2, atol=atol, rtol=rtol) + np.testing.assert_allclose( + ee_features_1, ee_features_2, atol=atol, rtol=rtol) + + +if __name__ == '__main__': + absltest.main() diff --git a/ferminet/pbc/tests/hamiltonian_test.py b/ferminet/pbc/tests/hamiltonian_test.py new file mode 100644 index 0000000..43c12fa --- /dev/null +++ b/ferminet/pbc/tests/hamiltonian_test.py @@ -0,0 +1,87 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""Tests for ferminet.pbc.hamiltonian""" + +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized +from ferminet import networks +from ferminet import base_config +from ferminet.pbc import feature_layer as pbc_feature_layer +from ferminet.pbc import hamiltonian +from ferminet.pbc import envelopes +import numpy as np + + +class PbcHamiltonianTest(parameterized.TestCase): + + def test_periodicity(self): + cfg = base_config.default() + + nspins = (6, 5) + atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) + charges = jnp.asarray([2, 5, 7]) + key = jax.random.PRNGKey(42) + key, subkey = jax.random.split(key) + xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) + + feature_layer = pbc_feature_layer.make_pbc_feature_layer( + charges, nspins, ndim=3, lattice=jnp.eye(3), include_r_ae=False) + + kpoints = envelopes.make_kpoints(jnp.eye(3), nspins) + + network_init, signed_network, _ = networks.make_fermi_net( + atoms, + nspins, + charges, + envelope=envelopes.make_multiwave_envelope(kpoints), + feature_layer=feature_layer, + bias_orbitals=cfg.network.bias_orbitals, + use_last_layer=cfg.network.use_last_layer, + hf_solution=None, + full_det=cfg.network.full_det, + **cfg.network.detnet) + + key, subkey = jax.random.split(key) + params = network_init(subkey) + + local_energy = hamiltonian.local_energy( + f=signed_network, + atoms=atoms, + charges=charges, + nspins=nspins, + use_scan=False, + lattice=jnp.eye(3), + heg=False) + + key, subkey = jax.random.split(key) + e1 = local_energy(params, subkey, xs.flatten()) + + # Select random electron coordinate to displace by a random lattice vec + key, subkey = jax.random.split(key) + e_idx = jax.random.randint(subkey, (1,), 0, xs.shape[0]) + key, subkey = jax.random.split(key) + randvec = jax.random.randint(subkey, (3,), 0, 100).astype(jnp.float32) + xs = xs.at[e_idx].add(randvec) + + key, subkey = jax.random.split(key) + e2 = local_energy(params, subkey, xs.flatten()) + + atol, rtol = 4.e-3, 4.e-3 + np.testing.assert_allclose(e1, e2, atol=atol, rtol=rtol) + + +if __name__ == '__main__': + absltest.main() diff --git a/ferminet/tests/hamiltonian_test.py b/ferminet/tests/hamiltonian_test.py index 310f1e8..97f3444 100644 --- a/ferminet/tests/hamiltonian_test.py +++ b/ferminet/tests/hamiltonian_test.py @@ -154,8 +154,17 @@ def test_fermi_net_laplacian(self, full_det): cfg.network.full_det = full_det cfg.network.detnet.hidden_dims = ((8, 4),)*2 cfg.network.detnet.determinants = 2 + feature_layer = networks.make_ferminet_features( + charges, + cfg.system.electrons, + cfg.system.ndim, + ) network_init, signed_network, _ = networks.make_fermi_net( - atoms, nspins, charges, full_det=full_det, **cfg.network.detnet) + atoms, nspins, charges, + full_det=full_det, + feature_layer=feature_layer, + **cfg.network.detnet + ) network = lambda params, x: signed_network(params, x)[1] key = jax.random.PRNGKey(47) params = network_init(key) diff --git a/ferminet/tests/networks_test.py b/ferminet/tests/networks_test.py index 8d7d673..1667555 100644 --- a/ferminet/tests/networks_test.py +++ b/ferminet/tests/networks_test.py @@ -59,7 +59,7 @@ def _network_options(): # Key for each option and corresponding values to test. all_options = { 'vmap': [True, False], - 'envelope': list(envelopes.EnvelopeLabel), + 'envelope_label': list(envelopes.EnvelopeLabel), 'bias_orbitals': [True, False], 'full_det': [True, False], 'use_last_layer': [True, False], @@ -95,7 +95,6 @@ def test_antisymmetry(self, envelope_label, dtype): kwargs.update({'charges': charges, 'nspins': nspins}) options = networks.FermiNetOptions( hidden_dims=((16, 16), (16, 16)), - envelope_label=envelope_label, envelope=envelopes.get_envelope(envelope_label, **kwargs)) params = networks.init_fermi_net_params( @@ -196,8 +195,19 @@ def test_fermi_net(self, vmap, **network_options): atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) charges = jnp.asarray([2, 5, 7]) key = jax.random.PRNGKey(42) - + feature_layer = networks.make_ferminet_features( + charges, + nspins, + ndim=3, + ) + kwargs = {} + if network_options['envelope_label'] == envelopes.EnvelopeLabel.EXACT_CUSP: + kwargs.update({'charges': charges, 'nspins': nspins}) + network_options['envelope'] = envelopes.get_envelope( + network_options['envelope_label'], **kwargs) + del network_options['envelope_label'] init, fermi_net, _ = networks.make_fermi_net(atoms, nspins, charges, + feature_layer=feature_layer, **network_options) key, subkey = jax.random.split(key) @@ -211,9 +221,8 @@ def test_fermi_net(self, vmap, **network_options): expected_shape = () key, subkey = jax.random.split(key) - sto_envelopes = (envelopes.EnvelopeLabel.STO, - envelopes.EnvelopeLabel.STO_POLY) - if (network_options['envelope'] in sto_envelopes and + envelope = network_options['envelope'] + if (envelope.apply_type == envelopes.EnvelopeType.PRE_ORBITAL and network_options['bias_orbitals']): with self.assertRaises(ValueError): init(subkey) @@ -229,8 +238,13 @@ def test_spin_polarised_fermi_net(self, nspins, full_det): atoms = jnp.zeros(shape=(1, 3)) charges = jnp.ones(shape=1) key = jax.random.PRNGKey(42) + feature_layer = networks.make_ferminet_features( + charges, + nspins, + ndim=3, + ) init, fermi_net, _ = networks.make_fermi_net( - atoms, nspins, charges, full_det=full_det) + atoms, nspins, charges, feature_layer=feature_layer, full_det=full_det) key, subkey1, subkey2 = jax.random.split(key, num=3) params = init(subkey1) xs = jax.random.uniform(subkey2, shape=(sum(nspins) * 3,)) diff --git a/ferminet/train.py b/ferminet/train.py index b7cfb02..c12d27e 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -24,6 +24,7 @@ from ferminet import checkpoint from ferminet import constants from ferminet import curvature_tags_and_blocks +from ferminet import envelopes from ferminet import hamiltonian from ferminet import loss as qmc_loss_functions from ferminet import mcmc @@ -302,12 +303,39 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): ]) hf_solution = hartree_fock if cfg.pretrain.method == 'direct_init' else None + + if cfg.network.make_feature_layer_fn: + feature_layer_module, feature_layer_fn = ( + cfg.network.make_feature_layer_fn.rsplit('.', maxsplit=1)) + feature_layer_module = importlib.import_module(feature_layer_module) + make_feature_layer = getattr(feature_layer_module, feature_layer_fn) + feature_layer = make_feature_layer( + charges, + cfg.system.electrons, + cfg.system.ndim, + **cfg.network.make_feature_layer_kwargs) # type: networks.FeatureLayer + else: + feature_layer = networks.make_ferminet_features( + charges, + cfg.system.electrons, + cfg.system.ndim, + ) + + if cfg.network.make_envelope_fn: + envelope_module, envelope_fn = ( + cfg.network.make_envelope_fn.rsplit('.', maxsplit=1)) + envelope_module = importlib.import_module(envelope_module) + make_envelope = getattr(envelope_module, envelope_fn) + envelope = make_envelope(**cfg.network.make_envelope_kwargs) # type: envelopes.Envelope + else: + envelope = envelopes.make_isotropic_envelope() + network_init, signed_network, network_options = networks.make_fermi_net( atoms, nspins, charges, - envelope=cfg.network.envelope_type, - feature_layer=cfg.network.get('feature_layer', 'standard'), + envelope=envelope, + feature_layer=feature_layer, bias_orbitals=cfg.network.bias_orbitals, use_last_layer=cfg.network.use_last_layer, hf_solution=hf_solution, diff --git a/ferminet/utils/elements.py b/ferminet/utils/elements.py index 932e7c3..0448c8a 100644 --- a/ferminet/utils/elements.py +++ b/ferminet/utils/elements.py @@ -116,6 +116,7 @@ def nbeta(self) -> int: # [_element(s, n+1) for n, s in enumerate(symbols)] # where symbols is the list of chemical symbols of all elements. _ELEMENTS = ( + Element(symbol='X', atomic_number=0, period=0), Element(symbol='H', atomic_number=1, period=1), Element(symbol='He', atomic_number=2, period=1), Element(symbol='Li', atomic_number=3, period=2), diff --git a/ferminet/utils/tests/elements_test.py b/ferminet/utils/tests/elements_test.py index f84bce6..d17e84f 100644 --- a/ferminet/utils/tests/elements_test.py +++ b/ferminet/utils/tests/elements_test.py @@ -26,7 +26,9 @@ def test_elements(self): for n, element in elements.ATOMIC_NUMS.items(): self.assertEqual(n, element.atomic_number) self.assertEqual(elements.SYMBOLS[element.symbol], element) - if element.symbol in ['Li', 'Na', 'K', 'Rb', 'Cs', 'Fr']: + if element.symbol == 'X': + continue + elif element.symbol in ['Li', 'Na', 'K', 'Rb', 'Cs', 'Fr']: self.assertEqual(element.period, elements.ATOMIC_NUMS[n - 1].period + 1) elif element.symbol != 'H': self.assertEqual(element.period, elements.ATOMIC_NUMS[n - 1].period) @@ -80,7 +82,7 @@ def test_element_group_period(self, element, period, group, spin_config, def test_periods(self): self.assertLen(elements.ATOMIC_NUMS, sum(len(period) for period in elements.PERIODS.values())) - period_length = {1: 2, 2: 8, 3: 8, 4: 18, 5: 18, 6: 32, 7: 32} + period_length = {0: 1, 1: 2, 2: 8, 3: 8, 4: 18, 5: 18, 6: 32, 7: 32} for p, es in elements.PERIODS.items(): self.assertLen(es, period_length[p]) @@ -93,7 +95,7 @@ def test_groups(self): # Iterate over all elements in order of atomic number. Group should # increment monotonically (except for accommodating absence of d block and # presence of f block) and reset to 1 on the first element in each period. - for i in range(1, len(elements.ATOMIC_NUMS)+1): + for i in range(1, len(elements.ATOMIC_NUMS)): element = elements.ATOMIC_NUMS[i] if element.atomic_number in period_starts: prev_group = 0 @@ -126,7 +128,7 @@ def test_groups(self): # Check each group contains the expected number of elements. nelements_in_group = [0]*18 for element in elements.ATOMIC_NUMS.values(): - if element.group != -1: + if element.group != -1 and element.period != 0: nelements_in_group[element.group-1] += 1 self.assertListEqual(nelements_in_group, [7, 6] + [4]*10 + [6]*5 + [7]) diff --git a/pylintrc b/pylintrc index 26c23af..03174d6 100644 --- a/pylintrc +++ b/pylintrc @@ -156,12 +156,6 @@ disable=abstract-method, # mypackage.mymodule.MyReporterClass. output-format=text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - # Tells whether to display a full report or only the messages reports=no @@ -273,6 +267,7 @@ ignore-long-lines=(?x)( ^\s*(\#\ )??$| ^\s*(from\s+\S+\s+)?import\s+.+$| ^.*pytype:\sdisable=.*$| + ^.*(\#\s*)type:\s.*$| ^\s*""".*$| ^.*pylint:\sdisable=.*$) @@ -280,12 +275,6 @@ ignore-long-lines=(?x)( # else. single-line-if-stmt=yes -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check= - # Maximum number of lines in a module max-module-lines=99999