# Neural Network Variational Monte Carlo

1 Dimensional, N bosons or N_up and N_down fermions. 

In [43]:
import os
import multiprocessing
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"
os.environ["JAX_ENABLE_X64"]="false"

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
from jax import numpy as jnp
from jax import random, jit, jacfwd, grad, vmap
from jax.tree_util import tree_map
from flax import linen as nn
from typing import Any
import numpy as np
import time
import tqdm 
from tqdm import trange
import gvar as gv
from functools import partial
import csv
import jax.example_libraries.optimizers as jax_opt
import plotly.graph_objects as go
import mc

In [None]:
bosonic = True
if bosonic:
    N = 3
else:
    N_up = 2
    N_down = 1

g = .1
omega = 1
harmonic_omega = 1

In [68]:

class MLP(nn.Module):
    features: list[int]  # e.g. [64, 64, 1]

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.Dense(feat)(x)
            x = nn.celu(x)
        x = nn.Dense(self.features[-1])(x)
        return x


@jit
def transform(coords):
    # if running into NaNs, try to increase this
    ret = jnp.zeros(N)
    for i in range(N):
        ret = ret.at[i].set(jnp.sum(jnp.power(coords/C, i + 1)))
    return ret  

def create_model(rng_key, N, features):
    model = MLP(features=features)
    dummy_x = jnp.ones((N,))          # N inputs of length d → shape (N,d)
    params = model.init(rng_key, dummy_x)["params"]
    return model, params


In [69]:
model, params = create_model(random.PRNGKey(int(time.time())), N, [50, 50, 100, 50, 50,1])
print(model)
# print(params)

MLP(
    # attributes
    features = [50, 50, 100, 50, 50, 1]
)


In [70]:

@jit
def A(x, params):
    """Neural network output, A(x) given a set of parameters"""
    return model.apply({'params': params}, transform(x)).squeeze() + omega * jnp.sum(x**2)


@jit
def psi(x, params):
    """Wavefunction, psi(x) given a set of parameters"""
    return jnp.exp(-A(x, params))


print(psi(jnp.ones((N,)), params)) 

0.051811736


In [71]:
mcs = mc.Sampler(psi, (N, 1))

In [72]:

s,sp, ar = mcs.sample_parallel(params, 100000, 1000, 2, .7, jnp.zeros((N,)), num_chains=16)
print(ar)

0.5059033


In [73]:
# s,sp, ar = mcs.sample(params, 100000, 1000, 2, .7, jnp.zeros((N,)))
# print(s.shape)
# print(ar)

In [74]:
# derivative of the wavefunction with respect to the parameters
dnn_dtheta = jit(grad(psi, 1))
vdnn_dtheta = jit(vmap(dnn_dtheta, in_axes=(0, None), out_axes=0))


psi_hessian = jacfwd(jit(grad(psi, 0)), 0)
@jit
def ddpsi(coords, params):
    return jnp.diag(psi_hessian(coords, params))






In [75]:
if bosonic: 
        num_interactions = N * (N - 1) // 2
else:
    num_interactions = N_up * N_down

dA_dtheta = jit(grad(A, 1))
vdA_dtheta = vmap(dA_dtheta, in_axes=(0, None), out_axes=0)

dA_dx = jit(grad(A, 0))
vdA_dx = vmap(dA_dx, in_axes=(0, None), out_axes=0)

A_hessian = jacfwd(jit(grad(A, 0)), 0)
@jit
def d2A_dx2(coords, params):
    return jnp.diag(A_hessian(coords, params))


@jit
def Es_nodelta(coords, params):
    return - (1/2) * (1/ psi(coords, params)) * jnp.sum(ddpsi(coords, params)) + (1/2) * jnp.sum(coords**2) 
    # return (1/2) * jnp.sum(d2A_dx2(coords, params) - (dA_dx(coords, params))**2) + (1/2) * jnp.sum(coords**2) 

vEs_nodelta = vmap(Es_nodelta, in_axes=(0,None), out_axes=0)

@jit
def Es_delta(coords, coords_prime, params, alpha, g):
    return num_interactions * g * (psi(coords_prime, params)**2)/(psi(coords, params)**2) * (1/(jnp.sqrt(jnp.pi)*alpha))*jnp.exp(-(coords[-1]/alpha)**2)

vEs_delta = vmap(Es_delta, in_axes=(0,0, None, None, None), out_axes=0)

def pytree_scalar_mult(scalar, pytree):
    return tree_map(lambda x: scalar * x, pytree)

def pytree_add(a, b):
    return tree_map(lambda x, y: x + y, a, b)

@jit
def tree_mean(pytree_batched):
    return tree_map(lambda x: jnp.mean(x, axis=0), pytree_batched)


@jit
def gradient_comp(coords, coords_prime, params, es_nodelta, energy_calc, es_delta):
    return pytree_add(pytree_scalar_mult(-2 * (es_nodelta - energy_calc), dA_dtheta(coords, params)), pytree_scalar_mult(-2 * es_delta, dA_dtheta(coords_prime, params)))

vgradient_comp = vmap(gradient_comp, in_axes=(0, 0, None, 0, None, 0), out_axes=0)

def gradient(params, g, num_samples=10**3, thermal=200, skip=5, variation_size=1.0):

    samples, samples_prime, _ = mcs.sample(params, num_samples, thermal, skip, variation_size, jnp.zeros((N,)))
    # samples, samples_prime, _ = mcs.sample_parallel(params, num_samples, thermal, skip, variation_size, jnp.zeros((N,)), num_chains=16)

    ys = jnp.array(samples_prime[:, -1]) 
    alpha = jnp.sqrt(jnp.max(abs(jnp.array(ys)))**2/(-jnp.log(jnp.sqrt(jnp.pi)*(10**-10))))

    e_nodeltas = vEs_nodelta(samples, params)
    e_deltas = vEs_delta(samples, samples_prime, params, alpha, g)

    e_term = e_nodeltas + e_deltas
    
    # print(e_nodeltas)
    energy_calc = jnp.mean(e_term)
    
    # compute the uncertainty in the energy
    uncert = jnp.std(e_term)/jnp.sqrt(len(samples)) 
    # gradient computation
    grads = vgradient_comp(samples, samples_prime, params, e_nodeltas, energy_calc, e_deltas)

    # gradient_calc = jnp.mean(grads, axis=0)

    gradient_calc = tree_mean(grads)

    return gradient_calc, energy_calc, uncert




def step(params, opt_state, step_num, num_samples, thermal, skip, variation_size, g):
    """
    One optimization step.
    - params: current parameters pytree
    - opt_state: optimizer state (must be carried across steps)
    Returns: (new_params, new_opt_state, energy, uncert)
    """
    grad, energy, uncert = gradient(
        params,
        g,
        num_samples=num_samples,
        thermal=thermal,
        skip=skip,
        variation_size=variation_size,
    )

    new_opt_state = opt_update(step_num, grad, opt_state)
    new_params = get_params(new_opt_state)

    return new_params, new_opt_state, energy, uncert


def train(params, iterations, num_samples, thermal, skip, variation_size, g):
    """
    Training loop.
    Returns: (hs, us, ns, final_params)
    """
    hs, us = [], []
    ns = np.arange(iterations)

    # Initialize optimizer state ONCE
    opt_state = opt_init(params)

    pbar = trange(iterations, desc="", leave=True)
    old_params = params

    for step_num in pbar:
        new_params, opt_state, energy, uncert = step(
            old_params,
            opt_state,
            step_num,
            num_samples,
            thermal,
            skip,
            variation_size,
            g,
        )

        hs.append(energy)
        us.append(uncert)
        old_params = new_params

        pbar.set_description(f"Energy = {energy}", refresh=True)

        # Use jnp.isnan if energy is a JAX scalar; np.isnan is OK if it's a Python float
        if np.isnan(np.asarray(energy)):
            print("NaN encountered, stopping...")
            break

    return hs, us, ns, old_params
    

In [76]:
opt_init, opt_update, get_params = jax_opt.adam(10 ** (-3))

In [77]:
resultsa = train(params, 200, 5000, 100, 2, 0.7, g)

Energy = 1.5116232633590698: 100%|██████████| 200/200 [00:50<00:00,  3.95it/s]


In [78]:
# using plotly, plot the energy vs iteration number with error bars
def astra_energy():
    return (N * harmonic_omega)/2 - m * g**2  * (N*(N**2 - 1))/(24)

true_energy = astra_energy()

# true_energy = .5 * N

print("True energy: ", true_energy)

fig = go.Figure()
fig.add_trace(go.Scatter(
    x=np.arange(len(resultsa[0])),
    y=resultsa[0],
    error_y=dict(
        type='data',
        array=resultsa[1],
        visible=True
    ),
    mode='lines+markers',
    name='Energy vs Iteration'
))

# add a horizontal line for the true energy
fig.add_trace(go.Scatter(
    x=[0, len(resultsa[0])],
    y=[true_energy, true_energy],
    mode='lines',
    name='True Energy',
    line=dict(dash='dash', color='red')
))

fig.update_layout(
    title='VMC Energy Convergence',
    xaxis_title='Iteration Number',
    yaxis_title='Energy',
    template='plotly_dark'
)
fig.show()

True energy:  1.4999


In [None]:
# test = jnp.array([1.0, 2.0, 3.0])
# test2 = jnp.array([2.0, 1.0, 3.0])
# assert(psi(test, params) == psi(test2, params))

In [None]:
params2 = params.copy()
test = pytree_add(params, params2)
# test = pytree_scalar_mult(2.0, params)



print(params)
print(test)