In [None]:
basedir = '/home/abhinavgupta0110/NeuralODEs_ROM_Closure'

import os

is_google_colab = False
is_use_GPU = False

### Mount the Google drive if needed

In [None]:
if is_use_GPU:
    gpu_info = !nvidia-smi
    gpu_info = '\n'.join(gpu_info)
    if gpu_info.find('failed') >= 0:
        print('No GPU found!')
    else:
        print(gpu_info)

if is_google_colab:
    from google.colab import drive
    drive.mount('/content/drive')

    %pip install quadpy
    
os.chdir(os.path.join(basedir, 'neuralClosureModels'))

%load_ext autoreload

### Load modules

In [None]:
%autoreload 2

from src.utilities.DDE_Solver import ddeinttf 
from src.utilities.helper_classes import *
import src.solvers.neuralDistDDE_with_adjoint_accel_Exp as nddde
import src.bio_eqn_case.bio_eqn_1D_modcall_numpy as bio
from src.bio_eqn_case.Bio_Eqn_1D_Helper_Classes import * 

import time
import sys
from IPython.core.debugger import set_trace

import numpy as np
import scipy.interpolate
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib import cm
from shutil import move
import pickle

print(tf.__version__)

tf.keras.backend.set_floatx('float32')
import logging
tf.get_logger().setLevel(logging.ERROR)

## Define some useful classes

### Class for user-defined arguments

In [None]:
class bio_eq_nDistDDE_args(nddde.nddde_arguments, bio.bio_eqn_args):

    def __init__(self, batch_time = 12, batch_time_skip = 2, batch_size = 5, epochs = 500, learning_rate = 0.05, decay_rate = 0.95, test_freq = 1, plot_freq = 2, 
                 d_max = 1, nn_d1 = 0., nn_d2 = 0.5, adj_data_size = 2,
                 model_dir = 'Bio1D_nDistDDE_testcase_v2/model_dir_test', restart = 0, val_percentage = 0.2,
                 T = 365.*2, nt = 365*2, nz = 50, z_max = -100, k_w = 0.067, alpha = 0.025, V_m = 1.5, I_0 = 158.075, 
                 K_u = 1., Psi = 1.46, Xi = 0.1, R_m = 1.5, Lambda = 0.06, gamma = 0.3, Tau = 0.145, 
                 Phi = 0.175, Omega = 0.041, T_bio_min = 10., T_bio_max = 30., wp = 0.65, wd = 8.0, 
                 K_zb = 0.0864, K_z0 = 100.*0.0864, gamma_K = 0.1, T_mld = 365, bio_model_low_complex = 'NPZ', bio_model_high_complex = 'NNPZD', isplot = True,
                 ode_alg_name = 'dopri5', nsteps = 1, is_tstart_zero = True): # add more arguments as needed
        
        if bio_model_low_complex == 'NPZ': state_dim = 3*nz
        elif bio_model_low_complex == 'NPZD': state_dim = 4*nz
        elif bio_model_low_complex == 'NNPZD': state_dim = 5*nz

        nddde.nddde_arguments.__init__(self, data_size = nt, batch_time = batch_time, batch_time_skip = batch_time_skip, batch_size = batch_size, epochs = epochs,
                           learning_rate = learning_rate, decay_rate = decay_rate, test_freq = test_freq, plot_freq = plot_freq, d_max = d_max, nn_d1 = nn_d1,
                           nn_d2 = nn_d2, state_dim = state_dim, adj_data_size = state_dim, model_dir = model_dir, restart = restart, val_percentage = val_percentage, isplot = isplot, is_tstart_zero = is_tstart_zero)

        bio.bio_eqn_args.__init__(self, T = T, nt = nt, nz = nz, z_max = z_max, k_w = k_w, alpha = alpha, V_m = V_m, I_0 = I_0, K_u = K_u, Psi = Psi,
                    Xi = Xi, R_m = R_m, Lambda = Lambda, gamma = gamma, Tau = Tau, Phi = Phi, Omega = Omega, T_bio_min = T_bio_min, T_bio_max = T_bio_max,
                                  wp = wp, wd = wd, bio_model = bio_model_low_complex, K_zb = K_zb, K_z0 = K_z0, gamma_K = gamma_K, T_mld = T_mld)
        
        self.bio_args_for_high_complex = bio.bio_eqn_args(T = T, nt = nt, nz = nz, z_max = z_max, k_w = k_w, alpha = alpha, V_m = V_m, I_0 = I_0, K_u = K_u, Psi = Psi,
                    Xi = Xi, R_m = R_m, Lambda = Lambda, gamma = gamma, Tau = Tau, Phi = Phi, Omega = Omega, T_bio_min = T_bio_min, T_bio_max = T_bio_max,
                                  wp = wp, wd = wd, bio_model = bio_model_high_complex)
        
        self.bio_model_low_complex = bio_model_low_complex
        self.bio_model_high_complex = bio_model_high_complex
        self.ode_alg_name = ode_alg_name
        self.nsteps = nsteps

### Define the neural net architecture

In [None]:
class BioConstrainLayer(tf.keras.layers.Layer):
    
    def __init__(self, **kwargs):
        
        super(BioConstrainLayer, self).__init__(**kwargs)
        
        self.gamma = tf.Variable(0.1, trainable=True, constraint = self.constraint)
        
    def constraint(self, gamma):
        
        out = tf.where(gamma <= 1., tf.where(gamma >= 0., gamma, 0.), 1.)
        
        return out
        
        
    def call(self, input):
        
        N_channel = self.gamma * input
        P_channel = - input
        Z_channel = (1. - self.gamma) * input
        
        output = tf.concat([N_channel, P_channel, Z_channel], axis=-1)
        
        return output

In [None]:
class AddExtraChannel(tf.keras.layers.Layer):
    
    def __init__(self, **kwargs):
        
        super(AddExtraChannel, self).__init__(**kwargs)

        
    def call(self, input, channels_to_add):

        output = tf.concat([input, channels_to_add], axis=-1)
        
        return output

In [None]:
class DDEFuncMain(tf.keras.Model):

    def __init__(self, args, **kwargs):
        super(DDEFuncMain, self).__init__(**kwargs)
        
        self.depth = AddExtraChannel()
        
        self.c1 = tf.keras.layers.Conv1D(filters=7, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c2 = tf.keras.layers.Conv1D(filters=9, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c3 = tf.keras.layers.Conv1D(filters=9, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c4 = tf.keras.layers.Conv1D(filters=7, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c5 = tf.keras.layers.Conv1D(filters=5, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c6 = tf.keras.layers.Conv1D(filters=3, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c_out = tf.keras.layers.Conv1D(filters=1, kernel_size=1, strides=1, activation='linear',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)

        self.bio = BioConstrainLayer()
        
        self.flat = tf.keras.layers.Flatten('channels_first')
        
        self.args = args
    
    def process_input(self, zy):     
        
        z = zy[:, :self.args.state_dim]
        z = tf.reshape(z, [z.shape[0], tf.floor(self.args.state_dim / self.args.nz), -1])
        z = tf.transpose(z, perm=[0, 2, 1])
        y = zy[:, self.args.state_dim:]
        y = tf.reshape(y, [z.shape[0], -1, self.args.nz])
        y = tf.transpose(y, perm=[0, 2, 1])
        z_stack_y = tf.concat([z, y], axis=-1)
        
        return z_stack_y
    
    @tf.function
    def pass_layers(self, z, channels_to_add=None):
        
        z = self.layers[0](z, channels_to_add)
        
        for i in range(1, len(self.layers)):
            z = self.layers[i](z)
            
        return z
    
    def call(self, z, channels_to_add=None):
        
        z = self.process_input(z)
        
        z = self.pass_layers(z, channels_to_add)
        
        return z

In [None]:
class DDEFuncAux(tf.keras.Model):

    def __init__(self, args, **kwargs):
        super(DDEFuncAux, self).__init__(**kwargs)
        
        self.c1 = tf.keras.layers.Conv1D(filters=3, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c2 = tf.keras.layers.Conv1D(filters=5, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c3 = tf.keras.layers.Conv1D(filters=7, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.c4 = tf.keras.layers.Conv1D(filters=5, kernel_size=1, strides=1, activation='swish',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.dc_out = tf.keras.layers.Conv1D(filters=2, kernel_size=1, strides=1, activation='linear',
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), use_bias=True)
        
        self.flat = tf.keras.layers.Flatten('channels_first')
        
        self.args = args
    
    def process_input(self, z):     
               
        z = tf.reshape(z, [z.shape[0], -1, self.args.nz])
        z = tf.transpose(z, perm=[0, 2, 1])
        
        return z
    
    @tf.function
    def pass_layers(self, z):
        
        for i in range(len(self.layers)):
            z = self.layers[i](z)
            
        return z
    
    def call(self, z):
        
        z = self.process_input(z)
        
        z = self.pass_layers(z)
        
        return z

In [None]:
class split_zy:
    def __init__(self, zy, args):
        self.zy = zy
        self.args = args

    def __call__(self, t):
        return self.zy(t)[:, :self.args.state_dim]

In [None]:
class DistDDEFunc(tf.keras.Model):

    def __init__(self, main, aux, rom_model, args, **kwargs):
        super(DistDDEFunc, self).__init__(**kwargs)
        self.main = main
        self.aux = aux
        self.rom_model = rom_model
        self.args = args
        
    def process_input(self, y, t ,d, t_start):     
        
        input = [y(t)]
        input.append(y(t - d[0])[:, :self.args.state_dim])
        input.append(y(t - d[1])[:, :self.args.state_dim])
        
        z = tf.tile(tf.expand_dims(tf.expand_dims(self.args.z, axis=-1), axis=0), [t_start.shape[0], 1, 1])
        I = tf.concat([tf.expand_dims(tf.expand_dims(self.args.I(t + t_start[i]), axis=-1), axis=0) for i in range(t_start.shape[0])], axis=0)
        
        channels_to_add = tf.concat([z, I], axis=-1)
        
        return input, channels_to_add
    
    def call_nn_part(self, input, channels_to_add):
        dz_dt = self.main(input[0], channels_to_add)
        gz_t1 = self.aux(input[1])
        gz_t2 = self.aux(input[2])
        dy_dt = gz_t1 - gz_t2
        
        return tf.concat([dz_dt, dy_dt], axis=-1)

    def __call__(self, y, t ,d, t_start = np.array([0.])):
        
        get_z = split_zy(y, self.args)       
        
        input, channels_to_add = self.process_input(y, t ,d, t_start)
        dzy_dt = self.call_nn_part(input, channels_to_add)
        
        rom_output = self.rom_model(get_z, t, t_start)
        rom_output = tf.concat([rom_output, tf.zeros([dzy_dt.shape[0], dzy_dt.shape[1] - rom_output.shape[1]])], axis=-1)
        
        dzy_dt = dzy_dt + rom_output
        
        return dzy_dt

### Initialize model related parameters

In [None]:
args = bio_eq_nDistDDE_args(batch_time = 6, batch_time_skip = 2, batch_size = 4, epochs = 200, learning_rate = 0.05, 
                            decay_rate = 0.97, test_freq = 1, plot_freq = 1, d_max = 5., nn_d1 = 0., nn_d2 = 2., 
                            model_dir = 'Bio1D_nDistDDE_testcase_v3/model_dir_case_test', restart = 0, val_percentage = 1.,
                            T = 30, nt = 300, nz = 20, z_max = -100, k_w = 0.067, alpha = 0.025, V_m = 1.5, I_0 = 158.075, 
                            K_u = 1., Psi = 1.46, Xi = 0.1, R_m = 1.5, Lambda = 0.06, gamma = 0.3, Tau = 0.145, 
                            Phi = 0.175, Omega = 0.041, T_bio_min = 10., T_bio_max = 30., wp = 0*0.65, wd = 0*8.0, 
                            K_zb = 0.0864, K_z0 = 100.*0.0864, gamma_K = 0.1, T_mld = 365, bio_model_low_complex = 'NPZ', 
                            bio_model_high_complex = 'NNPZD', isplot = True, ode_alg_name = 'dopri5', nsteps = 5, 
                            is_tstart_zero = False)

### Make a copy of the current script

In [None]:
testcase_dir = 'neuralClosureModels/testcases/Bio_Eqn_1D'
save_dir_obj = save_dir(args = args, basedir = basedir, testcase_dir = testcase_dir, save_user_inputs=False)
save_dir_obj(script_name = 'neuralDistDDE_Bio_Eqn_1D_TestCase-Accel-ConstrainLayer-AddChannels-Exp')

### Run setup

In [None]:
os.chdir(os.path.join(basedir, testcase_dir))

%run -i setup

In [None]:
class custom_loss:
    
    def __init__(self, args):
        self.args = args
        self.T_bio = tf.expand_dims(tf.expand_dims(args.T_bio, axis=0), axis=0)

    def __call__(self, true_y, pred_y):
        
        true_y = tf.reshape(true_y, [true_y.shape[0], true_y.shape[1], tf.floor(self.args.state_dim / self.args.nz), -1])
        true_y = tf.transpose(true_y, perm=[0, 1, 3, 2])
        pred_y = tf.reshape(pred_y, [pred_y.shape[0], pred_y.shape[1], tf.floor(self.args.state_dim / self.args.nz), -1])
        pred_y = tf.transpose(pred_y, perm=[0, 1, 3, 2])
        
        loss = tf.sqrt(tf.reduce_sum(tf.math.squared_difference(pred_y, true_y), axis=-1) + 1e-10)
        loss = tf.reduce_mean(loss, axis=-1)
        loss = tf.reduce_mean(loss, axis=0)

        return loss

## Main part starts here

### Make objects and define learning-rate schedule

In [None]:
time_meter = nddde.RunningAverageMeter(0.97)

func_main = DDEFuncMain(args)
func_aux = DDEFuncAux(args)
rom_model = bio.bio_eqn(args, K_z_obj)
func = DistDDEFunc(func_main, func_aux, rom_model, args)
adj_func = nddde.nddde_adj_eqn(func, args, rom_model.jac_npz)
get_batch = nddde.create_batch(true_x_low_complex, x0_low_complex, t, args)
loss_obj = custom_loss(args)
plot_obj = custom_plot(tf.concat([true_x_low_complex, val_true_x_low_complex], axis=0), tf.concat([x_low_complex, val_x_low_complex], axis=0), 
                       args.z, tf.concat([t, val_t], axis=0), save_dir_obj.figsave_dir, args)
loss_history = nddde.history(args)

initial_learning_rate = args.learning_rate
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=args.niters,
    decay_rate=args.decay_rate,
    staircase=True)

### Quick test to see how the true coefficients looks like

In [None]:
if args.restart == 1: 
    func.load_weights(tf.train.latest_checkpoint(save_dir_obj.checkpoint_dir))
    process_true_z0 = nddde.process_DistDDE_IC(ai_t0, func_aux, t_lowerlim = t[0] - args.nn_d2, t_upperlim = t[0] - args.nn_d1)
    pred_zy = ddeinttf(func, process_true_z0, tf.concat([t, val_t], axis=0), fargs=([args.nn_d1, args.nn_d2],), alg_name = args.ode_alg_name, nsteps = args.nsteps)
    
    plot_obj.plot(pred_zy[:, :, :args.state_dim], epoch = 0)

    loss_history.read()
    
    initial_learning_rate = 0.05
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, decay_steps=args.niters, decay_rate=0.95, staircase=True)
    
else:
    plot_obj.plot(epoch = 0)

### Training starts here

In [None]:
optimizer = tf.keras.optimizers.RMSprop(learning_rate = lr_schedule)

nDistDDE_train_obj = nddde.train_nDistDDE(func = func, adj_func = adj_func, d = [args.nn_d1, args.nn_d2], loss_obj = loss_obj, batch_obj = get_batch,
                            optimizer = optimizer, args = args, plot_obj = plot_obj, time_meter = time_meter, checkpoint_dir = save_dir_obj.checkpoint_prefix,
                            validation_obj = val_obj, loss_history_obj = loss_history)

nDistDDE_train_obj.train(true_x_low_complex, x0_low_complex, t, val_true_x_low_complex)