In [1]:
import torch

In [2]:
#創建tensor
x = torch.rand((2,2))
y = torch.rand((2,2))
z = torch.rand((2,2), requires_grad=True)
a = x + y
b = a + z

In [3]:
#查看requires_grad
print(f'x: {x.requires_grad}, y: {y.requires_grad}, z: {z.requires_grad}')
print(f'a: {a.requires_grad}, b: {b.requires_grad}')

x: False, y: False, z: True
a: False, b: True


In [4]:
#查看grad_fn
print(f'x: {x.grad_fn}, y: {y.grad_fn}, z: {z.grad_fn}')
print(f'a: {a.grad_fn}, b: {b.grad_fn}')

x: None, y: None, z: None
a: None, b: <AddBackward0 object at 0x000001A2ED0F7A48>


### 使用較複雜一點的計算圖來計算梯度(微分)

In [5]:
#創建tensor
x = torch.ones((2,2), requires_grad=True)
x

tensor([[1., 1.],
        [1., 1.]], requires_grad=True)

In [6]:
#以加法創建新tensor
y = x + 2
y, y.requires_grad

(tensor([[3., 3.],
         [3., 3.]], grad_fn=<AddBackward0>),
 True)

In [7]:
z = y*y*3
out = z.mean()
print(z)
print(out)

tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)
tensor(27., grad_fn=<MeanBackward0>)


In [8]:
print(x.grad_fn)
print(y.grad_fn)
print(z.grad_fn)
print(out.grad_fn)

None
<AddBackward0 object at 0x000001A2ED0F2D08>
<MulBackward0 object at 0x000001A2ED0F2788>
<MeanBackward0 object at 0x000001A2ED0F2D08>


在這邊我們簡單運算一下，在對out進行針對x微分後的直應該要是多少

$$
\begin{aligned}
\frac{\partial{out}}{\partial{x_i}} &= \frac{\partial{\frac{1}{4}\sum_{i}z_i}}{\partial{x_i}} \\
                                  &= \frac{\partial{\frac{1}{4}\sum_{i}z_i}}{\partial{x_i}} \\
                                  &= \frac{\partial{\frac{1}{4}3(x_i+2)^2}}{\partial{x_i}} \\
                                  &= \frac{3}{2}(x_i+2), \ where\ x_i=1 \\
\frac{\partial{out}}{\partial{x_i}} &= \frac{3}{2}(1+2) = 4.5
\end{aligned}
$$

In [9]:
#計算微分 
out.backward()

In [10]:
x.grad

tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

不紀錄倒傳遞

In [11]:
# requires_grad設定為False

a = torch.randn(2, 2, requires_grad=False) 
a = ((a * 3) / (a - 1))
b = (a * a).sum()
print(b.requires_grad, a.requires_grad)
print(b.grad_fn)

False False
None


In [12]:
# 使用torch.no_grad()
a = torch.randn(2, 2, requires_grad=True)
a = ((a * 3) / (a - 1))
with torch.no_grad():
    b = (a * a).sum()

print(b.requires_grad, b.grad_fn)

False None
