In [33]:
import numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt
import jax
from jax import grad, hessian, jit, vmap
from jax.nn import celu, relu
import time
from functools import partial
from IPython.display import clear_output
import optax
from tqdm import trange





In [34]:
num_particles = 1
m = 1
hbar = 1
omega = 3

In [35]:
# define a class called Network that takes in any number of inputs, and outputs a single value, and uses any number of hidden layers and neurons
class Network:
    # initialize the network with the number of inputs, outputs, and hidden layers
    def __init__(self, num_inputs, num_outputs, num_hidden_layers, num_neurons):
        # initialize the weights and biases of the network
        self.weights = []
        self.biases = []
        # add the weights and biases of the first hidden layer
        self.weights.append(np.random.randn(num_inputs, num_neurons))
        self.biases.append(np.random.randn(num_neurons))
        # add the weights and biases of the middle hidden layers
        for i in range(num_hidden_layers-1):
            self.weights.append(np.random.randn(num_neurons, num_neurons))
            self.biases.append(np.random.randn(num_neurons))
        # add the weights and biases of the output layer
        self.weights.append(np.random.randn(num_neurons, num_outputs))
        self.biases.append(np.random.randn(num_outputs))
    # define a function that takes in an input and returns the output of the network
    def output(self, x, params):
        # unflatten the weights and biases
        self.set_weights(params)
        # apply the activation function to the input
        a = relu(x)
        # apply the weights and biases of the first hidden layer
        a = np.dot(a, self.weights[0]) + self.biases[0]
        # apply the activation function to the output of the first hidden layer
        a = relu(a)
        # apply the weights and biases of the middle hidden layers
        for i in range(1, len(self.weights)-1):
            a = np.dot(a, self.weights[i]) + self.biases[i]
            a = relu(a)
        # apply the weights and biases of the output layer
        a = np.dot(a, self.weights[-1]) + self.biases[-1]
        # return the output of the network
        return a[0]

    # define a function that flattens the weights and biases of the network into 1 array
    def flatten(self):
        # initialize the flattened array
        flat = []
        # add the weights and biases of the first hidden layer
        flat.append(self.weights[0].flatten())
        flat.append(self.biases[0].flatten())
        # add the weights and biases of the middle hidden layers
        for i in range(1, len(self.weights)-1):
            flat.append(self.weights[i].flatten())
            flat.append(self.biases[i].flatten())
        # add the weights and biases of the output layer
        flat.append(self.weights[-1].flatten())
        flat.append(self.biases[-1].flatten())
        # return the flattened array
        return np.concatenate(flat)
    
    # define a function that takes in an array of weights and biases and sets the weights and biases of the network to those values
    def set_weights(self, flat):
        # initialize the index of the flattened array
        index = 0
        # set the weights and biases of the first hidden layer
        self.weights[0] = flat[index:index+self.weights[0].size].reshape(self.weights[0].shape)
        index += self.weights[0].size
        self.biases[0] = flat[index:index+self.biases[0].size].reshape(self.biases[0].shape)
        index += self.biases[0].size
        # set the weights and biases of the middle hidden layers
        for i in range(1, len(self.weights)-1):
            self.weights[i] = flat[index:index+self.weights[i].size].reshape(self.weights[i].shape)
            index += self.weights[i].size
            self.biases[i] = flat[index:index+self.biases[i].size].reshape(self.biases[i].shape)
            index += self.biases[i].size
        # set the weights and biases of the output layer
        self.weights[-1] = flat[index:index+self.weights[-1].size].reshape(self.weights[-1].shape)
        index += self.weights[-1].size
        self.biases[-1] = flat[index:index+self.biases[-1].size].reshape(self.biases[-1].shape)
        index += self.biases[-1].size

# make a network with 1 input, 1 output, 3 hidden layer, and 10 neurons
net = Network(1, 1, 1, 50)

# print the output of the network for a given input x   
x = np.array([1])
print(net.output(x, net.flatten()))
print(net.flatten())


-0.1747930729609477
[ 1.02256516 -0.94739213 -0.17160132 -0.28744272 -0.14356092  0.93191099
 -1.15554803  0.09102015  1.2926112   1.4739416  -0.16865028 -0.4325779
 -1.13178202  2.38739394 -0.7257647   0.88033135  1.18910717  0.81506199
 -0.91021029 -0.21686561 -0.10101651 -1.29520781 -1.40026974  0.6798522
 -1.30737491  2.67414251 -1.54634183 -0.29573395  1.55867212 -0.5484812
 -0.07306159 -0.14513625  0.34146883  2.27767444  0.05688756 -0.23345156
 -1.25622495 -0.18615627  0.08046792 -1.11762402 -0.34970439 -0.9755842
 -0.61531449 -1.25903736 -0.08851744  0.14859931  0.00351794 -0.87389182
  0.55734954  0.23319716  0.76961262  0.14774445 -1.50618313 -0.13001804
 -0.75854851  0.06110094 -1.30376233 -0.03289771  1.52026765  2.44632524
  0.26446545 -0.90521538  1.03635718 -1.61205628  0.92374755  0.1364236
 -0.0440408   0.57304016 -0.49746337 -0.36044132  1.42688982 -0.69510469
 -1.01115072 -1.44588608 -1.44902824 -0.25530447  0.17673038 -1.00325499
  0.69435459 -0.34207699  1.12580887