<a href="https://colab.research.google.com/github/matinmoezzi/ebola-virus-ode-dnn/blob/main/system%20of%20ODE_keras_lbfgs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt

In [None]:
train_size = 10000
test_size = 2000
batch_size = 32
epochs = 100

In [None]:
x_min = -2
x_max = 2

In [None]:
x_train = tf.random.uniform(shape=[train_size, 1], minval=x_min, maxval=x_max)
x_test = tf.linspace(x_min, x_max - 1, num=test_size)[:, tf.newaxis]
x_test = tf.cast(x_test, dtype=tf.float32)

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

test_dataset = tf.data.Dataset.from_tensor_slices(x_test)
test_dataset = test_dataset.batch(batch_size)

In [None]:
def loss_fn(inputs, grad, logit, init_val):
  ode_loss = grad + 2*inputs*logit 
  init_loss = init_val - 1
  return tf.reduce_sum(tf.square(ode_loss)) + tf.reduce_sum(tf.square(init_loss))

In [None]:
def loss_val_grad(model, inputs):

  shapes = tf.shape_n(model.trainable_weights)
  n_tensors = len(shapes)

  # we'll use tf.dynamic_stitch and tf.dynamic_partition later, so we need to
  # prepare required information first
  count = 0
  idx = [] # stitch indices
  part = [] # partition indices

  for i, shape in enumerate(shapes):
    n = np.product(shape)
    idx.append(tf.reshape(tf.range(count, count+n, dtype=tf.int32), shape))
    part.extend([i]*n)
    count += n

  part = tf.constant(part)

  def update_params(params):
    params_var = tf.dynamic_partition(params, part, n_tensors)
    for i, (shape, param) in enumerate(zip(shapes, params_var)):
         model.trainable_variables[i].assign(tf.reshape(param, shape))

  def func(params):
    update_params(params)
    with tf.GradientTape(persistent=True) as tp:
      with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(inputs)
        logit = model(inputs)
      df_dx = tape.gradient(logit, inputs)
      loss_val = loss_fn(inputs, df_dx, logit, model(tf.constant([[0.0]])))
    grads = tp.gradient(loss_val, model.trainable_weights)
    grads = tf.dynamic_stitch(idx, grads)
    return loss_val, grads

  func.idx = idx
  func.part = part
  func.shapes = shapes
  func.update_params = update_params

  return func

In [None]:
nn = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation='sigmoid'),
    tf.keras.layers.Dense(1)
])
nn.build(input_shape=(None,1))

In [None]:
for epoch in range(epochs):
  print(f"\nStart of epoch {epoch}:")
  for step, x_batch_train in enumerate(test_dataset):
    val_grad_func = loss_val_grad(nn, x_batch_train)
    lbfgs_init_pos = tf.dynamic_stitch(val_grad_func.idx, nn.trainable_weights)
    optim_results = tfp.optimizer.lbfgs_minimize(val_grad_func, initial_position=lbfgs_init_pos, max_iterations=100)
    val_grad_func.update_params(optim_results.position)
    
    # Callback
    if (optim_results.objective_value.numpy() < 1e-7):
      break
    
    # if step % 100 == 0:
    print(f"\tTraining loss at step {step}: {optim_results.objective_value.numpy()}")
    

In [None]:
xs = np.linspace(-2,2,num=400)
plt.plot(xs, np.exp(-xs**2), label='exact')
plt.plot(xs, nn(tf.convert_to_tensor(xs)[:,tf.newaxis]), label='approx')
plt.legend()
plt.show()