From 4cfdf1e388403e038de816fd718f4239b004efe1 Mon Sep 17 00:00:00 2001 From: Gino Cassella Date: Wed, 29 Jun 2022 12:46:55 +0100 Subject: [PATCH 1/6] Implement calculations in periodic boundary conditions Includes the feature layer, envelope, and Hamiltonian described in ```Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183.``` Several changes to the configuration interface for feature layers and envelopes have been made. These are now specified by an interface similar to the local energy function. These changes, and the new periodic boundary conditions, are demonstrated in `ferminet.configs.heg`. --- .github/workflows/ci-build.yaml | 2 +- ferminet/base_config.py | 14 +- ferminet/configs/heg.py | 43 ++++++ ferminet/networks.py | 56 ++++---- ferminet/pbc/__init__.py | 0 ferminet/pbc/envelopes.py | 105 ++++++++++++++ ferminet/pbc/feature_layer.py | 76 ++++++++++ ferminet/pbc/hamiltonian.py | 187 +++++++++++++++++++++++++ ferminet/pbc/tests/__init__.py | 0 ferminet/pbc/tests/features_test.py | 74 ++++++++++ ferminet/pbc/tests/hamiltonian_test.py | 73 ++++++++++ ferminet/tests/hamiltonian_test.py | 11 +- ferminet/tests/networks_test.py | 29 +++- ferminet/train.py | 36 ++++- ferminet/utils/elements.py | 1 + ferminet/utils/tests/elements_test.py | 10 +- pylintrc | 13 +- 17 files changed, 672 insertions(+), 58 deletions(-) create mode 100644 ferminet/configs/heg.py create mode 100644 ferminet/pbc/__init__.py create mode 100644 ferminet/pbc/envelopes.py create mode 100644 ferminet/pbc/feature_layer.py create mode 100644 ferminet/pbc/hamiltonian.py create mode 100644 ferminet/pbc/tests/__init__.py create mode 100644 ferminet/pbc/tests/features_test.py create mode 100644 ferminet/pbc/tests/hamiltonian_test.py diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index 2ddaf82..13e8fbd 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -44,7 +44,7 @@ jobs: flake8 . - name: Lint with pylint run: | - pylint ferminet + pylint --fail-under 9.75 ferminet - name: Type check with pytype run: | pytype ferminet diff --git a/ferminet/base_config.py b/ferminet/base_config.py index e12cfcc..a475ad0 100644 --- a/ferminet/base_config.py +++ b/ferminet/base_config.py @@ -171,13 +171,25 @@ def default() -> ml_collections.ConfigDict: 'determinants': 16, 'after_determinants': (1,), }, - 'envelope_type': 'full', # Where does the envelope go? 'bias_orbitals': False, # include bias in last layer to orbitals # Whether to use the last layer of the two-electron stream of the # DetNet 'use_last_layer': False, # If true, determinants are dense rather than block-sparse 'full_det': True, + # String set to module.make_feature_layer, where make_feature_layer is + # callable (type: MakeFeatureLayer) which creates an object with + # member functions init() and apply() that initialize parameters + # for custom input features and modify raw input features, + # respectively. Module is the absolute module containing + # make_feature_layer. + # If not set, networks.make_ferminet_features is used. + 'make_feature_layer_fn': '', + # Additional kwargs to pass into make_local_energy_fn. + 'make_feature_layer_kwargs': {}, + # Same structure as make_feature_layer + 'make_envelope_fn': '', + 'make_envelope_kwargs': {} }, 'debug': { # Check optimizer state, parameters and loss and raise an exception if diff --git a/ferminet/configs/heg.py b/ferminet/configs/heg.py new file mode 100644 index 0000000..7095c57 --- /dev/null +++ b/ferminet/configs/heg.py @@ -0,0 +1,43 @@ +"""Unpolarised 14 electron simple cubic homogeneous electron gas.""" + +from ferminet import base_config +from ferminet.utils import system +from ferminet.pbc import envelopes + +import numpy as np + + +def _sc_lattice_vecs(rs: float, nelec: int) -> np.ndarray: + """Returns simple cubic lattice vectors with Wigner-Seitz radius rs.""" + volume = (4 / 3) * np.pi * (rs**3) * nelec + length = volume**(1 / 3) + return length * np.eye(3) + + +def get_config(): + """Returns config for running unpolarised 14 electron gas with FermiNet.""" + # Get default options. + cfg = base_config.default() + cfg.system.electrons = (7, 7) + # A ghost atom at the origin defines one-electron coordinate system. + # Element 'X' is a dummy nucleus with zero charge + cfg.system.molecule = [system.Atom("X", (0., 0., 0.))] + # Pretraining is not currently implemented for systems in PBC + cfg.pretrain.method = None + + lattice = _sc_lattice_vecs(1.0, sum(cfg.system.electrons)) + kpoints = envelopes.make_kpoints(lattice, cfg.system.electrons) + + cfg.system.make_local_energy_fn = "ferminet.pbc.hamiltonian.local_energy" + cfg.system.make_local_energy_kwargs = {"lattice": lattice, "heg": True} + cfg.network.make_feature_layer_fn = \ + "ferminet.pbc.feature_layer.make_pbc_feature_layer" + cfg.network.make_feature_layer_kwargs = { + "lattice": lattice, + "include_r_ae": False + } + cfg.network.make_envelope_fn = \ + "ferminet.pbc.envelopes.make_multiwave_envelope" + cfg.network.make_envelope_kwargs = {"kpoints": kpoints} + cfg.network.full_det = True + return cfg diff --git a/ferminet/networks.py b/ferminet/networks.py index 6118c66..262cedf 100644 --- a/ferminet/networks.py +++ b/ferminet/networks.py @@ -116,6 +116,23 @@ class FeatureLayerType(enum.Enum): STANDARD = enum.auto() +class MakeFeatureLayer(Protocol): + + def __call__(self, + charges: jnp.ndarray, + nspins: Sequence[int], + ndim: int, + **kwargs: Any) -> FeatureLayer: + """Builds the FeatureLayer object. + + Args: + charges: (natom) array of atom nuclear charges. + nspins: tuple of the number of spin-up and spin-down electrons. + ndim: dimension of the system. + **kwargs: additional kwargs to use for creating the specific FeatureLayer. + """ + + ## Network settings ## @@ -138,8 +155,6 @@ class FermiNetOptions: block-diagonalise determinants into spin channels. bias_orbitals: If true, include a bias in the final linear layer to shape the outputs into orbitals. - envelope_label: Envelope to use to impose orbitals go to zero at infinity. - See envelopes module. envelope: Envelope object to create and apply the multiplicative envelope. feature_layer: Feature object to create and apply the input features for the one- and two-electron layers. @@ -150,11 +165,10 @@ class FermiNetOptions: determinants: int = 16 full_det: bool = True bias_orbitals: bool = False - envelope_label: envelopes.EnvelopeLabel = envelopes.EnvelopeLabel.ISOTROPIC envelope: envelopes.Envelope = attr.ib( default=attr.Factory( - lambda self: envelopes.get_envelope(self.envelope_label), - takes_self=True)) + lambda: envelopes.make_isotropic_envelope(), + takes_self=False)) feature_layer: FeatureLayer = attr.ib( default=attr.Factory( lambda self: make_ferminet_features(ndim=self.ndim), takes_self=True)) @@ -344,8 +358,7 @@ def init_fermi_net_params( PyTree of network parameters. Spin-dependent parameters are only created for spin channels containing at least one particle. """ - if options.envelope_label in (envelopes.EnvelopeLabel.STO, - envelopes.EnvelopeLabel.STO_POLY): + if options.envelope.apply_type == envelopes.EnvelopeType.PRE_ORBITAL: if options.bias_orbitals: raise ValueError('Cannot bias orbitals w/STO envelope.') if hf_solution is not None: @@ -700,8 +713,8 @@ def make_fermi_net( nspins: Tuple[int, int], charges: jnp.ndarray, *, - envelope: Union[str, envelopes.EnvelopeLabel] = 'isotropic', - feature_layer: Union[str, FeatureLayerType] = FeatureLayerType.STANDARD, + envelope: Optional[envelopes.Envelope] = None, + feature_layer: Optional[FeatureLayer] = None, bias_orbitals: bool = False, use_last_layer: bool = False, hf_solution: Optional[scf.Scf] = None, @@ -742,23 +755,11 @@ def make_fermi_net( """ del after_determinants - if isinstance(envelope, str): - envelope = envelope.upper().replace('-', '_') - envelope_label = envelopes.EnvelopeLabel[envelope] - else: - # support naming scheme used in config files. - envelope_label = envelope - if envelope_label == envelopes.EnvelopeLabel.EXACT_CUSP: - envelope_kwargs = {'nspins': nspins, 'charges': charges} - else: - envelope_kwargs = {} + if not envelope: + envelope = envelopes.make_isotropic_envelope() - if isinstance(feature_layer, str): - feature_layer = FeatureLayerType[feature_layer.upper()] - if feature_layer == FeatureLayerType.STANDARD: - feature_layer_fns = make_ferminet_features(charges, nspins) - else: - raise ValueError(f'Unsupported feature layer type: {feature_layer}') + if not feature_layer: + feature_layer = make_ferminet_features(charges, nspins) options = FermiNetOptions( hidden_dims=hidden_dims, @@ -766,9 +767,8 @@ def make_fermi_net( determinants=determinants, full_det=full_det, bias_orbitals=bias_orbitals, - envelope_label=envelope_label, - envelope=envelopes.get_envelope(envelope_label, **envelope_kwargs), - feature_layer=feature_layer_fns, + envelope=envelope, + feature_layer=feature_layer, ) init = functools.partial( diff --git a/ferminet/pbc/__init__.py b/ferminet/pbc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ferminet/pbc/envelopes.py b/ferminet/pbc/envelopes.py new file mode 100644 index 0000000..5feb82c --- /dev/null +++ b/ferminet/pbc/envelopes.py @@ -0,0 +1,105 @@ +"""Multiplicative envelopes appropriate for periodic boundary conditions. + +See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., +Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183.""" + +from itertools import product +from ferminet import envelopes +from ferminet.utils import scf +from typing import Mapping, Optional, Sequence, Tuple + +import jax.numpy as jnp + + +def make_multiwave_envelope(kpoints: jnp.ndarray) -> envelopes.Envelope: + """Returns an oscillatory envelope. + + Envelope consists of a sum of truncated 3D Fourier series, one centered on + each atom, with Fourier frequencies given by kpoints: + + sigma_{2i}*cos(kpoints_i.r_{ae}) + sigma_{2i+1}*sin(kpoints_i.r_{ae}) + + Initialization sets the coefficient of the first term in each + series to 1, and all other coefficients to 0. This corresponds to the + cosine of the first entry in kpoints. If this is [0, 0, 0], the envelope + will evaluate to unity at the beginning of training. + + Args: + kpoints: Reciprocal lattice vectors of terms included in the Fourier + series. Shape (nkpoints, ndim) (Note that ndim=3 is currently + a hard-coded default). + + Returns: + An instance of ferminet.envelopes.Envelope with apply_type + envelopes.EnvelopeType.PRE_DETERMINANT + """ + + def init(natom: int, + output_dims: Sequence[int], + hf: Optional[scf.Scf] = None, + ndim: int = 3) -> Sequence[Mapping[str, jnp.ndarray]]: + """See ferminet.envelopes.EnvelopeInit.""" + del hf, natom, ndim # unused + params = [] + nk = kpoints.shape[0] + for output_dim in output_dims: + params.append({'sigma': jnp.zeros((2 * nk, output_dim))}) + params[-1]['sigma'] = params[-1]['sigma'].at[0, :].set(1.0) + return params + + def apply(*, ae: jnp.ndarray, r_ae: jnp.ndarray, r_ee: jnp.ndarray, + sigma: jnp.ndarray) -> jnp.ndarray: + """See ferminet.envelopes.EnvelopeApply.""" + del r_ae, r_ee # unused + phase_coords = ae @ kpoints.T + waves = jnp.concatenate((jnp.cos(phase_coords), jnp.sin(phase_coords)), + axis=2) + env = waves @ (sigma**2.0) + return jnp.sum(env, axis=1) + + return envelopes.Envelope(envelopes.EnvelopeType.PRE_DETERMINANT, init, apply) + + +def make_kpoints(lattice: jnp.ndarray, + spins: Tuple[int, int], + min_kpoints: Optional[int] = None) -> jnp.ndarray: + """Generates an array of reciprocal lattice vectors. + + Args: + lattice: Matrix whose columns are the primitive lattice vectors of the + system, shape (ndim, ndim). (Note that ndim=3 is currently + a hard-coded default). + spins: Tuple of the number of spin-up and spin-down electrons. + min_kpoints: If specified, the number of kpoints which must be included in + the output. The number of kpoints returned will be the + first filled shell which is larger than this value. Defaults to None, + which results in min_kpoints == sum(spins). + + Raises: + ValueError: Fewer kpoints requested by min_kpoints than number of + electrons in the system. + + Returns: + jnp.ndarray, shape (nkpoints, ndim), an array of reciprocal lattice + vectors sorted in ascending order according to length. + """ + rec_lattice = 2 * jnp.pi * jnp.linalg.inv(lattice) + # Calculate required no. of k points + if min_kpoints is None: + min_kpoints = sum(spins) + elif min_kpoints < sum(spins): + raise ValueError( + 'Number of kpoints must be equal or greater than number of electrons') + + dk = 1 + 1e-5 + # Generate ordinals of the lowest min_kpoints kpoints + max_k = int(jnp.ceil(min_kpoints * dk)**(1 / 3.)) + ordinals = sorted(range(-max_k, max_k+1), key=abs) + ordinals = jnp.asarray(list(product(ordinals, repeat=3))) + + kpoints = ordinals @ rec_lattice.T + kpoints = jnp.asarray(sorted(kpoints, key=jnp.linalg.norm)) + k_norms = jnp.linalg.norm(kpoints, axis=1) + + return kpoints[k_norms <= k_norms[min_kpoints - 1] * dk] diff --git a/ferminet/pbc/feature_layer.py b/ferminet/pbc/feature_layer.py new file mode 100644 index 0000000..17ed513 --- /dev/null +++ b/ferminet/pbc/feature_layer.py @@ -0,0 +1,76 @@ +"""Feature layer for periodic boundary conditions. + +See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., +Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183.""" + +from typing import Optional, Tuple + +import jax.numpy as jnp +from ferminet.networks import FeatureLayer, Param + + +def make_pbc_feature_layer(charges: Optional[jnp.ndarray] = None, + nspins: Optional[Tuple[int, ...]] = None, + ndim: int = 3, + lattice: jnp.ndarray = jnp.eye(3), + include_r_ae: bool = True) -> FeatureLayer: + """Returns the init and apply functions for periodic features. + + Args: + charges: (natom) array of atom nuclear charges. + nspins: tuple of the number of spin-up and spin-down electrons. + ndim: dimension of the system. + lattice: Matrix whose columns are the primitive lattice vectors of the + system, shape (ndim, ndim). (Note that ndim=3 is currently + a hard-coded default). + include_r_ae: Flag to enable electron-atom distance features. Set to False + to avoid cusps with ghost atoms in, e.g., homogeneous electron gas. + """ + + del charges, nspins + + # Calculate reciprocal vectors, factor 2pi omitted + reciprocal_vecs = jnp.linalg.inv(lattice) + lattice_metric = lattice.T @ lattice + + def periodic_norm(vec, metric): + a = (1 - jnp.cos(2 * jnp.pi * vec)) + b = jnp.sin(2 * jnp.pi * vec) + # i,j = nelectron, natom for ae + cos_term = jnp.einsum('ijm,mn,ijn->ij', a, metric, a) + sin_term = jnp.einsum('ijm,mn,ijn->ij', b, metric, b) + return (1 / (2 * jnp.pi)) * jnp.sqrt(cos_term + sin_term) + + def init() -> Tuple[Tuple[int, int], Param]: + if include_r_ae: + return (2 * ndim, 2 * ndim + 1), {} + else: + return (2 * ndim + 1, 2 * ndim + 1), {} + + def apply(ae, r_ae, ee, r_ee) -> Tuple[jnp.ndarray, jnp.ndarray]: + # One e features in phase coordinates, (s_ae)_i = k_i . ae + s_ae = jnp.einsum('il,jkl->jki', reciprocal_vecs, ae) + # Two e features in phase coordinates + s_ee = jnp.einsum('il,jkl->jki', reciprocal_vecs, ee) + # Periodized features + ae = jnp.concatenate( + (jnp.sin(2 * jnp.pi * s_ae), jnp.cos(2 * jnp.pi * s_ae)), axis=-1) + ee = jnp.concatenate( + (jnp.sin(2 * jnp.pi * s_ee), jnp.cos(2 * jnp.pi * s_ee)), axis=-1) + # Distance features defined on orthonormal projections + r_ae = periodic_norm(s_ae, lattice_metric) + # Don't take gradients through |0| + n = ee.shape[0] + s_ee += jnp.eye(n)[..., None] + r_ee = periodic_norm(s_ee, lattice_metric) * (1.0 - jnp.eye(n)) + + if include_r_ae: + ae_features = ae + else: + ae_features = jnp.concatenate((r_ae[..., None], ae), axis=2) + ae_features = jnp.reshape(ae_features, [jnp.shape(ae_features)[0], -1]) + ee_features = jnp.concatenate((r_ee[..., None], ee), axis=2) + return ae_features, ee_features + + return FeatureLayer(init=init, apply=apply) diff --git a/ferminet/pbc/hamiltonian.py b/ferminet/pbc/hamiltonian.py new file mode 100644 index 0000000..d7cf59c --- /dev/null +++ b/ferminet/pbc/hamiltonian.py @@ -0,0 +1,187 @@ +"""Ewald summation of Coulomb Hamiltonian in periodic boundary conditions. + +See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., +Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183.""" + +from typing import Callable, Sequence + +import chex +from ferminet import networks +from ferminet import hamiltonian +import jax +import jax.numpy as jnp +from itertools import product + + +def make_ewald_potential( + lattice: jnp.ndarray, + atoms: jnp.ndarray, + charges: jnp.ndarray, + truncation_limit: int = 5, + include_heg_background: bool = True +) -> Callable[[jnp.ndarray, jnp.ndarray], float]: + """Creates a function to evaluate infinite Coulomb sum for periodic lattice. + + Args: + lattice: Shape (3, 3). Matrix whose columns are the primitive lattice + vectors. + atoms: Shape (natoms, ndim). Positions of the atoms. + charges: Shape (natoms). Nuclear charges of the atoms. + truncation_limit: Integer. Half side length of cube of nearest neighbours + to primitive cell which are summed over in evaluation of Ewald sum. + Must be large enough to achieve convergence for the real and reciprocal + space sums. + include_heg_background: bool. When True, includes cell-neutralizing + background term for homogeneous electron gas. + + Returns: + Callable with signature f(ae, ee), where (ae, ee) are atom-electon and + electron-electron displacement vectors respectively, which evaluates the + Coulomb sum for the periodic lattice via the Ewald method. + """ + rec = 2 * jnp.pi * jnp.linalg.inv(lattice) + volume = jnp.abs(jnp.linalg.det(lattice)) + # the factor gamma tunes the width of the summands in real / reciprocal space + # and this value is chosen to optimize the convergence trade-off between the + # two sums. See CASINO QMC manual. + gamma = (2.8 / volume**(1 / 3))**2 + ordinals = sorted(range(-truncation_limit, truncation_limit + 1), key=abs) + ordinals = jnp.array(list(product(ordinals, repeat=3))) + lat_vectors = jnp.einsum('kj,ij->ik', lattice, ordinals) + rec_vectors = jnp.einsum('kj,ij->ik', rec, ordinals[1:]) + rec_vec_square = jnp.einsum('ij,ij->i', rec_vectors, rec_vectors) + + def real_sum(separation: jnp.ndarray): + """Real-space Ewald potential between charges seperated by separation. + """ + displacements = jnp.linalg.norm( + separation - lat_vectors, axis=-1) # |r - R| + return jnp.sum( + jax.scipy.special.erfc(gamma**0.5 * displacements) / displacements) + + def recp_sum(separation: jnp.ndarray): + """Reciprocal-space Ewald potential between charges seperated by separation. + """ + return (4 * jnp.pi / volume) * jnp.sum( + jnp.exp(1.0j * jnp.dot(rec_vectors, separation)) * \ + jnp.exp(-rec_vec_square / (4 * gamma)) / rec_vec_square) + + def ewald_sum(separation: jnp.ndarray): + """Evaluates combined real and reciprocal space Ewald potential.""" + return real_sum(separation) + recp_sum(separation) - jnp.pi / ( + volume * gamma) + + lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1) + madelung_const = jnp.sum( + jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / \ + lat_vec_norm) - 2 * gamma**0.5 / jnp.pi**0.5 + madelung_const += (4*jnp.pi / volume) * \ + jnp.sum(jnp.exp(-rec_vec_square/(4*gamma))/rec_vec_square) - \ + jnp.pi / (volume*gamma) + + batch_ewald_sum = jax.vmap(ewald_sum, in_axes=(0,)) + + def atom_electron_potential(ae: jnp.ndarray): + """Evaluates periodic atom-electron potential.""" + nelec = ae.shape[0] + ae = jnp.reshape(ae, [-1, 3]) # flatten electronxatom axis + # calculate potential for each ae pair + ewald = batch_ewald_sum(ae) - madelung_const + return jnp.sum(-jnp.tile(charges, nelec) * ewald) + + def electron_electron_potential(ee: jnp.ndarray): + """Evaluates periodic electron-electron potential.""" + nelec = ee.shape[0] + ee = jnp.reshape(ee, [-1, 3]) + if include_heg_background: + ewald = batch_ewald_sum(ee) + else: + ewald = batch_ewald_sum(ee) - madelung_const + ewald = jnp.reshape(ewald, [nelec, nelec]) + ewald = ewald.at[jnp.diag_indices(nelec)].set(0.0) + if include_heg_background: + return 0.5 * jnp.sum(ewald) + 0.5 * nelec * madelung_const + else: + return 0.5 * jnp.sum(ewald) + + # Atom-atom potential + natom = atoms.shape[0] + if natom > 1: + aa = jnp.reshape(atoms, [1, -1, 3]) - jnp.reshape(atoms, [-1, 1, 3]) + aa = jnp.reshape(aa, [-1, 3]) + chargeprods = (charges[..., None] @ charges[..., None].T).flatten() + ewald = batch_ewald_sum(aa) - madelung_const + ewald = jnp.reshape(ewald, [natom, natom]) + ewald = ewald.at[jnp.diag_indices(natom)].set(0.0) + ewald = ewald.flatten() + atom_atom_potential = 0.5 * jnp.sum(chargeprods * ewald) + else: + atom_atom_potential = 0.0 + + def potential(ae: jnp.ndarray, ee: jnp.ndarray): + """Accumulates atom-electron, atom-atom, and electron-electron potential.""" + # Reduce vectors into first unit cell - Ewald summation + # is only guaranteed to converge close to the origin + phase_ae = jnp.einsum('il,jkl->jki', rec / (2 * jnp.pi), ae) + phase_ee = jnp.einsum('il,jkl->jki', rec / (2 * jnp.pi), ee) + phase_prim_ae = phase_ae % 1 + phase_prim_ee = phase_ee % 1 + prim_ae = jnp.einsum('il,jkl->jki', lattice, phase_prim_ae) + prim_ee = jnp.einsum('il,jkl->jki', lattice, phase_prim_ee) + return jnp.real( + atom_electron_potential(prim_ae) + + electron_electron_potential(prim_ee) + atom_atom_potential) + + return potential + + +def local_energy(f: networks.FermiNetLike, + atoms: jnp.ndarray, + charges: jnp.ndarray, + nspins: Sequence[int], + use_scan: bool = False, + lattice: jnp.ndarray = jnp.eye(3), + heg: bool = True, + convergence_radius: int = 5) -> hamiltonian.LocalEnergy: + """Creates the local energy function in periodic boundary conditions. + + Args: + f: Callable which returns the sign and log of the magnitude of the + wavefunction given the network parameters and configurations data. + atoms: Shape (natoms, ndim). Positions of the atoms. + charges: Shape (natoms). Nuclear charges of the atoms. + nspins: Number of particles of each spin. + use_scan: Whether to use a `lax.scan` for computing the laplacian. + lattice: Shape (ndim, ndim). Matrix of lattice vectors. + heg: bool. Flag to enable features specific to the electron gas. + convergence_radius: int. Radius of cluster summed over by Ewald sums. + + + Returns: + Callable with signature e_l(params, key, data) which evaluates the local + energy of the wavefunction given the parameters params, RNG state key, + and a single MCMC configuration in data. + """ + del nspins + log_abs_f = lambda *args, **kwargs: f(*args, **kwargs)[1] + ke = hamiltonian.local_kinetic_energy(log_abs_f, use_scan=use_scan) + potential_energy = make_ewald_potential(lattice, atoms, charges, + convergence_radius, heg) + + def _e_l(params: networks.ParamTree, key: chex.PRNGKey, + data: jnp.ndarray) -> jnp.ndarray: + """Returns the total energy. + + Args: + params: network parameters. + key: RNG state. + data: MCMC configuration. + """ + del key # unused + ae, ee, _, _ = networks.construct_input_features(data, atoms) + potential = potential_energy(ae, ee) + kinetic = ke(params, data) + return potential + kinetic + + return _e_l diff --git a/ferminet/pbc/tests/__init__.py b/ferminet/pbc/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ferminet/pbc/tests/features_test.py b/ferminet/pbc/tests/features_test.py new file mode 100644 index 0000000..3505fa9 --- /dev/null +++ b/ferminet/pbc/tests/features_test.py @@ -0,0 +1,74 @@ +"""Tests for ferminet.pbc.feature_layer""" + +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized +from ferminet import networks +from ferminet.pbc import feature_layer as pbc_feature_layer +import numpy as np + + +class FeatureLayerTest(parameterized.TestCase): + + @parameterized.parameters([True, False]) + def test_shape(self, heg): + """ Assert that output shape of apply() matches what is expected by init() + """ + nspins = (6, 5) + atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) + natom = atoms.shape[0] + charges = jnp.asarray([2, 5, 7]) + key = jax.random.PRNGKey(42) + key, subkey = jax.random.split(key) + xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) + + feature_layer = pbc_feature_layer.make_pbc_feature_layer( + charges, nspins, 3, lattice=jnp.eye(3), include_r_ae=heg) + + dims, params = feature_layer.init() + ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) + + ae_features, ee_features = feature_layer.apply( + ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) + + assert dims[0] * natom == ae_features.shape[-1] + assert dims[1] == ee_features.shape[-1] + + def test_periodicity(self): + nspins = (6, 5) + atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) + charges = jnp.asarray([2, 5, 7]) + key = jax.random.PRNGKey(42) + key, subkey = jax.random.split(key) + xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) + + feature_layer = pbc_feature_layer.make_pbc_feature_layer( + charges, nspins, 3, lattice=jnp.eye(3), include_r_ae=False) + + _, params = feature_layer.init() + ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) + + ae_features_1, ee_features_1 = feature_layer.apply( + ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) + + # Select random electron coordinate to displace by a random lattice vec + key, subkey = jax.random.split(key) + e_idx = jax.random.randint(subkey, (1,), 0, xs.shape[0]) + key, subkey = jax.random.split(key) + randvec = jax.random.randint(subkey, (3,), 0, 100).astype(jnp.float32) + xs = xs.at[e_idx].add(randvec) + + ae, ee, r_ae, r_ee = networks.construct_input_features(xs, atoms) + + ae_features_2, ee_features_2 = feature_layer.apply( + ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) + + atol, rtol = 4.e-3, 4.e-3 + np.testing.assert_allclose( + ae_features_1, ae_features_2, atol=atol, rtol=rtol) + np.testing.assert_allclose( + ee_features_1, ee_features_2, atol=atol, rtol=rtol) + + +if __name__ == '__main__': + absltest.main() diff --git a/ferminet/pbc/tests/hamiltonian_test.py b/ferminet/pbc/tests/hamiltonian_test.py new file mode 100644 index 0000000..beee6f8 --- /dev/null +++ b/ferminet/pbc/tests/hamiltonian_test.py @@ -0,0 +1,73 @@ +"""Tests for ferminet.pbc.hamiltonian""" + +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized +from ferminet import networks +from ferminet import base_config +from ferminet.pbc import feature_layer as pbc_feature_layer +from ferminet.pbc import hamiltonian +from ferminet.pbc import envelopes +import numpy as np + + +class PbcHamiltonianTest(parameterized.TestCase): + + def test_periodicity(self): + cfg = base_config.default() + + nspins = (6, 5) + atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) + charges = jnp.asarray([2, 5, 7]) + key = jax.random.PRNGKey(42) + key, subkey = jax.random.split(key) + xs = jax.random.uniform(subkey, shape=(sum(nspins), 3)) + + feature_layer = pbc_feature_layer.make_pbc_feature_layer( + charges, nspins, ndim=3, lattice=jnp.eye(3), include_r_ae=False) + + kpoints = envelopes.make_kpoints(jnp.eye(3), nspins) + + network_init, signed_network, _ = networks.make_fermi_net( + atoms, + nspins, + charges, + envelope=envelopes.make_multiwave_envelope(kpoints), + feature_layer=feature_layer, + bias_orbitals=cfg.network.bias_orbitals, + use_last_layer=cfg.network.use_last_layer, + hf_solution=None, + full_det=cfg.network.full_det, + **cfg.network.detnet) + + key, subkey = jax.random.split(key) + params = network_init(subkey) + + local_energy = hamiltonian.local_energy( + f=signed_network, + atoms=atoms, + charges=charges, + nspins=nspins, + use_scan=False, + lattice=jnp.eye(3), + heg=False) + + key, subkey = jax.random.split(key) + e1 = local_energy(params, subkey, xs.flatten()) + + # Select random electron coordinate to displace by a random lattice vec + key, subkey = jax.random.split(key) + e_idx = jax.random.randint(subkey, (1,), 0, xs.shape[0]) + key, subkey = jax.random.split(key) + randvec = jax.random.randint(subkey, (3,), 0, 100).astype(jnp.float32) + xs = xs.at[e_idx].add(randvec) + + key, subkey = jax.random.split(key) + e2 = local_energy(params, subkey, xs.flatten()) + + atol, rtol = 4.e-3, 4.e-3 + np.testing.assert_allclose(e1, e2, atol=atol, rtol=rtol) + + +if __name__ == '__main__': + absltest.main() diff --git a/ferminet/tests/hamiltonian_test.py b/ferminet/tests/hamiltonian_test.py index 310f1e8..97f3444 100644 --- a/ferminet/tests/hamiltonian_test.py +++ b/ferminet/tests/hamiltonian_test.py @@ -154,8 +154,17 @@ def test_fermi_net_laplacian(self, full_det): cfg.network.full_det = full_det cfg.network.detnet.hidden_dims = ((8, 4),)*2 cfg.network.detnet.determinants = 2 + feature_layer = networks.make_ferminet_features( + charges, + cfg.system.electrons, + cfg.system.ndim, + ) network_init, signed_network, _ = networks.make_fermi_net( - atoms, nspins, charges, full_det=full_det, **cfg.network.detnet) + atoms, nspins, charges, + full_det=full_det, + feature_layer=feature_layer, + **cfg.network.detnet + ) network = lambda params, x: signed_network(params, x)[1] key = jax.random.PRNGKey(47) params = network_init(key) diff --git a/ferminet/tests/networks_test.py b/ferminet/tests/networks_test.py index 8d7d673..01bb114 100644 --- a/ferminet/tests/networks_test.py +++ b/ferminet/tests/networks_test.py @@ -59,7 +59,7 @@ def _network_options(): # Key for each option and corresponding values to test. all_options = { 'vmap': [True, False], - 'envelope': list(envelopes.EnvelopeLabel), + 'envelope_label': list(envelopes.EnvelopeLabel), 'bias_orbitals': [True, False], 'full_det': [True, False], 'use_last_layer': [True, False], @@ -95,7 +95,6 @@ def test_antisymmetry(self, envelope_label, dtype): kwargs.update({'charges': charges, 'nspins': nspins}) options = networks.FermiNetOptions( hidden_dims=((16, 16), (16, 16)), - envelope_label=envelope_label, envelope=envelopes.get_envelope(envelope_label, **kwargs)) params = networks.init_fermi_net_params( @@ -196,8 +195,20 @@ def test_fermi_net(self, vmap, **network_options): atoms = jnp.asarray([[0., 0., 0.2], [1.2, 1., -0.2], [2.5, -0.8, 0.6]]) charges = jnp.asarray([2, 5, 7]) key = jax.random.PRNGKey(42) - + feature_layer = networks.make_ferminet_features( + charges, + nspins, + ndim=3, + ) + kwargs = {} + if network_options['envelope_label'] == envelopes.EnvelopeLabel.EXACT_CUSP: + kwargs.update({'charges': charges, 'nspins': nspins}) + network_options['envelope'] = envelopes.get_envelope( + network_options['envelope_label'], **kwargs + ) + del network_options['envelope_label'] init, fermi_net, _ = networks.make_fermi_net(atoms, nspins, charges, + feature_layer=feature_layer, **network_options) key, subkey = jax.random.split(key) @@ -211,9 +222,8 @@ def test_fermi_net(self, vmap, **network_options): expected_shape = () key, subkey = jax.random.split(key) - sto_envelopes = (envelopes.EnvelopeLabel.STO, - envelopes.EnvelopeLabel.STO_POLY) - if (network_options['envelope'] in sto_envelopes and + if (network_options['envelope'].apply_type == + envelopes.EnvelopeType.PRE_ORBITAL and network_options['bias_orbitals']): with self.assertRaises(ValueError): init(subkey) @@ -229,8 +239,13 @@ def test_spin_polarised_fermi_net(self, nspins, full_det): atoms = jnp.zeros(shape=(1, 3)) charges = jnp.ones(shape=1) key = jax.random.PRNGKey(42) + feature_layer = networks.make_ferminet_features( + charges, + nspins, + ndim=3, + ) init, fermi_net, _ = networks.make_fermi_net( - atoms, nspins, charges, full_det=full_det) + atoms, nspins, charges, feature_layer=feature_layer, full_det=full_det) key, subkey1, subkey2 = jax.random.split(key, num=3) params = init(subkey1) xs = jax.random.uniform(subkey2, shape=(sum(nspins) * 3,)) diff --git a/ferminet/train.py b/ferminet/train.py index b7cfb02..d6ed20d 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -24,6 +24,7 @@ from ferminet import checkpoint from ferminet import constants from ferminet import curvature_tags_and_blocks +from ferminet import envelopes from ferminet import hamiltonian from ferminet import loss as qmc_loss_functions from ferminet import mcmc @@ -302,12 +303,39 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): ]) hf_solution = hartree_fock if cfg.pretrain.method == 'direct_init' else None + + if cfg.network.make_feature_layer_fn: + feature_layer_module, feature_layer_fn = ( + cfg.network.make_feature_layer_fn.rsplit('.', maxsplit=1)) + feature_layer_module = importlib.import_module(feature_layer_module) + make_feature_layer = getattr(feature_layer_module, feature_layer_fn) + feature_layer = make_feature_layer( + charges, + cfg.system.electrons, + cfg.system.ndim, + **cfg.network.make_feature_layer_kwargs) # type: networks.FeatureLayer + else: + feature_layer = networks.make_ferminet_features( + charges, + cfg.system.electrons, + cfg.system.ndim, + ) + + if cfg.network.make_envelope_fn: + envelope_module, envelope_fn = ( + cfg.network.make_envelope_fn.rsplit('.', maxsplit=1)) + envelope_module = importlib.import_module(envelope_module) + make_envelope = getattr(envelope_module, envelope_fn) + envelope = make_envelope(**cfg.network.make_envelope_kwargs) # type: envelopes.Envelope + else: + envelope = envelopes.make_isotropic_envelope() + network_init, signed_network, network_options = networks.make_fermi_net( atoms, nspins, charges, - envelope=cfg.network.envelope_type, - feature_layer=cfg.network.get('feature_layer', 'standard'), + envelope=envelope, + feature_layer=feature_layer, bias_orbitals=cfg.network.bias_orbitals, use_last_layer=cfg.network.use_last_layer, hf_solution=hf_solution, @@ -406,14 +434,14 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): local_energy_module, local_energy_fn = ( cfg.system.make_local_energy_fn.rsplit('.', maxsplit=1)) local_energy_module = importlib.import_module(local_energy_module) - make_local_energy = getattr(local_energy_module, local_energy_fn) # type: hamiltonian.MakeLocalEnergy + make_local_energy = getattr(local_energy_module, local_energy_fn) local_energy = make_local_energy( f=signed_network, atoms=atoms, charges=charges, nspins=nspins, use_scan=False, - **cfg.system.make_local_energy_kwargs) + **cfg.system.make_local_energy_kwargs) # type: hamiltonian.MakeLocalEnergy else: local_energy = hamiltonian.local_energy( f=signed_network, diff --git a/ferminet/utils/elements.py b/ferminet/utils/elements.py index 932e7c3..0448c8a 100644 --- a/ferminet/utils/elements.py +++ b/ferminet/utils/elements.py @@ -116,6 +116,7 @@ def nbeta(self) -> int: # [_element(s, n+1) for n, s in enumerate(symbols)] # where symbols is the list of chemical symbols of all elements. _ELEMENTS = ( + Element(symbol='X', atomic_number=0, period=0), Element(symbol='H', atomic_number=1, period=1), Element(symbol='He', atomic_number=2, period=1), Element(symbol='Li', atomic_number=3, period=2), diff --git a/ferminet/utils/tests/elements_test.py b/ferminet/utils/tests/elements_test.py index f84bce6..d17e84f 100644 --- a/ferminet/utils/tests/elements_test.py +++ b/ferminet/utils/tests/elements_test.py @@ -26,7 +26,9 @@ def test_elements(self): for n, element in elements.ATOMIC_NUMS.items(): self.assertEqual(n, element.atomic_number) self.assertEqual(elements.SYMBOLS[element.symbol], element) - if element.symbol in ['Li', 'Na', 'K', 'Rb', 'Cs', 'Fr']: + if element.symbol == 'X': + continue + elif element.symbol in ['Li', 'Na', 'K', 'Rb', 'Cs', 'Fr']: self.assertEqual(element.period, elements.ATOMIC_NUMS[n - 1].period + 1) elif element.symbol != 'H': self.assertEqual(element.period, elements.ATOMIC_NUMS[n - 1].period) @@ -80,7 +82,7 @@ def test_element_group_period(self, element, period, group, spin_config, def test_periods(self): self.assertLen(elements.ATOMIC_NUMS, sum(len(period) for period in elements.PERIODS.values())) - period_length = {1: 2, 2: 8, 3: 8, 4: 18, 5: 18, 6: 32, 7: 32} + period_length = {0: 1, 1: 2, 2: 8, 3: 8, 4: 18, 5: 18, 6: 32, 7: 32} for p, es in elements.PERIODS.items(): self.assertLen(es, period_length[p]) @@ -93,7 +95,7 @@ def test_groups(self): # Iterate over all elements in order of atomic number. Group should # increment monotonically (except for accommodating absence of d block and # presence of f block) and reset to 1 on the first element in each period. - for i in range(1, len(elements.ATOMIC_NUMS)+1): + for i in range(1, len(elements.ATOMIC_NUMS)): element = elements.ATOMIC_NUMS[i] if element.atomic_number in period_starts: prev_group = 0 @@ -126,7 +128,7 @@ def test_groups(self): # Check each group contains the expected number of elements. nelements_in_group = [0]*18 for element in elements.ATOMIC_NUMS.values(): - if element.group != -1: + if element.group != -1 and element.period != 0: nelements_in_group[element.group-1] += 1 self.assertListEqual(nelements_in_group, [7, 6] + [4]*10 + [6]*5 + [7]) diff --git a/pylintrc b/pylintrc index 26c23af..03174d6 100644 --- a/pylintrc +++ b/pylintrc @@ -156,12 +156,6 @@ disable=abstract-method, # mypackage.mymodule.MyReporterClass. output-format=text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - # Tells whether to display a full report or only the messages reports=no @@ -273,6 +267,7 @@ ignore-long-lines=(?x)( ^\s*(\#\ )??$| ^\s*(from\s+\S+\s+)?import\s+.+$| ^.*pytype:\sdisable=.*$| + ^.*(\#\s*)type:\s.*$| ^\s*""".*$| ^.*pylint:\sdisable=.*$) @@ -280,12 +275,6 @@ ignore-long-lines=(?x)( # else. single-line-if-stmt=yes -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check= - # Maximum number of lines in a module max-module-lines=99999 From be2f449983785035ec8c19e79d5ab1b9d6aa98be Mon Sep 17 00:00:00 2001 From: gcassella Date: Thu, 30 Jun 2022 16:11:57 +0100 Subject: [PATCH 2/6] Add license declarations to ferminet.pbc --- ferminet/pbc/__init__.py | 13 +++++++++++++ ferminet/pbc/envelopes.py | 14 ++++++++++++++ ferminet/pbc/feature_layer.py | 14 ++++++++++++++ ferminet/pbc/hamiltonian.py | 14 ++++++++++++++ 4 files changed, 55 insertions(+) diff --git a/ferminet/pbc/__init__.py b/ferminet/pbc/__init__.py index e69de29..85e4882 100644 --- a/ferminet/pbc/__init__.py +++ b/ferminet/pbc/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ferminet/pbc/envelopes.py b/ferminet/pbc/envelopes.py index 5feb82c..718ef6a 100644 --- a/ferminet/pbc/envelopes.py +++ b/ferminet/pbc/envelopes.py @@ -1,3 +1,17 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Multiplicative envelopes appropriate for periodic boundary conditions. See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., diff --git a/ferminet/pbc/feature_layer.py b/ferminet/pbc/feature_layer.py index 17ed513..fb19c67 100644 --- a/ferminet/pbc/feature_layer.py +++ b/ferminet/pbc/feature_layer.py @@ -1,3 +1,17 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Feature layer for periodic boundary conditions. See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., diff --git a/ferminet/pbc/hamiltonian.py b/ferminet/pbc/hamiltonian.py index d7cf59c..33fd4ec 100644 --- a/ferminet/pbc/hamiltonian.py +++ b/ferminet/pbc/hamiltonian.py @@ -1,3 +1,17 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Ewald summation of Coulomb Hamiltonian in periodic boundary conditions. See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., From 41527bf73db38291778183df25f620e207b4b5d1 Mon Sep 17 00:00:00 2001 From: gcassella Date: Thu, 30 Jun 2022 16:17:14 +0100 Subject: [PATCH 3/6] Remove trailing whitespace from license --- ferminet/pbc/__init__.py | 26 +++++++++++++------------- ferminet/pbc/envelopes.py | 26 +++++++++++++------------- ferminet/pbc/feature_layer.py | 26 +++++++++++++------------- ferminet/pbc/hamiltonian.py | 26 +++++++++++++------------- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/ferminet/pbc/__init__.py b/ferminet/pbc/__init__.py index 85e4882..b5730d3 100644 --- a/ferminet/pbc/__init__.py +++ b/ferminet/pbc/__init__.py @@ -1,13 +1,13 @@ -# Copyright 2022 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/ferminet/pbc/envelopes.py b/ferminet/pbc/envelopes.py index 718ef6a..3cb6fc2 100644 --- a/ferminet/pbc/envelopes.py +++ b/ferminet/pbc/envelopes.py @@ -1,16 +1,16 @@ -# Copyright 2022 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License """Multiplicative envelopes appropriate for periodic boundary conditions. diff --git a/ferminet/pbc/feature_layer.py b/ferminet/pbc/feature_layer.py index fb19c67..31d3a7c 100644 --- a/ferminet/pbc/feature_layer.py +++ b/ferminet/pbc/feature_layer.py @@ -1,16 +1,16 @@ -# Copyright 2022 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License """Feature layer for periodic boundary conditions. diff --git a/ferminet/pbc/hamiltonian.py b/ferminet/pbc/hamiltonian.py index 33fd4ec..b70fdcb 100644 --- a/ferminet/pbc/hamiltonian.py +++ b/ferminet/pbc/hamiltonian.py @@ -1,16 +1,16 @@ -# Copyright 2022 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License """Ewald summation of Coulomb Hamiltonian in periodic boundary conditions. From e24a47099194a5d1bf479d924856749d4a226389 Mon Sep 17 00:00:00 2001 From: gcassella Date: Fri, 1 Jul 2022 09:52:38 +0100 Subject: [PATCH 4/6] Add license declarations to pbc tests and config --- ferminet/configs/heg.py | 14 ++++++++++++++ ferminet/pbc/tests/__init__.py | 13 +++++++++++++ ferminet/pbc/tests/features_test.py | 14 ++++++++++++++ ferminet/pbc/tests/hamiltonian_test.py | 14 ++++++++++++++ 4 files changed, 55 insertions(+) diff --git a/ferminet/configs/heg.py b/ferminet/configs/heg.py index 7095c57..eeb09ec 100644 --- a/ferminet/configs/heg.py +++ b/ferminet/configs/heg.py @@ -1,3 +1,17 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + """Unpolarised 14 electron simple cubic homogeneous electron gas.""" from ferminet import base_config diff --git a/ferminet/pbc/tests/__init__.py b/ferminet/pbc/tests/__init__.py index e69de29..b5730d3 100644 --- a/ferminet/pbc/tests/__init__.py +++ b/ferminet/pbc/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License diff --git a/ferminet/pbc/tests/features_test.py b/ferminet/pbc/tests/features_test.py index 3505fa9..0bf89b2 100644 --- a/ferminet/pbc/tests/features_test.py +++ b/ferminet/pbc/tests/features_test.py @@ -1,3 +1,17 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + """Tests for ferminet.pbc.feature_layer""" import jax diff --git a/ferminet/pbc/tests/hamiltonian_test.py b/ferminet/pbc/tests/hamiltonian_test.py index beee6f8..43c12fa 100644 --- a/ferminet/pbc/tests/hamiltonian_test.py +++ b/ferminet/pbc/tests/hamiltonian_test.py @@ -1,3 +1,17 @@ +# Copyright 2022 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + """Tests for ferminet.pbc.hamiltonian""" import jax From 58c2d65f937ed8fd7c91801b90c1ed94b4efa307 Mon Sep 17 00:00:00 2001 From: gcassella Date: Fri, 1 Jul 2022 09:58:41 +0100 Subject: [PATCH 5/6] Modify default args to ferminet.pbc.feature_layer Default to lattice=None, and instance ndim*ndim identity in this case --- ferminet/pbc/feature_layer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ferminet/pbc/feature_layer.py b/ferminet/pbc/feature_layer.py index 31d3a7c..38cf669 100644 --- a/ferminet/pbc/feature_layer.py +++ b/ferminet/pbc/feature_layer.py @@ -27,7 +27,7 @@ def make_pbc_feature_layer(charges: Optional[jnp.ndarray] = None, nspins: Optional[Tuple[int, ...]] = None, ndim: int = 3, - lattice: jnp.ndarray = jnp.eye(3), + lattice: Optional[jnp.ndarray] = None, include_r_ae: bool = True) -> FeatureLayer: """Returns the init and apply functions for periodic features. @@ -36,14 +36,16 @@ def make_pbc_feature_layer(charges: Optional[jnp.ndarray] = None, nspins: tuple of the number of spin-up and spin-down electrons. ndim: dimension of the system. lattice: Matrix whose columns are the primitive lattice vectors of the - system, shape (ndim, ndim). (Note that ndim=3 is currently - a hard-coded default). + system, shape (ndim, ndim). include_r_ae: Flag to enable electron-atom distance features. Set to False to avoid cusps with ghost atoms in, e.g., homogeneous electron gas. """ del charges, nspins + if lattice is None: + lattice = jnp.eye(ndim) + # Calculate reciprocal vectors, factor 2pi omitted reciprocal_vecs = jnp.linalg.inv(lattice) lattice_metric = lattice.T @ lattice From 676d4a3e3805d6621622f2cbd31b4d2762f6e74c Mon Sep 17 00:00:00 2001 From: gcassella Date: Mon, 4 Jul 2022 10:58:31 +0100 Subject: [PATCH 6/6] Formatting to Google style guide --- ferminet/configs/heg.py | 8 +++--- ferminet/pbc/envelopes.py | 10 ++++--- ferminet/pbc/feature_layer.py | 12 ++++---- ferminet/pbc/hamiltonian.py | 44 ++++++++++++++--------------- ferminet/pbc/tests/features_test.py | 2 +- ferminet/tests/networks_test.py | 19 ++++++------- ferminet/train.py | 10 +++---- 7 files changed, 54 insertions(+), 51 deletions(-) diff --git a/ferminet/configs/heg.py b/ferminet/configs/heg.py index eeb09ec..b756919 100644 --- a/ferminet/configs/heg.py +++ b/ferminet/configs/heg.py @@ -44,14 +44,14 @@ def get_config(): cfg.system.make_local_energy_fn = "ferminet.pbc.hamiltonian.local_energy" cfg.system.make_local_energy_kwargs = {"lattice": lattice, "heg": True} - cfg.network.make_feature_layer_fn = \ - "ferminet.pbc.feature_layer.make_pbc_feature_layer" + cfg.network.make_feature_layer_fn = ( + "ferminet.pbc.feature_layer.make_pbc_feature_layer") cfg.network.make_feature_layer_kwargs = { "lattice": lattice, "include_r_ae": False } - cfg.network.make_envelope_fn = \ - "ferminet.pbc.envelopes.make_multiwave_envelope" + cfg.network.make_envelope_fn = ( + "ferminet.pbc.envelopes.make_multiwave_envelope") cfg.network.make_envelope_kwargs = {"kpoints": kpoints} cfg.network.full_det = True return cfg diff --git a/ferminet/pbc/envelopes.py b/ferminet/pbc/envelopes.py index 3cb6fc2..466e1a7 100644 --- a/ferminet/pbc/envelopes.py +++ b/ferminet/pbc/envelopes.py @@ -16,12 +16,14 @@ See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions -with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183.""" +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. +""" + +import itertools +from typing import Mapping, Optional, Sequence, Tuple -from itertools import product from ferminet import envelopes from ferminet.utils import scf -from typing import Mapping, Optional, Sequence, Tuple import jax.numpy as jnp @@ -110,7 +112,7 @@ def make_kpoints(lattice: jnp.ndarray, # Generate ordinals of the lowest min_kpoints kpoints max_k = int(jnp.ceil(min_kpoints * dk)**(1 / 3.)) ordinals = sorted(range(-max_k, max_k+1), key=abs) - ordinals = jnp.asarray(list(product(ordinals, repeat=3))) + ordinals = jnp.asarray(list(itertools.product(ordinals, repeat=3))) kpoints = ordinals @ rec_lattice.T kpoints = jnp.asarray(sorted(kpoints, key=jnp.linalg.norm)) diff --git a/ferminet/pbc/feature_layer.py b/ferminet/pbc/feature_layer.py index 38cf669..6ca5ec9 100644 --- a/ferminet/pbc/feature_layer.py +++ b/ferminet/pbc/feature_layer.py @@ -16,19 +16,21 @@ See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions -with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183.""" +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. +""" from typing import Optional, Tuple import jax.numpy as jnp -from ferminet.networks import FeatureLayer, Param + +from ferminet import networks def make_pbc_feature_layer(charges: Optional[jnp.ndarray] = None, nspins: Optional[Tuple[int, ...]] = None, ndim: int = 3, lattice: Optional[jnp.ndarray] = None, - include_r_ae: bool = True) -> FeatureLayer: + include_r_ae: bool = True) -> networks.FeatureLayer: """Returns the init and apply functions for periodic features. Args: @@ -58,7 +60,7 @@ def periodic_norm(vec, metric): sin_term = jnp.einsum('ijm,mn,ijn->ij', b, metric, b) return (1 / (2 * jnp.pi)) * jnp.sqrt(cos_term + sin_term) - def init() -> Tuple[Tuple[int, int], Param]: + def init() -> Tuple[Tuple[int, int], networks.Param]: if include_r_ae: return (2 * ndim, 2 * ndim + 1), {} else: @@ -89,4 +91,4 @@ def apply(ae, r_ae, ee, r_ee) -> Tuple[jnp.ndarray, jnp.ndarray]: ee_features = jnp.concatenate((r_ee[..., None], ee), axis=2) return ae_features, ee_features - return FeatureLayer(init=init, apply=apply) + return networks.FeatureLayer(init=init, apply=apply) diff --git a/ferminet/pbc/hamiltonian.py b/ferminet/pbc/hamiltonian.py index b70fdcb..4aa9333 100644 --- a/ferminet/pbc/hamiltonian.py +++ b/ferminet/pbc/hamiltonian.py @@ -16,16 +16,18 @@ See Cassella, G., Sutterud, H., Azadi, S., Drummond, N.D., Pfau, D., Spencer, J.S. and Foulkes, W.M.C., 2022. Discovering Quantum Phase Transitions -with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183.""" +with Fermionic Neural Networks. arXiv preprint arXiv:2202.05183. +""" +import itertools from typing import Callable, Sequence -import chex -from ferminet import networks from ferminet import hamiltonian +from ferminet import networks + +import chex import jax import jax.numpy as jnp -from itertools import product def make_ewald_potential( @@ -61,38 +63,36 @@ def make_ewald_potential( # two sums. See CASINO QMC manual. gamma = (2.8 / volume**(1 / 3))**2 ordinals = sorted(range(-truncation_limit, truncation_limit + 1), key=abs) - ordinals = jnp.array(list(product(ordinals, repeat=3))) + ordinals = jnp.array(list(itertools.product(ordinals, repeat=3))) lat_vectors = jnp.einsum('kj,ij->ik', lattice, ordinals) rec_vectors = jnp.einsum('kj,ij->ik', rec, ordinals[1:]) rec_vec_square = jnp.einsum('ij,ij->i', rec_vectors, rec_vectors) + lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1) - def real_sum(separation: jnp.ndarray): - """Real-space Ewald potential between charges seperated by separation. - """ + def real_space_ewald(separation: jnp.ndarray): + """Real-space Ewald potential between charges seperated by separation.""" displacements = jnp.linalg.norm( separation - lat_vectors, axis=-1) # |r - R| return jnp.sum( jax.scipy.special.erfc(gamma**0.5 * displacements) / displacements) - def recp_sum(separation: jnp.ndarray): - """Reciprocal-space Ewald potential between charges seperated by separation. - """ + def recp_space_ewald(separation: jnp.ndarray): + """Returns reciprocal-space Ewald potential between charges.""" return (4 * jnp.pi / volume) * jnp.sum( - jnp.exp(1.0j * jnp.dot(rec_vectors, separation)) * \ + jnp.exp(1.0j * jnp.dot(rec_vectors, separation)) * jnp.exp(-rec_vec_square / (4 * gamma)) / rec_vec_square) def ewald_sum(separation: jnp.ndarray): """Evaluates combined real and reciprocal space Ewald potential.""" - return real_sum(separation) + recp_sum(separation) - jnp.pi / ( - volume * gamma) - - lat_vec_norm = jnp.linalg.norm(lat_vectors[1:], axis=-1) - madelung_const = jnp.sum( - jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / \ - lat_vec_norm) - 2 * gamma**0.5 / jnp.pi**0.5 - madelung_const += (4*jnp.pi / volume) * \ - jnp.sum(jnp.exp(-rec_vec_square/(4*gamma))/rec_vec_square) - \ - jnp.pi / (volume*gamma) + return (real_space_ewald(separation) + recp_space_ewald(separation) - + jnp.pi / (volume * gamma)) + + madelung_const = (jnp.sum( + jax.scipy.special.erfc(gamma**0.5 * lat_vec_norm) / lat_vec_norm) - + 2 * gamma**0.5 / jnp.pi**0.5) + madelung_const += ((4*jnp.pi / volume) * + jnp.sum(jnp.exp(-rec_vec_square/(4*gamma))/rec_vec_square) - + jnp.pi / (volume*gamma)) batch_ewald_sum = jax.vmap(ewald_sum, in_axes=(0,)) diff --git a/ferminet/pbc/tests/features_test.py b/ferminet/pbc/tests/features_test.py index 0bf89b2..1f4db78 100644 --- a/ferminet/pbc/tests/features_test.py +++ b/ferminet/pbc/tests/features_test.py @@ -65,7 +65,7 @@ def test_periodicity(self): ae_features_1, ee_features_1 = feature_layer.apply( ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params) - # Select random electron coordinate to displace by a random lattice vec + # Select random electron coordinate to displace by a random lattice vector key, subkey = jax.random.split(key) e_idx = jax.random.randint(subkey, (1,), 0, xs.shape[0]) key, subkey = jax.random.split(key) diff --git a/ferminet/tests/networks_test.py b/ferminet/tests/networks_test.py index 01bb114..1667555 100644 --- a/ferminet/tests/networks_test.py +++ b/ferminet/tests/networks_test.py @@ -196,16 +196,15 @@ def test_fermi_net(self, vmap, **network_options): charges = jnp.asarray([2, 5, 7]) key = jax.random.PRNGKey(42) feature_layer = networks.make_ferminet_features( - charges, - nspins, - ndim=3, + charges, + nspins, + ndim=3, ) kwargs = {} if network_options['envelope_label'] == envelopes.EnvelopeLabel.EXACT_CUSP: kwargs.update({'charges': charges, 'nspins': nspins}) network_options['envelope'] = envelopes.get_envelope( - network_options['envelope_label'], **kwargs - ) + network_options['envelope_label'], **kwargs) del network_options['envelope_label'] init, fermi_net, _ = networks.make_fermi_net(atoms, nspins, charges, feature_layer=feature_layer, @@ -222,8 +221,8 @@ def test_fermi_net(self, vmap, **network_options): expected_shape = () key, subkey = jax.random.split(key) - if (network_options['envelope'].apply_type == - envelopes.EnvelopeType.PRE_ORBITAL and + envelope = network_options['envelope'] + if (envelope.apply_type == envelopes.EnvelopeType.PRE_ORBITAL and network_options['bias_orbitals']): with self.assertRaises(ValueError): init(subkey) @@ -240,9 +239,9 @@ def test_spin_polarised_fermi_net(self, nspins, full_det): charges = jnp.ones(shape=1) key = jax.random.PRNGKey(42) feature_layer = networks.make_ferminet_features( - charges, - nspins, - ndim=3, + charges, + nspins, + ndim=3, ) init, fermi_net, _ = networks.make_fermi_net( atoms, nspins, charges, feature_layer=feature_layer, full_det=full_det) diff --git a/ferminet/train.py b/ferminet/train.py index d6ed20d..c12d27e 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -316,9 +316,9 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): **cfg.network.make_feature_layer_kwargs) # type: networks.FeatureLayer else: feature_layer = networks.make_ferminet_features( - charges, - cfg.system.electrons, - cfg.system.ndim, + charges, + cfg.system.electrons, + cfg.system.ndim, ) if cfg.network.make_envelope_fn: @@ -434,14 +434,14 @@ def train(cfg: ml_collections.ConfigDict, writer_manager=None): local_energy_module, local_energy_fn = ( cfg.system.make_local_energy_fn.rsplit('.', maxsplit=1)) local_energy_module = importlib.import_module(local_energy_module) - make_local_energy = getattr(local_energy_module, local_energy_fn) + make_local_energy = getattr(local_energy_module, local_energy_fn) # type: hamiltonian.MakeLocalEnergy local_energy = make_local_energy( f=signed_network, atoms=atoms, charges=charges, nspins=nspins, use_scan=False, - **cfg.system.make_local_energy_kwargs) # type: hamiltonian.MakeLocalEnergy + **cfg.system.make_local_energy_kwargs) else: local_energy = hamiltonian.local_energy( f=signed_network,