In [1]:
%load_ext autoreload
%autoreload 2
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np

from neural_pfaffian.nn.envelopes import EfficientEnvelope
from neural_pfaffian.nn.ferminet import FermiNet
from neural_pfaffian.nn.meta_network import MetaGNN
from neural_pfaffian.nn.module import ParamTypes, ReparamModule
from neural_pfaffian.nn.orbitals import Pfaffian
from neural_pfaffian.nn.psiformer import PsiFormer
from neural_pfaffian.nn.wave_function import GeneralizedWaveFunction, WaveFunction
from neural_pfaffian.systems import Systems

In [3]:
system = Systems(
    spins=((2, 2),),
    charges=((2,),),
    electrons=jnp.asarray(np.random.normal(size=(4, 3))),
    nuclei=jnp.zeros((1, 3)),
)
embedding = FermiNet(256, [(256, 32), (256, 32)], jnp.tanh)
embedding = PsiFormer(256, 256, 4, 4, jnp.tanh)
# embedding = Moon(256, 4, 256, 64, 16, 6, jnp.tanh)
envelopes = EfficientEnvelope(1, 1, True, 8)
orbitals = Pfaffian(16, 8, envelopes, 10, 0.1, 1.0, 1e-4)
meta = MetaGNN(None, 64, 128, 4, jnp.tanh, 8, tuple(system.flat_charges))
wf = WaveFunction(embedding, orbitals, [])
g_wf = GeneralizedWaveFunction.create(wf, meta)
params = g_wf.init(jax.random.key(0), system)

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


In [7]:
system2 = system.replace(
    electrons=system.electrons[None],
    nuclei=system.nuclei,
)

In [4]:
g_wf.fix_structure(params, system)

GeneralizedWaveFunction(wave_function=WaveFunction(
    # attributes
    embedding_module = PsiFormer(
        # attributes
        embedding_dim = 256
        dim = 256
        n_head = 4
        n_layer = 4
        activation = tanh
    )
    orbital_module = Pfaffian(
        # attributes
        determinants = 16
        orb_per_nuc = 8
        envelopes = EfficientEnvelope(
            # attributes
            out_dim = 1
            pi_init = 1
            out_per_nuc = True
            env_per_nuc = 8
        )
        hf_match_steps = 10
        hf_match_lr = 0.1
        hf_match_orbitals = 1.0
        hf_match_pfaffian = 0.0001
    )
    jastrow_modules = []
), meta_network=MetaGNN(
    # attributes
    out_structure = {'embedding_module': {'FermiNetFeatures_0': {'kernel': ParamMeta(param_type=<ParamTypes.NUCLEI: ParamType(name='nuclei', chunk_fn=<function chunk_nuclei at 0x7f09c70bafc0>)>, shape_and_dtype=ShapeDtypeStruct(shape=(4, 256), dtype=float32), mean=Array(-0.02240709

In [8]:
g_wf.signed(params, system)

(Array([-1.], dtype=float32), Array([-0.35908234], dtype=float32))

In [10]:
jax.jit(g_wf.signed)(params, system)

(Array([-1.], dtype=float32), Array([-0.35908216], dtype=float32))

In [4]:
g_wf.signed(params, system.replace(electrons=system.electrons[np.array([1, 0, 2, 3])]))

(Array([-1.], dtype=float32), Array([-6.871541], dtype=float32))

In [11]:
class Test(ReparamModule):
    @nn.compact
    def __call__(self, i):
        self.param('x', lambda *_: jnp.zeros(2))
        return self.reparam(
            'test', lambda key: jnp.arange(i), param_type=ParamTypes.NUCLEI
        )[0].sum()

In [12]:
m = Test()
p = m.init(jax.random.PRNGKey(0), 1)

In [13]:
p

{'params': {'x': Array([0., 0.], dtype=float32)},
 'reparam': {'test': Array([0], dtype=int32)},
 'reparam_meta': {'test': ParamMeta(param_type=<ParamTypes.NUCLEI: ParamType(name='nuclei', chunk_fn=<function chunk_nuclei at 0x7fd083833d80>)>, shape_and_dtype=ShapeDtypeStruct(shape=(), dtype=int32), mean=Array(0., dtype=float32), std=Array(0., dtype=float32), bias=True, chunk_axis=None)}}

In [10]:
p['reparam']['test'] = jnp.array([0])
print(m.apply(p, 1))
p['reparam']['test'] = jnp.array([1, 2, 3])
print(m.apply(p, 3))

TypeError: 'NoneType' object is not subscriptable

In [6]:
p['reparam']['test'] = jnp.array([0, 1, 3])
print(m.apply(p, 1))

ValueError: Parameter test has shape (3,) but expected (1,)!