In [13]:
import numpy as np
import jax.numpy as jnp
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from pylab import figure, cm
from jax import grad, hessian, jit, vmap
from jax.nn import celu
import time
from functools import partial
from IPython.display import clear_output
import tensorflow as tf

hbar = 1
m = 1
omega = 1


In [14]:
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(1,)),
    tf.keras.layers.Dense(5, activation="relu",use_bias=False),
    # tf.keras.layers.Dense(20, activation="relu",use_bias=False),
    # tf.keras.layers.Dense(5, activation="relu",use_bias=False),
    tf.keras.layers.Dense(1, use_bias=False)
])

def get_dims(model):
    dimensions = []
    for l in model.layers:
        dimensions.append(l.get_weights()[0].shape)
        return dimensions


def unwrap(model):
    flattened = []
    for l in model.layers:
        w = l.get_weights()[0]
        shape = (len(w), len(w[0]))
        for x in range(shape[0]):
            for y in range(shape[1]):
                flattened.append(w[x][y])
    return flattened

def set_weights(params, model):
    dimensions = get_dims(model)
    # the last index of the matrix
    max_ind = 0
    prev_ind = 0
    for i in range(len(dimensions)):
        dim = dimensions[i]
        max_ind += dim[0]*dim[1]
        temp = params[prev_ind : max_ind]
        prev_ind = max_ind
        model.layers[i].set_weights([jnp.array(temp).reshape(dim)])
    


# print(dimensions)
# print(unwrap(model))
# set_weights(unwrap(model), model)

#print(Weights)
#print(model.output_shape)
# print(model.predict(np.array([1]))[0][0])


def psi(coords, params):
    set_weights(params, model)
    return model.predict(jnp.array([coords]),verbose=0)[0][0] * jnp.exp(coords**2)

def sample(params, num_samples):
    # random.seed(seed)
    outputs = []
    coords_t = 0
    for _ in range(num_samples):
        coords_prime = coords_t + np.random.uniform(-1,1)
        if (np.random.uniform(0,1) < psi(coords_prime, params)**2/psi(coords_t, params)**2):
            coords_t = coords_prime
        outputs.append(coords_t)
    return jnp.array(outputs)

# second derivative of the wavefunction with respect to the coordinate
ddpsi = jit(grad(jit(grad(psi, allow_int = True)), allow_int = True))

@jit
def Hpsi(coords, params, omega):
    return (m*.5*omega**2*coords**2) - hbar**2 / (2*m) * jnp.sum(ddpsi(coords, params)) * 1/psi(coords, params)

venergy = vmap(Hpsi, in_axes=(0, None, None), out_axes=0)


@jit
def logpsi(coords, params):
    return jnp.log(psi(coords, params))

# define the derivative with respect to every parameter of the log of psi:
dlogpsi_dtheta_stored = jit(grad(logpsi, 1))

vlog_term = jit(vmap(dlogpsi_dtheta_stored, in_axes=(0, None), out_axes=0))

vboth = vmap(jnp.multiply, in_axes=(0, 0), out_axes=0)

def gradient(params, omega, num_samples=10**3):
    # get the samples
    samples = sample(params, num_samples)
    psiHpsi = venergy(samples, params, omega)
    logs = vlog_term(samples, params)

    energy = 1/num_samples * jnp.sum(psiHpsi)
    print(energy)
    log_term = 1/num_samples * jnp.sum(logs,0)

    both = 1/num_samples * jnp.sum(vboth(psiHpsi, logs),0)

    gradient_calc = (2 * both - 2*energy * log_term)
    return gradient_calc

def avg_energy(params, omega, num_samples = 10**3):
    samples = sample(params, num_samples)
    psiHpsi = venergy(samples, params, omega)
    return 1/num_samples * jnp.sum(psiHpsi)


def vgrad_opt(start_params, omega, num_samples=10**3, learning_rate=.1, max_iterations=10000, tolerance=.01):
    params = start_params
    hist = [start_params]

    for it in range(max_iterations):
        clear_output(wait=True)
        diff = jnp.asarray((learning_rate * gradient(params, omega, num_samples)))
        #print(diff)
        
        if all((abs(val) < tolerance) for val in diff):
            print("All under tolerance")
            return hist
        # make a step in the direction opposite the gradient
        params = params - diff
        # print(params)
        hist.append(params)
    return hist

In [17]:
xs = np.linspace(-5,5,1000)
ysi = [psi(x, unwrap(model)) for x in xs]
optd = vgrad_opt(unwrap(model), omega)


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[1,5])>with<DynamicJaxprTrace(level=0/3)>
While tracing the function Hpsi at /tmp/ipykernel_18994/2050216976.py:67 for jit, this concrete value was not available in Python because it depends on the values of the argument 'params'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError