In [72]:
import jax.numpy as np
from jax import random, grad, jacfwd, jacrev, vmap, jit
from jax.ops import index_add
from functools import partial

key = random.PRNGKey(0)
key, subkey = random.split(key)

# Stochastic reconfiguration w/ Hylleraas wavefunction

In [73]:
class Wavefunction():
    def __init__(self, wf, params):
        self.params = params
        self.wf = wf
    
    def eval(self, x):
        return wf(x, params)

In [74]:
@jit
def hirschfelder(x, p):
    r = np.linalg.norm(x, axis=1)
    r1 = r[0]
    r2 = r[1]

    s = r1 + r2
    t = r1 - r2
    u = np.linalg.norm(np.subtract(x[1], x[0]))

    return np.exp(-2*s)*(1 + 0.5*u*np.exp(-p[0]*u))*(1 + p[1]*s*u + p[2]*np.power(t, 2) + p[3]*np.power(u, 2))

@jit
def simple(x, p):
    r = np.linalg.norm(x, axis=1)
    r1 = r[0]
    r2 = r[1]

    return np.exp(-p[0]*(r1 + r2))

def hess(f): 
    return jacfwd(jacrev(f, 0), 0)

def lapl_evalat(f, x):
    """
    Evaluates (\nabla^2 f)(x) by taking the trace of the Hessian matrix of f
    """
    H = hess(f)(x).reshape(x.shape[0]*x.shape[1], x.shape[0]*x.shape[1])
    return np.trace(H)

In [75]:
def get_configs(config_init, n_iter, n_equi, step_size, wf):
    """
    Carries out Metropolis-Hastings sampling according to the distribution |`wf`|**2.0.
    
    Performs `n_equi` equilibriation steps and `n_iter` sampling steps.
    """
    
    global key
    configs = []
    config = config_init

    i = 0
    accepted = 0
    while i < (n_iter + n_equi):
        config_idx = i % config_init.shape[0]

        key, subkey = random.split(key)
        new_config = index_add(config, config_idx, step_size*random.normal(subkey, (3,)))
        # Isotropic gaussian -> T / T' = 1 so only need pi(R') / pi(R)
        
        sample_dens = jit(lambda x: np.power(np.abs(wf(x, p0)), 2))
        
        acceptance = min(1, sample_dens(new_config) / sample_dens(config))

        key, subkey = random.split(key)
        if random.uniform(subkey) < acceptance:
            config = new_config
            accepted += 1

        # Record position
        if i > n_equi:
            configs.append(config)
        i+=1
    return np.array(configs)

In [79]:
@partial(jit, static_argnums=(1,))
def itime_hamiltonian(config, wf, tau=0.1):
    n_electron = config.shape[0]
    curr_wf = wf(config, p0)
    acc = 0
    # Calculate kinetic energy
    acc += -0.5*(1/curr_wf)*lapl_evalat(lambda x: wf(x, p0), config)
    # Calculate electron-electron energy
    for i in range(n_electron):
        for j in range(n_electron):
            if i < j:
                acc += 1 / np.linalg.norm(np.subtract(config[i], config[j]))

    # Calculate electron-nucleus energy, assume z=ne FOR NOW
    for i in range(n_electron):
        acc -= n_electron / np.linalg.norm(config[i])
    # Forget about nucleus - nucleus energy FOR NOW

    return 1-tau*acc

@partial(jit, static_argnums=(1,))
def sr_op(config, wf):
    gradlog = np.concatenate((np.array([1]), np.array(jit(grad(lambda x, p: np.log(wf(x, p)), 1))(config, p0))))
    ih = itime_hamiltonian(config, wf)
    
    return np.multiply(gradlog, ih)

@partial(jit, static_argnums=(1,))
def overlap_matrix(config, wf):
    """
    Find the overlap matrix on the space of the parametric derivatives of `wf`
    """
    
    gradlog = np.concatenate((np.array([1]), np.array(jit(grad(lambda x, p: np.log(wf(x, p)), 1))(config, p0))))
    overlap_ij = lambda i, j: gradlog[i]*gradlog[j]
    
    grid = np.indices((gradlog.shape[0], gradlog.shape[0]))
    
    return vmap(overlap_ij)(grid[0], grid[1])

In [85]:
# TODO: create wavefunction dataclass that holds current parameterization so i dont have to wrangle passing it around all the time

def monte_carlo(op, wf, configs):
    """
    Performs a Monte Carlo integration using the `configs` walker positions
    of the expectation value of `op` for the wavefunction `wf`.
    
    Returns the expectation value, variance and a list of the sampled values {O_i}
    """

    samp_rate = 20
    walker_values = vmap(jit(lambda config: op(config, wf)))(configs)
    blocks = np.array(np.split(
        walker_values[:samp_rate*(walker_values.shape[0]//samp_rate)],
        samp_rate,
        axis=0
        )
    )
    k = blocks.shape[0]
    block_means = np.mean(blocks, axis=0)
    op_expec = np.mean(block_means, axis=0)
    op_var = 1/(k*(k-1))*np.sum(np.power(block_means - op_expec, 2), axis=0)
    return op_expec, op_var, walker_values

def local_energy(config, wf):
    """
    Local energy operator. Uses JAX autograd to obtain laplacian for KE.
    """

    n_electron = config.shape[0]
    acc = 0
    # Calculate kinetic energy
    acc += -0.5*(1/wf(config, p0))*lapl_evalat(lambda x: wf(x, p0), config)
    # Calculate electron-electron energy
    for i in range(n_electron):
        for j in range(n_electron):
            if i < j:
                acc += 1 / np.linalg.norm(np.subtract(config[i], config[j]))

    # Calculate electron-nucleus energy, assume z=ne FOR NOW
    for i in range(n_electron):
        acc -= n_electron / np.linalg.norm(config[i])

    return acc

In [87]:
from time import time
t = time()
p0 = [1.0]
xi = np.array([[2.0, 1.0, 1.1], [1.0, 1.0, 2.0]])
n_equi = 10
n_iter = 1000
step = 0.5

configs = get_configs(xi, n_iter, n_equi, step, simple)
overlap_E, overlap_V, _ = monte_carlo(overlap_matrix, simple, configs)
coeff_E, coeff_V, _ = monte_carlo(sr_op, simple, configs)
print(overlap_E)
print(coeff_E)
print(time() - t)

[[ 1.       -3.242458]
 [-3.242458 12.447595]]
[ 1.2435613 -3.9221253]
62.737221002578735


In [None]:
from scipy.linalg import solve

dp = solve(overlap_E, coeff_E)
p0 += dp[1:] / dp[0]

In [None]:
np.sqrt(coeff_V)

In [None]:
dp[1:] / dp[0] / p0

# Multi layer perceptron wavefunction

In [14]:
# Code reproduced with modifications from
# https://github.com/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
      w_key, b_key = random.split(key)
      return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [6, 32, 32, 32, 1]
params = init_network_params(layer_sizes, random.PRNGKey(0))

def tanh(x):
    return np.tanh(x)

def predict(x, p):
    # per-example predictions
    activations = x.flatten()
    params = p
    for w, b in params[:-1]:
        outputs = np.dot(w, activations) + b
        activations = tanh(outputs)
    
    final_w, final_b = params[-1]
    return (np.dot(final_w, activations) + final_b)[0]

In [15]:
p0 = params

In [16]:
%%timeit
predict(xi, p0)

12.5 ms ± 1.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [44]:
sample_dens(xi, predict)

DeviceArray([0.00051011], dtype=float32)

In [49]:
e_expec, e_var, _ = monte_carlo(local_energy, predict, configs)

In [59]:
grad(predict)(xi, p0)

DeviceArray([[-1.1215644e-06,  9.2081854e-07,  1.2777372e-06],
             [-2.7824185e-06,  8.5019434e-07,  2.5524919e-06]],            dtype=float32)

In [61]:
predict

ValueError: All input arrays must have the same shape.

ValueError: All input arrays must have the same shape.