In [None]:
basedir = '/home/abhinavgupta0110/NeuralODEs_ROM_Closure'

import os

is_google_colab = False
is_use_GPU = False

### Mount the Google drive if needed

In [None]:
if is_use_GPU:
    gpu_info = !nvidia-smi
    gpu_info = '\n'.join(gpu_info)
    if gpu_info.find('failed') >= 0:
        print('No GPU found!')
    else:
        print(gpu_info)

if is_google_colab:
    from google.colab import drive
    drive.mount('/content/drive')

    %pip install quadpy
    
os.chdir(os.path.join(basedir, 'neuralDDE_ROM_Closure'))

### Load modules

In [None]:
from src.utilities.DDE_Solver import ddeinttf 
import src.solvers.neuralDistDDE_with_adjoint as nddde
import src.bio_eqn_case.bio_eqn as bio

import time
import sys
from IPython.core.debugger import set_trace

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from shutil import move
import pickle

tf.keras.backend.set_floatx('float32')
import logging
tf.get_logger().setLevel(logging.ERROR)

## Define some useful classes

### Class for user-defined arguments

In [None]:
class bio_eq_nDistDDE_args(nddde.nddde_arguments, bio.bio_eqn_args):

    def __init__(self, batch_time = 12, batch_time_skip = 2, batch_size = 5, epochs = 500, learning_rate = 0.05, decay_rate = 0.95, test_freq = 1, plot_freq = 2, 
                 d_max = 1, nn_d1 = 0., nn_d2 = 0.5, adj_data_size = 2,
                 model_dir = 'Bio_nDistDDE_testcase/model_dir_test', restart = 0, val_percentage = 0.2,
                 T = 2000., nt = 4000, z = -15, k_w = 0.067, alpha = 0.025, V_m = 1.5, I_0 = 158.075, K_u = 1., Psi = 1.46,
                    Xi = 0.1, R_m = 1., Lambda = 0.06, gamma = 0.3, Tau = 0.145, Phi = 0.175, Omega = 0.041, T_bio = 30, bio_model_low_complex = 'NPZ', bio_model_high_complex = 'NNPZD', isplot = True): # add more arguments as needed
        
        if bio_model_low_complex == 'NPZ': state_dim = 3
        elif bio_model_low_complex == 'NPZD': state_dim = 4
        elif bio_model_low_complex == 'NNPZD': state_dim = 5

        nddde.nddde_arguments.__init__(self, data_size = nt, batch_time = batch_time, batch_time_skip = batch_time_skip, batch_size = batch_size, epochs = epochs,
                           learning_rate = learning_rate, decay_rate = decay_rate, test_freq = test_freq, plot_freq = plot_freq, d_max = d_max, nn_d1 = nn_d1,
                           nn_d2 = nn_d2, state_dim = state_dim, adj_data_size = state_dim, model_dir = model_dir, restart = restart, val_percentage = val_percentage, isplot = isplot)

        bio.bio_eqn_args.__init__(self, T = T, nt = nt, z = z, k_w = k_w, alpha = alpha, V_m = V_m, I_0 = I_0, K_u = K_u, Psi = Psi,
                    Xi = Xi, R_m = R_m, Lambda = Lambda, gamma = gamma, Tau = Tau, Phi = Phi, Omega = Omega, T_bio = T_bio, bio_model = bio_model_low_complex)
        
        self.bio_args_for_high_complex = bio.bio_eqn_args(T = T, nt = nt, z = z, k_w = k_w, alpha = alpha, V_m = V_m, I_0 = I_0, K_u = K_u, Psi = Psi,
                    Xi = Xi, R_m = R_m, Lambda = Lambda, gamma = gamma, Tau = Tau, Phi = Phi, Omega = Omega, T_bio = T_bio, bio_model = bio_model_high_complex)
        
        self.bio_model_low_complex = bio_model_low_complex
        self.bio_model_high_complex = bio_model_high_complex

### Define the neural net architecture

In [None]:
class DDEFuncMain(tf.keras.Model):

    def __init__(self, **kwargs):
        super(DDEFuncMain, self).__init__(**kwargs)
        
        self.x1 = tf.keras.layers.Dense(7, activation='tanh',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.x2 = tf.keras.layers.Dense(7, activation='tanh',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.out = tf.keras.layers.Dense(args.state_dim, activation='linear',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)

    def call(self, z):
        for i in range(len(self.layers)):
            z = self.layers[i](z)
        return z

In [None]:
class DDEFuncAux(tf.keras.Model):

    def __init__(self, **kwargs):
        super(DDEFuncAux, self).__init__(**kwargs)
        
        self.x1 = tf.keras.layers.Dense(5, activation='tanh',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.x2 = tf.keras.layers.Dense(5, activation='tanh',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)

        self.out = tf.keras.layers.Dense(4, activation='linear',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)

    def call(self, z):
        for i in range(len(self.layers)):
            z = self.layers[i](z)
        return z

In [None]:
class split_zy:
    def __init__(self, zy, args):
        self.zy = zy
        self.args = args

    def __call__(self, t):
        return self.zy(t)[:, :self.args.state_dim]

In [None]:
class DistDDEFunc:

    def __init__(self, main, aux, rom_model, args):
        self.main = main
        self.aux = aux
        self.rom_model = rom_model
        self.args = args

    def __call__(self, y, t ,d):
        
        get_z = split_zy(y, self.args)       

        dz_dt = self.main(y(t)) + self.rom_model(get_z, t)
        gz_t1 = self.aux(y(t - d[0])[:, :self.args.state_dim])
        gz_t2 = self.aux(y(t - d[1])[:, :self.args.state_dim])
        dy_dt = gz_t1 - gz_t2
        return tf.concat([dz_dt, dy_dt], axis=-1)

### Define a custom loss function

In [None]:
class custom_loss:

    def __call__(self, true_y, pred_y):
        
        zero_places = tf.logical_or(tf.less(pred_y, tf.constant([0.])), tf.greater(pred_y, tf.constant([args.T_bio])))
        mask_tensor = tf.where(zero_places, 1., 0.)
        
        loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.math.squared_difference(pred_y, true_y), axis=-1)), axis=0) \
                + tf.reduce_mean(tf.reduce_sum(mask_tensor, axis=-1), axis=0) \
                 + 0.1 * tf.reduce_mean(tf.math.abs(tf.reduce_sum(pred_y, axis=-1) - args.T_bio), axis=0) 
        return loss

### Define a custom plotting function

In [None]:
class custom_plot:

    def __init__(self, true_y, y_no_nn, t, figsave_dir, args):
        self.true_y = true_y
        self.y_no_nn = y_no_nn
        self.t = t
        self.figsave_dir = figsave_dir
        self.args = args
        self.colors = ['b', 'g', 'r', 'k', 'c', 'm']

    def plot(self, *pred_y, epoch = 0):
        fig = plt.figure(figsize=(6, 4), facecolor='white')
        ax_x1 = fig.add_subplot(111)

        ax_x1.cla()
        ax_x1.set_title('Bio Model Comparison')
        ax_x1.set_xlabel('t')
        ax_x1.set_ylabel('Bio Variable')
        ax_x1.plot(self.t.numpy(), self.true_y.numpy()[:, 0, 0], '-r', label = 'N (High Complex)')
        ax_x1.plot(self.t.numpy(), self.y_no_nn.numpy()[:, 0, 0], '--r', label = 'N (NPZ)')
        ax_x1.plot(self.t.numpy(), self.true_y.numpy()[:, 0, 1], '-g', label = 'P (High Complex)')
        ax_x1.plot(self.t.numpy(), self.y_no_nn.numpy()[:, 0, 1], '--g', label = 'P (NPZ)')
        ax_x1.plot(self.t.numpy(), self.true_y.numpy()[:, 0, 2], '-b', label = 'Z (High Complex)')
        ax_x1.plot(self.t.numpy(), self.y_no_nn.numpy()[:, 0, 2], '--b', label = 'Z (NPZ)')

        if epoch != 0 or self.args.restart == 1 :
            ax_x1.plot(self.t.numpy(), pred_y[0].numpy()[:, 0, 0], '-.r', label = 'N (Learned)')
            ax_x1.plot(self.t.numpy(), pred_y[0].numpy()[:, 0, 1], '-.g', label = 'P (Learned)')
            ax_x1.plot(self.t.numpy(), pred_y[0].numpy()[:, 0, 2], '-.b', label = 'Z (Learned)')

        ax_x1.set_xlim(self.t[0], self.t[-1])
        ax_x1.legend(bbox_to_anchor=(1.04,1), loc="upper left")

        plt.show() 

        if epoch != 0: 
            fig.savefig(os.path.join(self.figsave_dir, 'img'+str(epoch)))

### Initial Conditions

In [None]:
class initial_cond:

    def __init__(self, app):
        self.app = app

    def __call__(self, t):

        if self.app.bio_model == 'NPZ':
            x0 = [self.app.T_bio - 0.5*2, 0.5, 0.5]
        elif self.app.bio_model == 'NPZD':
            x0 = [self.app.T_bio - 0.5*2, 0.5, 0.5, 0.]
        elif self.app.bio_model == 'NNPZD':
            x0 = [self.app.T_bio/2., self.app.T_bio/2. - 2*0.5, 0.5, 0.5, 0.]
        return tf.convert_to_tensor([x0], dtype=tf.float32)

### Initialize model related parameters

In [None]:
args = bio_eq_nDistDDE_args(batch_time = 6, batch_time_skip = 2, batch_size = 8, epochs = 200, learning_rate = 0.05, decay_rate = 0.97, test_freq = 1, plot_freq = 1,
                    d_max = 3., nn_d1 = 0., nn_d2 = 2.0, model_dir = 'Bio_nDistDDE_testcase/model_dir_case', restart = 0, val_percentage = 1.,
                    T = 20., nt = 400, z = -15, k_w = 0.067, alpha = 0.025, V_m = 1.5, I_0 = 158.075, K_u = 1., Psi = 1.46,
                    Xi = 0.1, R_m = 1.5, Lambda = 0.06, gamma = 0.3, Tau = 0.145, Phi = 0.175, Omega = 0.041, T_bio = 30., bio_model_low_complex = 'NPZ', bio_model_high_complex = 'NNPZD')

### Make a copy of the current script

In [None]:
os.chdir(basedir)

if not os.path.exists(args.model_dir):
    os.makedirs(args.model_dir)

checkpoint_dir_main = os.path.join(args.model_dir, "ckpt_main")
checkpoint_dir_aux = os.path.join(args.model_dir, "ckpt_aux")
checkpoint_prefix_main = os.path.join(checkpoint_dir_main, "cp-{epoch:04d}.ckpt")
checkpoint_prefix_aux = os.path.join(checkpoint_dir_aux, "cp-{epoch:04d}.ckpt")
if not os.path.exists(checkpoint_dir_main):
  os.makedirs(checkpoint_dir_main)
if not os.path.exists(checkpoint_dir_aux):
  os.makedirs(checkpoint_dir_aux)

figsave_dir = os.path.join(args.model_dir, "img")
if not os.path.exists(figsave_dir):
    os.makedirs(figsave_dir)

!jupyter nbconvert --to python neuralDDE_ROM_Closure/testcases/Bio_Eqn/neuralDistDDE_Bio_Eqn_TestCase.ipynb
move("neuralDDE_ROM_Closure/testcases/Bio_Eqn/neuralDistDDE_Bio_Eqn_TestCase.py", os.path.join(args.model_dir, "orig_run_file.py"))

pickle.dump( args, open( os.path.join(args.model_dir, "args.p"), "wb" ) )
with open(os.path.join(args.model_dir, 'args.pkl'), 'wb') as output:
    pickle.dump(args, output, pickle.HIGHEST_PROTOCOL)

### Solve for the high complexity model

In [None]:
t = tf.linspace(0., args.T, args.nt) # Time array

x0_high_complex = initial_cond(args.bio_args_for_high_complex) # Initial conditions

x_high_complex = ddeinttf(bio.bio_eqn(args.bio_args_for_high_complex), x0_high_complex, t)

# Compute FOM for the validation time
dt = t[1] - t[0]
val_t_len =  args.val_percentage * (t[-1] - t[0])
n_val = np.ceil(np.abs(val_t_len/dt.numpy())).astype(int)
val_t = tf.linspace(t[-1], t[-1] + val_t_len, n_val)

val_x_high_complex = ddeinttf(bio.bio_eqn(args.bio_args_for_high_complex), nddde.create_interpolator(x_high_complex, t), val_t)

print('Higher complexity model done!')

### Transform states of high complexity model to low complexity model

In [None]:
# Create modes for the training and validation period combined
true_x_low_complex = bio.convert_high_complex_to_low_complex_states(x_high_complex, args)

x0_low_complex = initial_cond(args)


Solve the low complexity model

In [None]:
x_low_complex = ddeinttf(bio.bio_eqn(args), x0_low_complex, t)

val_x_low_complex = ddeinttf(bio.bio_eqn(args), nddde.create_interpolator(x_low_complex, t), val_t)

#### Create validation set

In [None]:
val_obj = nddde.create_validation_set_nddde(x0_low_complex, t, args)

val_true_x_low_complex = bio.convert_high_complex_to_low_complex_states(val_x_high_complex, args)

## Main part starts here

### Make objects and define learning-rate schedule

In [None]:
time_meter = nddde.RunningAverageMeter(0.97)

func_main = DDEFuncMain()
func_aux = DDEFuncAux()
func = DistDDEFunc(func_main, func_aux, bio.bio_eqn(app = args), args)
adj_func = nddde.nddde_adj_eqn(func, args)
get_batch = nddde.create_batch(true_x_low_complex, x0_low_complex, t, args)
loss_obj = custom_loss()
plot_obj = custom_plot(tf.concat([true_x_low_complex, val_true_x_low_complex], axis=0), tf.concat([x_low_complex, val_x_low_complex], axis=0), 
                       tf.concat([t, val_t], axis=0), figsave_dir, args)
loss_history = nddde.history(args)

initial_learning_rate = args.learning_rate
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=args.niters,
    decay_rate=args.decay_rate,
    staircase=True)

### Quick test to see how the true coefficients looks like

In [None]:
if args.restart == 1: 
    func_main.load_weights(tf.train.latest_checkpoint(checkpoint_dir_main))
    func_aux.load_weights(tf.train.latest_checkpoint(checkpoint_dir_aux))
    process_true_z0 = nddde.process_DistDDE_IC(ai_t0, func_aux, t_lowerlim = t[0] - args.nn_d2, t_upperlim = t[0] - args.nn_d1)
    pred_zy = ddeinttf(func, process_true_z0, tf.concat([t, val_t], axis=0), fargs=([args.nn_d1, args.nn_d2],))
    
    plot_obj.plot(pred_zy[:, :, :args.state_dim], epoch = 0)

    loss_history.read()
    
    initial_learning_rate = 0.05
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, decay_steps=args.niters, decay_rate=0.95, staircase=True)
    
else:
    plot_obj.plot(epoch = 0)

### Training starts here

In [None]:
optimizer_main = tf.keras.optimizers.RMSprop(learning_rate = lr_schedule)
optimizer_aux = tf.keras.optimizers.RMSprop(learning_rate = lr_schedule)

nDistDDE_train_obj = nddde.train_nDistDDE(func = func, adj_func = adj_func, d = [args.nn_d1, args.nn_d2], loss_obj = loss_obj, batch_obj = get_batch,
                            checkpoint_dir_aux = checkpoint_prefix_aux, optimizer_main = optimizer_main, optimizer_aux = optimizer_aux, args = args, plot_obj = plot_obj, time_meter = time_meter, checkpoint_dir_main = checkpoint_prefix_main,
                            validation_obj = val_obj, loss_history_obj = loss_history)

nDistDDE_train_obj.train(true_x_low_complex, x0_low_complex, t, val_true_x_low_complex)