Skip to content

Commit

Permalink
Implement the Psiformer, from A Self-Attention Ansatz for Ab Initio Q…
Browse files Browse the repository at this point in the history
…uantum Chemistry, ICLR 2023.

PiperOrigin-RevId: 513235453
Change-Id: Ic38fe88da627debc51e3bd063d873a21f45e46b0
  • Loading branch information
Ingrid von Glehn authored and jsspencer committed Mar 6, 2023
1 parent 9fd04cd commit b5c82de
Show file tree
Hide file tree
Showing 8 changed files with 696 additions and 32 deletions.
22 changes: 17 additions & 5 deletions ferminet/base_config.py
Expand Up @@ -176,14 +176,26 @@ def default() -> ml_collections.ConfigDict:
'blocks': 1, # Number of blocks to split the MCMC sampling into
},
'network': {
'detnet': {
'network_type': 'ferminet', # One of 'ferminet' or 'psiformer'.
# Config specific to original FermiNet architecture.
# Only used if network_type is 'ferminet'.
'ferminet': {
'hidden_dims': ((256, 32), (256, 32), (256, 32), (256, 32)),
'determinants': 16,
# Whether to use the last layer of the two-electron stream of the
# FermiNet.
'use_last_layer': False,
},
# Only used if network_type is 'psiformer'.
'psiformer': {
'num_layers': 2,
'num_heads': 4,
'heads_dim': 64,
'mlp_hidden_dims': (256,),
'use_layer_norm': False,
},
# Config common to all architectures.
'determinants': 16, # Number of determinants.
'bias_orbitals': False, # include bias in last layer to orbitals
# Whether to use the last layer of the two-electron stream of the
# DetNet
'use_last_layer': False,
# If true, determinants are dense rather than block-sparse
'full_det': True,
# If specified, include a pre-determinant Jastrow factor.
Expand Down
18 changes: 13 additions & 5 deletions ferminet/networks.py
Expand Up @@ -169,6 +169,7 @@ def __call__(
r_ae: jnp.ndarray,
ee: jnp.ndarray,
r_ee: jnp.ndarray,
spins: jnp.ndarray,
charges: jnp.ndarray,
) -> jnp.ndarray:
"""Forward evaluation of the equivariant interaction layers.
Expand All @@ -179,6 +180,7 @@ def __call__(
r_ae: electron-nuclear distances.
ee: electron-electron vectors.
r_ee: electron-electron distances.
spins: spin of each electron.
charges: nuclear charges.
Returns:
Expand Down Expand Up @@ -533,6 +535,7 @@ def apply(
r_ae: jnp.ndarray,
ee: jnp.ndarray,
r_ee: jnp.ndarray,
spins: jnp.ndarray,
charges: jnp.ndarray,
) -> jnp.ndarray:
"""Applies the FermiNet interaction layers to a walker configuration.
Expand All @@ -543,14 +546,15 @@ def apply(
r_ae: electron-nuclear distances.
ee: electron-electron vectors.
r_ee: electron-electron distances.
spins: spin of each electron.
charges: nuclear charges.
Returns:
Array of shape (nelectron, output_dim), where the output dimension,
output_dim, is given by init, and is suitable for projection into orbital
space.
"""
del charges # Unused.
del spins, charges # Unused.

ae_features, ee_features = options.feature_layer.apply(
ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params['input']
Expand Down Expand Up @@ -578,7 +582,7 @@ def apply(
def make_orbitals(
nspins: Tuple[int, ...],
charges: jnp.ndarray,
options: FermiNetOptions,
options: BaseNetworkOptions,
equivariant_layers: Tuple[InitLayersFn, ApplyLayersFn],
) -> ...:
"""Returns init, apply pair for orbitals.
Expand Down Expand Up @@ -678,11 +682,15 @@ def apply(
columns under the exchange of inputs of shape (ndet, nalpha+nbeta,
nalpha+nbeta) (or (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta)).
"""
del spins

ae, ee, r_ae, r_ee = construct_input_features(pos, atoms, ndim=options.ndim)
h_to_orbitals = equivariant_layers_apply(
params['layers'], ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, charges=charges
params['layers'],
ae=ae,
r_ae=r_ae,
ee=ee,
r_ee=r_ee,
spins=spins,
charges=charges,
)

if options.envelope.apply_type == envelopes.EnvelopeType.PRE_ORBITAL:
Expand Down
3 changes: 1 addition & 2 deletions ferminet/pbc/tests/hamiltonian_test.py
Expand Up @@ -52,9 +52,8 @@ def test_periodicity(self):
envelope=envelopes.make_multiwave_envelope(kpoints),
feature_layer=feature_layer,
bias_orbitals=cfg.network.bias_orbitals,
use_last_layer=cfg.network.use_last_layer,
full_det=cfg.network.full_det,
**cfg.network.detnet
**cfg.network.ferminet,
)

key, subkey = jax.random.split(key)
Expand Down

0 comments on commit b5c82de

Please sign in to comment.