In [2]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import math

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Layer

from scipy.integrate import odeint
from ddeint import ddeint

In [1]:
################## define general constants

def base_freq(ni,bd):
    if ni==3:
        freq=[tf.constant((np.pi)/(bd[1]-bd[0]),dtype=tf.float32),tf.constant((np.pi)/(bd[3]-bd[2]),dtype=tf.float32),tf.constant((pi)/(bd[5]-bd[4]),dtype=tf.float32)]
    if ni==2:
        freq=[tf.constant((np.pi)/(bd[1]-bd[0]),dtype=tf.float32),tf.constant((np.pi)/(bd[3]-bd[2]),dtype=tf.float32)]
    if ni==1:
        freq=tf.constant((np.pi)/(bd[1]-bd[0]),dtype=tf.float32)
    return freq


###############################################################
##Network

def dNNsolve(di,do,n_l,bd,step=1.):
    if di==3: 
        #initialization parameters
        wsint = tf.keras.initializers.RandomUniform(minval=freq_t, maxval=step*freq_t*n_l)
        wsinx = tf.keras.initializers.RandomUniform(minval=freq_x, maxval=step*freq_x*n_l)
        wsiny = tf.keras.initializers.RandomUniform(minval=freq_y, maxval=step*freq_y*n_l)
        
        wsig = tf.keras.initializers.RandomNormal(mean=0., stddev=0.0001)
        wdense = tf.keras.initializers.Constant(value=0.0001)
        input_t = keras.Input(shape=[1], name="Input_t")
        input_x = keras.Input(shape=[1], name="Input_x")
        input_y = keras.Input(shape=[1], name="Input_y")
        inputs=[input_t,input_x,input_y]
        F_t= layers.Dense(n_l, activation=tf.sin,kernel_initializer=wsint)(input_t)
        F_x= layers.Dense(n_l, activation=tf.sin,kernel_initializer=wsinx)(input_x)
        F_y= layers.Dense(n_l, activation=tf.sin,kernel_initializer=wsiny)(input_y)
        S_t= layers.Dense(n_l, activation=tf.sigmoid,kernel_initializer=wsig)(input_t)
        S_x= layers.Dense(n_l, activation=tf.sigmoid,kernel_initializer=wsig)(input_x)
        S_y= layers.Dense(n_l, activation=tf.sigmoid,kernel_initializer=wsig)(input_y)
        FFF=layers.Multiply()([F_t,F_x,F_y])
        SSS=layers.Multiply()([S_t,S_x,S_y])
        FFFSSS=layers.Multiply()([FFF,SSS])
        concat=layers.concatenate([FFF, SSS,FFFSSS])
        if do==1:
            outputs = layers.Dense(1,kernel_initializer=wdense)(concat)
        elif do==2:
            u = layers.Dense(1,kernel_initializer=wdense)(concat)
            v = layers.Dense(1,kernel_initializer=wdense)(concat)
            outputs=[u,v]
    elif di==2: 
        #initialization parameters
        wsint = tf.keras.initializers.RandomUniform(minval=freq_t, maxval=step*freq_t*n_l)
        wsinx = tf.keras.initializers.RandomUniform(minval=freq_x, maxval=step*freq_x*n_l)
        wsig = tf.keras.initializers.RandomNormal(mean=0., stddev=0.0001)
        wdense = tf.keras.initializers.Constant(value=0.0001)
        input_t = keras.Input(shape=[1], name="Input_t")
        input_x = keras.Input(shape=[1], name="Input_x")
        inputs=[input_t,input_x]
        F_t= layers.Dense(n_l, activation=tf.sin,kernel_initializer=wsint)(input_t)
        F_x= layers.Dense(n_l, activation=tf.sin,kernel_initializer=wsinx)(input_x)
        S_t= layers.Dense(n_l, activation=tf.sigmoid,kernel_initializer=wsig)(input_t)
        S_x= layers.Dense(n_l, activation=tf.sigmoid,kernel_initializer=wsig)(input_x)
        FF=layers.Multiply()([F_t,F_x])
        SS=layers.Multiply()([S_t,S_x])
        FFSS=layers.Multiply()([FF,SS])
        concat=layers.concatenate([FF, SS,FFSS])
        outputs = layers.Dense(1,kernel_initializer=wdense)(concat)
    elif di==1:
        #initialization parameters
        wsint = tf.keras.initializers.RandomUniform(minval=freq_t, maxval=step*freq_t*n_l)
        wsig = tf.keras.initializers.RandomNormal(mean=0., stddev=0.0001)
        wdense = tf.keras.initializers.Constant(value=0.0001)
        inputs = keras.Input(shape=(1,))
        F = layers.Dense(n_l, activation=tf.sin,kernel_initializer=wsint)(inputs)
        S = layers.Dense(n_l, activation=tf.sigmoid,kernel_initializer=wsig)(inputs)
        FS = layers.Multiply()([F,S])
        concat = layers.concatenate([F,S,FS])
        outputs = layers.Dense(no,kernel_initializer=wdense)(concat)
            
    return keras.Model(inputs=inputs, outputs=outputs)


# From https://gist.github.com/piyueh/712ec7d4540489aad2dcfb80f9a54993

def function_factory(model, loss, y_true):
    """A factory to create a function required by tfp.optimizer.lbfgs_minimize.
    Args:
        model [in]: an instance of `tf.keras.Model` or its subclasses.
        loss [in]: a function with signature loss_value = loss(pred_y, true_y).
        train_x [in]: the input part of training data.
        train_y [in]: the output part of training data.
    Returns:
        A function that has a signature of:
            loss_value, gradients = f(model_parameters).
    """

    # obtain the shapes of all trainable parameters in the model
    shapes = tf.shape_n(model.trainable_variables)
    n_tensors = len(shapes)

    # we'll use tf.dynamic_stitch and tf.dynamic_partition later, so we need to
    # prepare required information first
    count = 0
    idx = [] # stitch indices
    part = [] # partition indices

    for i, shape in enumerate(shapes):
        n = np.product(shape)
        idx.append(tf.reshape(tf.range(count, count+n, dtype=tf.int32), shape))
        part.extend([i]*n)
        count += n

    part = tf.constant(part)

    @tf.function
    def assign_new_model_parameters(params_1d):
        """A function updating the model's parameters with a 1D tf.Tensor.
        Args:
            params_1d [in]: a 1D tf.Tensor representing the model's trainable parameters.
        """

        params = tf.dynamic_partition(params_1d, part, n_tensors)
        for i, (shape, param) in enumerate(zip(shapes, params)):
            model.trainable_variables[i].assign(tf.reshape(param, shape))

    # now create a function that will be returned by this factory
    @tf.function
    def f(params_1d):
        """A function that can be used by tfp.optimizer.lbfgs_minimize.
        This function is created by function_factory.
        Args:
           params_1d [in]: a 1D tf.Tensor.
        Returns:
            A scalar loss and the gradients w.r.t. the `params_1d`.
        """

        # use GradientTape so that we can calculate the gradient of loss w.r.t. parameters
        with tf.GradientTape() as tape:
            # update the parameters in the model
            assign_new_model_parameters(params_1d)
            # calculate the loss
            loss_value_temp, loss_bulk, loss_IC, loss_board = loss(y_true,y_true)
            loss_value = loss_value_temp
            #loss_value = loss()

        # calculate gradients and convert to 1D tf.Tensor
        grads = tape.gradient(loss_value, model.trainable_variables)
        grads = tf.dynamic_stitch(idx, grads)

        # print out iteration & loss
        f.iter.assign_add(1)
        if (f.iter % 50 == 0):
            tf.print("Iter:", f.iter, "loss:", loss_value)

        # store loss value so we can retrieve later
        tf.py_function(f.history.append, inp=[[loss_value, loss_bulk, loss_IC, loss_board]], Tout=[])

        return loss_value, grads

    # store these information as members so we can use them outside the scope
    f.iter = tf.Variable(0)
    f.idx = idx
    f.part = part
    f.shapes = shapes
    f.assign_new_model_parameters = assign_new_model_parameters
    f.history = []

    return f

############################################################################
def rs(b0,b1,n):
    return np.random.uniform(b0,b1,(n,1))

def cs(c0,n):
    return  np.full((n, 1), c0, dtype=np.float32)

############################################################################
def random_sampling(ni,form='rect'):
    if ni==3:
        @tf.function
        def random_sampling_3D(bd,n_bulk,n_IC,n_board):
            X=np.concatenate((cs(bd[0],n_IC),rs(bd[2],bd[3],n_IC),rs(bd[4],bd[5],n_IC)),axis=1)
            X=np.concatenate((X,np.concatenate((cs(bd[1],n_IC),rs(bd[2],bd[3],n_IC),rs(bd[4],bd[5],n_IC)),axis=1)),axis=0)
            X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_board),cs(bd[2],n_board),rs(bd[4],bd[5],n_board)),axis=1)),axis=0)
            X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_board),cs(bd[3],n_board),rs(bd[4],bd[5],n_board)),axis=1)),axis=0)
            X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_board),rs(bd[2],bd[3],n_board),cs(bd[4],n_board)),axis=1)),axis=0)
            X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_board),rs(bd[2],bd[3],n_board),cs(bd[5],n_board)),axis=1)),axis=0)
            X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_bulk),rs(bd[2],bd[3],n_bulk),rs(bd[4],bd[5],n_bulk)),axis=1)),axis=0)
            np.random.shuffle(X)
            X=tf.constant(X,dtype=tf.float32)
            t,x,y=tf.split(X,3,axis=1)
            return X, t, x, y
        return random_sampling_3D
    
    elif ni==2 and form=='rect':
        @tf.function
        def random_sampling_2D(bd,n_bulk,n_IC,n_board):
            X=np.concatenate((cs(bd[0],n_IC),rs(bd[2],bd[3],n_IC)),axis=1)
            X=np.concatenate((X,np.concatenate((cs(bd[1],n_IC),rs(bd[2],bd[3],n_IC)),axis=1)),axis=0)
            X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_board),cs(bd[2],n_board)),axis=1)),axis=0)
            X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_board),cs(bd[3],n_board)),axis=1)),axis=0)
            X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_bulk),rs(bd[2],bd[3],n_bulk)),axis=1)),axis=0)
            np.random.shuffle(X)
            X=tf.constant(X,dtype=tf.float32)
            t,x=tf.split(X,2,axis=1)
            return X, t, x
        return random_sampling_2D
    
    elif ni==2 and form=='disk':
        @tf.function
        def random_sampling_2D(bd,n_bulk,n_board):
            length = tf.cast(tf.sqrt(rs(bd[2],bd[3],n_bulk)),dtype=tf.float32)
            angle  = tf.cast(pi *rs(0.,2.,n_bulk),dtype=tf.float32)
            angle_b = tf.cast(pi *rs(0.,2.,n_board),dtype=tf.float32)
            X = tf.concat([length *tf.cos(angle),length * tf.sin(angle)],axis=1)
            X = tf.concat([X,tf.concat((tf.sqrt(bd[3])*tf.cos(angle_b),tf.sqrt(bd[3])*tf.sin(angle_b)),axis=1)],axis=0)
            tf.random.shuffle(X)
            X=tf.cast(X,dtype=tf.float32)
            t,x=tf.split(X,2,axis=1)
            return X, t, x
        return random_sampling_2D
    
    elif ni==1:
        @tf.function
        def random_sampling_1D(bd,n_bulk):
            X=rs(bd[0],bd[1],n_bulk)
            np.random.shuffle(X)
            X=tf.constant(X,dtype=tf.float32)
            return X
        return random_sampling_1D
    

def random_sampling_bulk_IC(bd,n_bulk,n_IC):
    X=np.concatenate((cs(bd[0],n_IC),rs(bd[2],bd[3],n_IC),rs(bd[4],bd[5],n_IC)),axis=1)
    X=np.concatenate((X,np.concatenate((rs(bd[0],bd[1],n_bulk),rs(bd[2],bd[3],n_bulk),rs(bd[4],bd[5],n_bulk)),axis=1)),axis=0)
    np.random.shuffle(X)
    X=tf.constant(X,dtype=tf.float32)
    t,x,y=tf.split(X,3,axis=1)
    return X, t, x, y

def t0tLsample(bd,eom='wave'):
    t=np.concatenate((cs(bd[0],1),cs(bd[1],1)),axis=0)
    t=tf.constant(t,dtype=tf.float32)
    return t

def random_sampling_board(bd,n_bo):
    rs1=[rs(bd[0],bd[1],n_bo),rs(bd[4],bd[5],n_bo)]
    rs2=[rs(bd[0],bd[1],n_bo),rs(bd[2],bd[3],n_bo)]
    X=np.concatenate((rs1[0],cs(bd[2],n_bo),rs1[1]),axis=1)
    X=np.concatenate((X,np.concatenate((rs1[0],cs(bd[3],n_bo),rs1[1]),axis=1)),axis=0)
    X=np.concatenate((X,np.concatenate((rs2[0],rs2[1],cs(bd[4],n_bo)),axis=1)),axis=0)
    X=np.concatenate((X,np.concatenate((rs2[0],rs2[1],cs(bd[5],n_bo)),axis=1)),axis=0)
    X=tf.constant(X,dtype=tf.float32)
    t,x,y=tf.split(X,3,axis=1)
    I_x0=tf.cast(x==float(bd[2]),dtype=tf.float32)
    I_xL=tf.cast(x==float(bd[3]),dtype=tf.float32)
    I_y0=tf.cast(y==float(bd[4]),dtype=tf.float32)
    I_yL=tf.cast(y==float(bd[5]),dtype=tf.float32)
    return t, x, y, I_x0, I_xL, I_y0, I_yL

def counters(ni,form='rect'):
    if ni==3:
        def count(bd,t,x,y):
        #setting counters for IC and boundary conditions
            I_t0=tf.cast(t==float(bd[0]),dtype=tf.float32)
            I_tL=tf.cast(t==float(bd[1]),dtype=tf.float32)
            I_x0=tf.cast(x==float(bd[2]),dtype=tf.float32)
            I_xL=tf.cast(x==float(bd[3]),dtype=tf.float32)
            I_y0=tf.cast(y==float(bd[4]),dtype=tf.float32)
            I_yL=tf.cast(y==float(bd[5]),dtype=tf.float32)
            I_bulk=tf.cast(1.-I_x0-I_xL-I_y0-I_yL-I_t0-I_tL,dtype=tf.float32)
            return tf.concat([I_t0,I_tL,I_x0,I_xL,I_y0,I_yL,I_bulk],axis=1)
        return count
    
    elif ni==2 and form=='rect':
        def count(bd,t,x):
        #setting counters for IC and boundary conditions
            I_t0=tf.cast(t==float(bd[0]),dtype=tf.float32)
            I_tL=tf.cast(t==float(bd[1]),dtype=tf.float32)
            I_x0=tf.cast(x==float(bd[2]),dtype=tf.float32)
            I_xL=tf.cast(x==float(bd[3]),dtype=tf.float32)
            I_bulk=tf.cast(1.-I_x0-I_xL-I_t0-I_tL,dtype=tf.float32)
            return tf.concat([I_t0,I_tL,I_x0,I_xL,I_bulk],axis=1)
        return count
    elif ni==2 and form=='disk':
        def count(bd,t,x):
            Ic_board=tf.cast((t**2+x**2<bd[3]+1e-4),dtype=tf.float32)*tf.cast((t**2+x**2>bd[3]-1e-4) ,dtype=tf.float32)
            Ic_bulk=1.-Ic_board
            return tf.concat([Ic_board,Ic_bulk],axis=1)
        return count 
    elif ni==1:
        def count(bd,t):
        #setting counters for IC and boundary conditions
            I_t0=tf.cast(t==float(bd[0]),dtype=tf.float32)
            I_tL=tf.cast(t==float(bd[1]),dtype=tf.float32)
            I_bulk=tf.cast(1.-I_t0-I_tL,dtype=tf.float32)
            return tf.concat(I_bulk,axis=1)
        return count


############################################################################
class Print_Loss_Every_so_many_Epochs(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 10 == 0:
            print("Epoch {}:\t   log10 loss: {:7.2f} ".format(epoch, np.log10(logs["loss"])))
            
class Print_Loss_Every_so_many_Epochs_1D(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if epoch % 100 == 0:
            print("Epoch {}:\t   log10 loss: {:7.2f} ".format(epoch, np.log10(logs["loss"])))
            
############################################################################
#ANIMATION
def update_plot_3to1(frame_number, points_u, plot):
    plot[0].remove()
    plot[0] = ax.plot_surface(x_p, y_p, points_u[:,:,frame_number], cmap="magma")

def update_plot_3to2(frame_number, points_ux, points_uy, plot):
    plot[0].remove()
    M = np.hypot(points_ux[:,:,frame_number],points_uy[:,:,frame_number])
    #plot[0] = ax.quiver(x_p, y_p, points_ux[:,:,frame_number],points_uy[:,:,frame_number], units='width')
    plot[0] = ax.quiver(x_p, y_p, points_ux[:,:,frame_number],points_uy[:,:,frame_number], [M], units='width',scale=1 / 0.15)
    
############################################################################
#Dictionaries to be used in plots

## Equation dictionary name
def eqname_gen(ni,no):
    if ni==3 and no==1:
        eqnname_dict={'waveb': 'Wave equation (1)',
                  'wave':  'Wave equation (2)',
                  'wavet': 'Traveling wave',
                  'poisson1': 'Poisson equation (1)',
                  'poisson2': 'Poisson equation (2)',
                  'poisson3': 'Poisson equation (3)',
                  'heat1': 'Heat equation (1)',
                  'heat2': 'Heat equation (2)'}
    if ni==3 and no==2:
        eqnname_dict={'NS': 'Navier Stokes',
                      'GT': 'Taylor-Green vortex',
                      'LO': 'Lamb-Oseen vortex'}
    if ni==2:
        eqnname_dict={'wave': 'Wave equation',
                      'waveN': 'Wave equation - Neumann BCs',
                      'twave': 'Traveling wave',
                      'par': 'Parabolic equation',
                      'poisson': 'Poisson equation 1',
                      'heat': 'Heat equation',
                      'heat2f': 'Heat equation with 2 frequencies',
                      'heat0': 'Heat equation - Neumann BCs',
                      'poisson1': 'Poisson equation 2',
                      'AdvectionDiffusion': 'Advection-Diffusion equation',
                      'Burgers': 'Burgers equation',
                      'poisson_par': 'Poisson equation 1'}
    if ni==1:
        eqnname_dict={'mat': 'Mathieu Equation', 
                      'exp': 'Decaying Exponential',
                      'wave': 'Harmonic Oscillator',
                      'dho': 'Damped Harmonic Oscillator',
                      'linear': 'Linear',
                      'oscillon': 'Profile Oscillon 1D',
                      'stiff': 'Stiff Equation',
                      'gaussian':'Gaussian',
                      'delay':'Delay Equation',
                      'twofreq': 'Two Frequencies'}
    return eqnname_dict

def eqname_analytic(ni,no):
    if ni==3 and no==1:
        analytic=['waveb','wave','wavet','poisson1','poisson2','poisson3','heat1','heat2']
    if ni==3 and no==2:
        analytic=['GT','LO']
    if ni==2:
        analytic=['wave', 'waveN','twave','heat','heat2f','heat0','par','poisson1','poisson','AdvectionDiffusion'] 
    if ni==1:
        analytic=['oscillon', 'wave', 'exp', 'linear', 'gaussian', 'dho','stiff','twofreq']

    return analytic

def dictgen(ni,no):
    if ni==3 and no==1:
        ## PDE
        #The form needs to be [ eq1 , eq2 ] 
        eomdict={'waveb':'[dduddt-1.**2*(dduddx+dduddy)]',
                 'wave': '[dduddt-1.**2*(dduddx+dduddy)]',
                 'wavet':'[dudt-(1./5.)*(dudx+dudy)]',
                 'poisson1':'[dduddt+dduddx+dduddy+(3.*pi**2)*tf.sin(pi*t)*tf.sin(pi*x)*tf.sin(pi*y)]',
                 'poisson2':'[dduddt+dduddx+dduddy-6.]',
                 'poisson3':'[dduddt+dduddx+dduddy-2.]',
                 'heat1':'[dudt-(dduddx+dduddy)]',
                 'heat2':'[dudt-(dduddx+dduddy)]'} 

        #Boundary conditions in t belonging to [t0,tL].
        #The form needs to be [ [ u - u(t0,x,y),dudt - dudt(t0,x,y)] , [u - u(tL,x,y),dudt - dudt(tL,x,y)] ] 
        #If no condition needs to be provided on some boundary just replace (u - u(t,x,y0) ) with  [none].
        ICdict={'waveb':[ '[ (u-tf.sin(freq_x *x)*tf.sin(freq_y *y)) , (dudt-0.) ]' , '[none]' ],
                'wave': [ '[ (u-tf.sin(3.*freq_x *x)*tf.sin(4.*freq_y *y)) , (dudt-0.)]' , '[none]' ],
                'wavet':[ '[ (u-tf.sin(3.*pi *x+2.*pi *y)) ]' , '[none]' ],
                'poisson1':[ '[u-0.]' , '[u-0.]' ],
                'poisson2':[ '[u-(x**2+y**2)]' , '[u-(x**2+y**2+float(bd[1])**2)]' ],
                'poisson3':[ '[u-(x**2-y**2)]' ,'[u-(x**2-y**2+float(bd[1])**2)]'  ],
                'heat1':[ '[u-tf.exp(x+y)]' , '[none]' ],
                'heat2':[ '[u-(1.-y)*tf.exp(x)]' , '[none]' ]}

        #Boundary conditions in x belonging to [x0,xL].
        #The form needs to be [ [u - u(t,x0,y),dudx - dudx(t,x0,y)] , [u - u(t,xL,y),dudx - dudx(t,xL,y)] ]   
        #If no condition needs to be provided on some boundary just replace (u - u(t,x,y0) ) with  [none].
        boardx={'waveb':[ '[u-0.]' , '[u-0.]' ],
               'wave':  [ '[u-0.]' , '[u-0.]' ],
               'wavet': [ '[u-tf.sin(pi *t+2.*pi *y)]' , '[u+tf.sin(pi *t+2.*pi *y)]' ],
               'poisson1':[ '[u-0.]' , '[u-0.]' ],
               'poisson2':[ '[u-(t**2+y**2)]' , '[u-(t**2+y**2+float(bd[3])**2)]' ],
               'poisson3':[ '[u-(t**2-y**2)]' , '[u-(float(bd[3])**2+t**2-y**2)]' ],
               'heat1':[ '[u-tf.exp(2.*t+y)]' , '[u-tf.exp(2.*t+y+float(bd[3]))]' ],
               'heat2':[ '[u-(1.-y)*tf.exp(t)]' , '[u-(1.-y)*tf.exp(t+float(bd[3]))]' ]}

        #Boundary conditions in y belonging to [y0,yL].
        #The form needs to be [ [u - u(t,x,y0),dudy - dudy(t,x,y0)] , [u - u(t,x,yL),dudy - dudy(t,x,yL)] ] 
        #If no condition needs to be provided on some boundary just replace (u - u(t,x,y0) ) with  [none].
        boardy={'waveb':[ '[u-0.]' , '[u-0.]' ],
               'wave':  [ '[u-0.]' , '[u-0.]' ],
               'wavet': [ '[u-tf.sin(pi *t+3.*pi *x)]' , '[u-tf.sin(pi *t+3.*pi *x)]' ],
               'poisson1':[ '[u-0.]' , '[u-0.]' ],
               'poisson2':[ '[u-(t**2+x**2)]' , '[u-(t**2+x**2+float(bd[5])**2)]' ],
               'poisson3':[ '[u-(x**2+t**2)]' , '[u-(x**2+t**2-float(bd[5])**2)]' ],
               'heat1':[ '[u-tf.exp(2.*t+x)]' , '[u-tf.exp(2.*t+x+float(bd[5]))]' ],
               'heat2':[ '[u-tf.exp(x+t)]' , '[u-(1.-float(bd[5]))*tf.exp(x+t)]' ]}

        #weights to be used in loss function
        #the form needs to be [w_bulk,w_IC,w_board]
        wdict={'waveb':[1.,10.,1.],
               'wave': [1.,10.,1.],
               'wavet':[1.,1.,1.],
               'poisson1':[1.,1.,1.],
               'poisson2':[1.,1.,1.],
               'poisson3':[1.,1.,1.],
               'heat1':[1.,1.,10.],
               'heat2':[1.,1.,10.]}
        
        dicttot=[eomdict,ICdict,boardx,boardy,wdict]
        
        
        
    if ni==3 and no==2:
        
        eomdict={'GT':'[dudt+u*dudx+v*dudy+0.5*tf.exp(-4.*t)*tf.sin(2*x)-(dduddx+dduddy),(dvdt+u*dvdx+v*dvdy+0.5*tf.exp(-4.*t)*tf.sin(2*y)-(ddvddx+ddvddy)),(dudx+dvdy)]',
                 'LO':'[(dwdt+u*dwdx+v*dwdy-(ddwddx+ddwddy)),(dudx+dvdy)]',
                 'NS':'[(dwdt+u*dwdx+v*dwdy-5.*1e-3*(ddwddx+ddwddy)-0.75*(tf.sin(2.*pi*(x+y))+tf.cos(2.*pi*(x+y)))),(dudx+dvdy)]'} 


        ICdict={'GT':[ '[(u-tf.cos(x)*tf.sin(y)),(v+tf.cos(y)*tf.sin(x))]' , '[none]' ],
                'LO':[ '[w-tf.exp(-(x**2+y**2)/(4.*bd[0]))/(4.*pi*bd[0])]' , '[none]' ],
                'NS':[ '[(w-pi*(tf.cos(3.*freq_x *x)-tf.cos(3.*freq_y *y)))]' , '[none]' ]}

        #continuity
        boardx={'GT':[ '[(u-tf.exp(-2.*t)*tf.sin(y)),(v-0.)]' , '[(u-tf.exp(-2.*t)*tf.sin(y)),(v-0.)]' ],
                'LO':[ '[(u+(1.-tf.exp(-(bd[2]**2+y**2)/(4.*t)))*y/(2.*pi*(bd[2]**2+y**2))),(v-(1.-tf.exp(-(bd[2]**2+y**2)/(4.*t)))*bd[2]/(2.*pi*(bd[2]**2+y**2)))]' , '[(u+(1.-tf.exp(-(bd[3]**2+y**2)/(4.*t)))*y/(2.*pi*(bd[3]**2+y**2))),(v-(1.-tf.exp(-(bd[3]**2+y**2)/(4.*t)))*bd[3]/(2.*pi*(bd[3]**2+y**2)))]' ],
                'NS':[ '[(tf.boolean_mask(u, I_x0)-tf.boolean_mask(u, I_xL)),(tf.boolean_mask(v, I_x0)-tf.boolean_mask(v, I_xL))]' , '[none]' ]}

        boardy={'GT':[ '[(u-0.),(v+tf.exp(-2.*t)*tf.sin(x))]' , '[(u-0.),(v+tf.exp(-2.*t)*tf.sin(x))]' ],
                'LO':[ '[(u+(1.-tf.exp(-(x**2+bd[4]**2)/(4.*t)))*bd[4]/(2.*pi*(x**2+bd[4]**2))),(v-(1.-tf.exp(-(x**2+bd[4]**2)/(4.*t)))*x/(2.*pi*(x**2+bd[4]**2)))]' , '[(u+(1.-tf.exp(-(x**2+bd[5]**2)/(4.*t)))*bd[5]/(2.*pi*(x**2+bd[5]**2))),(v-(1.-tf.exp(-(x**2+bd[5]**2)/(4.*t)))*x/(2.*pi*(x**2+bd[5]**2)))]' ],
                'NS':[ '[(tf.boolean_mask(u, I_y0)-tf.boolean_mask(u, I_yL)),(tf.boolean_mask(v, I_y0)-tf.boolean_mask(v, I_yL))]' , '[none]' ]}


        wdict={'GT':[1.,10.,1.],
               'LO':[1.,1.,1.],
               'NS':[1.,1.,1.]}
        
        dicttot=[eomdict,ICdict,boardx,boardy,wdict]


    if ni==2:        
                    
        eomdict={'wave':'[dduddt-1.**2*dduddx]',
                 'waveN':'[dduddt-1.**2*dduddx]',
                 'twave':'[dudt-1.**2*dudx]',
                 'heat':'[dudt-0.05*dduddx]',
                 'heat2f':'[dudt-0.01*dduddx]',
                 'heat0':'[dudt-0.05*dduddx]',
                 'poisson': '[dduddt+dduddx + 2. * np.pi**2. * tf.sin(np.pi * t) * tf.sin(np.pi * x)]',
                 'poisson1':'[dduddt+dduddx-10.*(t-1)*tf.math.cos(5.*x)+25.*(t-1)*(x-1)*tf.math.sin(5.*x)]',
                 'AdvectionDiffusion': '[dudt - 0.25 * dduddx]',
                 'Burgers': '[dudt + u * dudx - 0.25 * dduddx]',
                 'poisson_par': '[(dduddt+dduddx-tf.exp(-(t**2+10.*x**2)))]',
                 'par':'[(dduddt+dduddx-4.)]'}
        
        ICdict={'wave':  [ '[(u-tf.math.sin(3.*freq_x*x)),(dudt-0.)]' , '[none]' ],
                'waveN': ['[(u-tf.math.cos(3.*freq_x*x)),(dudt-0.)]' , '[none]' ],
                'twave': ['[(u-tf.math.sin(2.*freq_x*x))]' , '[none]' ],
                'heat':  [ '[(u-1.*tf.math.sin(3.*freq_x*x))]', '[none]' ],
                'heat2f':[ '[(u-2.*tf.math.sin(9.*freq_x*x)+0.3*tf.math.sin(4.*freq_x*x))]', '[none]' ],
                'heat0': [ '[(u-1.*tf.math.cos(3.*freq_x*x))]', '[none]' ],
                'poisson':[ '[(u-0.)]', '[u-0.]' ],
                'poisson1':[ '[(u-(1.-x)*tf.math.sin(5.*x))]', '[(u-0.)]' ],
                'AdvectionDiffusion':[ '[(u -0.25*tf.sin(freq_x*x))]', '[none]' ],
                'Burgers': [ '[(u - x * (1. - x))]', '[none]' ],
                'poisson_par':[ '[none]' , '[none]' ],
                'par':[ '[none]' , '[none]' ]}
            
        boardx={'wave':[ '[(u-0.)]', '[(u-0.)]' ],
                  'waveN':[ '[(dudx-0.)]', '[(dudx-0.)]' ],
                  'twave':[ '[(u-tf.math.sin(2.*freq_t*t))]', '[(u-tf.math.sin(2.*freq_t*t))]' ],
                  'heat':[ '[(u-0.)]', '[(u-0.)]' ],
                  'heat2f':[ '[(u-0.)]', '[(u-0.)]' ],
                  'heat0':[ '[(dudx-0.)]', '[(dudx-0.)]' ],
                  'heat2f':[ '[(u-0.)]', '[(u-0.)]' ],
                  'poisson':[ '[(u-0.)]', '[(u-0.)]' ],
                  'poisson1':[ '[(u-0.)]', '[(u-0.)]' ],
                  'AdvectionDiffusion': [ '[(u-0.)]', '[(u-0.)]' ],
                  'Burgers': [ '[(u-0.)]', '[(u-0.)]' ],
                  'poisson_par':[ '[(u-0.)]' , '[none]' ],
                  'par':[ '[(u-1.)]' , '[none]' ]}
            
        wdict={'wave':[1.,10.,1.],
               'waveN':[1.,10.,1.],
               'twave':[1.,10.,1.],
               'heat':[1.,10.,1.],
               'heat2f':[1.,10.,1.],
               'heat0':[1.,10.,1.],
               'heat2f':[1.,10.,1.],
               'poisson':[1.,10.,1.],
               'poisson1':[1.,10.,1.],
               'AdvectionDiffusion':[1.,10.,1.],
               'Burgers':[1.,10.,1.],
               'poisson_par':[1.,10.,1.],
               'par':[1.,10.,1.]}
        
        dicttot=[eomdict,ICdict,boardx,wdict]
    
    if ni==1:        
                             
        eomdict={'oscillon':'[dduddt-0.5**2*u+2.*u**3.]',
                 'mat':'[dduddt+(1.-0.4*tf.cos(2.*t))*u]',
                 'exp':'[dudt+0.5*u]',
                 'wave':'[dduddt+25.*u]',
                 'dho':'[dduddt+0.5*dudt+25.*u]',
                 'linear':'[dudt-1.]',
                 'delay':'[dudt-0.5*u+1.*du]',
                 'gaussian':'[dudt+0.2*t*u]',
                 'stiff':'[dudt+21.*u-tf.exp(-t)]',
                 'twofreq':'[dduddt+u+2.*tf.cos(5.*t)+6.*tf.sin(10.*t)]'}
        
        
        ICdict={'oscillon':['[tf.where(u<0.5*u0guess,2*(u-u0guess),0.), dudt-0.]','[dudt+0.5*u]'],
                'mat':['[u-1.0,dudt-0.]','[none]'],
                'exp':['[u-1.0]','[none]'],
                'wave':['[u-1.0,dudt-0.]','[none]'],
                'dho':['[u-1.0,dudt-0.]','[none]'],
                'linear':['[u-1.0]','[none]'],
                'delay':['[u-1.0]','[none]'],
                'gaussian':['[u-1.0]','[none]'],
                'stiff':['[u-1.0]','[none]'],
                'twofreq':['[u-1.0,dudt-0.]','[none]']}
    
        
        dicttot=[eomdict,ICdict]
        
    return dicttot



###########################################
## Define analytical solution to be compared with NN output (where possible) and name dictionaries to be used in plots
## Analytic solutions 
def sol_analytic(ni,no):
    if ni==3 and no==1:
        def sol_3to1(eom,t,x,y):
            if eom=='waveb':
                return tf.sin(freq_x *x)*tf.sin(freq_y *y)*tf.cos(tf.sqrt(2.)*pi*t)
            if eom=='wave':
                return tf.sin(3.*freq_x *x)*tf.sin(4.*freq_y *y)*tf.cos(5.*pi*t)
            if eom=='wavet' or eom=='wavet2' :
                return tf.sin(3.*pi*x+2.*pi*y+pi*t)
            if eom=='poisson1':
                return tf.sin(pi*t)*tf.sin(pi*x)*tf.sin(pi*y)
            if eom=='poisson2':
                return t**2+x**2+y**2
            if eom=='poisson3':
                return t**2+x**2-y**2
            if eom=='heat1':
                return tf.exp(x+y+2.*t)
            if eom=='heat2':
                return (1.-y)*tf.exp(x+t)
        return sol_3to1
    
    if ni==3 and no==2:
        def sol_3to2(eom,t,x,y):
            if eom=='GT':
                u=tf.cos(x)*tf.sin(y)*tf.exp(-2.*t)
                v=-tf.sin(x)*tf.cos(y)*tf.exp(-2.*t)
                return u,v
            if eom=='LO':
                r2=(x**2+y**2)
                u=-y/(2.*pi*r2)*(1.-tf.exp(-r2/(4.*t)))
                v=x/(2.*pi*r2)*(1.-tf.exp(-r2/(4.*t)))
                return u,v
        return sol_3to2
    
    if ni==2:
        def sol_2to1(eom,t,x):
            if eom=='wave':
                return tf.cos(3.*freq_t* t)*tf.sin(3.*freq_x* x)
            if eom=='waveN':
                return tf.cos(3.*freq_t* t)*tf.cos(3.*freq_x* x)
            if eom=='twave':
                return tf.sin(2.*freq_t* (t+x))
            if eom=='heat':
                return tf.sin(3.*freq_x* x)*tf.exp(-0.05*(3.*pi)**2*t)
            if eom=='heat2f':
                return 2.*tf.math.sin(9.*pi*x)*tf.exp(-0.01*(9.*pi)**2*t)-0.3*tf.math.sin(4.*pi*x)*tf.exp(-0.01*(4.*pi)**2*t)
            if eom=='heat0':
                return tf.cos(3.*freq_x* x)*tf.exp(-0.05*(3.*pi)**2*t)
            if eom=='par':
                return x**2+t**2
            if eom=='poisson1':
                return (1-t)*(1-x)*tf.sin(5.*x)
            if eom=='poisson':
                return np.sin(np.pi * x) * np.sin(np.pi * t)
            if eom=='AdvectionDiffusion':
                return 0.25 * tf.exp(-0.25 * np.pi**2. * t) * tf.sin(np.pi * x)
        return sol_2to1
    
    if ni==1:
        def sol_1to1(eom,t):
            if eom=='oscillon':
                u=0.5/tf.cosh(0.5*t)
            elif eom=='wave':
                u=tf.cos(5.*t)
            elif eom=='exp':
                u=tf.exp(-0.5*t)
            elif eom=='linear':
                u=1+t
            elif eom=='gaussian':
                u=tf.exp(-0.1*t**2)
            elif eom=='dho':
                frq=tf.sqrt(25.-0.5**2/4)
                u=tf.exp(-0.5/2.*t)*(tf.cos(frq*t)+0.5/(2.*frq)*tf.sin(frq*t))
            elif eom=='stiff':
                u=1./20.*(tf.exp(-1.*t)+19.*tf.exp(-21.*t))
            elif eom=='twofreq':
                u=1/132.*(121.*tf.cos(t)+11.*tf.cos(5.*t)-80.*tf.sin(t)+8.*tf.sin(10.*t))
            return u
        return sol_1to1

    
      
####################################################################
def points_plt_mse(ni,no,bd,model,sol,domain='rect'):
    if ni==3 and no==1:
        ##DATA AND RMSE
        t_range=np.linspace(float(bd[0]),float(bd[1]),t_frames)
        points_x=np.reshape(np.linspace(float(bd[2]), float(bd[3]), num=nplot).flatten(),(nplot,1))
        points_y=np.reshape(np.linspace(float(bd[4]), float(bd[5]), num=nplot).flatten(),(nplot,1))
        y_p,x_p=np.meshgrid(points_y,points_x)
        points_u=np.zeros((nplot,nplot,t_frames))
        mse=0.
        count=0.
        for i in range(t_frames):
            tp=tf.ones_like(y_p)*t_range[i]
            data_t_plot=tf.reshape(tf.cast(tp,dtype=tf.float32),(nplot**2,1))
            data_x_plot=tf.reshape(tf.cast(x_p,dtype=tf.float32),(nplot**2,1))
            data_y_plot=tf.reshape(tf.cast(y_p,dtype=tf.float32),(nplot**2,1))
            pred=model.predict([data_t_plot,data_x_plot,data_y_plot])
            if eom in analytic:
                true=sol(eom,data_t_plot,data_x_plot,data_y_plot)
                mse=mse+tf.reduce_sum((true-pred)**2)
            points_u[:,:,i]=(np.reshape(pred,(nplot,nplot)))
        mse=np.sqrt(mse/(t_frames*nplot**2)) 
        
        return points_u, mse, x_p, y_p
    
    elif ni==3 and no==2:
        t_range=np.linspace(float(bd[0]),float(bd[1]),t_frames)
        points_x=np.reshape(np.linspace(float(bd[2]), float(bd[3]), num=nplot).flatten(),(nplot,1))
        points_y=np.reshape(np.linspace(float(bd[4]), float(bd[5]), num=nplot).flatten(),(nplot,1))
        y_p,x_p=np.meshgrid(points_y,points_x)
        points_u=np.zeros((nplot,nplot,t_frames))
        points_v=np.zeros((nplot,nplot,t_frames))
        mse=0.

        for i in range(t_frames):
            tp=tf.ones_like(y_p)*t_range[i]
            data_t_plot=tf.reshape(tf.cast(tp,dtype=tf.float32),(nplot**2,1))
            data_x_plot=tf.reshape(tf.cast(x_p,dtype=tf.float32),(nplot**2,1))
            data_y_plot=tf.reshape(tf.cast(y_p,dtype=tf.float32),(nplot**2,1))
            pred_u, pred_v=model.predict([data_t_plot,data_x_plot,data_y_plot])
            if eom in analytic:
                true_u,true_v=sol(eom,data_t_plot,data_x_plot,data_y_plot)
                mse+=tf.reduce_sum((true_u-pred_u)**2)+tf.reduce_sum((true_v-pred_v)**2)
            points_u[:,:,i]=(np.reshape(pred_u,(nplot,nplot)))
            points_v[:,:,i]=(np.reshape(pred_v,(nplot,nplot)))
        mse=np.sqrt(mse/(2*t_frames*nplot**2)) 
        return points_u, points_v, mse, x_p, y_p
    
    elif ni==2 and domain=='disk':
        R = np.linspace(0., bd[3], n_plt)
        theta = np.linspace(0,  2*pi, n_plt)
        tplt = tf.reshape(tf.cast(np.outer(R, np.cos(theta)),dtype=tf.float32),(n_plt**2,1))
        xplt = tf.reshape(tf.cast(np.outer(R, np.sin(theta)),dtype=tf.float32),(n_plt**2,1))
        pred=model.predict([tplt,xplt])
        mse=0.
        if eom in analytic:
            true=sol(eom,tplt,xplt)
            mse+=np.sqrt(tf.reduce_mean((true-pred)**2))
        return pred, mse, tplt, xplt
    
    elif ni==2 and domain=='rect':
        T= np.linspace(bd[0], bd[1], n_plt)
        X= np.linspace(bd[2], bd[3], n_plt)
        tplt,xplt=np.meshgrid(T,X)
        tplt=tf.reshape(tf.cast(tplt,dtype=tf.float32),(n_plt**2,1))
        xplt=tf.reshape(tf.cast(xplt,dtype=tf.float32),(n_plt**2,1))
        pred=model.predict([tplt,xplt])
        mse=0.
        if eom in analytic:
            true=sol(eom,tplt,xplt)
            mse+=np.sqrt(tf.reduce_mean((true-pred)**2))
        return pred, mse, tplt, xplt
    
    elif ni==1:
        n_plt=20000
        tplt0= np.linspace(bd[0], bd[1], n_plt)
        tplt=tf.reshape(tf.cast(tplt0,dtype=tf.float32),(n_plt,1))
        pred=model.predict(tplt)
        rmse=0.
        if eom in analytic:
            true=sol(eom,tplt)
            rmse+=np.sqrt(tf.reduce_mean((true-pred)**2))
        elif eom=='mat':
            def model_mat(u,t):
                return [u[1], (- 1. + 0.4 * np.cos(2. * t)) * u[0]]
            #true=odeint(model_mat, [1.,0.], tplt0)[:,0]
            #rmse+=np.sqrt(tf.reduce_mean((true-pred)**2))
            true=odeint(model_mat, [1.,0.], tplt0)
            rmse+=np.sqrt(tf.reduce_mean((true[:,0]-pred[:,0])**2))
        elif eom=='delay':
            def model_delay(u, t):
                return 0.5*u(t) - u(t-1.0)
            def model_linear(u, t): 
                return 1.
            def values_before_zero(t):
                return 1-t
            true_repl = 0*odeint(model_linear, [1.0], tplt0)
            true_repl[:,0]=ddeint(model_delay, values_before_zero, tplt0)
            #true=true_repl[:,0]
            #mse+=np.sqrt(tf.reduce_mean((true-pred)**2))
            rmse+=np.sqrt(tf.reduce_mean((true_repl[:,0]-pred[:,0])**2))
            
        return pred, rmse, tplt
        
        
    
    
#####################################################################################
def to_loss_3to1(eom='wave'):
    
    @tf.function
    def loss_cube(yt,yp):
        
        t,x,y,I_t0,I_tL,I_x0,I_xL,I_y0,I_yL,I_bulk=tf.split(yt,10,axis=1)
        n_bulk,n_IC,n_tL,n_board=tf.reduce_sum(I_bulk)+1e-10,tf.reduce_sum(I_t0)+1e-10,tf.reduce_sum(I_tL)+1e-10,tf.reduce_sum(I_x0+I_xL+I_y0+I_yL)+1e-10
        with tf.GradientTape(persistent=True) as tape_2:
            tape_2.watch([t,x,y])
            with tf.GradientTape() as tape_1:
                tape_1.watch([t,x,y]) 
                u=model([t,x,y],training=True)
            dudt, dudx, dudy=tape_1.gradient(u,[t,x,y])
        dduddy=tape_2.gradient(dudy,y)
        dduddx=tape_2.gradient(dudx,x)            
        dduddt=tape_2.gradient(dudt,t) 
        del tape_2

        smse_bulk=tf.sqrt(tf.reduce_sum(sum([element**2 * I_bulk for element in eval(eomdict[eom])]))/n_bulk +1e-20)
        mset0=tf.reduce_sum(sum([element**2 * I_t0 for element in eval(ICdict[eom][0])]))
        msetL=tf.reduce_sum(sum([element**2 * I_tL for element in eval(ICdict[eom][1])]))
        msex0=tf.reduce_sum(sum([element**2 * I_x0 for element in eval(boardx[eom][0])]))
        msexL=tf.reduce_sum(sum([element**2 * I_xL for element in eval(boardx[eom][1])]))
        msey0=tf.reduce_sum(sum([element**2 * I_y0 for element in eval(boardy[eom][0])]))
        mseyL=tf.reduce_sum(sum([element**2 * I_yL for element in eval(boardy[eom][1])]))
        smse_IC=tf.sqrt(tf.reduce_sum(mset0+msetL)/(n_IC+n_tL) +1e-20)
        smse_board=tf.sqrt(tf.reduce_sum(msex0+msexL+msey0+mseyL)/(n_board) +1e-20)
        
        wbu,wIC,wbo=wdict[eom]
        
        loss_value=wbu*smse_bulk+wIC*smse_IC+wbo*smse_board
        
        return loss_value, smse_bulk, smse_IC, smse_board

  
    return loss_cube


#########################################################################################
def to_loss_3to2(eom='GT'):
    
    if eom=='GT':
        @tf.function
        def loss_cube(yt,yp):
            t,x,y,I_t0,I_tL,I_x0,I_xL,I_y0,I_yL,I_bulk=tf.split(yt,10,axis=1)
            n_bulk,n_IC,n_tL,n_board=tf.reduce_sum(I_bulk)+1e-10,tf.reduce_sum(I_t0)+1e-10,tf.reduce_sum(I_tL)+1e-10,tf.reduce_sum(I_x0+I_xL+I_y0+I_yL)+1e-10
            with tf.GradientTape(persistent=True) as tape2:
                tape2.watch([t,x,y])
                with tf.GradientTape(persistent=True) as tape:
                    tape.watch([t,x,y])
                    u, v =model([t,x,y],training=True)
                dudt, dudx, dudy = tape.gradient(u,[t,x,y])
                dvdt, dvdx, dvdy = tape.gradient(v,[t,x,y])
            dduddx =tape2.gradient(dudx,x)
            dduddy =tape2.gradient(dudy,y)
            ddvddx =tape2.gradient(dvdx,x)
            ddvddy =tape2.gradient(dvdy,y)
            del tape
            del tape2

            smse_bulk=tf.sqrt(tf.reduce_sum(sum([element**2 * I_bulk for element in eval(eomdict[eom])]))/n_bulk +1e-20)
            mset0=tf.reduce_sum(sum([element**2 * I_t0 for element in eval(ICdict[eom][0])]))
            msetL=tf.reduce_sum(sum([element**2 * I_tL for element in eval(ICdict[eom][1])]))            
            msex0=tf.reduce_sum(sum([element**2 * I_x0 for element in eval(boardx[eom][0])]))
            msexL=tf.reduce_sum(sum([element**2 * I_xL for element in eval(boardx[eom][1])]))
            msey0=tf.reduce_sum(sum([element**2 * I_y0 for element in eval(boardy[eom][0])]))
            mseyL=tf.reduce_sum(sum([element**2 * I_yL for element in eval(boardy[eom][1])]))
            smse_IC=tf.sqrt(tf.reduce_sum(mset0+msetL)/(n_IC+n_tL) +1e-20)
            smse_board=tf.sqrt(tf.reduce_sum(msex0+msexL+msey0+mseyL)/(n_board) +1e-20)
            wbu,wIC,wbo=wdict[eom]
            loss_value=wbu*smse_bulk+wIC*smse_IC+wbo*smse_board

            return loss_value, smse_bulk, smse_IC, smse_board

        
    elif eom=='LO':
        @tf.function
        def loss_cube(yt,yp):
            
            t,x,y,I_t0,I_tL,I_x0,I_xL,I_y0,I_yL,I_bulk=tf.split(yt,10,axis=1)
            n_bulk,n_IC,n_tL,n_board=tf.reduce_sum(I_bulk)+1e-10,tf.reduce_sum(I_t0)+1e-10,tf.reduce_sum(I_tL)+1e-10,tf.reduce_sum(I_x0+I_xL+I_y0+I_yL)+1e-10
            with tf.GradientTape(persistent=True) as tape3:
                tape3.watch([t,x,y])
                with tf.GradientTape() as tape2:
                    tape2.watch([t,x,y])
                    with tf.GradientTape(persistent=True) as tape:
                        tape.watch([t,x,y])
                        u, v =model([t,x,y])
                    dudx, dudy = tape.gradient(u,[x,y])
                    dvdx, dvdy = tape.gradient(v,[x,y])
                    w = dvdx - dudy
                dwdt, dwdx, dwdy=tape2.gradient(w,[t,x,y])
            ddwddx=tape3.gradient(dwdx,x)
            ddwddy=tape3.gradient(dwdy,y)
            del tape
            del tape3
            
            smse_bulk=tf.sqrt(tf.reduce_sum(sum([element**2 * I_bulk for element in eval(eomdict[eom])]))/n_bulk +1e-20)
            mset0=tf.reduce_sum(sum([element**2 * I_t0 for element in eval(ICdict[eom][0])]))
            msetL=tf.reduce_sum(sum([element**2 * I_tL for element in eval(ICdict[eom][1])]))            
            msex0=tf.reduce_sum(sum([element**2 * I_x0 for element in eval(boardx[eom][0])]))
            msexL=tf.reduce_sum(sum([element**2 * I_xL for element in eval(boardx[eom][1])]))
            msey0=tf.reduce_sum(sum([element**2 * I_y0 for element in eval(boardy[eom][0])]))
            mseyL=tf.reduce_sum(sum([element**2 * I_yL for element in eval(boardy[eom][1])]))
            smse_IC=tf.sqrt(tf.reduce_sum(mset0+msetL)/(n_IC+n_tL) +1e-20)
            smse_board=tf.sqrt(tf.reduce_sum(msex0+msexL+msey0+mseyL)/(n_board) +1e-20)
            wbu,wIC,wbo=wdict[eom]            
            loss_value=wbu*smse_bulk+wIC*smse_IC+wbo*smse_board

            return loss_value, smse_bulk, smse_IC, smse_board
        
            
    elif eom=='NS':
        @tf.function
        def loss_cube(yt,yp):
            
            t,x,y,I_t0,I_tL,I_bulk=tf.split(yt,6,axis=1)
            n_bulk,n_IC,n_tL=tf.reduce_sum(I_bulk)+1e-10,tf.reduce_sum(I_t0)+1e-10,tf.reduce_sum(I_tL)+1e-10
            with tf.GradientTape(persistent=True) as tape3:
                tape3.watch([t,x,y])
                with tf.GradientTape() as tape2:
                    tape2.watch([t,x,y])
                    with tf.GradientTape(persistent=True) as tape:
                        tape.watch([t,x,y])
                        u, v =model([t,x,y])
                    dudx, dudy = tape.gradient(u,[x,y])
                    dvdx, dvdy = tape.gradient(v,[x,y])
                    w = dvdx - dudy
                dwdt, dwdx, dwdy=tape2.gradient(w,[t,x,y])
            ddwddx=tape3.gradient(dwdx,x)
            ddwddy=tape3.gradient(dwdy,y)
            del tape
            del tape3
            
            
            #aggiungi sampling simmetrico
            smse_bulk=tf.sqrt(tf.reduce_sum(sum([element**2 * I_bulk for element in eval(eomdict[eom])]))/n_bulk +1e-20)
            mset0=tf.reduce_sum(sum([element**2 * I_t0 for element in eval(ICdict[eom][0])]))
            msetL=tf.reduce_sum(sum([element**2 * I_tL for element in eval(ICdict[eom][1])]))
            smse_IC=tf.sqrt(tf.reduce_sum(mset0+msetL)/(n_IC+n_tL) +1e-20)
            tb,xb,yb,I_x0,I_xL,I_y0,I_yL=random_sampling_board(bd,n_bo)
            u, v =model([tb,xb,yb])
            msex0=tf.reduce_sum(sum([element**2 for element in eval(boardx[eom][0])]) )
            msexL=tf.reduce_sum(sum([element**2 for element in eval(boardx[eom][1])]) )
            msey0=tf.reduce_sum(sum([element**2 for element in eval(boardy[eom][0])]) )
            mseyL=tf.reduce_sum(sum([element**2 for element in eval(boardy[eom][1])]) )
            smse_board=tf.sqrt(tf.reduce_sum(msex0+msexL+msey0+mseyL)/(2.*n_bo) +1e-20)
            wbu,wIC,wbo=wdict[eom]
            loss_value=wbu*smse_bulk+wIC*smse_IC+wbo*smse_board

            return loss_value, smse_bulk, smse_IC, smse_board
        
                

  
    return loss_cube


#############################################################################################
def to_loss_2to1(eom='wave',domain='rect'):
    
    @tf.function
    def loss_square(yt,yp):
        
        t,x,I_t0,I_tL,I_x0,I_xL,I_bulk=tf.split(yt,7,axis=1)
        n_bulk,n_IC,n_tL,n_board=tf.reduce_sum(I_bulk)+1e-10,tf.reduce_sum(I_t0)+1e-10,tf.reduce_sum(I_tL)+1e-10,tf.reduce_sum(I_x0+I_xL)+1e-10
        with tf.GradientTape(persistent=True) as tape_2:
            tape_2.watch([t,x])
            with tf.GradientTape() as tape_1:
                tape_1.watch([t,x])
                u=model([t,x])
            dudt, dudx=tape_1.gradient(u,[t,x])
        dduddx=tape_2.gradient(dudx,x)
        dduddt=tape_2.gradient(dudt,t) 
        del tape_2
            
        smse_bulk=tf.sqrt(tf.reduce_sum(sum([element**2 * I_bulk for element in eval(eomdict[eom])]))/n_bulk +1e-20)
        mset0=tf.reduce_sum(sum([element**2 * I_t0 for element in eval(ICdict[eom][0])]))
        msetL=tf.reduce_sum(sum([element**2 * I_tL for element in eval(ICdict[eom][1])]))
        msex0=tf.reduce_sum(sum([element**2 * I_x0 for element in eval(boardx[eom][0])]))
        msexL=tf.reduce_sum(sum([element**2 * I_xL for element in eval(boardx[eom][1])]))
        smse_IC=tf.sqrt(tf.reduce_sum(mset0+msetL)/(n_IC+n_tL) +1e-20)
        smse_board=tf.sqrt(tf.reduce_sum(msex0+msexL)/(n_board)+1e-20)
        wbu,wIC,wbo=wdict[eom]
        
        loss_value=wbu*smse_bulk+wIC*smse_IC+wbo*smse_board
       
        return loss_value, smse_bulk, smse_IC, smse_board

    
    
    @tf.function
    def loss_disk(yt,yp):
        
        t,x,I_board,I_bulk=tf.split(yt,4,axis=1)
        n_bulk,n_board=tf.reduce_sum(I_bulk)+1e-10,tf.reduce_sum(I_board)+1e-10
        with tf.GradientTape(persistent=True) as tape_2:
            tape_2.watch([t,x])
            with tf.GradientTape() as tape_1:
                tape_1.watch([t,x])
                u=model([t,x])                        
            dudt, dudx=tape_1.gradient(u,[t,x])
        dduddx=tape_2.gradient(dudx,x)
        dduddt=tape_2.gradient(dudt,t)
        del tape_2
       
        smse_bulk=tf.sqrt(tf.reduce_sum(sum([element**2 * I_bulk for element in eval(eomdict[eom])]))/n_bulk +1e-20)
        msex0=tf.reduce_sum(sum([element**2 * I_board for element in eval(boardx[eom][0])]) +1e-20)
        msexL=tf.reduce_sum(sum([element**2 * I_board for element in eval(boardx[eom][1])]) +1e-20)
        smse_IC=10.**(-20.)
        smse_board=tf.sqrt(tf.reduce_sum(msex0+msexL)/(n_board))
        wbu,wIC,wbo=wdict[eom]
        
        loss_value=wbu*smse_bulk+wIC*smse_IC+wbo*smse_board
        
        return loss_value,  smse_bulk, smse_IC, smse_board

    
    
    if domain=='rect':
        return loss_square
    else:
        return loss_disk
    
    
#############################################################################################
def to_loss_1to1(eom='wave'):
    
    @tf.function
    def loss(yt,yp):
        
        t,I_bulk=tf.split(yt,2,axis=1)
        n_bulk=tf.reduce_sum(I_bulk)
        t=tf.concat((t,t0tLsample(bd,eom)),0)
        I_t0=tf.concat((tf.zeros_like(I_bulk),[[1.],[0.]]),0)
        I_tL=tf.concat((tf.zeros_like(I_bulk),[[0.],[1.]]),0)
        I_bulk=tf.concat((I_bulk,[[0.],[0.]]),0)
        dt=t-1.
        with tf.GradientTape() as tape_2:
            tape_2.watch(t)
            with tf.GradientTape() as tape_1:
                tape_1.watch(t)
                u=model(t)    
                du = tf.where(dt <= 0., 1.-dt, model(dt))
            dudt=tape_1.gradient(u,t)
        dduddt=tape_2.gradient(dudt,t)
            
        smse_bulk=tf.sqrt(tf.reduce_sum(sum([element**2 * I_bulk for element in eval(eomdict[eom])]))/n_bulk)
        mset0=tf.reduce_sum(sum([element**2 * I_t0 for element in eval(ICdict[eom][0])]))
        msetL=tf.reduce_sum(sum([element**2 * I_tL for element in eval(ICdict[eom][1])]))
        smse_IC=tf.sqrt(tf.reduce_sum(mset0+msetL))
        
        loss_value=smse_bulk+smse_IC
       
        return loss_value, smse_bulk, smse_IC, 0.
    
    return loss


def to_loss_1to1_minibatch(eom='wave'):
    
    @tf.function
    def loss(yt,yp):
        
        # generate a new dataset every epoch
        t = I_bulk*tf.random.uniform(I_bulk.shape, minval=bd[0], maxval=bd[1])+I_t0*bd[0]+I_tL*bd[1]
        n_bulk,n_IC,n_tL=tf.reduce_sum(I_bulk)+1e-10,tf.reduce_sum(I_t0)+1e-10,tf.reduce_sum(I_tL)+1e-10
        dt=t-1.
        with tf.GradientTape() as tape_2:
            tape_2.watch(t)
            with tf.GradientTape() as tape_1:
                tape_1.watch(t)
                u=model(t)    
                du = tf.where(dt <= 0., 1.-dt, model(dt))
            dudt=tape_1.gradient(u,t)
        dduddt=tape_2.gradient(dudt,t)
            
        smse_bulk=tf.sqrt(tf.reduce_sum(sum([element**2 * I_bulk for element in eval(eomdict[eom])]))/n_bulk+1e-20)
        mset0=tf.reduce_sum(sum([element**2 * I_t0 for element in eval(ICdict[eom][0])]))
        msetL=tf.reduce_sum(sum([element**2 * I_tL for element in eval(ICdict[eom][1])]))
        smse_IC=tf.sqrt(tf.reduce_sum(mset0+msetL)/(n_IC+n_tL)+1e-20)
        
        loss_value=smse_bulk+smse_IC
       
        return loss_value, smse_bulk, smse_IC, 0.
    
    return loss



NameError: name 'keras' is not defined