In [1]:
import os
#must set these before loading numpy:
os.environ["OMP_NUM_THREADS"] = '8'
os.environ["OPENBLAS_NUM_THREADS"] = '8'

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
#os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".1"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

import numpy as np
import jax
import jax.numpy as jnp
import optax
from tqdm import tqdm
from jax import jit
from functools import partial

import jax.example_libraries.optimizers as jax_opt

import matplotlib.pyplot as plt


#wavefunction preparation 
def init_params_SAMPLE(layer_widths):
    weights = []
    biases = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        weights.append(jnp.array(np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in)))
        biases.append(jnp.array(np.random.normal(size=(n_out,))))
        
    return [weights, biases]

#initializes a set of NN parameters
def init_params(layer_widths):
    weights = []
    biases = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        weights.append(jnp.array(np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in)))
        biases.append(jnp.array(np.random.normal(size=(n_out,))))
        
    return flatten_params([weights, biases])

def get_phi_params(layer_widths):
    tot_params = []
    for i in range(N_UP + N_DOWN):
        tot_params.append(init_params(layer_widths))
    
    return jnp.array(tot_params)

@jit
def flatten_params(ps):
    weights, biases = ps
    params = jnp.array([])
    for i in range(len(weights)):
        params = jnp.concatenate((params, weights[i].flatten()))
        params = jnp.concatenate((params, biases[i].flatten()))
    return jnp.array(params)

@jit
def unflatten_params(params):
    weights = []
    biases = []
    start = 0
    for i in range(len(SAMPLE_W)):
        end = start + SAMPLE_W[i].size 
        weights.append(jnp.reshape(jnp.array(params[start:end]), SAMPLE_W[i].shape))
        start = end
        end = start + SAMPLE_B[i].size
        biases.append(jnp.reshape(jnp.array(params[start:end]), SAMPLE_B[i].shape))
        start = end
    return [weights, biases]

@jit
def NN(positions, params):
    weights, biases = unflatten_params(params) 
    a = jnp.array(positions)
    for i in range(len(weights) - 1):
        z = jnp.dot(a, weights[i]) + biases[i]
        a = jax.nn.celu(z)
    a = jnp.dot(a, weights[-1]) + biases[-1]
    return a[0]

@partial(jit, static_argnums=(1, ))
def inputs_up(positions, j):
    
    reordered = jnp.concatenate([jnp.array([positions[j]]), jnp.array(positions[:j]), jnp.array(positions[j+1:])])
    
    
    sym_piece1 = reordered[1:N_UP]
    sym_piece2 = reordered[N_UP:]
    
    new1 = []
    new2 = []
    for i in range(1, N_UP):
        new1.append(jnp.sum((jnp.array(sym_piece1)/SYMNUM)**i))
    for i in range(1, N_DOWN+1):
        new2.append(jnp.sum((jnp.array(sym_piece2)/SYMNUM)**i))
    
    return jnp.array([reordered[0]] + new1 + new2)

@partial(jit, static_argnums=(1, ))
def inputs_down(positions, j):
    reordered = jnp.concatenate([jnp.array([positions[j+N_UP]]), jnp.array(positions[:j+N_UP]), 
                                 jnp.array(positions[j+N_UP+1:])])
    
    sym_piece1 = reordered[1:N_UP+1]
    sym_piece2 = reordered[N_UP+1:]
    
    new1 = []
    new2 = []
    for i in range(1, N_UP+1):
        new1.append(jnp.sum((jnp.array(sym_piece1)/SYMNUM)**i))
    for i in range(1, N_DOWN):
        new2.append(jnp.sum((jnp.array(sym_piece2)/SYMNUM)**i))
        
    return jnp.array([reordered[0]] + new1 + new2)
    

@jit
def PHI_up(positions, params):
    
    #params = jax.device_put(params, device=jax.devices("cpu")[0])
    
    mat = jnp.zeros((N_UP, N_UP))
    for i in range(N_UP):
        for j in range(N_UP):
            mat = mat.at[i, j].set(NN(inputs_up(positions, j), params[i]))
    
    return jnp.linalg.det(mat)/jnp.sqrt(NUMFACTUP)

@jit
def PHI_down(positions, params):
    
    #params = jax.device_put(params, device=jax.devices("cpu")[0])
    
    mat = jnp.zeros((N_DOWN, N_DOWN))
    for i in range(N_DOWN):
        for j in range(N_DOWN):
            mat = mat.at[i, j].set(NN(inputs_down(positions, j), params[i+N_UP]))
    
    return jnp.linalg.det(mat)/jnp.sqrt(NUMFACTDOWN)

@jit
def Psi(positions, params):
    
    #params = jax.device_put(params, device=jax.devices("cpu")[0])
    
    return PHI_up(positions, params)*PHI_down(positions, params)*jnp.e**(-omeg*jnp.sum(jnp.array(positions)**2.))


#sampling 
@jit
def mcstep_E(xis, limit, positions, params):
    
    params = jax.device_put(params, device=jax.devices("cpu")[0])
    
    newpositions = jnp.array(positions) + xis
    
    prob = Psi(newpositions, params)**2./Psi(positions, params)**2.
    
    def truefunc(p):
        return [newpositions, True]

    def falsefunc(p):
        return [positions, False]
    
    return jax.lax.cond(prob >= limit, truefunc, falsefunc, prob)

def get_sequence(stepsize, Nsweeps, keep, Ntherm, positions_initial, params, progress):
    sq = []
    counter = 0
    
    params = jax.device_put(params, device=jax.devices("cpu")[0])

    randoms = np.random.uniform(-stepsize, stepsize, size = (Nsweeps, N_UP+N_DOWN))
    limits = np.random.uniform(0, 1, size = Nsweeps)

    positions_prev = positions_initial
    
    if progress == True:
        for i in tqdm(range(0, Nsweeps), position = 0, leave = True, desc = "MC"):
            
            new, moved = mcstep_E(randoms[i], limits[i], positions_prev, params)
        
            if moved == True:
                counter += 1
                
            if i%keep == 0 and i >= Ntherm:
                #sq = np.vstack((sq, np.array(new)))
                sq.append(new)
                
            positions_prev = new
                
    else: 
        for i in range(Nsweeps):
            new, moved = mcstep_E(randoms[i], limits[i], positions_prev, params)
        
            if moved == True:
                counter += 1
                
            if i%keep == 0 and i >= Ntherm:
                #sq = np.vstack((sq, np.array(new)))
                sq.append(new)
                
            positions_prev = new

    return [sq, counter/Nsweeps]

#gradients and energies
@jit
def potential_minus_delta(positions):
    
    harmonic_piece = 0.5*MASS*(OMEGA**2)*(jnp.sum(jnp.array(positions)**2.))
    
    return harmonic_piece 


@jit
def dpsi_dtheta(positions, params):
    
    #params = jax.device_put(params, device=jax.devices("cpu")[0])
    
    deriv = jax.grad(Psi, argnums = 1)
    
    return deriv(jnp.array(positions), params)


@jit
def dpsi2_dx2(positions, params):
    
    #params = jax.device_put(params, device=jax.devices("cpu")[0])
    
    deriv = jax.grad(Psi, argnums = 0)
    hess = jax.jacfwd(deriv, argnums = 0)
    
    dA = deriv(jnp.array(positions), params)
    Hessian = hess(jnp.array(positions), params)
    
    return jnp.diag(Hessian)


@jit
@partial(jax.vmap, in_axes=(0, None))
def Esv(pos, params):
    return (-1/(2*MASS))*(1/Psi(pos, params))*jnp.sum(dpsi2_dx2(pos, params)) + potential_minus_delta(pos)


@jit
@partial(jax.vmap, in_axes=(0, 0, None, None))
def gradv(pos, ens, params, E):
    return (2.0/Psi(pos, params))*dpsi_dtheta(pos, params)*(ens - E)


#mc observables vectorized
def get_expectsv_progress(sequence, params):
    
    Numseqs = len(sequence)

    params = jax.device_put(params, device=jax.devices("cpu")[0])
    
  
    for i in tqdm(range(0, 1), position = 0, leave = True, desc = "energy calc"): 
        ens = Esv(sequence, params)
        E = jnp.mean(ens)
        err = jnp.std(ens)/jnp.sqrt(Numseqs)
    
    for i in tqdm(range(0, 1), position = 0, leave = True, desc = "grad calc"):
        grads = gradv(sequence, ens, params, E)
        
    for i in tqdm(range(1), position = 0, leave = True, desc="means"):
        res = [E, jnp.sum(grads, axis=0)/Numseqs, err]
    
    return res

@jit
def get_expectsv_noprogress(sequence, params):
    
    Numseqs = len(sequence)

    params = jax.device_put(params, device=jax.devices("cpu")[0])
    
    ens = Esv(sequence, params)
    E = jnp.mean(ens)
    err = jnp.std(ens)/jnp.sqrt(Numseqs)
    
    grads = gradv(sequence, ens, params, E)
    
    return [E, jnp.sum(grads, axis=0)/Numseqs, err]

#computes gradients and energies
def dEdt(stepsize, Nsweeps, keep, Ntherm, params, progress):
    
    params = jax.device_put(params, device=jax.devices("cpu")[0])
    
    sequence, rate = get_sequence(stepsize, Nsweeps, keep, Ntherm, 
                                  jnp.array(np.random.uniform(-1, 1, N_UP+N_DOWN)), params, progress)

    if progress == True:
        dEdtheta, E, err = get_expectsv_progress(jnp.array(sequence), params)
        print("Accept rate: ", rate)
        print("Energy: ", E)
        print("Error: ", err)
        
    elif progress == False:
        dEdtheta, E, err = get_expectsv_noprogress(jnp.array(sequence), params)

    return dEdtheta, E, err
    

#performs a training step
def trainstep(stepsize, Nsweeps, keep, Ntherm, initial_params, step_i):
    
    grad, energy, err = dEdt(stepsize, Nsweeps, keep, Ntherm, initial_params, False)
    
    opt_state = opt_init(initial_params)
    new = opt_update(step_i, grad, opt_state)
    
    return get_params(new), energy, err

def train(Ntrains, stepsize, Nsweeps, keep, Ntherm, initial_params):
    old_params = initial_params
    
    energies = []
    errs = []
    ns = np.arange(Ntrains)

    for n in tqdm(range(0, Ntrains), position = 0, leave = True, desc = "Training"):
        new_params, energy, err = trainstep(stepsize, Nsweeps, keep, Ntherm, old_params, n)
        energies.append(energy)
        errs.append(err)
        old_params = new_params

        
    return [ns, energies, errs, old_params]

def binning(arr, binsize):
    vals = []
    i = 0
    while i < len(arr)/binsize:
        vals.append(np.mean(arr[i*binsize:(i+1)*binsize]))
        i += 1
    
    return np.mean(vals), np.std(vals)/np.sqrt(len(vals))
    



In [None]:
OMEGA = 1.0
MASS = 1.0
omeg = 1.0


N_UP = 5
N_DOWN = 5

SYMNUM = 4.0
#NUMFACTUP = np.math.factorial(N_UP)
#NUMFACTDOWN = np.math.factorial(N_DOWN)
NUMFACTUP = 3.0
NUMFACTDOWN = 3.0

lss = [N_UP+N_DOWN, 25, 50, 50, 25, 1]

SAMPLE_W, SAMPLE_B = init_params_SAMPLE(lss)

NNvars_init = get_phi_params(lss)

len(NNvars_init[0])*(N_UP+N_DOWN)

In [None]:
opt_init, opt_update, get_params = jax_opt.adam(10**-4)

resultsa = train(1000, 0.27, 1000, 1, 0, NNvars_init)