In [1]:
import jax.numpy as np
from jax import random, grad, jacfwd, jacrev, vmap
from jax.ops import index_add

key = random.PRNGKey(0)
key, subkey = random.split(key)



In [2]:
def hirschfelder(x, p):
    r = np.linalg.norm(x, axis=1)
    r1 = r[0]
    r2 = r[1]

    s = r1 + r2
    t = r1 - r2
    u = np.linalg.norm(np.subtract(x[1], x[0]))

    return np.exp(-2*s)*(1 + 0.5*u*np.exp(-p[0]*u))*(1 + p[1]*s*u + p[2]*np.power(t, 2) + p[3]*np.power(u, 2))

def hess(f): 
    return jacfwd(jacrev(f, 0), 0)

def lapl_evalat(f, x):
    """
    Evaluates (\nabla^2 f)(x) by taking the trace of the Hessian matrix of f
    """
    H = hess(f)(x).reshape(x.shape[0]*x.shape[1], x.shape[0]*x.shape[1])
    return np.trace(H)

In [144]:
xi = np.array([[2.0, 1.0, 1.1], [1.0, 1.0, 2.0]])
p0 = [1.013, 0.2119, 0.1406, -0.003]
n_equi = 100
n_iter = 1000
step = 0.5

def sample_dens(config, wf): 
    return np.power(np.abs(wf(config, p0)), 2) # unnormalized pi(R)

def get_configs(config_init, n_iter, n_equi, step_size, wf):
    """
    Carries out Metropolis-Hastings sampling according to the distribution |`wf`|**2.0.
    
    Performs `n_equi` equilibriation steps and `n_iter` sampling steps.
    """
    
    global key
    configs = []
    config = config_init

    i = 0
    accepted = 0
    while i < (n_iter + n_equi):
        config_idx = i % config_init.shape[0]

        key, subkey = random.split(key)
        new_config = index_add(config, config_idx, step_size*random.normal(subkey, (3,)))
        # Isotropic gaussian -> T / T' = 1 so only need pi(R') / pi(R)
        acceptance = np.array([1, sample_dens(new_config, wf) / sample_dens(config, wf)]).min()

        key, subkey = random.split(key)
        if random.uniform(subkey) < acceptance:
            config = new_config
            accepted += 1

        # Record position
        if i > n_equi:
            configs.append(config)
        i+=1
    return np.array(configs)

configs = get_configs(xi, n_iter, n_equi, step, hirschfelder)

In [155]:
def itime_hamiltonian(config, wf, tau=0.01):
    n_electron = config.shape[0]
    curr_wf = wf(config, p0)
    acc = 0
    # Calculate kinetic energy
    acc += -0.5*(1/curr_wf)*lapl_evalat(lambda x: wf(x, p0), config)
    # Calculate electron-electron energy
    for i in range(n_electron):
        for j in range(n_electron):
            if i < j:
                acc += 1 / np.linalg.norm(np.subtract(config[i], config[j]))

    # Calculate electron-nucleus energy, assume z=ne FOR NOW
    for i in range(n_electron):
        acc -= n_electron / np.linalg.norm(config[i])
    # Forget about nucleus - nucleus energy FOR NOW

    return 1-tau*acc

def sr_op(config, wf):
    gradlog = np.concatenate((np.array([1]), np.array(grad(lambda x, p: np.log(wf(x, p)), 1)(config, p0))))
    ih = itime_hamiltonian(config, wf)
    
    return np.multiply(gradlog, ih)

def overlap_matrix(config, wf):
    """
    Find the overlap matrix on the space of the parametric derivatives of `wf`
    """
    
    gradlog = np.concatenate((np.array([1]), np.array(grad(lambda x, p: np.log(wf(x, p)), 1)(config, p0))))
    overlap_ij = lambda i, j: gradlog[i]*gradlog[j]
    
    grid = np.indices((gradlog.shape[0], gradlog.shape[0]))
    
    return vmap(overlap_ij)(grid[0], grid[1])

In [156]:
# TODO: create wavefunction dataclass that holds current parameterization so i dont have to wrangle passing it around all the time

def monte_carlo(op, wf, configs):
    """
    Performs a Monte Carlo integration using the `configs` walker positions
    of the expectation value of `op` for the wavefunction `wf`.
    
    Returns the expectation value, variance and a list of the sampled values {O_i}
    """

    samp_rate = 20
    walker_values = vmap(lambda config: op(config, wf))(configs)
    blocks = np.array(np.split(
        walker_values[:samp_rate*(walker_values.shape[0]//samp_rate)],
        samp_rate,
        axis=0
        )
    )
    k = blocks.shape[0]
    block_means = np.mean(blocks, axis=0)
    op_expec = np.mean(block_means, axis=0)
    op_var = 1/(k*(k-1))*np.sum(np.power(block_means - op_expec, 2), axis=0)
    return op_expec, op_var, walker_values

def local_energy(config, wf):
    """
    Local energy operator. Uses JAX autograd to obtain laplacian for KE.
    """

    n_electron = config.shape[0]
    acc = 0
    # Calculate kinetic energy
    acc += -0.5*(1/wf(config, p0))*lapl_evalat(lambda x: wf(x, p0), config)
    # Calculate electron-electron energy
    for i in range(n_electron):
        for j in range(n_electron):
            if i < j:
                acc += 1 / np.linalg.norm(np.subtract(config[i], config[j]))

    # Calculate electron-nucleus energy, assume z=ne FOR NOW
    for i in range(n_electron):
        acc -= n_electron / np.linalg.norm(config[i])
    # Forget about nucleus - nucleus energy FOR NOW

    return acc

op_expec, op_var, walker_values = monte_carlo(overlap_matrix, hirschfelder, configs)
print(op_expec)
print(np.sqrt(op_var))

[[ 1.         -0.1741187   1.4552355   0.2558543   1.1524497 ]
 [-0.1741187   0.03338361 -0.28719205 -0.0499455  -0.23192196]
 [ 1.4552355  -0.28719205  2.6143005   0.4110514   2.1107445 ]
 [ 0.2558543  -0.0499455   0.4110514   0.1529913   0.33090666]
 [ 1.1524497  -0.23192196  2.1107445   0.33090666  1.7651697 ]]
[[0.         0.0036663  0.03745607 0.01057168 0.04353414]
 [0.0036663  0.00118213 0.01061693 0.00233131 0.01133536]
 [0.03745607 0.01061693 0.10590241 0.0193458  0.10980862]
 [0.01057168 0.00233131 0.0193458  0.01591917 0.01699207]
 [0.04353414 0.01133536 0.10980862 0.01699207 0.11294261]]


In [157]:
overlap_E, overlap_V, _ = monte_carlo(overlap_matrix, hirschfelder, configs)
coeff_E, coeff_V, _ = monte_carlo(sr_op, hirschfelder, configs)
print(overlap_E)
print(coeff_E)

[[ 1.         -0.1741187   1.4552355   0.2558543   1.1524497 ]
 [-0.1741187   0.03338361 -0.28719205 -0.0499455  -0.23192196]
 [ 1.4552355  -0.28719205  2.6143005   0.4110514   2.1107445 ]
 [ 0.2558543  -0.0499455   0.4110514   0.1529913   0.33090666]
 [ 1.1524497  -0.23192196  2.1107445   0.33090666  1.7651697 ]]
[ 1.0288726  -0.17916317  1.4973911   0.26326704  1.1857547 ]


In [158]:
from scipy.linalg import solve

solve(overlap_E, coeff_E)

array([ 1.0270861e+00, -1.5735919e-02,  9.9208474e-04, -2.8693629e-04,
       -2.0163895e-03], dtype=float32)

In [159]:
coeff_E

DeviceArray([ 1.0288726 , -0.17916317,  1.4973911 ,  0.26326704,
              1.1857547 ], dtype=float32)