In [2]:
import tensorflow as tf

# tf.Variable

The input to a graph can be represented by a tf.Variable.

Computations on tf.Variable return tensors.

Computations on tensors in general are **trackable by AutoDiff**, including those derived from tf.Variable.


In [6]:
def f(val):
    x = tf.Variable(val, name='x')
    return x**2


res = f(10)
print(res)

tf.Tensor(100, shape=(), dtype=int32)


# Gradient Tape

Gradient tape uses the context management pattern to track operations on TensorFlow tensors so that you can take gradients afterwards.

WARNING: you will get None as the gradient if your inputs are integers. They need to be **floats only**!.


In [24]:
# A function that knows nothing about AutoDiff.
def f(x):
    return x**2 + 2 * x + 1


# The independent variable needs to be a TF variable.
x = tf.Variable(3.0, dtype=tf.float32)

# All operations on the variable and its resulting tensors
# will be tracked within the gradient tape context.
with tf.GradientTape() as tape:
    y = f(x)

# Use the computation graph tracked by the tape.
dy_dx = tape.gradient(y, x)
print(dy_dx)


# Let's manually verify the result
def df_dx(x):
    return 2 * x + 2


print(df_dx(x))  # close enough


tf.Tensor(7.9999995, shape=(), dtype=float32)
tf.Tensor(8.0, shape=(), dtype=float32)


# Matrices


In [33]:
A = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
x = tf.Variable([[100], [200], [300]], dtype=float)

with tf.GradientTape() as tape:
    b = A @ x

db_dx, db_dA = tape.gradient(b, [x, A])

print(db_dx)
print(db_dA)

tf.Tensor(
[[12.]
 [15.]
 [18.]], shape=(3, 1), dtype=float32)
tf.Tensor(
[[100. 200. 300.]
 [100. 200. 300.]
 [100. 200. 300.]], shape=(3, 3), dtype=float32)


# Gradient Descent


In [36]:
A = tf.Variable([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
x = tf.Variable([[100], [200], [300]], dtype=float)
alpha = 0.001

with tf.GradientTape() as tape:
    b = A @ x
    J = tf.norm(b)

dJ_dA = tape.gradient(J, A)
A.assign_add(-0.001 * dJ_dA)

<tf.Variable 'UnreadVariable' shape=(3, 3) dtype=float32, numpy=
array([[0.9770461, 1.9540921, 2.9311383],
       [3.9475338, 4.8950677, 5.842602 ],
       [6.9180217, 7.8360434, 8.7540655]], dtype=float32)>