<a href="https://colab.research.google.com/github/gcassella/NN-VMC/blob/main/helium_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as np
import jax
from jax import random, grad, jacfwd, jacrev, vmap, jit, pmap
from jax.ops import index_add, index_update
from functools import partial

# Make sure the Colab Runtime is set to Accelerator: TPU.
import requests
import os
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1

# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

key = random.PRNGKey(0)
key, subkey = random.split(key)

grpc://10.93.56.226:8470


# Stochastic reconfiguration w/ Hylleraas wavefunction

In [2]:
class Wavefunction():
  def __init__(self, f, p0):
    self.f = f
    self.p = p0

    self.hess = jacfwd(jacrev(lambda x: self.f(x, self.p), 0), 0)
    self.p_grad = grad(self.f, 1)
    self.p_gradlog = grad(lambda x, p: np.log(self.f(x, p)), 1)

    # Cache evaluations to speed up?
    self.p_gradlog_eval = jit(lambda x: self.p_gradlog(x, self.p))
    self.p_grad_eval = jit(lambda x: self.p_grad(x, self.p))
    self.lapl_eval = jit(lambda x: np.trace(self.hess(x).reshape(x.shape[0]*x.shape[1], x.shape[0]*x.shape[1])))
    self.eval = jit(lambda x: self.f(x, self.p))
    self.pdf_eval = jit(lambda x: np.power(np.abs(self.eval(x)), 2))


In [3]:
@jit
def hirschfelder_f(x, p):
    r = np.linalg.norm(x, axis=1)
    r1 = r[0]
    r2 = r[1]

    s = r1 + r2
    t = r1 - r2
    u = np.linalg.norm(np.subtract(x[1], x[0]))

    return np.exp(-2*s)*(1 + 0.5*u*np.exp(-p[0]*u))*(1 + p[1]*s*u + p[2]*np.power(t, 2) + p[3]*np.power(u, 2))

hirschfelder = Wavefunction(hirschfelder_f, np.array([1.0, 0.5, 0.5, -0.1]))

@jit
def simple_f(x, p):
    r = np.linalg.norm(x, axis=1)
    r1 = r[0]
    r2 = r[1]

    return np.exp(-p[0]*(r1 + r2))

simple = Wavefunction(simple_f, np.array([2.0]))

In [4]:
@partial(jit, static_argnums=(1,))
def config_step(key, wf, config, config_prob, config_idx, step_size):
    key, subkey = random.split(key)
    move_proposal = random.normal(key, shape=(config.shape[1],))*step_size
    proposal = index_add(config, config_idx%config.shape[0], move_proposal)
    proposal_prob = wf.pdf_eval(proposal)

    uniform = random.uniform(subkey)
    accept = uniform < (proposal_prob / config_prob)

    new_config = np.where(accept, proposal, config)
    config_prob = np.where(accept, proposal_prob, config_prob)
    return new_config, config_prob, config_idx+1

@partial(jit, static_argnums=(1, 2, 3, 4))
def get_configs(key, wf, n_iter, n_equi, step_size, initial_config):
    """
    Carries out Metropolis-Hastings sampling according to the distribution |`wf`|**2.0.
    
    Performs `n_equi` equilibriation steps and `n_iter` sampling steps.
    """
    
    def mh_update(i, state):
      key, config, prob, idx = state
      _, key = random.split(key)
      new_config, new_prob, new_idx = config_step(
          key,
          wf,
          config,
          prob,
          idx,
          step_size
      )
      return (key, new_config, new_prob, new_idx)

    def mh_update_and_store(i, state):
      key, config, prob, idx, configs = state
      _, key = random.split(key)
      new_config, new_prob, new_idx = config_step(
          key,
          wf,
          config,
          prob,
          idx,
          step_size
      )
      new_configs = index_update(configs, idx, new_config)
      return (key, new_config, new_prob, new_idx, new_configs)

    prob = wf.pdf_eval(initial_config)
    key, config, prob, idx = jax.lax.fori_loop(0, n_equi, mh_update, (key, initial_config, prob, 0))
    init_configs = np.zeros((n_iter, *initial_config.shape))
    key, config, prob, idx, configs = jax.lax.fori_loop(0, n_iter, mh_update_and_store, (key, config, prob, 0, init_configs))

    return configs

In [5]:
@partial(jit, static_argnums=(1,))
def itime_hamiltonian(config, wf, tau=0.01):
    n_electron = config.shape[0]
    curr_wf = wf.eval(config)
    acc = 0
    # Calculate kinetic energy
    acc += -0.5*(1/curr_wf)*wf.lapl_eval(config)
    # Calculate electron-electron energy
    for i in range(n_electron):
        for j in range(n_electron):
            if i < j:
                acc += 1 / np.linalg.norm(np.subtract(config[i], config[j]))

    # Calculate electron-nucleus energy, assume z=ne FOR NOW
    for i in range(n_electron):
        acc -= n_electron / np.linalg.norm(config[i])
    # Forget about nucleus - nucleus energy FOR NOW

    return 1-tau*acc

@partial(jit, static_argnums=(1,))
def sr_op(config, wf):
    gradlog = np.concatenate((np.array([1]), np.array(wf.p_gradlog_eval(config))))
    ih = itime_hamiltonian(config, wf)
    
    return np.multiply(gradlog, ih)

@partial(jit, static_argnums=(1,))
def overlap_matrix(config, wf):
    """
    Find the overlap matrix on the space of the parametric derivatives of `wf`
    """
    
    gradlog = np.concatenate((np.array([1]), np.array(wf.p_gradlog_eval(config))))
    overlap_ij = vmap(lambda idx: gradlog[idx[0]]*gradlog[idx[1]])
    
    grid_pairs = np.array([(i,j) for i in range(gradlog.shape[0]) for j in range(gradlog.shape[0])])
    
    return overlap_ij(grid_pairs).reshape(gradlog.shape[0], gradlog.shape[0])

@partial(jit, static_argnums=(1,))
def local_energy(config, wf):
    """
    Local energy operator. Uses JAX autograd to obtain laplacian for KE.
    """

    n_electron = config.shape[0]
    acc = 0
    # Calculate kinetic energy
    acc += -0.5*(1/wf.eval(config))*wf.lapl_eval(config)
    # Calculate electron-electron energy
    for i in range(n_electron):
        for j in range(n_electron):
            if i < j:
                acc += 1 / np.linalg.norm(np.subtract(config[i], config[j]))

    # Calculate electron-nucleus energy, assume z=ne FOR NOW
    for i in range(n_electron):
        acc -= n_electron / np.linalg.norm(config[i])

    return acc

In [6]:
@partial(jit, static_argnums=(1,2,))
def monte_carlo(configs, op, wf):
    """
    Performs a Monte Carlo integration using the `configs` walker positions
    of the expectation value of `op` for the wavefunction `wf`.
    
    Returns the expectation value, variance and a list of the sampled values {O_i}
    """

    samp_rate = 100
    walker_values = vmap(lambda config: op(config, wf))(configs)
    op_output_shape = walker_values[0].shape
    num_blocks = (walker_values.shape[0]//samp_rate)
    blocks = walker_values[:samp_rate*(num_blocks)].reshape((num_blocks, samp_rate, *op_output_shape))
    k = blocks.shape[0]
    block_means = np.mean(blocks, axis=1)
    op_expec = np.mean(block_means, axis=0)
    op_var = 1/(k*(k-1))*np.sum(np.power(block_means - op_expec, 2), axis=0)
    return op_expec, op_var

In [7]:
run_mcmc = vmap(get_configs, in_axes=(0, None, None, None, None, 0), out_axes=0)
run_int = vmap(monte_carlo, in_axes=(0, None, None), out_axes=0)

def reduce_mc_outs(outs):
  mean = np.mean(outs[0], axis=0)
  variance = np.mean(outs[1] + np.power(outs[0] - mean, 2), axis=0)
  return mean, variance

In [8]:
n_equi = 1000
n_iter = 10000
n_chains = 500
xis = random.uniform(key, (n_chains, 2, 3))
keys = random.split(key, n_chains)
configs = run_mcmc(keys, simple, n_iter, n_equi, 0.5, xis)
E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, simple))
overlap_E, overlap_V = reduce_mc_outs(run_int(configs, overlap_matrix, simple))
#sr_E, sr_V = reduce_mc_outs(run_int(configs, sr_op, simple))

In [9]:
E_E

DeviceArray(-2.7496672, dtype=float32)

## Simple WF

In [None]:
key = random.PRNGKey(0)
n_equi = 10000
n_iter = 100000
n_chains = 500
xis = random.uniform(key, (n_chains, 2, 3))
keys = random.split(key, n_chains)
simple = Wavefunction(simple_f, np.array([2.0]))
vals = [np.array(2.0)]

for i in range(40):
  configs = run_mcmc(keys, simple, n_iter, n_equi, 0.5, xis)
  E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, simple))
  overlap_E, overlap_V = reduce_mc_outs(run_int(configs, overlap_matrix, simple))
  sr_E, sr_V = reduce_mc_outs(run_int(configs, sr_op, simple))

  dps = np.linalg.solve(overlap_E, sr_E)
  p0 = np.add(simple.p, dps[1:] / dps[0])
  # VERY IMPORTANT NOTE: JAX will not re-jit the operators if Wavefunction.p
  # is updated internally by, e.g., a getter or setter. I don't know how to solve
  # this problem currently aside from simply reinstantiating Wavefunction each
  # time Wavefunction.p needs to be changed
  #
  # Perhaps this isn't such an issue if one sticks with a purely functional style
  # and uses classes like immutable structs?
  simple = Wavefunction(simple_f, p0)
  vals.append(p0)
  print(p0)

[1.9276963]
[1.8775854]
[1.840691]
[1.8123969]
[1.7903057]
[1.7726845]
[1.7584919]
[1.7468067]
[1.7373891]
[1.7295324]
[1.7229042]
[1.7173197]
[1.712626]
[1.7087704]
[1.7055061]
[1.702796]
[1.700402]
[1.6984478]
[1.696852]
[1.6954029]
[1.6942947]
[1.6933272]
[1.6924536]
[1.6917834]
[1.6911279]
[1.690638]
[1.6901832]
[1.6898696]
[1.6895639]
[1.689298]
[1.6891017]
[1.6889309]
[1.6887196]
[1.6885333]
[1.6884184]
[1.6882296]
[1.6881112]
[1.6879984]
[1.6878452]
[1.6877065]


In [None]:
configs = run_mcmc(keys, simple, n_iter, n_equi, 0.5, xis)
E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, simple))
print("Ground state energy {} pm {} after 20 iterations with parameter {}".format(E_E, np.sqrt(E_V), p0))

Ground state energy -2.8464293479919434 pm 0.0006964870844967663 after 20 iterations with parameter [1.6877065]


For reference the true minimum of $\langle E\rangle \simeq -2.85 \text{a. u.}$ of the simple wf ansatz occurs at $\alpha$=1.6875

## Hirschfelder wavefunction

In [None]:
n_equi = 10000
n_iter = 100000
n_chains = 100
xis = random.uniform(key, (n_chains, 2, 3))
keys = random.split(key, n_chains)
hirschfelder = Wavefunction(hirschfelder_f, np.array([1.0, 0.5, 0.5, -0.1]))
vals = [p0]

for i in range(40):
  configs = run_mcmc(keys, hirschfelder, n_iter, n_equi, 0.5, xis)
  E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, hirschfelder))
  overlap_E, overlap_V = reduce_mc_outs(run_int(configs, overlap_matrix, hirschfelder))
  sr_E, sr_V = reduce_mc_outs(run_int(configs, sr_op, hirschfelder))

  dps = np.linalg.solve(overlap_E, sr_E)
  p0 = np.add(hirschfelder.p, dps[1:] / dps[0])
  hirschfelder = Wavefunction(hirschfelder_f, p0)
  vals.append(p0)
  print(p0)

[ 1.4236624   0.5324303   0.44780415 -0.1185777 ]
[ 1.8274268   0.52555925  0.38857707 -0.14034285]
[ 2.1532216   0.5014799   0.34327382 -0.14871532]
[ 2.1948583   0.47658455  0.31286922 -0.14467369]
[ 2.0115626   0.4537761   0.29153362 -0.13413884]
[ 1.8770361   0.43206456  0.27365276 -0.12360602]
[ 1.7592784   0.41211852  0.2586345  -0.11351466]
[ 1.6649686   0.39407298  0.24570781 -0.10441782]
[ 1.577496    0.37777448  0.23474768 -0.09616143]
[ 1.5076677   0.3633409   0.22492754 -0.0890992 ]
[ 1.4467105   0.35067847  0.21633679 -0.08313622]
[ 1.3945475   0.33940798  0.2087992  -0.0780477 ]
[ 1.3490429   0.32962883  0.20208855 -0.07381274]
[ 1.3063763   0.3208327   0.19627199 -0.07010238]
[ 1.2701937   0.31305918  0.1910933  -0.06695889]
[ 1.2391235   0.3063077   0.18647681 -0.06442851]
[ 1.2104979   0.30031046  0.18252607 -0.06231608]
[ 1.187672    0.29506224  0.17885162 -0.06052488]
[ 1.1666455   0.29016247  0.17578799 -0.05881443]
[ 1.1456738   0.28589332  0.17293528 -0.0574354 ]


In [None]:
configs = run_mcmc(keys, hirschfelder, n_iter, n_equi, 0.5, xis)
E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, hirschfelder))
print("Ground state energy {} pm {} after 40 iterations with parameter {}".format(E_E, np.sqrt(E_V), p0))

Ground state energy -2.901717185974121 pm 0.00027340053929947317 after 40 iterations with parameter [ 0.9998273   0.2546163   0.1520364  -0.04904639]


In [None]:
print("This {}mHa from the true ground state energy".format(
    np.abs(-2.903 - E_E)*1e3
))

This 1.2829303741455078mHa from the true ground state energy
