In [3]:
import numpy as np
import jax
# import flax
# import flax.linen as nn
import jax.numpy as jnp
# import optax
from tqdm import tqdm
from jax import jit
import sys
from scipy import interpolate


import jax.example_libraries.optimizers as jax_opt

import matplotlib.pyplot as plt


In [5]:
#initializes a set of NN parameters (weights and biases)
#given layer widths lss = [input, hidden1, hidden2, ..., output]
#returns a list of parameters [[jnp.array(weights1), jnp.array(biases1)], [jnp.array(weights2), jnp.array(biases2)], etc.]
def init_params(layer_widths):
    params = []
    for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
        params.append(
            [jnp.array(np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in)),
            jnp.array(np.random.normal(size=(n_out,)))]
            )
        
    return params

#defines a NN for A where psi = e^-A given a set of N symmetrized inputs
#CELU(WI + b) at each layer
#x: inputs; params: NN params

@jit
def expA(x, params):
    *hidden, last = params
    for layer in hidden:
        x = jax.nn.celu(jnp.dot(jnp.array(x), jnp.array(layer[0])) + layer[1])
    result = jnp.dot(jnp.array(x), jnp.array(last[0])) + last[1]
    
    return result[0]


#computes the symmetric coordinates that are fed into the NN
#I_n = sum_i (x_i)^n
#adjust num such that the inputs are not too large

@jit
def get_inputs(positions):
    inputs = []
    num = 2.0
    n = 1
    while n <= NP:
        inputs.append(jnp.sum((jnp.array(positions)/num)**n))
        n += 1
    
    return inputs


#computes the NN for A; add a quadratic piece to multiply psi by a Gaussian
#A = NN((x_1, x_2, ..., x_N), NN_parameters) + quadratic
@jit
def bosonic_A(positions, Avariables):
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    inputs = get_inputs(positions)
    
    A = expA(inputs, Avariables) + 0.5*jnp.sum(jnp.array(positions)**2.)
    
    
    
    return A

#exponentiates A to get psi
@jit
def bosonic_psi(positions, Avariables):
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    A = bosonic_A(positions, Avariables)
    
    result = jnp.exp(-A)
    
    return result



#computes the potential without the delta function part
#This is a harmonic trap piece 0.5*M*OMEGA*(sum_i x_i^2)
#Plus a long range interaction piece sum_{i<j} SIGMA*|x_i - x_j|
@jit
def potential_minus_delta(positions):
    
    harmonic_piece = 0.5*MASS*(OMEGA**2)*(jnp.sum(jnp.array(positions)**2.))
    
    interaction_piece = 0
    i = 0
    while i < NP:
        j = i + 1
        while j < NP:
        
            interaction_piece += SIGMA*jnp.abs(positions[i]-positions[j])
            j += 1
        i += 1
        
    return harmonic_piece + interaction_piece
    

#The delta function potential piece is sum_{i<j} G*delta(x_i - x_j)
#See delta function trick for integrating this piece using MC

#Computes the trick namely Psi^2(x_1, x_1, x_3, ..., x_N)/Psi^2(x_1, x_2, x_3, ..., x_N) * g(x_2) 
#where g(y) is a Gaussian that integrates to 1.

#positions: (x_1, x_2, x_3, ..., x_N)
#posprime: (x_1, x_1, x_3, ..., x_N)
#width is the parameter adjusting the width of the gaussian; determined for each training step (wavefunction)

@jit
def delta_part(positions, posprime, Avariables, width):

    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    num = bosonic_psi(posprime, Avariables)**2.
    denom = bosonic_psi(positions, Avariables)**2.

    delta = (1/(np.sqrt(np.pi)*width))*np.e**(-(positions[1]/width)**2)
    
    return (num/denom)*GCOUP*delta


#get monte carlo sequence

#This is the MC accept/reject statement; requires jax.lax.cond
#Limit is a random number uniform from (0, 1) - input as an argument 
#The random number itself is generated from the unjitted function get_sequence

@jit
def mcstep_E(xis, limit, positions, Avariables):
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    
    newpositions = jnp.array(positions) + xis
    #updates each position at the same time
    
    prob = (bosonic_psi(newpositions, Avariables)/bosonic_psi(positions, Avariables))**2.
    
    def truefunc(p):
        return [list(newpositions), True]

    def falsefunc(p):
        return [list(positions), False]
    
    return jax.lax.cond(prob >= limit, truefunc, falsefunc, prob)


#actually gets the monte carlo sequence
#stepsize is the proposal stepsize
#Nsweeps is total number of accept/reject steps
#keep is how many to skip before each kept sample
#Ntherm is how many to throw away as thermalization in the beginning
#position_initial: initial "sample" to start MC chain - for bosons start at 0 vector
#progress: bool show progress bar or not
def get_sequence(stepsize, Nsweeps, keep, Ntherm, positions_initial, Avariables, progress):
    sq = []
    counter = 0
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    #randoms are the pregenerated random numbers to add to each position sample as a proposal
    randoms = np.random.uniform(-stepsize, stepsize, size = (Nsweeps, NP))
    #limits are the pregenerated unif(0, 1) numbers for the accept/reject step
    limits = np.random.uniform(0, 1, size = Nsweeps)

    positions_prev = positions_initial
    
    #with progress bar
    if progress == True:
        for i in tqdm(range(0, Nsweeps), position = 0, leave = True, desc = "MC"):
            
            new, moved = mcstep_E(randoms[i], limits[i], positions_prev, Avariables)
        
            if moved == True:
                counter += 1
                
            if i%keep == 0 and i >= Ntherm:
                sq.append(new)
                
            positions_prev = new
    
    #without progress bar
    else: 
        for i in range(Nsweeps):
            new, moved = mcstep_E(randoms[i], limits[i], positions_prev, Avariables)
        
            if moved == True:
                counter += 1
                
            if i%keep == 0 and i >= Ntherm:
                sq.append(new)
                
            positions_prev = new

    return [sq, counter/Nsweeps]


#binning analysis for MC
def binning(arr, binsize):
    vals = []
    i = 0
    while i < len(arr)/binsize:
        vals.append(np.mean(arr[i*binsize:(i+1)*binsize]))
        i += 1
    
    return np.mean(vals), np.std(vals)/np.sqrt(len(vals))
    

#derivatives of the hamiltonian used to compute <E> and <dE/dt>

#computes the derivative of A with respect to the NN parameters Avariables
#for a particular sample dA(positions, NNparams)/dNNparams

@jit
def derivA_theta(positions, Avariables):
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    deriv = jax.grad(bosonic_A, argnums = 1)
    
    dA_dtheta = deriv(jnp.array(positions), Avariables)
    
    return dA_dtheta


#computes the derivatives of A with respect to position 
#returns first and second derivative used for the Hamiltonian
#dA(positions, NNparams)/dpositions

@jit
def derivA_x(positions, Avariables):
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    deriv = jax.grad(bosonic_A, argnums = 0)
    hess = jax.jacfwd(deriv, argnums = 0)
    
    dA = deriv(jnp.array(positions), Avariables)
    Hessian = hess(jnp.array(positions), Avariables)
    
    return [dA, jnp.diag(Hessian)]

#manipulations needed to multiply the format of the NN inputs by a scalar
#because the format of the neural net parameters is a list of different sized arrays
@jit
def nn_mult(parameters, scalar):
    result = []
    for layer in parameters:
        newlay = []
        for mat in layer:
            newlay.append(scalar*mat)
        result.append(newlay)
    
    return result

#takes a mean of a list of NN inputs (need this function because the stored format of the NN is a list of jnp arrays)
def nn_mean(nnlist):
    
    answer = []
    for l in range(len(nnlist[0])):
        answer.append([nnlist[0][l][0], nnlist[0][l][1]])
    
    
    for p in range(1, len(nnlist)):
        for lay in range(len(answer)):
            answer[lay][0] = jnp.add(answer[lay][0], nnlist[p][lay][0])
            answer[lay][1] = jnp.add(answer[lay][1], nnlist[p][lay][1])
    
    return nn_mult(answer, 1/len(nnlist))
    

#get expectation values of <E> and <dE/dtheta>
#gexps computes energy of at a particular MC position sample

#pos: MC sample position (x_1, x_2, ..., x_N)
#pospriime: altered MC sample position for delta function term (x_1, x_1, x_3, ..., x_N)
#Avariables: NN parameters
#width: the width of the gaussian term for the delta function term

#returns a list [energy E, dA/dNNparams, E_withoutdeltaterm*(dA/dNNparams), E_deltaterm*(dA/dNNparams)]
@jit
def gexps(pos, posprime, Avariables, width):
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    factor = 0.5*NP*(NP-1)
    #factor for number of delta function terms
    
    #di: first position derivative
    #hi: second position derivative
    di, hi = derivA_x(pos, Avariables)

    #kinetic + potential terms (excluding the delta function term)
    energyminusdelta = (1/(2*MASS))*(jnp.sum(hi)-jnp.sum(jnp.array(di)**2)) + potential_minus_delta(pos)
    #delta function potential term
    delta = factor*delta_part(pos, posprime, Avariables, width)

    energy = energyminusdelta + delta
            
    dAdt = derivA_theta(pos, Avariables)
    dAdtprime = derivA_theta(posprime, Avariables)
        
    
    return [energy, dAdt, nn_mult(dAdt, energyminusdelta), nn_mult(dAdtprime, delta)]
  
#this computes the energy at each MC sampled position and takes the average to get <E> and <dE/dtheta>

#sequence: the sequence of MC samples obtained from the MCMC
#Avariables: the NNparams
#width: the delta function term gaussian width
#progress: bool progress bar or no

def get_expects(sequence, Avariables, width, progress):
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    Numseqs = len(sequence)
    
    ens = []
    
    #the three terms involving derivatives w.r.t theta in eq. 18 of boson notes
    dAdts = []
    dAdts_times_Emds = []
    dAdts_times_ds = []
    
    if progress == True:
        for i in tqdm(range(0, Numseqs), position = 0, leave = True, desc = "Seq Calc"):
                        
            seqprime = sequence[i].copy()
            seqprime[1] = seqprime[0]
            
            energy, dAdt, dAdts_times_Emd, dAdts_times_d = gexps(sequence[i], seqprime, Avariables, width)
            
            ens.append(energy)
            dAdts.append(dAdt)
            dAdts_times_Emds.append(dAdts_times_Emd)
            dAdts_times_ds.append(dAdts_times_d)
            
    else:
        for i in range(Numseqs):
            
            seqprime = sequence[i].copy()
            seqprime[1] = seqprime[0]
            
            energy, dAdt, dAdts_times_Emd, dAdts_times_d = gexps(sequence[i], seqprime, Avariables, width)
            
            ens.append(energy)
            dAdts.append(dAdt)
            dAdts_times_Emds.append(dAdts_times_Emd)
            dAdts_times_ds.append(dAdts_times_d)

        
    return [np.mean(ens), nn_mean(dAdts), 
            nn_mean(dAdts_times_Emds), nn_mean(dAdts_times_ds), np.std(ens)/np.sqrt(len(ens))]

#returns the gradient w.r.t parameters, <E> and std(E)

#input: MC parameters and NN parameters
#returns: <dE/dNNparams>, <E>, and std(E) given the particular set of NN params

def dEdt(stepsize, Nsweeps, keep, Ntherm, Avariables, progress):
    
    Avariables = jax.device_put(Avariables, device=jax.devices("cpu")[0])
    
    
    sequence, rate = get_sequence(stepsize, Nsweeps, keep, Ntherm, [0.]*NP, Avariables, progress)
    
    ys = []
    for thing in sequence:
        ys.append(thing[1])
    
    #width of the gaussian for g(y) chosen in a special way
    #Such that g(y_max) approx 10**-10
    #y_max is the maximum x_2 obtained during MCMC
    
    
    width = np.sqrt(np.max(abs(np.array(ys)))**2/-np.log(np.sqrt(np.pi)*(10**-10)))
    #width = 8.0

    
    expectations = get_expects(sequence, Avariables, width, progress)
    
    if progress == True:
        print("Accept rate: ", rate)
        print("Width: ", width)
        print("Energy: ", expectations[0])
        print("Error: ", expectations[4])
    
    first = nn_mult(expectations[1], 2.0*expectations[0])
    second = nn_mult(expectations[2], -2.0)
    third = nn_mult(expectations[3], -2.0)
    
    dEdtheta = nn_mult(nn_mean([first, second, third]), 3.0)
    
    
    return dEdtheta, expectations[0], expectations[4]

#Computes one training step using adam
def trainstep(stepsize, Nsweeps, keep, Ntherm, initial_params, step_i):
    
    grad, energy, err = dEdt(stepsize, Nsweeps, keep, Ntherm, initial_params, False)
    
    opt_state = opt_init(initial_params)
    new = opt_update(step_i, grad, opt_state)
    
    return get_params(new), energy, err

#does the training over many training steps
#returns <E> at each step and the error, and the final set of parameters trained to
def train(Ntrains, stepsize, Nsweeps, keep, Ntherm, initial_params):
    old_params = initial_params.copy()
    
    energies = []
    errs = []
    ns = np.arange(Ntrains)

    for n in tqdm(range(0, Ntrains), position = 0, leave = True, desc = "Training"):
        new_params, energy, err = trainstep(stepsize, Nsweeps, keep, Ntherm, old_params, n)
        energies.append(energy)
        errs.append(err)
        old_params = new_params.copy()

        
    return [ns, energies, errs, old_params]

In [6]:
OMEGA = 1.0
MASS = 1.0
SIGMA = -G/2

NP = 2

lss = [NP, 50, 100, 50, 1]

Avars_init = init_params(lss)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [7]:
#computes how many parameters are in the NN
lis = lss
i = 0
total = 0
while i < len(lis)-1:
    total += lis[i]*lis[i+1] + lis[i+1]
    i += 1
    
total

10351

In [8]:
info = dEdt(1.0, 500000, 100, 1000, resultsg[3], True)

NameError: name 'resultsg' is not defined

In [9]:
opt_init, opt_update, get_params = jax_opt.adam(10**-3)

In [10]:
resultsa = train(20, 1.0, 10000, 10, 50, Avars_init)

Training:   0%|          | 0/20 [00:00<?, ?it/s]


NameError: name 'SIGMA' is not defined

In [None]:
resultsb = train(30, 1.0, 40000, 10, 50, resultsa[3])

In [None]:
etc.

In [None]:
results_ens = resultsa[1] + resultsb[1]

results_errs = resultsa[2] + resultsb[2]