In [None]:
basedir = '/home/abhinavgupta0110/NeuralODEs_ROM_Closure'

import os

is_google_colab = False
is_use_GPU = False

### Mount the drive

In [None]:
if is_use_GPU:
    gpu_info = !nvidia-smi
    gpu_info = '\n'.join(gpu_info)
    if gpu_info.find('failed') >= 0:
        print('No GPU found!')
    else:
        print(gpu_info)

if is_google_colab:
    from google.colab import drive
    drive.mount('/content/drive')

    %pip install quadpy
    
os.chdir(os.path.join(basedir, 'neuralClosureModels'))


### Load modules

In [None]:
from src.utilities.DDE_Solver import ddeinttf 
import src.solvers.neuralDDE_with_adjoint_accel as ndde

import time
import sys
from IPython.core.debugger import set_trace

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from shutil import move
import pickle

tf.keras.backend.set_floatx('float32')
import logging
tf.get_logger().setLevel(logging.ERROR)

## Define some useful classes

### Define a custom loss function

In [None]:
class custom_loss(tf.keras.losses.Loss):

    def __call__(self, true_y, pred_y):
        loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.math.squared_difference(pred_y, true_y), axis=-1)), axis=0)
        return loss

### Define a custom plotting function

In [None]:
class custom_plot:

    def __init__(self, true_y, t, figsave_dir, args):
        self.true_y = true_y
        self.t = t
        self.figsave_dir = figsave_dir
        self.args = args

    def plot(self, *pred_y, epoch = 0):
        fig = plt.figure(figsize=(12, 4), facecolor='white')
        ax_traj = fig.add_subplot(121, frameon=False)
        ax_phase = fig.add_subplot(122, frameon=False)

        ax_traj.cla()
        ax_traj.set_title('Trajectories')
        ax_traj.set_xlabel('t')
        ax_traj.set_ylabel('x,y')
        ax_traj.plot(self.t.numpy(), self.true_y.numpy()[:, 0, 0], 'b-', self.t.numpy(), self.true_y.numpy()[:, 0, 1], 'g-')
        ax_traj.set_xlim(min(self.t.numpy()), max(self.t.numpy()))
        ax_traj.set_ylim(-1, 1)

        ax_phase.cla()
        ax_phase.set_title('Phase Portrait')
        ax_phase.set_xlabel('x')
        ax_phase.set_ylabel('y')
        ax_phase.plot(self.true_y.numpy()[:, 0, 0], self.true_y.numpy()[:, 0, 1], 'g-')
        ax_phase.set_xlim(-1, 1)
        ax_phase.set_ylim(-1, 1)  
        
        if epoch != 0 or self.args.restart == 1 :
            ax_traj.plot(self.t.numpy(), pred_y[0].numpy()[:, 0, 0], 'b--', self.t.numpy(), pred_y[0].numpy()[:, 0, 1], 'g--',)
            ax_phase.plot(pred_y[0].numpy()[:, 0, 0], pred_y[0].numpy()[:, 0, 1], 'g--')

        plt.show() 

        if epoch != 0: 
            fig.savefig(os.path.join(self.figsave_dir, 'img'+str(epoch)))

### Define the neural net architecture

In [None]:
class DDEFunc(tf.keras.Model):

    def __init__(self, **kwargs):
        super(DDEFunc, self).__init__(**kwargs)
        
        self.rnn_layer = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(2, activation='linear', use_bias=False, kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.1), 
                                                                        recurrent_initializer='random_normal'))
    
    def call(self, y, t, d):
        input = []
        for i in np.flip(np.arange(d[0])): # d is a list and d[0] contains the number of steps, while d[1] contains the time-step value to skip
            input.append(tf.expand_dims(y(t - i*d[1]), axis=0))

        input = tf.concat(input, axis=0)
        input = tf.transpose(input, perm=[1, 0] + [i for i in range(2, input.shape.rank)])
        y = input
        for i in range(len(self.layers)):
            y = self.layers[i](y)
        return y

### Initialize model related parameters

In [None]:
args = ndde.arguments(data_size = 100, batch_time = 12, batch_time_skip = 2, batch_size = 5, epochs = 500, learning_rate = 0.075, decay_rate = 0.95, test_freq = 1, plot_freq = 2, 
                d_max = 1.1, rnn_nmax = 3, rnn_dt = 0.5, state_dim = 2, adj_data_size = 2,
                model_dir = 'DDE_runs/model_dir_example', restart = 0, val_percentage = 0.2)

t = tf.linspace(0., 10., args.data_size) # Time array

### Make a copy of the current script

In [None]:
os.chdir(basedir)

if not os.path.exists(args.model_dir):
    os.makedirs(args.model_dir)

checkpoint_dir = os.path.join(args.model_dir, "ckpt")
checkpoint_prefix = os.path.join(checkpoint_dir, "cp-{epoch:04d}.ckpt")
if not os.path.exists(checkpoint_dir):
  os.makedirs(checkpoint_dir)


figsave_dir = os.path.join(args.model_dir, "img")
if not os.path.exists(figsave_dir):
    os.makedirs(figsave_dir)

!jupyter nbconvert --to python neuralDDE_ROM_Closure/examples/neuralDDE_Example.ipynb
move("neuralDDE_ROM_Closure/examples/neuralDDE_Example.py", os.path.join(args.model_dir, "orig_run_file.py"))

with open(os.path.join(args.model_dir, 'args.pkl'), 'wb') as output:
    pickle.dump(args, output, pickle.HIGHEST_PROTOCOL)

### Define initial conditions and other parameters associated with the true DDE

In [None]:
class initial_cond(tf.keras.Model):

    def call(self, t):
        return tf.convert_to_tensor([[1., 0.]], dtype=tf.float32)

true_y0 = initial_cond() # Initial conditions
true_A = tf.convert_to_tensor([[-0.1, 2.0], [-2.0, -0.1]], dtype=tf.float32)
d = [0.5, 1.]

### Solve for the true DDE

In [None]:
class Lambda(tf.keras.Model):

    def call(self, y, t, d):
        return tf.cast(tf.einsum('ab, cb -> ca', tf.cast(tf.transpose(true_A), tf.float64), tf.cast(y(t), tf.float64)) 
            - 0.1*tf.einsum('ab, cb -> ca', tf.cast(tf.transpose(true_A), tf.float64), tf.cast(y(t - d[0]), tf.float64))
            - 0.1*tf.einsum('ab, cb -> ca', tf.cast(tf.transpose(true_A), tf.float64), tf.cast(y(t - d[1]), tf.float64)), tf.float32)
        
true_y = ddeinttf(Lambda(), true_y0, t, fargs=(d,))  # Solve for the true ODE solution

### Create validation set

In [None]:
val_obj = ndde.create_validation_set(true_y0, t, args)

val_true_y = val_obj.data(Lambda(), true_y, t, d)

## Main part starts here

### Make objects and define learning-rate schedule

In [None]:
time_meter = ndde.RunningAverageMeter(0.97)

func = DDEFunc()
adj_func = ndde.adj_eqn(func, args)
get_batch = ndde.create_batch(true_y, true_y0, t, args)
loss_obj = custom_loss()
plot_obj = custom_plot(tf.concat([true_y, val_true_y], axis=0), tf.concat([t, val_obj.val_t], axis=0), figsave_dir, args)
loss_history = ndde.history(args)

initial_learning_rate = args.learning_rate
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=args.niters,
    decay_rate=args.decay_rate,
    staircase=True)

### Quick test to see how the true DDE looks like

In [None]:
if args.restart == 1: 
    func.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
    pred_y = ddeinttf(func, true_y0, t, fargs=([args.rnn_nmax, args.rnn_dt],))
    
    plot_obj.plot(pred_y, epoch = 0)

    loss_history.read()
    
    initial_learning_rate = 0.05
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, decay_steps=args.niters, decay_rate=0.93, staircase=True)
    
else:
    plot_obj.plot(epoch = 0)

### Training starts here

In [None]:
optimizer = tf.keras.optimizers.RMSprop(learning_rate = lr_schedule)

nDDE_train_obj = ndde.train_nDDE(func = func, adj_func = adj_func, d = [args.rnn_nmax, args.rnn_dt], loss_obj = loss_obj, batch_obj = get_batch,
                            optimizer = optimizer, args = args, plot_obj = plot_obj, time_meter = time_meter, checkpoint_dir = checkpoint_prefix,
                            validation_obj = val_obj, loss_history_obj = loss_history)

nDDE_train_obj.train(true_y, true_y0, t, val_true_y)