<a href="https://colab.research.google.com/github/matinmoezzi/ebola-virus-ode-dnn/blob/main/sysODE_keras_Adam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

In [None]:
train_size = 10000
test_size = 500
batch_size = 32
epochs = 100

In [None]:
x_min = 0
x_max = 3

In [None]:
x_train = tf.random.uniform(shape=[train_size, 1], minval=x_min, maxval=x_max )
x_test = tf.linspace(x_min, x_max, num=test_size)[:, tf.newaxis]

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

test_dataset = tf.data.Dataset.from_tensor_slices(x_test)
test_dataset = test_dataset.batch(batch_size)

In [None]:
f1_trial = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation='sigmoid'),
    tf.keras.layers.Dense(1)
])
f2_trial = tf.keras.Sequential([
    tf.keras.layers.Dense(10, activation='sigmoid'),
    tf.keras.layers.Dense(1)
])
models = [f1_trial, f2_trial]

In [None]:
f1_x_init = tf.constant([[0.0]])
f2_x_init = tf.constant([[0.0]])
f1_init_val = tf.constant([[0.0]])
f2_init_val = tf.constant([[1.0]])
true_init_vals = [f1_init_val, f2_init_val]

In [None]:
def loss_fn(inputs, input_grads, logits, init_vals):
  ode_loss1 = tf.square(input_grads[0] - tf.math.cos(inputs) - logits[0]**2 - logits[1] + (1 + inputs**2 + tf.math.sin(inputs)**2))
  init_loss1 = tf.square(init_vals[0] - true_init_vals[0])
  ode_loss2 = tf.square(input_grads[1] - 2*inputs + (1 + inputs**2)*tf.math.sin(inputs) - logits[0]*logits[1])
  init_loss2 = tf.square(init_vals[1] - true_init_vals[1])
  loss1 = tf.reduce_sum(ode_loss1 + init_loss1)
  loss2 = tf.reduce_sum(ode_loss2 + init_loss2)
  return loss1 + loss2

In [None]:
def loss_val_grad(models, inputs):

  shapes = [var for sub in [tf.shape_n(m.trainable_weights) for m in models] for var in sub]
  n_tensors = len(shapes)

  # we'll use tf.dynamic_stitch and tf.dynamic_partition later, so we need to
  # prepare required information first
  count = 0
  idx = [] # stitch indices
  part = [] # partition indices

  for i, shape in enumerate(shapes):
    n = np.product(shape)
    idx.append(tf.reshape(tf.range(count, count+n, dtype=tf.int32), shape))
    part.extend([i]*n)
    count += n

  part = tf.constant(part)

  def update_params(params):
    train_weights = [var for sub in [m.trainable_weights for m in models] for var in sub]
    params_var = tf.dynamic_partition(params, part, n_tensors)
    for i, (shape, param) in enumerate(zip(shapes, params_var)):
        train_weights[i].assign(tf.reshape(param, shape))

  def func(params):
    logits = []
    input_grads = []
    update_params(params)
    init_vals = [models[0](f1_x_init), models[1](f2_x_init)]
    with tf.GradientTape(persistent=True) as tp:
      for model in models:
        with tf.GradientTape(watch_accessed_variables=False) as tape:
          tape.watch(inputs)
          logit = model(inputs)
          logits.append(logit)
        input_grads.append(tape.gradient(logit, inputs))
      loss_val = loss_fn(inputs, input_grads, logits, init_vals)
    grads = [tp.gradient(loss_val, m.trainable_weights) for m in models]
    grads = sum(grads, [])
    grads = tf.dynamic_stitch(idx, grads)
    return loss_val, grads

  func.idx = idx
  func.part = part
  func.shapes = shapes
  func.update_params = update_params

  return func

In [None]:
for epoch in range(epochs):
  print(f"\nStart of epoch {epoch}:")
  for step, x_batch_train in enumerate(train_dataset):
    val_grad_func = loss_val_grad(models, x_batch_train)
    lbfgs_init_pos = tf.dynamic_stitch(val_grad_func.idx, [var for sub in [m.trainable_weights for m in models] for var in sub])
    optim_results = tfp.optimizer.lbfgs_minimize(val_grad_func, initial_position=lbfgs_init_pos, max_iterations=100)
    val_grad_func.update_params(optim_results.position)
    
    # if step % 100 == 0:
    print(f"\tTraining loss at step {step}: {optim_results.objective_value.numpy()}")
    