In [145]:
import numpy as np
import tensorflow as tf

In [163]:
#create my dataset:
t_in = tf.linspace(0,3, 100)
#and the matrix to diagonalize
Q = tf.random.normal( (6,6) , dtype='float64') 
A = 0.5*(Q+tf.transpose(Q))

In [164]:
#create a model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(10, activation="sigmoid", input_shape=(1,), dtype='float64'))
model.add(tf.keras.layers.Dense(6, activation="linear", dtype='float64'))


In [224]:
#define function f which enters the definition of the differential equation
def f( x, A ):
    '''args:
    x: tensor of shape (npoints, n), where npoints is the number of points in the grid.
    A: (symmetric) matrix of size n x n, which is to be diagonalized'''
    
    n = A.shape[0]
    xT = tf.transpose(x) #(n x npoints)
    xxT = tf.linalg.diag_part(x@xT)#shape (npoints,)
    xAxT= tf.linalg.diag_part(x@A@xT)#shape (npoints,)
    mat1 = tf.tensordot(A, xxT, axes=0)#shape (n,n,npoints), this is a pile of npoints matrices
    Id6 = tf.eye(n, dtype='float64')
    mat2 = tf.tensordot(Id6, 1-xAxT, axes=0)#another stack of npoints matrices
    mat_tot = mat1+mat2 #shape(n,n,npoints)
    out = tf.einsum("ijk,kj->ki", mat_tot, x)

    return out

#define ansatz for our solution in terms of model

def ode_solution( t, x0) :
    '''args:
    t: time variable in tensor format (shape must be (N,1))
    x0: column vector of shape (n,1)
    returns solution of differential equation at time t
    '''
    t = t[:, tf.newaxis]
    out = (1-t)*tf.transpose(x0) + t*model(t)
    return out



In [225]:
x0 = tf.ones( (6,1), dtype='float64')

In [226]:
f(ode_solution( t_in , x0),A)

<tf.Tensor: shape=(100, 6), dtype=float64, numpy=
array([[-8.90036386e+00,  2.16354972e+01,  2.60294864e-01,
         1.02402506e+01, -1.83083707e+00, -1.54048417e+01],
       [-7.95947330e+00,  1.92278550e+01, -6.75066628e-02,
         1.03128326e+01, -1.73386355e+00, -1.35695049e+01],
       [-7.06629202e+00,  1.70597616e+01, -3.31288425e-01,
         1.03453810e+01, -1.62752503e+00, -1.18740594e+01],
       [-6.21832647e+00,  1.51172697e+01, -5.36363013e-01,
         1.03397724e+01, -1.51433046e+00, -1.03103858e+01],
       [-5.41307990e+00,  1.33864385e+01, -6.88075860e-01,
         1.02978098e+01, -1.39692044e+00, -8.87042390e+00],
       [-4.64805464e+00,  1.18533328e+01, -7.91805459e-01,
         1.02212199e+01, -1.27806647e+00, -7.54617165e+00],
       [-3.92075451e+00,  1.05040223e+01, -8.52963470e-01,
         1.01116523e+01, -1.16067003e+00, -6.32968386e+00],
       [-3.22868707e+00,  9.32458153e+00, -8.76994709e-01,
         9.97067720e+00, -1.04776156e+00, -5.21307122e+00]

In [108]:
t_in.dtype

tf.float64

In [109]:
x0.dtype

tf.float32

In [None]:
t_in*

In [116]:
t.reshape?

Object `t.reshape` not found.


In [118]:
tf.reshape(t_in, (1000,1))

<tf.Tensor: shape=(1000, 1), dtype=float64, numpy=
array([[0.        ],
       [0.003003  ],
       [0.00600601],
       [0.00900901],
       [0.01201201],
       [0.01501502],
       [0.01801802],
       [0.02102102],
       [0.02402402],
       [0.02702703],
       [0.03003003],
       [0.03303303],
       [0.03603604],
       [0.03903904],
       [0.04204204],
       [0.04504505],
       [0.04804805],
       [0.05105105],
       [0.05405405],
       [0.05705706],
       [0.06006006],
       [0.06306306],
       [0.06606607],
       [0.06906907],
       [0.07207207],
       [0.07507508],
       [0.07807808],
       [0.08108108],
       [0.08408408],
       [0.08708709],
       [0.09009009],
       [0.09309309],
       [0.0960961 ],
       [0.0990991 ],
       [0.1021021 ],
       [0.10510511],
       [0.10810811],
       [0.11111111],
       [0.11411411],
       [0.11711712],
       [0.12012012],
       [0.12312312],
       [0.12612613],
       [0.12912913],
       [0.13213213],
    