In [1]:
import torch

In [2]:
x = torch.ones(5)
y = torch.zeros(3)
w = torch.randn(5, 3, requires_grad=True) # 实际使用的时候用nn.Parameters
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

In [3]:
loss.backward() # return_graph=True
print(w.grad)
print(b.grad)

tensor([[0.2105, 0.2634, 0.2441],
        [0.2105, 0.2634, 0.2441],
        [0.2105, 0.2634, 0.2441],
        [0.2105, 0.2634, 0.2441],
        [0.2105, 0.2634, 0.2441]])
tensor([0.2105, 0.2634, 0.2441])


不需要计算梯度的情况
* fine-tuning: 主题网络的参数保持不变
* 只需要做测试

In [7]:
z = torch.matmul(x, w) + b #父节点有一个为True，则子节点就为True
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w) + b
print(z.requires_grad)

True
False


In [12]:
z_det = z.detach()
print(z_det.requires_grad)

False


# 向量微分，矩阵微分

In [13]:
# 不支持直接的向量对向量求导，返回的不是Jacobian matrix，而是J和向量的乘积（JVP）
# pytorch中默认梯度时累积的，所以需要进行梯度清零

inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True)
print("First call:", inp.grad)

out.backward(torch.ones_like(inp), retain_graph=True)
print("Second call:", inp.grad)

inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print("Zero call:", inp.grad)

First call: tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])
Second call: tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.],
        [4., 4., 4., 4., 8.]])
Zero call: tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])


In [20]:
from torch.autograd.functional import jacobian
def exp_reducer(x):
    return x.exp().sum(dim=1) #行求和

inputs = torch.rand(2, 3)
jacobian(exp_reducer, inputs)

tensor([[[1.9369, 1.7160, 1.1201],
         [0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000],
         [1.5937, 1.4162, 1.2642]]])