In [None]:
# 常规自动微分API: Tensor.backward()、torch.autograd.grad
# JAX启发的functional transform API
    # Google JAX: https://github.com/jax-ml/jax
    # torch.func: https://pytorch.org/docs/main/func.html

In [409]:
import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)

In [410]:
def predict(weight, bias, x):
    return F.linear(x, weight, bias).tanh()

In [411]:
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D)  # feature vector

In [412]:
# 我们必须每次使用不同的单位向量逐行计算它
def compute_jac(xp):
    jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
                     for vec in unit_vectors]
    return torch.stack(jacobian_rows)

xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)

jacobian = compute_jac(xp)

print(jacobian.shape)
print(jacobian[0])  # show first row

torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295,  0.4490,  0.3661, -0.1672, -1.1190,
         0.1705, -0.6683,  0.1851,  0.1630,  0.0634,  0.6547,  0.5908, -0.1308])


In [415]:
# torch.vmap我们可以使用 PyTorch 的函数变换来摆脱 for 循环并矢量化计算
from torch.func import vmap, vjp

_, vjp_fn = vjp(partial(predict, weight, bias), x)

ft_jacobian, = vmap(vjp_fn)(unit_vectors)

# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)

In [419]:
# torch.func.jacrev一个便捷函数来执行vmap-vjp组合以计算雅可比矩阵
from torch.func import jacrev

ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)

# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)
 
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)

In [418]:
# 比较性能
def get_perf(first, first_descriptor, second, second_descriptor):
    """takes torch.benchmark objects and compares delta of second vs first."""
    faster = second.times[0]
    slower = first.times[0]
    gain = (slower-faster)/slower
    if gain < 0: gain *=-1
    final_gain = gain*100
    print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
from torch.utils.benchmark import Timer

without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)

print(no_vmap_timer)
print(with_vmap_timer)
get_perf(no_vmap_timer, "without vmap",  with_vmap_timer, "vmap")

<torch.utils.benchmark.utils.common.Measurement object at 0x7f4a73b3d8a0>
compute_jac(xp)
  1.35 ms
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f4a59bbe2f0>
jacrev(predict, argnums=2)(weight, bias, x)
  368.34 us
  1 measurement, 500 runs , 1 thread
 Performance delta: 72.8105 percent improvement with vmap 


In [423]:
# 反向jac和正向jac
from torch.func import jacrev, jacfwd

print("输入少于输出")
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)

bias = torch.randn(Dout)
x = torch.randn(Din)

# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)

using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)

print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", )

print("输入多于输出")
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)

using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)

print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")

输入少于输出
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f4a5793ffd0>
jacfwd(predict, argnums=2)(weight, bias, x)
  686.11 us
  1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f4a579591e0>
jacrev(predict, argnums=2)(weight, bias, x)
  18.99 ms
  1 measurement, 500 runs , 1 thread
 Performance delta: 2667.2973 percent improvement with jacrev 
输入多于输出
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f4a57925c00>
jacfwd(predict, argnums=2)(weight, bias, x)
  7.06 ms
  1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f4a59bbe380>
jacrev(predict, argnums=2)(weight, bias, x)
  467.05 us
  1 measurement, 500 runs , 1 thread
 Performance delta: 1412.3726 percent improvement with jacfwd 


In [424]:
# Hessians 是雅可比矩阵的雅可比矩阵（或偏导数的偏导数，又称二阶）。
from torch.func import hessian

# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)

hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
torch.allclose(hess_api, hess_fwdfwd)

True

In [425]:
# 批量雅可比矩阵和批量黑森矩阵
# 如果您有多个输入参数，比如 func(x, y, z)，in_dims 的长度应该与输入参数的数量相同。
# 对于每个输入参数，您可以指定：
    # None：表示该输入参数不包含批量维度，函数将其视为单个样本。
    # 一个整数：表示该输入参数的哪个维度是批量维度。例如，0 表示第一个维度是批量维度。
    # 一个负数：表示从后向前数的维度。例如，-1 表示最后一个维度是批量维度
batch_size = 64
Din = 31
Dout = 33

weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")

bias = torch.randn(Dout)

x = torch.randn(batch_size, Din)

compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
print(batch_jacobian0.shape)

weight shape = torch.Size([33, 31])
torch.Size([64, 33, 31])


In [426]:
# 如果你有一个从 (B, N) -> (B, M) 的函数，并且确定每个输入都会产生一个独立的输出，那么有时也可以vmap通过对输出求和然后计算该函数的雅可比矩阵来做到这一点：
# 因为别的bs的输出对于当前bs的输入的导数是0
def predict_with_output_summed(weight, bias, x):
    return predict(weight, bias, x).sum(0)

batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
print(batch_jacobian1.shape)

torch.Size([64, 33, 31])


In [427]:
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))

batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape

torch.Size([64, 33, 31, 31])

In [428]:
# 将反向模式 AD 与正向模式 AD 组合（而不是将反向模式与反向模式组合）通常是计算 hvp 的更节省内存的方法，因为正向模式 AD 不需要构建 Autograd 图并保存后向的中间体
from torch.func import jvp, grad, vjp

def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]
def f(x):
  return x.sin().sum()

x = torch.randn(2048)
tangent = torch.randn(2048)

result = hvp(f, (x,), (tangent,))
# 如果 PyTorch 前向 AD 无法覆盖你的操作，那么我们可以用反向模式 AD 来组合反向模式 AD：
def hvp_revrev(f, primals, tangents):
  _, vjp_fn = vjp(grad(f), *primals)
  return vjp_fn(*tangents)

result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))
assert torch.allclose(result, result_hvp_revrev[0])