In [5]:
import sys
sys.path.insert(0, "..")

import jax
import jax.numpy as jnp
import jax.random
import flax.linen as nn
import flax
from jax import grad, jit, vmap
from matplotlib import pyplot as plt


import module.samplers.mcmc
import module.hamiltonians
import module.wavefunctions

In [6]:
key = jax.random.PRNGKey(0)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [9]:
class SimpleOrbital(module.wavefunctions.Wavefunction):
    pass

In [11]:
class SimpleOrbital(nn.Module):

    @nn.compact
    def __call__(self, x):
        
        # A tensor of variational parameters is defined by calling
        # the method `self.param` where the arguments will be:
        # - arbitrary name used to refer to this set of parameters
        # - an initializer used to provide the initial values. 
        # - The shape of the tensor
        # - The dtype of the tensor.
        a = self.param(
            "a", nn.initializers.normal(), (), float   # scalar
        )
        
        # compute the log amplitude
        # here we use exp[-r/a]
        logpsi = -jnp.sqrt(jnp.sum(x**2, axis = -1)) / jnp.abs(a)

        # sum the output
        return jnp.sum(x**2, axis = -1)

In [27]:
def potential(x):
    return 0.5*jnp.sum(x**2, axis = -1)

In [28]:
model = SimpleOrbital()
parameters = model.init(key, jnp.empty(3,))

In [29]:
model.apply(parameters, jnp.zeros(shape = (3,)))

Array(0., dtype=float32)

In [30]:
x = jnp.zeros((5,3))

In [31]:
H = hamiltonians.Particles.Particles(model, masses = [1], potential = potential)