This Python Notebook demonstrates a verification process for custom gradients, if a corresponding gradient can be accurately computed through automatic differentiation.

This next code cell defines the operation with the custom gradient. The sign of the gradient is flipped.

In [1]:
import tensorflow as tf
class FlipGrad(tf.keras.layers.Layer):
    def __init__(self,*args,**kwargs):
        @tf.custom_gradient
        def flip_grad2(x):
            def grad(dy):
                return -dy
            return tf.identity(x), grad
        super().__init__(*args,**kwargs)
        self.f = lambda x: flip_grad2(x)
    def call(self,inputs):
        return self.f(inputs)

In [2]:
import numpy as np
import post_process_grad as ppg
inputs = tf.keras.layers.Input(shape=(1,))
output1 = tf.keras.layers.Dense(units=1,use_bias=True)(inputs)
output2 = FlipGrad()(output1)
model1 = ppg.Model_record_grad(inputs,output1)
model1.compile(loss = tf.keras.losses.MSE,run_eagerly=True,optimizer=tf.keras.optimizers.SGD())
x = np.random.randn(1000,1)
y = -2*x + 4
xval = np.random.randn(100,1)
yval = -2*xval + 4
model1.fit(x=x,y=y,batch_size=10,epochs=8,shuffle=False,validation_data = (xval,yval))

gradients = []
for grad in model1.gradient_record:
    gradients.append([-elem for elem in grad])
for grad in model1.gradient_record:
    gradients.append(grad)

model2 = ppg.Model_passenger(gradients,inputs,output2)
model2.compile(loss = tf.keras.losses.MSE,run_eagerly=True,optimizer=tf.keras.optimizers.SGD())
model2.fit(x=x,y=y,batch_size=10,epochs=16,shuffle=False,validation_data = (xval,yval))

sumError = 0
for grad1,grad2 in zip(model1.gradient_record,model2.gradient_record[800:]):
    for gradval1,gradval2 in zip(grad1,grad2):
        sumError += (gradval1 + gradval2)**2
print(sumError)

Epoch 1/8
Epoch 2/8
Epoch 3/8
Epoch 4/8
Epoch 5/8
Epoch 6/8
Epoch 7/8
Epoch 8/8
Epoch 1/16
Epoch 2/16
Epoch 3/16
Epoch 4/16
Epoch 5/16
Epoch 6/16
Epoch 7/16
Epoch 8/16
Epoch 9/16
Epoch 10/16
Epoch 11/16
Epoch 12/16
Epoch 13/16
Epoch 14/16
Epoch 15/16
Epoch 16/16
tf.Tensor([[6.4228203e-07]], shape=(1, 1), dtype=float32)
