In [64]:
import tensorflow as tf

### 参考torch.autograd自动求导

In [None]:
# f(x) = a*x**2 + b*x + c的导数
x = tf.Variable(0.0)
a = tf.constant(1.0)
b = tf.constant(-2.0)
c = tf.constant(1.0)

In [66]:
with tf.GradientTape() as tape: # Record operations for automatic differentiation.
    # 不要在tf.tf.GradientTape()修改x的值
    # By default GradientTape will automatically watch any trainable variables that are accessed inside the context.
    y = a * tf.pow(x, 2) + b * x + c

dy_dx = tape.gradient(y, x)
print(dy_dx)

tf.Tensor(-2.0, shape=(), dtype=float32)


In [67]:
with tf.GradientTape() as tape:
    # 常量也可以被求导,但需要增加watch
    # tensor: a Tensor or list of Tensors
    tape.watch(tensor=[a, b, c])
    y = a * tf.pow(x, 2) + b * x + c

dy_dx, dy_da, dy_db, dy_dc = tape.gradient(y, [x, a, b, c])
print(dy_dx)
print(dy_da)
print(dy_db)
print(dy_dc)

tf.Tensor(-2.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(1.0, shape=(), dtype=float32)


In [68]:
# GradientTapes can be nested to compute higher-order derivatives.
with tf.GradientTape() as tape2:
    with tf.GradientTape() as tape1:
        y = a * tf.pow(x, 2) + b * x + c
    dy1_dx1 = tape1.gradient(y, x)
dy2_dx2 = tape2.gradient(dy1_dx1, x)

print(dy1_dx1)
print(dy2_dx2)  # 二阶导数

tf.Tensor(-2.0, shape=(), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)


In [72]:
x = tf.constant(3.0)
# By default, the resources held by a GradientTape are released as soon as GradientTape.gradient() method is called.
# To compute multiple gradients over the same computation, create a persistent gradient tape.
# This allows multiple calls to the gradient() method as resources are released when the tape object is garbage collected.
with tf.GradientTape(persistent=True) as g:
    g.watch(x)
    y = x * x
    z = y * y

# 通过设置persistent=True实现多次gradient计算
dy_dx = g.gradient(y, x)
print(dy_dx)

dz_dx = g.gradient(z, x)
print(dz_dx)

dz_dx = g.gradient(z, x)
print(dz_dx)

tf.Tensor(6.0, shape=(), dtype=float32)
tf.Tensor(108.0, shape=(), dtype=float32)
tf.Tensor(108.0, shape=(), dtype=float32)
