In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from joblib import Parallel, delayed
import multiprocessing

In [None]:
# For tensorboard
from datetime import datetime
now = datetime.utcnow().strftime("%Y%m%d%H%M%S")
root_logdir = "tf_logs"
logdir = "{}/run-{}/".format(root_logdir, now)

In [None]:
# CME and CLE parameters
concA = 10
concB = 20
k1 = 6
k2 = 1.0
k3 = 230
k4 = 1000
vol = 32 #30 #18 #8

In [None]:
# Simulation parameters
dt = 0.001
nreactions = 4
timesteps = 10000

# Network parameters
hlayer_depth = 4
hlayer_width = 50
activation_func = tf.nn.relu 
initialLearningRate = 1e-2 #1e-3 
trainingIterations = 4000 
batchSize   = 64 #512

# Domain parameters
domain = [0,2000] #[0,1500]
outputResolution = 50
outscale = (domain[1] - domain[0])/outputResolution

# File writing parameters
datasize = 4160 #1024*5
stride = 10
filename = "data/schloglCLE_v" + str(vol) + "_" + "_T=" + str(timesteps*dt) + str(datasize) + ".txt"

In [None]:
# Define CME birth/death rates
def lambdan(n):
    return concA*k1*n*(n-1)/vol + concB*k3*vol
def mun(n):
    return k2*n*(n-1)*(n-2)/vol**2 + n*k4

# Define intensity functions
def lambda1(n):
    return concA*k1*n*(n-1)/vol
def lambda2(n):
    return k2*n*(n-1)*(n-2)/vol**2
def lambda3(n):
    return concB*k3*vol
def lambda4(n):
    return  n*k4

# # Define terminal condition function
# def terminalCondition(n):
#     return 1.0*n

def terminalCondition(n):
    hist = np.linspace(domain[0], domain[1], outputResolution)
    dx = hist[1] - hist[0]
    for i in range(len(n)):
        index = np.where((n[i] >= hist) & (n[i] < hist + dx ))[0][0]
        iresult = np.zeros(outputResolution)
        if index == 0:
            iresult[index] = 0.9
            iresult[index+1] = 0.1
        elif index == len(n):
            iresult[index-1] = 0.1
            iresult[index] = 0.9
        else:
            iresult[index-1] = 0.1
            iresult[index] = 0.8
            iresult[index+1] = 0.1
    return iresult

def steadystate_solution(n):
    result = 1.0
    for i in range(n):
        result = result*(lambdan(i)/mun(i+1))
    return result

# Define drift and diffusion coefficients
def drift(x):
    return lambda1(x) - lambda2(x) + lambda3(x) - lambda4(x)

def diffusion(x):
    for i in range(len(x)):
        sigma = [np.sqrt(lambda1(x[i])), - np.sqrt(lambda2(x[i])), 
                 np.sqrt(lambda3(x[i])), -np.sqrt(lambda4(x[i]))]
    return np.array(sigma, dtype=np.float32)

def diffusiontf(x):
    sigma = [np.sqrt(lambda1(x)), - np.sqrt(lambda2(x)), 
             np.sqrt(lambda3(x)), -np.sqrt(lambda4(x))]
    return tf.convert_to_tensor(sigma)

In [None]:
# Generate Brownian paths function
def generateBrownian(n):
    output = np.zeros([n,4])
    for i in range(n):
        output[i] = np.random.normal(0., np.sqrt(dt), nreactions)
    return output

# Propagate forward SDE function
def propagateSDE(x0, brownianPath):
    n = len(brownianPath)
    dimension = len(x0)
    trajectory = np.zeros([n, dimension])
    trajectory[0] = 1.0*x0
    for i in range(n-1):
        x = trajectory[i]
        trajectory[i+1] = x + drift(x)*dt + np.dot(diffusion(x),brownianPath[i])
        if trajectory[i+1] < 0:
            trajectory[i+1] = 0
    return trajectory

In [None]:
# Generate network structure

# Inputs and targets placeholders for trained data
networkInput = tf.placeholder(dtype=tf.float32, shape=(None,1), name='input')
networkTarget = tf.placeholder(dtype=tf.float32, shape=(None, outputResolution), name='target')

# Hidden layers
hidden = (hlayer_depth-1)*[None]
hidden[0] = networkInput
for l in range(hlayer_depth-2):
    hidden[l+1] = tf.layers.dense(hidden[l], hlayer_width, activation=activation_func)
    
# Add predition outermost layer
networkPrediction = tf.layers.dense(hidden[hlayer_depth-2], outputResolution, activation=None, name='prediction')

In [None]:
# Training procedure

# Define loss function and error
loss = tf.reduce_mean(tf.square(networkPrediction - networkTarget))
error = tf.reduce_max(tf.abs(networkPrediction - networkTarget))

# use stochastic gradient descent with ADAM during optimization
step = tf.train.AdamOptimizer(initialLearningRate).minimize(loss)

# For tensor board
mse_summary = tf.summary.scalar('Error', error)
file_writer = tf.summary.FileWriter(logdir,	tf.get_default_graph())

In [None]:
# Generate Training data
generateData = True
    
if generateData:
    x0 = np.float32(np.random.randint(domain[0] + 50, domain[1] - 100, datasize ))    
    num_cores = multiprocessing.cpu_count() 
    brownianTrajs = Parallel(n_jobs=num_cores, 
                             verbose = 2)(delayed(generateBrownian)(timesteps) for i in range(datasize))
    solutionsSDE = Parallel(n_jobs=num_cores, 
                            verbose = 2)(delayed(propagateSDE)(np.array([x0[i]]),brownianTrajs[i]) for i in range(datasize))
    # No need to write to file
    #print("Writing to file ...", end="\r")
    #f = open(filename,"w")
    #for i in range(len(results)):
    #    f.write(" ".join(str(x) for x in results[i]) + "\n")
    #f.close()
    #print("Percentage finished:", 100, "%    ", end="\r")

In [None]:
# Define input and target data
inputData = [None]*datasize
targetData = [None]*datasize
for i in range(datasize):
    inputData[i] = solutionsSDE[i][0]
    targetData[i] = terminalCondition(solutionsSDE[i][-1])   

In [None]:
# Start Tensorflow session and initialize all network variables
session = tf.Session()
session.run(tf.global_variables_initializer())

In [None]:
# Run the training

# run gradient descent steps with Adam
print('\nStarted training...')
print('{:8s}\t{:8s}\t{:8s}'.format('iter', 'l2-loss', 'linf-err'))
print('{:8s}\t{:8s}\t{:8s}'.format(*(3*[8*'-'])))
for iter in range(trainingIterations):
    # generate random batch of inputs and corresponding target values
    indices =  np.random.randint(0, datasize, batchSize)
    inputBatch = np.reshape(np.take( inputData, indices), [-1,1])
    targetBatch = np.take(np.array(targetData), indices, axis = 0)

    # take gradient descent step and compute loss & error
    loss_val, error_val, _ = session.run(
        [loss, error, step],
        feed_dict={networkInput: inputBatch, 
                   networkTarget: targetBatch}
    )
    if iter % 100 == 0:
        print('{:8d}\t{:1.2e}\t{:1.2e}'.format(iter, loss_val, error_val))
print('...finished training.\n')

In [None]:
# Calculate steady state analytically
n=np.linspace(domain[0], (domain[1]-domain[0])-1, domain[1]-domain[0])
ss_solution=np.zeros(len(n))
for i in range(len(n)):
    ss_solution[i] = steadystate_solution(i)

ss_solution = ss_solution/np.sum(ss_solution) 

In [None]:
# Plot for vector case

plt.rcParams['figure.figsize'] = (11,8)
# generate full sample grid of input domain
RESOLUTION = domain[1] - domain[0]
xgrid = np.linspace(domain[0], domain[1], num=RESOLUTION)
xgrid = xgrid.astype(int)
input_test_batch = np.reshape(xgrid , [-1,1])

# get model predictions
prediction_test_batch = session.run( networkPrediction, feed_dict={networkInput: input_test_batch})

# Remove negative entries and renormalize
x0 = 400
#renormalized_output = np.mean(prediction_test_batch, axis=0)
renormalized_output = prediction_test_batch[x0]
renormalized_output[renormalized_output<0] = 0.0
renormalized_output = renormalized_output/np.sum(renormalized_output)

# plot resultiung histogram
#plt.bar(np.arange(OUTPUT_RES),np.mean(prediction_test_batch, axis=0))
plt.bar(np.arange(outscale/2,outscale*outputResolution,outscale),renormalized_output/outscale, 
        width=outscale, label="NN (CLE)", color=(0.0, 0.4, 1.0, 0.5))

#Plot analytic solution
plt.plot(n,ss_solution, '-r', lw = 3, label="Steady state (exact)")

#plt.ylim([0.0,0.05])
#plt.xlim([0, outscale*OUTPUT_RES])
#plt.xlim([0, 400])
#plt.ylim([0.0,0.02])


plt.ylabel('Probability', fontsize = 35)
plt.xlabel('$X[T]$', fontsize = 35)
#plt.legend(fontsize = 35)
plt.tick_params(labelsize=30)
plt.locator_params(axis='y', nbins=6)


plt.show()

In [None]:
np.sum(ss_solution)
np.sum(renormalized_output)

In [None]:
plt.plot(np.mean(targetData,axis=0))
plt.plot(n/outscale ,ss_solution*outscale, '-r', lw = 3, label="Steady state (exact)")


In [None]:
np.mean(targetData)