<a href="https://colab.research.google.com/github/gcassella/NN-VMC/blob/main/helium_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax.numpy as np
import jax
from jax import random, grad, jacfwd, jacrev, vmap, jit, pmap
from jax.ops import index_add, index_update
from functools import partial

key = random.PRNGKey(0)
key, subkey = random.split(key)

# Stochastic reconfiguration w/ Hylleraas wavefunction

In [2]:
class Wavefunction():
  def __init__(self, f, p0):
    self.f = f
    self.p = p0

    self.hess = jacfwd(jacrev(lambda x: self.f(x, self.p), 0), 0)
    self.p_grad = grad(self.f, 1)
    self.p_gradlog = grad(lambda x, p: np.log(self.f(x, p)), 1)

    # Cache evaluations to speed up?
    self.p_gradlog_eval = jit(lambda x: self.p_gradlog(x, self.p))
    self.p_grad_eval = jit(lambda x: self.p_grad(x, self.p))
    self.lapl_eval = jit(lambda x: np.trace(self.hess(x).reshape(x.shape[0]*x.shape[1], x.shape[0]*x.shape[1])))
    self.eval = jit(lambda x: self.f(x, self.p))
    self.pdf_eval = jit(lambda x: np.power(np.abs(self.eval(x)), 2))


In [3]:
@jit
def hirschfelder_f(x, p):
    r = np.linalg.norm(x, axis=1)
    r1 = r[0]
    r2 = r[1]

    s = r1 + r2
    t = r1 - r2
    u = np.linalg.norm(np.subtract(x[1], x[0]))

    return np.exp(-2*s)*(1 + 0.5*u*np.exp(-p[0]*u))*(1 + p[1]*s*u + p[2]*np.power(t, 2) + p[3]*np.power(u, 2))

hirschfelder = Wavefunction(hirschfelder_f, np.array([1.0, 0.5, 0.5, -0.1]))

@jit
def simple_f(x, p):
    r = np.linalg.norm(x, axis=1)
    r1 = r[0]
    r2 = r[1]

    return np.exp(-p[0]*(r1 + r2))

simple = Wavefunction(simple_f, np.array([2.0]))

In [4]:
@partial(jit, static_argnums=(1,))
def config_step(key, wf, config, config_prob, config_idx, step_size):
    key, subkey = random.split(key)
    move_proposal = random.normal(key, shape=(config.shape[1],))*step_size
    proposal = index_add(config, config_idx%config.shape[0], move_proposal)
    proposal_prob = wf.pdf_eval(proposal)

    uniform = random.uniform(subkey)
    accept = uniform < (proposal_prob / config_prob)

    new_config = np.where(accept, proposal, config)
    config_prob = np.where(accept, proposal_prob, config_prob)
    return new_config, config_prob, config_idx+1

@partial(jit, static_argnums=(1, 2, 3, 4))
def get_configs(key, wf, n_iter, n_equi, step_size, initial_config):
    """
    Carries out Metropolis-Hastings sampling according to the distribution |`wf`|**2.0.
    
    Performs `n_equi` equilibriation steps and `n_iter` sampling steps.
    """
    
    def mh_update(i, state):
      key, config, prob, idx = state
      _, key = random.split(key)
      new_config, new_prob, new_idx = config_step(
          key,
          wf,
          config,
          prob,
          idx,
          step_size
      )
      return (key, new_config, new_prob, new_idx)

    def mh_update_and_store(i, state):
      key, config, prob, idx, configs = state
      _, key = random.split(key)
      new_config, new_prob, new_idx = config_step(
          key,
          wf,
          config,
          prob,
          idx,
          step_size
      )
      new_configs = index_update(configs, idx, new_config)
      return (key, new_config, new_prob, new_idx, new_configs)

    prob = wf.pdf_eval(initial_config)
    key, config, prob, idx = jax.lax.fori_loop(0, n_equi, mh_update, (key, initial_config, prob, 0))
    init_configs = np.zeros((n_iter, *initial_config.shape))
    key, config, prob, idx, configs = jax.lax.fori_loop(0, n_iter, mh_update_and_store, (key, config, prob, 0, init_configs))

    return configs

In [5]:
@partial(jit, static_argnums=(1,))
def itime_hamiltonian(config, wf, tau=0.01):
    n_electron = config.shape[0]
    curr_wf = wf.eval(config)
    acc = 0
    # Calculate kinetic energy
    acc += -0.5*(1/curr_wf)*wf.lapl_eval(config)
    # Calculate electron-electron energy
    for i in range(n_electron):
        for j in range(n_electron):
            if i < j:
                acc += 1 / np.linalg.norm(np.subtract(config[i], config[j]))

    # Calculate electron-nucleus energy, assume z=ne FOR NOW
    for i in range(n_electron):
        acc -= n_electron / np.linalg.norm(config[i])
    # Forget about nucleus - nucleus energy FOR NOW

    return 1-tau*acc

@partial(jit, static_argnums=(1,))
def sr_op(config, wf):
    gradlog = np.concatenate((np.array([1]), np.array(wf.p_gradlog_eval(config))))
    ih = itime_hamiltonian(config, wf)
    
    return np.multiply(gradlog, ih)

@partial(jit, static_argnums=(1,))
def overlap_matrix(config, wf):
    """
    Find the overlap matrix on the space of the parametric derivatives of `wf`
    """
    
    gradlog = np.concatenate((np.array([1]), np.array(wf.p_gradlog_eval(config))))
    overlap_ij = vmap(lambda idx: gradlog[idx[0]]*gradlog[idx[1]])
    
    grid_pairs = np.array([(i,j) for i in range(gradlog.shape[0]) for j in range(gradlog.shape[0])])
    
    return overlap_ij(grid_pairs).reshape(gradlog.shape[0], gradlog.shape[0])

@partial(jit, static_argnums=(1,))
def local_energy(config, wf):
    """
    Local energy operator. Uses JAX autograd to obtain laplacian for KE.
    """

    n_electron = config.shape[0]
    acc = 0
    # Calculate kinetic energy
    acc += -0.5*(1/wf.eval(config))*wf.lapl_eval(config)
    # Calculate electron-electron energy
    for i in range(n_electron):
        for j in range(n_electron):
            if i < j:
                acc += 1 / np.linalg.norm(np.subtract(config[i], config[j]))

    # Calculate electron-nucleus energy, assume z=ne FOR NOW
    for i in range(n_electron):
        acc -= n_electron / np.linalg.norm(config[i])

    return acc

In [6]:
@partial(jit, static_argnums=(1,2,))
def monte_carlo(configs, op, wf):
    """
    Performs a Monte Carlo integration using the `configs` walker positions
    of the expectation value of `op` for the wavefunction `wf`.
    
    Returns the expectation value, variance and a list of the sampled values {O_i}
    """

    samp_rate = 100
    walker_values = vmap(lambda config: op(config, wf))(configs)
    op_output_shape = walker_values[0].shape
    num_blocks = (walker_values.shape[0]//samp_rate)
    blocks = walker_values[:samp_rate*(num_blocks)].reshape((num_blocks, samp_rate, *op_output_shape))
    k = blocks.shape[0]
    block_means = np.mean(blocks, axis=1)
    op_expec = np.mean(block_means, axis=0)
    op_var = 1/(k*(k-1))*np.sum(np.power(block_means - op_expec, 2), axis=0)
    return op_expec, op_var

In [7]:
run_mcmc = vmap(get_configs, in_axes=(0, None, None, None, None, 0), out_axes=0)
run_int = vmap(monte_carlo, in_axes=(0, None, None), out_axes=0)

def reduce_mc_outs(outs):
  k = outs[0].shape[0]
  mean = np.mean(outs[0], axis=0)
  variance = (1/k/(k-1))*np.sum(outs[1] + np.power(outs[0] - mean, 2), axis=0)
  return mean, variance

In [None]:
n_equi = 1000
n_iter = 10000
n_chains = 500
xis = random.uniform(key, (n_chains, 2, 3))
keys = random.split(key, n_chains)
configs = run_mcmc(keys, simple, n_iter, n_equi, 0.5, xis)
E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, simple))
overlap_E, overlap_V = reduce_mc_outs(run_int(configs, overlap_matrix, simple))
#sr_E, sr_V = reduce_mc_outs(run_int(configs, sr_op, simple))

In [None]:
E_E

DeviceArray(-2.7496672, dtype=float32)

## Simple WF

In [None]:
key = random.PRNGKey(0)
n_equi = 10000
n_iter = 100000
n_chains = 500
xis = random.uniform(key, (n_chains, 2, 3))
keys = random.split(key, n_chains)
simple = Wavefunction(simple_f, np.array([2.0]))
vals = [np.array(2.0)]

for i in range(40):
  configs = run_mcmc(keys, simple, n_iter, n_equi, 0.5, xis)
  E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, simple))
  overlap_E, overlap_V = reduce_mc_outs(run_int(configs, overlap_matrix, simple))
  sr_E, sr_V = reduce_mc_outs(run_int(configs, sr_op, simple))

  dps = np.linalg.solve(overlap_E, sr_E)
  p0 = np.add(simple.p, dps[1:] / dps[0])
  # VERY IMPORTANT NOTE: JAX will not re-jit the operators if Wavefunction.p
  # is updated internally by, e.g., a getter or setter. I don't know how to solve
  # this problem currently aside from simply reinstantiating Wavefunction each
  # time Wavefunction.p needs to be changed
  #
  # Perhaps this isn't such an issue if one sticks with a purely functional style
  # and uses classes like immutable structs?
  simple = Wavefunction(simple_f, p0)
  vals.append(p0)
  print(p0)

[1.9276963]
[1.8775854]
[1.840691]
[1.8123969]
[1.7903057]
[1.7726845]
[1.7584919]
[1.7468067]
[1.7373891]
[1.7295324]
[1.7229042]
[1.7173197]
[1.712626]
[1.7087704]
[1.7055061]
[1.702796]
[1.700402]
[1.6984478]
[1.696852]
[1.6954029]
[1.6942947]
[1.6933272]
[1.6924536]
[1.6917834]
[1.6911279]
[1.690638]
[1.6901832]
[1.6898696]
[1.6895639]
[1.689298]
[1.6891017]
[1.6889309]
[1.6887196]
[1.6885333]
[1.6884184]
[1.6882296]
[1.6881112]
[1.6879984]
[1.6878452]
[1.6877065]


In [None]:
configs = run_mcmc(keys, simple, n_iter, n_equi, 0.5, xis)
E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, simple))
print("Ground state energy {} pm {} after 20 iterations with parameter {}".format(E_E, np.sqrt(E_V), p0))

Ground state energy -2.8464293479919434 pm 0.0006964870844967663 after 20 iterations with parameter [1.6877065]


For reference the true minimum of $\langle E\rangle \simeq -2.85 \text{a. u.}$ of the simple wf ansatz occurs at $\alpha$=1.6875

## Hirschfelder wavefunction

In [None]:
n_equi = 10000
n_iter = 100000
n_chains = 100
xis = random.uniform(key, (n_chains, 2, 3))
keys = random.split(key, n_chains)
hirschfelder = Wavefunction(hirschfelder_f, np.array([1.0, 0.5, 0.5, -0.1]))
vals = [p0]

for i in range(40):
  configs = run_mcmc(keys, hirschfelder, n_iter, n_equi, 0.5, xis)
  E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, hirschfelder))
  overlap_E, overlap_V = reduce_mc_outs(run_int(configs, overlap_matrix, hirschfelder))
  sr_E, sr_V = reduce_mc_outs(run_int(configs, sr_op, hirschfelder))

  dps = np.linalg.solve(overlap_E, sr_E)
  p0 = np.add(hirschfelder.p, dps[1:] / dps[0])
  hirschfelder = Wavefunction(hirschfelder_f, p0)
  vals.append(p0)
  print(p0)

[ 1.4236624   0.5324303   0.44780415 -0.1185777 ]
[ 1.8274268   0.52555925  0.38857707 -0.14034285]
[ 2.1532216   0.5014799   0.34327382 -0.14871532]
[ 2.1948583   0.47658455  0.31286922 -0.14467369]
[ 2.0115626   0.4537761   0.29153362 -0.13413884]
[ 1.8770361   0.43206456  0.27365276 -0.12360602]
[ 1.7592784   0.41211852  0.2586345  -0.11351466]
[ 1.6649686   0.39407298  0.24570781 -0.10441782]
[ 1.577496    0.37777448  0.23474768 -0.09616143]
[ 1.5076677   0.3633409   0.22492754 -0.0890992 ]
[ 1.4467105   0.35067847  0.21633679 -0.08313622]
[ 1.3945475   0.33940798  0.2087992  -0.0780477 ]
[ 1.3490429   0.32962883  0.20208855 -0.07381274]
[ 1.3063763   0.3208327   0.19627199 -0.07010238]
[ 1.2701937   0.31305918  0.1910933  -0.06695889]
[ 1.2391235   0.3063077   0.18647681 -0.06442851]
[ 1.2104979   0.30031046  0.18252607 -0.06231608]
[ 1.187672    0.29506224  0.17885162 -0.06052488]
[ 1.1666455   0.29016247  0.17578799 -0.05881443]
[ 1.1456738   0.28589332  0.17293528 -0.0574354 ]


In [None]:
configs = run_mcmc(keys, hirschfelder, n_iter, n_equi, 0.5, xis)
E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, hirschfelder))
print("Ground state energy {} pm {} after 40 iterations with parameter {}".format(E_E, np.sqrt(E_V), p0))

Ground state energy -2.901717185974121 pm 0.00027340053929947317 after 40 iterations with parameter [ 0.9998273   0.2546163   0.1520364  -0.04904639]


In [None]:
print("This {}mHa from the true ground state energy".format(
    np.abs(-2.903 - E_E)*1e3
))

This 1.2829303741455078mHa from the true ground state energy


## Hirschfelder-type wavefunction with NNs

$$\psi=e^{-2\left(r_{1}+r_{2}\right)}\left(1+\frac{1}{2} r_{12} e^{-\alpha r_{12}}\right) g\left(r_{1}, r_{2}, r_{12}\right)$$

In [22]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1):
  w_key, b_key = random.split(key)
  return [scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))]

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

@jit
def tanh(x):
    return np.tanh(x)

def predict(x, params):
  # per-example predictions
  r = np.linalg.norm(x, axis=1)
  r1 = r[0]
  r2 = r[1]
  u = np.linalg.norm(np.subtract(x[1], x[0]))

  activations = np.array([r1, r2, u])
  for w, b in params[:-1]:
    outputs = np.dot(w, activations) + b
    activations = tanh(outputs)
  
  final_w, final_b = params[-1]
  outputs = np.dot(final_w, activations) + final_b
  return outputs[0]
 
layer_sizes = [3, 12, 12, 1]
key, subkey = random.split(key)
params = init_network_params(layer_sizes, key)

Pretrain params to match $g(r1,r2,r12) = \left(1+0.2119 s u+0.1406 t^{2}-0.003 u^{2}\right)$

In [9]:
def g(x):
  r = np.linalg.norm(x, axis=1)
  r1 = r[0]
  r2 = r[1]

  s = r1 + r2
  t = r1 - r2
  u = np.linalg.norm(np.subtract(x[1], x[0]))
  return (1 + 0.2119*s*u + 0.1406*t**2.0 - 0.003*u**2.0)

g(np.array([[2.0, 1.0, 1.0], [1.0,1.0,2.0]]))

DeviceArray(2.4620862, dtype=float32)

In [10]:
from jax.experimental import optimizers

In [23]:
print(params)

[[DeviceArray([[ 0.17032038, -0.15212938,  0.4953925 ],
             [ 0.55532193, -1.4711964 ,  0.00492484],
             [ 1.8502427 ,  0.49203628,  1.61406   ],
             [-1.4216027 ,  1.19387   , -0.23470964],
             [-1.0294129 , -0.8314868 , -0.2624491 ],
             [-2.29874   , -0.7720694 ,  0.68447214],
             [-1.0402237 , -0.6654453 ,  0.21445496],
             [ 0.19351044, -0.0628294 ,  0.67243254],
             [ 0.12465691,  0.3315257 , -0.27397084],
             [-0.08468974,  0.0333338 , -0.33523253],
             [-1.5351669 ,  1.7267078 , -1.5075749 ],
             [ 0.7374547 , -0.41872683,  0.21236524]], dtype=float32), DeviceArray([ 1.2887936 , -1.1984264 , -0.13047883, -0.13376504,
              0.14090303,  1.2809039 , -0.882967  ,  0.38619992,
             -0.43700072, -0.10178029, -0.40020615, -1.0684481 ],            dtype=float32)], [DeviceArray([[ 5.81776500e-01, -1.08965719e+00, -1.66196501e+00,
              -4.35995281e-01, -3.88160884e

In [11]:
batch_size = 1000
opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)

def loss(params, inputs, targets):
    # Computes average loss for the batch
    predictions = vmap(predict, in_axes=(0, None))(inputs, params)
    return np.mean((targets - predictions)**2.0)

opt_state = opt_init(params)

@jit
def step(i, opt_state, x1, y1):
    p = get_params(opt_state)
    v, g = jax.value_and_grad(loss)(p, x1, y1)
    return v, opt_update(i, g, opt_state)

for i in range(1000):
  key, subkey = jax.random.split(key)
  xis = random.uniform(key, (batch_size, 2, 3))
  yis = vmap(g)(xis)
             
  v, opt_state = step(i, opt_state, xis, yis)
  print(v)

params = get_params(opt_state)

0.9622067
0.57115227
0.6202285
0.5641536
0.5099169
0.35819966
0.27679816
0.27328286
0.29929578
0.25216475
0.219412
0.1625536
0.15355973
0.11677519
0.10389205
0.10031447
0.093164496
0.0881365
0.07747473
0.05603502
0.041122366
0.041176233
0.040185124
0.03940032
0.03855387
0.033953447
0.032044124
0.02646026
0.023414023
0.024791678
0.024816183
0.027833462
0.025229642
0.021677772
0.023617515
0.01920466
0.020640142
0.021336377
0.023836236
0.021429835
0.020081155
0.018534623
0.018833501
0.017895047
0.016987296
0.020895587
0.018035853
0.018343316
0.017376853
0.017063968
0.015897049
0.016556155
0.01710117
0.016812699
0.015962474
0.014042572
0.014829695
0.014093005
0.01362275
0.013545292
0.012196322
0.013471211
0.0128430985
0.012355731
0.011464764
0.013038701
0.01049162
0.012264994
0.011119368
0.011652776
0.010454827
0.011171997
0.010785171
0.009996867
0.009987246
0.00907038
0.010144171
0.009574603
0.008720641
0.008558175
0.008256233
0.008788521
0.0074349474
0.008565261
0.008287347
0.008347998
0

In [None]:
def nn_hylleraas(x, params):
    r = np.linalg.norm(x, axis=1)
    r1 = r[0]
    r2 = r[1]

    s = r1 + r2
    t = r1 - r2
    u = np.linalg.norm(np.subtract(x[1], x[0]))
    return np.exp(-2*s)*(1 + 0.5*u*np.exp(-u))*predict(x, params)

nn_hylleraas_wf = Wavefunction(nn_hylleraas, params)
print(nn_hylleraas_wf.p_gradlog_eval(np.array([[2.0, 1.0, 1.0], [1.0, 1.0, 2.0]])))
print(hirschfelder_f(np.array([[2.0, 1.0, 1.0], [1.0, 1.0, 2.0]]), [1.013, 0.2119, 0.1406, -0.003]))

In [16]:
# I don't like this but i can't think of a more elegant way of evaluating
# these operators atm without writing custom code for the ML wavefunction
# that unrolls the parameter list

@partial(jit, static_argnums=(1,))
def sr_op_ml(config, wf):
    gradlog = wf.p_gradlog_eval(config)
    ih = itime_hamiltonian(config, wf)
    
    # reuse gradlog to save memory
    gradlog = np.concatenate((np.array([1]), np.concatenate(tuple(np.concatenate((glw.flatten(), gb.flatten())) for (glw, gb) in gradlog))))
    return np.multiply(ih, gradlog)

@partial(jit, static_argnums=(1,))
def overlap_matrix_ml(config, wf):
    """
    Find the overlap matrix on the space of the parametric derivatives of `wf`
    """
    
    gradlog = wf.p_gradlog_eval(config)
    gradlog = np.concatenate((np.array([1]), np.concatenate(tuple(np.concatenate((glw.flatten(), gb.flatten())) for (glw, gb) in gradlog))))
    overlap_ij = vmap(lambda idx: gradlog[idx[0]]*gradlog[idx[1]])
    
    grid_pairs = np.array([(i,j) for i in range(gradlog.shape[0]) for j in range(gradlog.shape[0])])
    
    return overlap_ij(grid_pairs).reshape(gradlog.shape[0], gradlog.shape[0])

In [14]:
n_equi = 100
n_iter = 10000
n_chains = 500
xis = random.uniform(key, (n_chains, 2, 3))
keys = random.split(key, n_chains)
configs = run_mcmc(keys, nn_hylleraas_wf, n_iter, n_equi, 0.5, xis)
# Using VMAP here causes big time memory issues on devices with low memory
# I believe this is because JAX copies the wavefunction parameters to each
# vmap thread that is executing -> n_iter*n_chains*len(params) float32s
# which quickly runs into the hundreds of GB. what is the workaround for this?
# Surely a solved problem?

# Regardless we are still vmapping over each n_iter set of configs inside the
# monte_carlo function, so we incur n_chains serial executions
E_E, E_V = reduce_mc_outs(jax.lax.map(lambda x: monte_carlo(x, local_energy, nn_hylleraas_wf), configs))
overlap_E, overlap_V = reduce_mc_outs(jax.lax.map(lambda x: monte_carlo(x, overlap_matrix_ml, nn_hylleraas_wf), configs))
sr_E, sr_V = reduce_mc_outs(jax.lax.map(lambda x: monte_carlo(x, sr_op_ml, nn_hylleraas_wf), configs))

In [15]:
E_E

DeviceArray(-2.8516073, dtype=float32)

In [None]:
from jax.scipy.sparse.linalg import cg
layer_sizes = [3, 12, 12, 1]
params = list(np.load('good_nn.npy',allow_pickle=True))
params = np.concatenate(tuple(np.concatenate((w.flatten(), b.flatten())) for (w, b) in params))
p_wrapped = []
idx=0
for m, n in zip(layer_sizes[:-1], layer_sizes[1:]):
    p_wrapped.append(
        [params[idx:idx + m*n].reshape((n, m)), params[idx + m*n:idx + (m+1)*(n)]]
    )
    idx += (m+1)*(n)
params = p_wrapped
params

In [45]:
n_equi = 1000
n_iter = 10000
n_chains = 300
xis = random.uniform(key, (n_chains, 2, 3))
keys = random.split(key, n_chains+1)
ml_wf = Wavefunction(nn_hylleraas, params)

p_wrapped = params

for i in range(400):
  keys = random.split(keys[-1], n_chains+1)
  configs = run_mcmc(keys[:-1], ml_wf, n_iter, n_equi, 0.5, xis)
  E_E, E_V = reduce_mc_outs(run_int(configs, local_energy, ml_wf))

  def odotx(x):
      @partial(jit, static_argnums=(1,))
      def op(c, w):
        gradlog = w.p_gradlog_eval(c)
        gradlog = np.concatenate((np.array([1]), np.concatenate(tuple(np.concatenate((glw.flatten(), gb.flatten())) for (glw, gb) in gradlog))))

        return np.multiply(gradlog, np.dot(gradlog, x))

      E, V = reduce_mc_outs(run_int(configs, op, ml_wf))
      return E

  sr_E, sr_V = reduce_mc_outs(run_int(configs, sr_op_ml, ml_wf))

  dps, _ = cg(odotx, sr_E)
  p_flat = np.concatenate(tuple(np.concatenate((w.flatten(), b.flatten())) for (w, b) in p_wrapped))
  dps = dps[1:] / dps[0]
  p_flat = np.add(p_flat, dps)

  sizes = layer_sizes
  idx = 0
  p_wrapped = []
  for m, n in zip(sizes[:-1], sizes[1:]):
    p_wrapped.append(
        [p_flat[idx:idx + m*n].reshape((n, m)), p_flat[idx + m*n:idx + (m+1)*(n)]]
    )
    idx += (m+1)*(n)

  ml_wf = Wavefunction(nn_hylleraas, p_wrapped)
  print("{} pm {} at step {}".format(E_E, np.sqrt(E_V), i))

-2.899165391921997 pm 0.0006035205442458391 at step 0
-2.898836135864258 pm 0.0005917842499911785 at step 1
-2.8989243507385254 pm 0.0005316136521287262 at step 2
-2.8994643688201904 pm 0.0005737557075917721 at step 3
-2.899426221847534 pm 0.0005650423699989915 at step 4
-2.9001028537750244 pm 0.0005424456321634352 at step 5
-2.8991734981536865 pm 0.000598137965425849 at step 6
-2.8997559547424316 pm 0.0006672914605587721 at step 7
-2.8993923664093018 pm 0.0006329311290755868 at step 8
-2.900200843811035 pm 0.0005620394367724657 at step 9
-2.9005956649780273 pm 0.0006317804218269885 at step 10
-2.900709390640259 pm 0.0006901273736730218 at step 11
-2.9003806114196777 pm 0.0006241296650841832 at step 12
-2.8994028568267822 pm 0.0005616185371764004 at step 13
-2.8997457027435303 pm 0.0005336973699741066 at step 14
-2.8989109992980957 pm 0.0005374276079237461 at step 15
-2.900068521499634 pm 0.0005583901074714959 at step 16
-2.9008495807647705 pm 0.000536436855327338 at step 17
-2.8999872

RuntimeError: ignored

In [46]:
params = p_wrapped

In [47]:
np.save('good_nn', params)

In [23]:
E_E

DeviceArray(-2.898688, dtype=float32)

In [None]:
np.save