In [38]:
import numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from pylab import figure, cm
import random
from jax import grad, hessian, jit, vmap
import time
from functools import partial
from IPython.display import clear_output

N = 2
# true_E = N*omega/2 -  g**2/24 * N * (N**2 - 1)
hbar = 1
m = 1

In [39]:
class Network:
    def __init__(self, node_counts=[N, 10, 5, 1]):
        # defining the structure of the neural network
        self.num_layers = len(node_counts)
        # the number of nodes for each layer
        self.node_counts = node_counts
        # the total number of weights
        self.params_length = 0
        for i in range(self.num_layers - 1):
            self.params_length += node_counts[i] * node_counts[i + 1]
            i+=1
        
        # the list that stores the weight matrices (index 0 is the connections from the input to the first hidden layer)
        self.weights = []
        # generate weight matrices with the correct sizes, and random elements
        for i in range(self.num_layers - 1):
            self.weights.append(np.random.randn(node_counts[i + 1], node_counts[i]) * np.sqrt(1. / node_counts[i + 1]))
        self.weights = np.array(self.weights, dtype=object)
        
        # get the shape for reshaping a 1d array to this later
        self.dimensions = []
        for m in self.weights:
            self.dimensions.append(m.shape)

    # define the activation function that we use for the layers
    def l_act(self, x, derivative = False):
        if derivative:
            return (jnp.exp(-x))/((jnp.exp(-x)+1)**2)
        return 1/(1 + jnp.exp(-x))
    
    # define the activation function for the output
    def o_act(self, x):
        return 1/(jnp.exp(-x) + 1)
    @partial(jit, static_argnums=(0,))
    def conv1d(self):
        #start = time.time()
        flattened = []
        # loop through each matrix
        for m in range(self.num_layers - 1):
            mat = self.weights[m]
            for i in range(mat.shape[0]):
                for j in range(mat.shape[1]):
                    flattened.append(mat[i][j])
        #end = time.time()
        #print("conv1d: " + str(end - start))
        return jnp.array(flattened)
        
    @partial(jit, static_argnums=(0,))
    def convnd(self, params):
        #start = time.time()
        ret = []
        # the last index of the matrix
        max_ind = 0
        prev_ind = 0
        for dim in self.dimensions:
            max_ind += dim[0]*dim[1]
            temp = params[prev_ind : max_ind]
            prev_ind = max_ind
            ret.append(jnp.array(temp).reshape(dim))
        #end = time.time()
        #print("convnd: " + str(end - start))
        return ret


    # passing inputs into the neural network and getting an output
    def output(self, coords, params):
        # start = time.time()
        # sort the input coordinates in order to enforce particle swap invariance
        coords.sort()
        coords = jnp.asarray(coords)
        # format the parameters as weights
        self.weights = self.convnd(params)
        # compute the output of the neural network
        for i in range(self.num_layers - 1):
            w = jnp.array(self.weights[i])
            # if its the first layer, dot it against the input and use the activation function
            if i == 0:
                temp = self.l_act(jnp.dot(w, coords))
            elif (i < self.num_layers):
                temp = self.l_act(jnp.dot(w, temp))
            else:
                # on the output layer we se the output activation function
                temp = self.o_act(jnp.dot(w, temp))
        # print("output:" + str(time.time() -  start))
        return temp[0]

# create the network object
nn = Network([N,10,1])

# defines the wavefunction based on calls to the neural network
def psi(coords, params):
    return jnp.exp(-nn.output(coords, params))

df_dtheta_stored = grad(nn.output)
def df_dtheta(coords, params):
    grad = []

    for i in range(nn.params_length):
        grad.append(df_dtheta_stored(coords, params)[i])
    return jnp.array(grad)
# computes the hessian matrix, has all the second order partials
h = hessian(psi)
# return the second derivative of the wavefunction with respect to every parameter
def ddpsi(coords, params):
    return jnp.diagonal(h(coords, params))

## TODO: this is the main time bottleneck!!
# Perhaps we can cache the call of psi? since its reused as the denominator the next iteration, but only sometimes?
def sample(params, num_particles, num_samples):
    # random.seed(seed)
    outputs = []
    coords_t = np.zeros(num_particles)
    for _ in range(num_samples):
        coords_prime = coords_t + np.random.uniform(-1,1,num_particles)
        check = np.random.uniform(0,1) < psi(coords_prime, params)**2/psi(coords_t,params)**2
        if check:
            coords_t = coords_prime
        start = time.time()
        outputs.append(coords_t)
        print(time.time() - start)
    return jnp.array(outputs)


## Getting the average energy via MC
@partial(jit, static_argnums=[2,3,4])
def energy(coords, params, omega, g, num_particles):
    result = 0
    # first get the value of the wavefunction to store
    temppsi = psi(coords, params)
    # then loop over all particles
    for i in range(num_particles):
        # harmonic oscillator potential
        result += .5*m*omega**2*coords[i]**2*temppsi
        # second derivative
        result += - hbar**2/(2*m) * jnp.sum(ddpsi(coords, params))
    result *= temppsi
    # now moving to the second section
    for j in range(1, num_particles):
        for i in range(j - 1):
            result -= temppsi**2 * m*omega*g/hbar * jnp.abs(coords[i] - coords[j])
            # the delta function part
            # get the modified input, make coords[i] and coords[j] the same
            delta = jnp.copy(coords)
            delta = delta.at[i].set(delta[j])
            result += temppsi * g * psi(delta, params)
    return result


venergy = vmap(energy, in_axes=(0, None, None, None, None), out_axes=0)

def avg_energy(params, omega, g, num_particles, num_samples=10**3):
    # first we get a set of samples
    start = time.time()
    samples = sample(params, num_particles, num_samples)
    print(time.time() - start)
    # then use MC integration
    return 1/num_samples * jnp.sum(venergy(samples, params, omega, g, num_particles))

def gradient(params, omega, g, num_particles, num_samples=10**3):
    # get the samples
    samples = sample(params, num_particles, num_samples)


#TODO: eventually add bias nodes to the neural network


In [40]:
print(avg_energy(nn.conv1d(),1,1,2))
#print(df_dtheta(jnp.array([1.0,1.0]), nn.conv1d()))

1.6689300537109375e-06
1.1920928955078125e-06
9.5367431640625e-07
1.1920928955078125e-06
1.430511474609375e-06
1.430511474609375e-06
1.430511474609375e-06
1.1920928955078125e-06
2.1457672119140625e-06
9.5367431640625e-07
9.5367431640625e-07
9.5367431640625e-07
1.1920928955078125e-06
9.5367431640625e-07
7.152557373046875e-07
9.5367431640625e-07
9.5367431640625e-07
1.430511474609375e-06
9.5367431640625e-07
1.9073486328125e-06
7.152557373046875e-07
1.1920928955078125e-06
1.1920928955078125e-06
9.5367431640625e-07
1.6689300537109375e-06
9.5367431640625e-07
1.1920928955078125e-06
9.5367431640625e-07
9.703636169433594e-05
7.152557373046875e-07
9.5367431640625e-07
9.5367431640625e-07
1.430511474609375e-06
1.1920928955078125e-06
7.152557373046875e-07
1.1920928955078125e-06
9.5367431640625e-07
7.152557373046875e-07
9.5367431640625e-07
9.5367431640625e-07
2.1457672119140625e-06
9.5367431640625e-07
1.1920928955078125e-06
9.5367431640625e-07
9.5367431640625e-07
7.152557373046875e-07
1.430511474609