In [1]:
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras import Input

num_samples = 1024
x_train = 4 * (tf.random.uniform((num_samples, )) - 0.5)
y_train = tf.zeros((num_samples, ))
inputs = Input(shape=(1,))
x = Dense(16, 'tanh')(inputs)
x = Dense(8, 'tanh')(x)
x = Dense(4)(x)
y = Dense(1)(x)
model = Model(inputs=inputs, outputs=y)

# using the high level tf.data API for data handling
x_train = tf.reshape(x_train,(-1,1))
dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(1)

opt = Adam(learning_rate=0.01, beta_1=0.9, beta_2=0.99)
for step, (x,y_true) in enumerate(dataset):
    # we need to convert x to a variable if we want the tape to be 
    # able to compute the gradient according to x
    x_variable = tf.Variable(x) 
    with tf.GradientTape() as model_tape:
        with tf.GradientTape() as loss_tape:
            loss_tape.watch(x_variable)
            y_pred = model(x_variable)
        dy_dx = loss_tape.gradient(y_pred, x_variable)
        loss = tf.math.reduce_mean(tf.square(dy_dx + 3 * y_pred - y_true))
    grad = model_tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(grad, model.trainable_variables))
    if step%20==0:
        print(f"Step {step}: loss={loss.numpy()}")

Step 0: loss=0.002485536737367511
Step 20: loss=0.036862146109342575
Step 40: loss=0.001743412110954523
Step 60: loss=4.897563940176042e-06
Step 80: loss=0.00010362969624111429
Step 100: loss=0.0001648172183195129
Step 120: loss=2.714315542107215e-06
Step 140: loss=1.7781661881599575e-05
Step 160: loss=6.156313247629441e-06
Step 180: loss=1.7340279327981989e-06
Step 200: loss=7.833919880795293e-06
Step 220: loss=7.178419991760165e-07
Step 240: loss=3.0368861189344898e-05
Step 260: loss=8.075248025818382e-09
Step 280: loss=5.585348503700516e-07
Step 300: loss=2.592063765405328e-06
Step 320: loss=1.3920854371463065e-06
Step 340: loss=2.207995748904068e-06
Step 360: loss=1.2378376595734153e-05
Step 380: loss=0.00019707366300281137
Step 400: loss=0.00015773742052260786
Step 420: loss=0.0010254003573209047
Step 440: loss=0.0001986278803087771
Step 460: loss=0.0008267493103630841
Step 480: loss=0.0001223227591253817
Step 500: loss=5.052189862908563e-06
Step 520: loss=4.286090188543312e-05
St