In [None]:
import jax.numpy as jnp
import numpy as np
from jax import random
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import os
import time
from scipy import spatial, interpolate

# Random Gaussian Process
def gaussian_process(x , num_curves, length_scale_list, u_mean=0.):
    '''
    x -  discretized locations of each curve (numpy array of size N X 1)
    num_curves - Number of curves to sample (number of samples)
    length_scale_list - List of length scales (standard deviation) to sample from
    u_mean - Mean of the Gaussian process
    '''
    X = np.expand_dims(x, 1)
    ys = []
    # Draw samples from the prior at our data points.
    # Assume a mean of 0 for simplicity
    for _ in range(num_curves):
        length_scale = np.random.choice(length_scale_list) # Length scale of kernel randomly drawn from a list
        # Exponentiated quadratic kernel (or squared exponential, Gaussian, RBF)
        cov = np.exp(-0.5 * spatial.distance.cdist(X, X, 'sqeuclidean') / length_scale**2) # Kernel of data points
        yst = np.random.multivariate_normal( mean=u_mean * np.ones(len(X)), cov=cov, size=1)
        
        if len(ys) == 0:
            ys = yst
        else:
            ys = np.vstack((ys, yst))
    return ys

def normalize(ys):
    '''
    ys - N X M matrix of M curves with N points each
    Normalize the data to be between -1 and 1
    '''
    if (np.max(np.abs(ys)) > 1):
        ys = np.divide(ys, np.reshape(np.max(np.abs(ys), 1), (-1,1)) )
    return ys

# Training data for x
def x_train(Nx):
    x = np.linspace(-1, 1, Nx).reshape(-1, 1)
    return jnp.array(x, dtype=jnp.float32)

# Training data for u_k
def u_train(x, Nu, length_scale_list, u_mean=0.):
    x = x[:,0].numpy()
    u = gaussian_process(x, Nu, length_scale_list, u_mean)
    u = 0.8*u*(x+1.)*(x-1.)
    #u = normalize(u)
    # matrix of sign of u
    u = np.exp(1.5*(1 - np.abs(u))) * u
    return jnp.array(u, dtype=jnp.float32)



# Step 3.0 Generate training data
Nx = 100  # number of training data for x (discreted locations along x-axis)
Nu = 2000 # number of training data for u (inputs for branch network)
length_scale = [ 0.5, 1.0] # length scale for the kernel

X = x_train(Nx)
uk = u_train(X, Nu, length_scale)

# Create a dataset object for X_PDE
u_train_dataset = jnp.data.Dataset.from_tensor_slices((X, uk))
# Shuffle and batch data
u_train_dataset = u_train_dataset.shuffle(buffer_size=Nu).batch(256)



# sample 20 points out of 100 for each curve in uk
uk_sampled = uk[:,::5]
uk_sampled_train_dataset = jnp.data.Dataset.from_tensor_slices((uk_sampled, uk))
uk_sampled_train_dataset = uk_sampled_train_dataset.shuffle(buffer_size=Nu).batch(256)


for i in range(10):
    plt.plot(X[:,0],uk[i,:])

NameError: name 'tf' is not defined