@@ -169,6 +169,7 @@ def __call__(
169
169
r_ae : jnp .ndarray ,
170
170
ee : jnp .ndarray ,
171
171
r_ee : jnp .ndarray ,
172
+ spins : jnp .ndarray ,
172
173
charges : jnp .ndarray ,
173
174
) -> jnp .ndarray :
174
175
"""Forward evaluation of the equivariant interaction layers.
@@ -179,6 +180,7 @@ def __call__(
179
180
r_ae: electron-nuclear distances.
180
181
ee: electron-electron vectors.
181
182
r_ee: electron-electron distances.
183
+ spins: spin of each electron.
182
184
charges: nuclear charges.
183
185
184
186
Returns:
@@ -533,6 +535,7 @@ def apply(
533
535
r_ae : jnp .ndarray ,
534
536
ee : jnp .ndarray ,
535
537
r_ee : jnp .ndarray ,
538
+ spins : jnp .ndarray ,
536
539
charges : jnp .ndarray ,
537
540
) -> jnp .ndarray :
538
541
"""Applies the FermiNet interaction layers to a walker configuration.
@@ -543,14 +546,15 @@ def apply(
543
546
r_ae: electron-nuclear distances.
544
547
ee: electron-electron vectors.
545
548
r_ee: electron-electron distances.
549
+ spins: spin of each electron.
546
550
charges: nuclear charges.
547
551
548
552
Returns:
549
553
Array of shape (nelectron, output_dim), where the output dimension,
550
554
output_dim, is given by init, and is suitable for projection into orbital
551
555
space.
552
556
"""
553
- del charges # Unused.
557
+ del spins , charges # Unused.
554
558
555
559
ae_features , ee_features = options .feature_layer .apply (
556
560
ae = ae , r_ae = r_ae , ee = ee , r_ee = r_ee , ** params ['input' ]
@@ -578,7 +582,7 @@ def apply(
578
582
def make_orbitals (
579
583
nspins : Tuple [int , ...],
580
584
charges : jnp .ndarray ,
581
- options : FermiNetOptions ,
585
+ options : BaseNetworkOptions ,
582
586
equivariant_layers : Tuple [InitLayersFn , ApplyLayersFn ],
583
587
) -> ...:
584
588
"""Returns init, apply pair for orbitals.
@@ -678,11 +682,15 @@ def apply(
678
682
columns under the exchange of inputs of shape (ndet, nalpha+nbeta,
679
683
nalpha+nbeta) (or (ndet, nalpha, nalpha) and (ndet, nbeta, nbeta)).
680
684
"""
681
- del spins
682
-
683
685
ae , ee , r_ae , r_ee = construct_input_features (pos , atoms , ndim = options .ndim )
684
686
h_to_orbitals = equivariant_layers_apply (
685
- params ['layers' ], ae = ae , r_ae = r_ae , ee = ee , r_ee = r_ee , charges = charges
687
+ params ['layers' ],
688
+ ae = ae ,
689
+ r_ae = r_ae ,
690
+ ee = ee ,
691
+ r_ee = r_ee ,
692
+ spins = spins ,
693
+ charges = charges ,
686
694
)
687
695
688
696
if options .envelope .apply_type == envelopes .EnvelopeType .PRE_ORBITAL :
0 commit comments