Skip to content

Commit b5c82de

Browse files
Ingrid von Glehnjsspencer
Ingrid von Glehn
authored andcommitted
Implement the Psiformer, from A Self-Attention Ansatz for Ab Initio Quantum Chemistry, ICLR 2023.
PiperOrigin-RevId: 513235453 Change-Id: Ic38fe88da627debc51e3bd063d873a21f45e46b0
1 parent 9fd04cd commit b5c82de

File tree

8 files changed

+696
-32
lines changed

8 files changed

+696
-32
lines changed

ferminet/base_config.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,26 @@ def default() -> ml_collections.ConfigDict:
176176
'blocks': 1, # Number of blocks to split the MCMC sampling into
177177
},
178178
'network': {
179-
'detnet': {
179+
'network_type': 'ferminet', # One of 'ferminet' or 'psiformer'.
180+
# Config specific to original FermiNet architecture.
181+
# Only used if network_type is 'ferminet'.
182+
'ferminet': {
180183
'hidden_dims': ((256, 32), (256, 32), (256, 32), (256, 32)),
181-
'determinants': 16,
184+
# Whether to use the last layer of the two-electron stream of the
185+
# FermiNet.
186+
'use_last_layer': False,
182187
},
188+
# Only used if network_type is 'psiformer'.
189+
'psiformer': {
190+
'num_layers': 2,
191+
'num_heads': 4,
192+
'heads_dim': 64,
193+
'mlp_hidden_dims': (256,),
194+
'use_layer_norm': False,
195+
},
196+
# Config common to all architectures.
197+
'determinants': 16, # Number of determinants.
183198
'bias_orbitals': False, # include bias in last layer to orbitals
184-
# Whether to use the last layer of the two-electron stream of the
185-
# DetNet
186-
'use_last_layer': False,
187199
# If true, determinants are dense rather than block-sparse
188200
'full_det': True,
189201
# If specified, include a pre-determinant Jastrow factor.

ferminet/networks.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __call__(
169169
r_ae: jnp.ndarray,
170170
ee: jnp.ndarray,
171171
r_ee: jnp.ndarray,
172+
spins: jnp.ndarray,
172173
charges: jnp.ndarray,
173174
) -> jnp.ndarray:
174175
"""Forward evaluation of the equivariant interaction layers.
@@ -179,6 +180,7 @@ def __call__(
179180
r_ae: electron-nuclear distances.
180181
ee: electron-electron vectors.
181182
r_ee: electron-electron distances.
183+
spins: spin of each electron.
182184
charges: nuclear charges.
183185
184186
Returns:
@@ -533,6 +535,7 @@ def apply(
533535
r_ae: jnp.ndarray,
534536
ee: jnp.ndarray,
535537
r_ee: jnp.ndarray,
538+
spins: jnp.ndarray,
536539
charges: jnp.ndarray,
537540
) -> jnp.ndarray:
538541
"""Applies the FermiNet interaction layers to a walker configuration.
@@ -543,14 +546,15 @@ def apply(
543546
r_ae: electron-nuclear distances.
544547
ee: electron-electron vectors.
545548
r_ee: electron-electron distances.
549+
spins: spin of each electron.
546550
charges: nuclear charges.
547551
548552
Returns:
549553
Array of shape (nelectron, output_dim), where the output dimension,
550554
output_dim, is given by init, and is suitable for projection into orbital
551555
space.
552556
"""
553-
del charges # Unused.
557+
del spins, charges # Unused.
554558

555559
ae_features, ee_features = options.feature_layer.apply(
556560
ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, **params['input']
@@ -578,7 +582,7 @@ def apply(
578582
def make_orbitals(
579583
nspins: Tuple[int, ...],
580584
charges: jnp.ndarray,
581-
options: FermiNetOptions,
585+
options: BaseNetworkOptions,
582586
equivariant_layers: Tuple[InitLayersFn, ApplyLayersFn],
583587
) -> ...:
584588
"""Returns init, apply pair for orbitals.
@@ -678,11 +682,15 @@ def apply(
678682
columns under the exchange of inputs of shape (ndet, nalpha+nbeta,
679683
nalpha+nbeta) (or (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta)).
680684
"""
681-
del spins
682-
683685
ae, ee, r_ae, r_ee = construct_input_features(pos, atoms, ndim=options.ndim)
684686
h_to_orbitals = equivariant_layers_apply(
685-
params['layers'], ae=ae, r_ae=r_ae, ee=ee, r_ee=r_ee, charges=charges
687+
params['layers'],
688+
ae=ae,
689+
r_ae=r_ae,
690+
ee=ee,
691+
r_ee=r_ee,
692+
spins=spins,
693+
charges=charges,
686694
)
687695

688696
if options.envelope.apply_type == envelopes.EnvelopeType.PRE_ORBITAL:

ferminet/pbc/tests/hamiltonian_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ def test_periodicity(self):
5252
envelope=envelopes.make_multiwave_envelope(kpoints),
5353
feature_layer=feature_layer,
5454
bias_orbitals=cfg.network.bias_orbitals,
55-
use_last_layer=cfg.network.use_last_layer,
5655
full_det=cfg.network.full_det,
57-
**cfg.network.detnet
56+
**cfg.network.ferminet,
5857
)
5958

6059
key, subkey = jax.random.split(key)

0 commit comments

Comments
 (0)