In [33]:
%load_ext autoreload
%autoreload 2

import tensorflow as tf
from networks import Encoder, Decoder, MetaModel
import numpy as np
import ast
from sindy import SINDy


encoder = Encoder()
decoder = Decoder()
sindy = SINDy()

models = [encoder, decoder, sindy]
for model in [encoder,decoder, sindy]:
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [34]:
x = tf.random.uniform((2,40,128))
x_dot = tf.random.uniform((2,40,128))

In [39]:
with tf.GradientTape(persistent=True) as tape:
    tape.watch(encoder.trainable_variables)
    tape.watch(decoder.trainable_variables)
    tape.watch(x)
    tape.watch(tf.convert_to_tensor(sindy.coeffs))

    ### LOSS 0###
    z = encoder(x)
    tape.watch(z)
    x_quasi = decoder(z)
    loss0 = tf.keras.losses.MSE(x,x_quasi)

    ### LOSS 1 ###
    zdot_SINDy = sindy(z)
    with tape.stop_recording():
        dpsi_dz = tape.jacobian(x_quasi,z)
    ones = tf.ones(dpsi_dz.shape)
    dpsi_dz = tf.einsum('abcdef,yuidep->abcf',dpsi_dz, ones) ### most of those derivatives are zero, since they correspond to different batches!
    xdot_pred = tf.einsum('ntaj,ntj->nta',dpsi_dz,zdot_SINDy)
    loss1 = tf.keras.losses.MSE(x_dot,xdot_pred)

    ### LOSS 2 ###
    with tape.stop_recording():
        dphi_dx = tape.jacobian(z,x)
    ones = tf.ones(dphi_dx.shape)
    dphi_dx = tf.einsum('abcdef,yuidep->abcf',dphi_dx, ones) ### most of those derivatives are zero, since they correspond to different batches!
    zdot_pred = tf.einsum('ntaj,ntj->nta',dphi_dx,x_dot)
    loss2 = tf.cast(tf.keras.losses.MSE(zdot_pred, zdot_SINDy),tf.float32)

    ### LOSS 3 ###
    loss3 = tf.expand_dims(tf.einsum('ij->',tf.math.abs(sindy.coeffs)),axis=0)
    total_loss = loss0 + loss1 + loss2 + loss3


grads_enc = tape.gradient(total_loss, encoder.trainable_variables)
grads_dec = tape.gradient(total_loss, decoder.trainable_variables)
grads_SINDy_coeffs = [tape.gradient(total_loss, sindy.coeffs)]

gradients = [grads_enc, grads_dec, grads_SINDy_coeffs]
models = [encoder, decoder, sindy]

for model, gradient in zip(models, gradients):
    model.optimizer.apply_gradients(zip(gradient, model.trainable_variables))






In [40]:
sindy.optimizer.apply_gradients(zip(grads_SINDy_coeffs,sindy.trainable_variables))

<tf.Variable 'UnreadVariable' shape=() dtype=int64, numpy=2>

In [42]:
sindy.trainable_variables[0]

<tf.Variable 'Variable:0' shape=(3, 27) dtype=float32, numpy=
array([[0.998    , 0.998    , 0.998    , 1.002    , 1.002    , 0.998    ,
        0.998    , 0.998    , 0.998    , 0.998    , 0.998    , 0.998    ,
        1.002    , 0.998    , 0.998    , 0.998    , 0.998    , 0.998    ,
        0.998    , 0.998    , 0.998    , 0.998    , 0.998    , 0.998    ,
        0.998    , 0.998    , 0.998    ],
       [0.9980001, 0.9980001, 0.998    , 1.002    , 1.002    , 0.998    ,
        0.9980001, 0.998    , 0.998    , 0.9980001, 0.998    , 0.998    ,
        1.0019999, 0.998    , 0.998    , 0.998    , 0.998    , 0.998    ,
        0.998    , 0.998    , 0.998    , 0.998    , 0.998    , 0.998    ,
        0.998    , 0.998    , 0.998    ],
       [0.998    , 0.998    , 0.998    , 1.002    , 1.002    , 0.998    ,
        0.998    , 0.998    , 0.998    , 0.998    , 0.998    , 0.998    ,
        1.002    , 0.998    , 0.998    , 0.998    , 0.998    , 0.998    ,
        0.998    , 0.998    , 0.998    ,

<tf.Variable 'Variable:0' shape=(3, 27) dtype=float64, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])>

In [25]:
sindy1.trainable_variables

[<tf.Tensor: shape=(3, 27), dtype=float32, numpy=
 array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>]

m
m
m


AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute '_in_graph_mode'

[<tf.Variable 'encoder_1/dense_6/kernel:0' shape=(128, 64) dtype=float32, numpy=
 array([[-0.09033427, -0.04396643,  0.03536167, ..., -0.0455751 ,
         -0.09373885,  0.06527469],
        [-0.00925119, -0.07525584, -0.0725927 , ...,  0.06097997,
          0.03062052,  0.00068207],
        [-0.00975733,  0.09546264,  0.07032993, ...,  0.01553587,
         -0.07542764,  0.10051169],
        ...,
        [-0.05952345,  0.00431443, -0.01477603, ..., -0.08673827,
         -0.0715066 , -0.08745359],
        [ 0.08271442,  0.00602508, -0.07120262, ..., -0.0109188 ,
         -0.09876256, -0.07037709],
        [-0.04387222,  0.03468087,  0.03168849, ..., -0.05419493,
         -0.07637918, -0.07978804]], dtype=float32)>,
 <tf.Variable 'encoder_1/dense_6/bias:0' shape=(64,) dtype=float32, numpy=
 array([ 0.00300469,  0.00300475,  0.00299427,  0.00298889, -0.00290534,
         0.0029809 , -0.00030736,  0.00300074,  0.00298342, -0.00293195,
        -0.00295986,  0.00300039, -0.00294927,  0.00300

In [31]:
sindy.trainable_variables = sindy.coeffs

AttributeError: Can't set the attribute "trainable_variables", likely because it conflicts with an existing read-only @property of the object. Please choose a different name.

In [34]:
gradients[-1]

<tf.Tensor: shape=(3, 27), dtype=float32, numpy=
array([[ 4.62363828e+04, -5.75679932e+02,  8.94981003e+01,
         4.31119141e+03,  1.98481674e+01,  8.08720398e+01,
         4.68022736e+02,  7.44796143e+01,  8.00800934e+01,
         8.24283398e+03, -3.59157486e+01,  8.16785660e+01,
         8.28271912e+02,  6.93662491e+01,  8.01541061e+01,
         1.48618622e+02,  7.90241241e+01,  8.00141525e+01,
         1.52377832e+03,  5.95052414e+01,  8.02966843e+01,
         2.12344681e+02,  7.81199341e+01,  8.00272369e+01,
         9.21360474e+01,  7.98274689e+01,  8.00025024e+01],
       [ 1.07571914e+05, -1.44498486e+03,  1.02055939e+02,
         9.93183496e+03, -5.98693695e+01,  8.20244827e+01,
         9.83276001e+02,  6.71666107e+01,  8.01858902e+01,
         1.90920195e+04, -1.89624237e+02,  8.38982544e+01,
         1.82243677e+03,  5.52712059e+01,  8.03578033e+01,
         2.39752747e+02,  7.77311325e+01,  8.00328522e+01,
         3.44303394e+03,  3.23237267e+01,  8.06890717e+01,
      