In [3]:
import numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from pylab import figure, cm
import random
from jax import grad, hessian, jit, vmap
import time
from IPython.display import clear_output

N = 2
omega = 1
g = 1
true_E = N*omega/2 -  g**2/24 * N * (N**2 - 1)
hbar = 1
m = 1

In [4]:
class Network:
    def __init__(self, node_counts=[N, 10, 5, 1]):
        # defining the structure of the neural network
        self.num_layers = len(node_counts)
        # the number of nodes for each layer
        self.node_counts = node_counts
        # the total number of weights
        self.params_length = 0
        for i in range(self.num_layers - 1):
            self.params_length += node_counts[i] * node_counts[i + 1]
            i+=1
        
        # the list that stores the weight matrices (index 0 is the connections from the input to the first hidden layer)
        self.weights = []
        # generate weight matrices with the correct sizes, and random elements
        for i in range(self.num_layers - 1):
            self.weights.append(np.random.randn(node_counts[i + 1], node_counts[i]) * np.sqrt(1. / node_counts[i + 1]))
        self.weights = np.asarray(self.weights, dtype=object)
        
        # get the shape for reshaping a 1d array to this later
        self.weights_shape = self.weights.shape
        

    # define the activation function that we use for the layers
    def l_act(self, x, derivative = False):
        if derivative:
            return (np.exp(-x))/((np.exp(-x)+1)**2)
        return 1/(1 + np.exp(-x))
    
    # define the activation function for the output
    def o_act(self, x):
        return 1/(jnp.exp(-x) + 1)
    
    def conv1d(self):
        return self.weights.ravel()
    
    def convnd(self, params):
        return params.reshape(self.weights_shape)
    
    # passing inputs into the neural network and getting an output
    def output(self, coords, params):
        # sort the input coordinates in order to enforce particle swap invariance
        coords.sort()
        coords = jnp.asarray(coords)
        # format the parameters as weights
        self.weights = self.convnd(params)
        # compute the output of the neural network
        for i in range(self.num_layers - 1):
            w = self.weights[i]
            # if its the first layer, dot it against the input and use the activation function
            if i == 0:
                temp = self.l_act(jnp.dot(w, coords))
            elif (i < self.num_layers):
                temp = self.l_act(jnp.dot(w, temp))
            else:
                # on the output layer we se the output activation function
                temp = self.o_act(jnp.dot(w, temp))
        return temp[0]

# create the network object
nn = Network([N, 10, 10, 1])

# defines the wavefunction based on calls to the neural network
def psi(coords, params):
    return jnp.exp(-nn.output(coords, params))

#TODO: do the computations and figure out what's needed to compute the average energy
# as well as the gradient.
# then use the same gradient descent algorithm
# Perhaps we'll need to use a different method for computing the gradient, due to the increased
# number of parameters

#TODO: eventually add bias nodes to the neural network
