In [None]:
basedir = '/home/abhinavgupta0110/NeuralODEs_ROM_Closure'

import os

is_google_colab = False
is_use_GPU = False

### Mount the Google drive if needed

In [None]:
if is_use_GPU:
    gpu_info = !nvidia-smi
    gpu_info = '\n'.join(gpu_info)
    if gpu_info.find('failed') >= 0:
        print('No GPU found!')
    else:
        print(gpu_info)

if is_google_colab:
    from google.colab import drive
    drive.mount('/content/drive')

    %pip install quadpy
    
os.chdir(os.path.join(basedir, 'neuralClosureModels'))

%load_ext autoreload

### Load modules

In [None]:
%autoreload 2

from src.utilities.DDE_Solver import ddeinttf 
from src.utilities.helper_classes import * 
import src.solvers.neuralDDE_with_adjoint_accel as ndde
import src.bio_eqn_case.bio_eqn_modcall as bio
from src.bio_eqn_case.Bio_Eqn_Helper_Classes import * 

import time
import sys
from IPython.core.debugger import set_trace

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from shutil import move
import pickle

print(tf.__version__) 

tf.keras.backend.set_floatx('float32')
import logging
tf.get_logger().setLevel(logging.ERROR)

## Define some useful classes

### Class for user-defined arguments

In [None]:
class bio_eq_nDDE_args(ndde.arguments, bio.bio_eqn_args):

    def __init__(self, batch_time = 12, batch_time_skip = 2, batch_size = 5, epochs = 500, learning_rate = 0.05, decay_rate = 0.95, test_freq = 1, plot_freq = 2, 
                 d_max = 1.1, rnn_nmax = 3, rnn_dt = 0.5, adj_data_size = 2,
                 model_dir = 'Bio_nDDE_testcase/model_dir_test', restart = 0, val_percentage = 0.2,
                 T = 2000., nt = 4000, z = -15, k_w = 0.067, alpha = 0.025, V_m = 1.5, I_0 = 158.075, K_u = 1., Psi = 1.46,
                 Xi = 0.1, R_m = 1., Lambda = 0.06, gamma = 0.3, Tau = 0.145, Phi = 0.175, Omega = 0.041, T_bio = 30, 
                 bio_model_low_complex = 'NPZ', bio_model_high_complex = 'NNPZD', isplot = True, is_tstart_zero = True, 
                 ode_alg_name = 'dopri5', nsteps = 1): # add more arguments as needed
        
        if bio_model_low_complex == 'NPZ': state_dim = 3
        elif bio_model_low_complex == 'NPZD': state_dim = 4
        elif bio_model_low_complex == 'NNPZD': state_dim = 5

        ndde.arguments.__init__(self, data_size = nt, batch_time = batch_time, batch_time_skip = batch_time_skip, batch_size = batch_size, epochs = epochs,
                           learning_rate = learning_rate, decay_rate = decay_rate, test_freq = test_freq, plot_freq = plot_freq, d_max = d_max, rnn_nmax = rnn_nmax, 
                           rnn_dt = rnn_dt, state_dim = state_dim, adj_data_size = state_dim, model_dir = model_dir, restart = restart, val_percentage = val_percentage, isplot = isplot, is_tstart_zero = is_tstart_zero)

        bio.bio_eqn_args.__init__(self, T = T, nt = nt, z = z, k_w = k_w, alpha = alpha, V_m = V_m, I_0 = I_0, K_u = K_u, Psi = Psi,
                    Xi = Xi, R_m = R_m, Lambda = Lambda, gamma = gamma, Tau = Tau, Phi = Phi, Omega = Omega, T_bio = T_bio, bio_model = bio_model_low_complex)
        
        self.bio_args_for_high_complex = bio.bio_eqn_args(T = T, nt = nt, z = z, k_w = k_w, alpha = alpha, V_m = V_m, I_0 = I_0, K_u = K_u, Psi = Psi,
                    Xi = Xi, R_m = R_m, Lambda = Lambda, gamma = gamma, Tau = Tau, Phi = Phi, Omega = Omega, T_bio = T_bio, bio_model = bio_model_high_complex)
        
        self.bio_model_low_complex = bio_model_low_complex
        self.bio_model_high_complex = bio_model_high_complex
        
        self.ode_alg_name = ode_alg_name
        self.nsteps = nsteps

### Define the neural net architecture

In [None]:
class BioConstrainLayer(tf.keras.layers.Layer):
    
    def __init__(self, **kwargs):
        
        super(BioConstrainLayer, self).__init__(**kwargs)
        
        self.gamma = tf.Variable(0.1, trainable=True, constraint = self.constraint)
        
    def constraint(self, gamma):
        
        out = tf.where(gamma <= 1., tf.where(gamma >= 0., gamma, 0.), 1.)
        
        return out
        
        
    def call(self, input):
        
        N_channel = self.gamma * input
        P_channel = - input
        Z_channel = (1. - self.gamma) * input
        
        output = tf.concat([N_channel, P_channel, Z_channel], axis=-1)
        
        return output

In [None]:
class DDEFuncMain(tf.keras.Model):

    def __init__(self, args, **kwargs):
        super(DDEFuncMain, self).__init__(**kwargs)
        
        self.x1 = tf.keras.layers.Dense(7, activation='tanh',
                                        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.x2 = tf.keras.layers.Dense(7, activation='tanh',
                                        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.x3 = tf.keras.layers.Dense(7, activation='tanh',
                                        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.x4 = tf.keras.layers.Dense(7, activation='tanh',
                                        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.x5 = tf.keras.layers.Dense(7, activation='tanh',
                                        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.x6 = tf.keras.layers.Dense(7, activation='tanh',
                                        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.out = tf.keras.layers.Dense(args.state_dim - 2, activation='linear',
                                        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.bio = BioConstrainLayer()
        
        self.args = args


    @tf.function
    def pass_layers(self, y_nn):
        
        for i in range(len(self.layers)):
            y_nn = self.layers[i](y_nn)
            
        return y_nn
    
    def call(self, y_nn):
        
        y_nn = self.pass_layers(y_nn)
        
        return y_nn

In [None]:
class DiscDDEFunc(tf.keras.Model):

    def __init__(self, main, rom_model, args, **kwargs):
        super(DiscDDEFunc, self).__init__(**kwargs)
        
        self.main = main
        self.rom_model = rom_model
        self.args = args
        
    def process_input(self, y, t ,d):     
        input = y(t)
        
        return input

    def call_nn_part(self, input):
        dy_dt = self.main(input)
        return dy_dt
    
    def __call__(self, y, t ,d, t_start = np.array([0.])):
        
        y_nn = self.process_input(y, t, d)

        dy_dt = self.call_nn_part(y_nn) + self.rom_model(y, t, t_start)
        
        return dy_dt

### Initialize model related parameters

In [None]:
args = bio_eq_nDDE_args(batch_time = 6, batch_time_skip = 2, batch_size = 4, epochs = 350, learning_rate = 0.05, 
                        decay_rate = 0.97, test_freq = 1, plot_freq = 1, d_max = 2., rnn_nmax = 1, rnn_dt = 1.0, 
                        model_dir = 'Bio_nODE_testcase_v3/model_dir_case_test', restart = 0, val_percentage = 1.,
                        T = 30., nt = 600, z = -25, k_w = 0.067, alpha = 0.025, V_m = 1.5, I_0 = 158.075, K_u = 1., 
                        Psi = 1.46, Xi = 0.1, R_m = 1.52, Lambda = 0.06, gamma = 0.3, Tau = 0.145, Phi = 0.175, 
                        Omega = 0.041, T_bio = 30., bio_model_low_complex = 'NPZ', bio_model_high_complex = 'NNPZD', 
                        ode_alg_name = 'dopri5', nsteps = 5)

### Make a copy of the current script

In [None]:
testcase_dir = 'neuralClosureModels/testcases/Bio_Eqn'
save_dir_obj = save_dir(args = args, basedir = basedir, testcase_dir = testcase_dir)
save_dir_obj(script_name = 'neuralODE_Bio_Eqn_TestCase-Accel')

### Run setup

In [None]:
os.chdir(os.path.join(basedir, testcase_dir))

%run -i setup

In [None]:
### Define a custom loss function
class custom_loss:
    
    def __init__(self, args):
        self.args = args

    def __call__(self, true_y, pred_y):
        
#         zero_places = tf.logical_or(tf.less(pred_y, tf.constant([0.])), tf.greater(pred_y, tf.constant([self.args.T_bio])))
#         mask_tensor = tf.where(zero_places, 1., 0.)
        
        loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.math.squared_difference(pred_y, true_y), axis=-1)), axis=0)
        return loss

## Main part starts here

### Make objects and define learning-rate schedule

In [None]:
time_meter = ndde.RunningAverageMeter(0.97)

rom_model = bio.bio_eqn(app = args)
func_main = DDEFuncMain(args)
func = DiscDDEFunc(func_main, rom_model, args)
adj_func = ndde.adj_eqn_ODE(func, args, rom_model.jac_npz)
get_batch = ndde.create_batch(true_x_low_complex, x0_low_complex, t, args)
loss_obj = custom_loss(args)
plot_obj = custom_plot(tf.concat([true_x_low_complex, val_true_x_low_complex], axis=0), tf.concat([x_low_complex, val_x_low_complex], axis=0), 
                       tf.concat([t, val_t], axis=0), save_dir_obj.figsave_dir, args)
loss_history = ndde.history(args)

initial_learning_rate = args.learning_rate
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=args.niters,
    decay_rate=args.decay_rate,
    staircase=True)

### Quick test to see how the true coefficients looks like

In [None]:
if args.restart == 1: 
    func.load_weights(tf.train.latest_checkpoint(save_dir_obj.checkpoint_dir))
    pred_y = ddeinttf(func, x0_low_complex, tf.concat([t, val_t], axis=0), fargs=([args.rnn_nmax, args.rnn_dt],))
    
    plot_obj.plot(pred_y, epoch = 0)

    loss_history.read()
    
    initial_learning_rate = 0.00002
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, decay_steps=args.niters, decay_rate=0.95, staircase=True)
    
else:
    plot_obj.plot(epoch = 0)

### Training starts here

In [None]:
optimizer = tf.keras.optimizers.RMSprop(learning_rate = lr_schedule)

nDDE_train_obj = ndde.train_nDDE(func = func, adj_func = adj_func, d = [args.rnn_nmax, args.rnn_dt], loss_obj = loss_obj, batch_obj = get_batch,
                            optimizer = optimizer, args = args, plot_obj = plot_obj, time_meter = time_meter, checkpoint_dir = save_dir_obj.checkpoint_prefix, 
                            validation_obj = val_obj, loss_history_obj = loss_history)

nDDE_train_obj.train(true_x_low_complex, x0_low_complex, t, val_true_x_low_complex)