In [1]:
import numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from pylab import figure, cm
import random
from jax import grad, hessian, jit, vmap
import time
from IPython.display import clear_output

N = 2
omega = 1
g = 1
true_E = N*omega/2 -  g**2/24 * N * (N**2 - 1)
hbar = 1
m = 1

In [61]:
class Network:
    def __init__(self, num_layers=4, node_counts=[N, 10, 5, 1]):
        # defining the structure of the neural network
        self.num_layers = num_layers
        # the number of nodes for each layer
        if len(node_counts) == num_layers:
            self.node_counts = node_counts
        else:
            print("node_counts is of the wrong length.")
        # the total number of weights
        self.params_length = 0
        for i in range(num_layers - 1):
            self.params_length += node_counts[i] * node_counts[i + 1]
            i+=1
        
        # the list that stores the weight matrices (index 0 is the connections from the input to the first hidden layer)
        self.weights = []
        # generate weight matrices with the correct sizes, and random elements
        for i in range(num_layers - 1):
            self.weights.append(np.random.randn(node_counts[i + 1], node_counts[i]) * np.sqrt(1. / node_counts[i + 1]))
        self.weights = np.asarray(self.weights, dtype=object)
        
        # get the shape for reshaping a 1d array to this later
        self.weights_shape = self.weights.shape
        

    # define the activation function that we use for the layers
    def l_act(x):
        return np.max(0.0, x)
    
    # define the activation function for the output
    def o_act(x):
        return 1/(jnp.exp(-x) + 1)
    
    def conv1d(self):
        return self.weights.ravel()
    
    def convnd(self, params):
        return params.reshape(self.weights_shape)
    
    # passing inputs into the neural network and getting an output
    def output(self, coords, params):
        # sort the input coordinates in order to enforce particle swap invariance
        coords.sort()

        # format the parameters as weights
        self.weights = self.convnd(params)
        
        #TODO: implement the actual matrix multiplication of the weight matrices

# create the network object
nn = Network(4, [N, 10, 5, 1])

# defines the wavefunction based on calls to the neural network
def psi(coords, params):
    return jnp.exp(-nn.output(coords, params))

psij = jit(psi)
# returns the inverse of the wavefunction
def psi_neg(coords, params):
    return 1/psi(coords, params)
psi_negj = jit(psi_neg)
# returns the log of the wavefunction
def logpsi(coords, params):
    return jnp.log(psij(coords, params))
logpsij = jit(logpsi)

h = hessian(psij)
def ddpsi(coords, params):
    return jnp.diagonal(h(coords, params))
    # return h(coords, params)[i][i]

partials = jit(grad(logpsij, argnums=1))


def V(coords, l):
    v = 0
    # first get the potential due to the harmonic oscillators
    for i in range(len(coords)):
        # harmonic oscillator portion
        v += .5 * omega**2 * coords[i]
        # per particle 
        for j in range(0, i):
            v += - omega * g * jnp.abs(coords[i] - coords[j])
    return v
    #return .5 * omega**2 * (coords[0]**2 + coords[1] ** 2) + l * jnp.exp(-((coords[0] - coords[1])/R)**2)

def Hpsi(coords, params, l):
    hpsi = Vj(coords, l) * psij(coords, params) - hbar**2/(2*m) * jnp.sum(ddpsi(coords, params))
    return hpsi


def sample(params, num_samples):
    # random.seed(seed)
    outputs = []
    coords_t = [0 for i in range(nn.num_particles)]
    for _ in range(num_samples):
        coords_prime = [0 for i in range(nn.num_particles)]
        for i in range(nn.num_particles):
            coords_prime[i] = coords_t[i] + random.uniform(-1,1)
        if (random.uniform(0,1) < psij(coords_prime, params)**2/psij(coords_t,params)**2):
            coords_t = coords_prime
        outputs.append(coords_t)
    return jnp.array(outputs)

def grad_log_psi(coords, params):
    # this will return a vector of length len(params)
    grad = []
    # now add in the derivatives with respect to each variable
    for i in range(nn.num_params):
        grad.append(partials(coords, params)[i])
    return jnp.array(grad)

# jitted functions
Hpsij = jit(Hpsi)
Vj = jit(V)
# vectorized functions
vgradlogs = vmap(grad_log_psi, in_axes=(0, None), out_axes=0)
vpsi_neg = vmap(psi_negj, in_axes=(0, None), out_axes=0)
vhpsis = vmap(Hpsij, in_axes=(0, None, None), out_axes=0)
# vboth = vmap(jit(jnp.multiply), in_axes=(0, 0), out_axes=0)
vboth = vmap(jnp.multiply, in_axes=(0, 0), out_axes=0)


def avg_energy(params, l , N=10**3):
    # get N sampled points
    samples = sample(params, N)
    # apply the function to every point
    hpsis = vhpsis(samples, params, l)
    psi_neg = vpsi_neg(samples, params)
    psi_neg_times_hpsis = jnp.multiply(psi_neg, hpsis)
    return 1/N * jnp.sum(psi_neg_times_hpsis)

# optimized gradient code, using vectorization and jitted functions 
def vgradient(params, l, N):
    samples = sample(params, N)
    
    gradlogs = vgradlogs(samples, params)
    hpsis = vhpsis(samples, params, l)
    psi_neg = vpsi_neg(samples, params)


    psi_neg_times_hpsis = jnp.multiply(psi_neg, hpsis)
    # now compute the 3 different averages
    energy_exp = 1/N * jnp.sum(psi_neg_times_hpsis)
    # print("Average Energy: " + str(energy_exp))
    gradlogpsi_exp = 1/N * jnp.sum(gradlogs, 0)
    
    # both_exp = 1/N * jnp.sum(jnp.array([(psi_neg_times_hpsis[i] * gradlogs[i]) for i in range(len(samples))]), 0)
    both_exp = 1/N * jnp.sum(vboth(psi_neg_times_hpsis, gradlogs), 0)
    return (2 * both_exp - 2 * energy_exp * gradlogpsi_exp)

def vgrad_opt(start_params, l, learning_rate=.1, max_iterations=100, tolerance=.01, N=10**3):
    # pick the starting position
    params = start_params
    hist = [start_params]
    # iterate until we run out of iterations or less than tolerance
    for it in range(max_iterations):
        # clear_output(wait=True)
        # compute the gradient at that point
        diff = jnp.asarray((learning_rate * vgradient(params, l, N)))
        if all((abs(val) < tolerance) for val in diff):
            # print("All under tolerance")
            return hist
        # make a step in the direction opposite the gradient
        params = params - diff
        # print(params)
        hist.append(params)
    return hist