In [1]:
import torch

In [44]:
torch.manual_seed(234)
x = torch.ones(5)
y = torch.zeros(3)
# could also use w.requires_grad(True)
w = torch.randn((5, 3), requires_grad=True)
b = torch.randn(3, requires_grad=True)

In [45]:
z = torch.matmul(x, w) + b
# a = torch.nn.functional.sigmoid(z)
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
# loss = torch.nn.functional.binary_cross_entropy(a, y)

In [46]:
print(z.grad_fn)
print(loss.grad_fn)

<AddBackward0 object at 0x7f7965c7e0b0>
<BinaryCrossEntropyWithLogitsBackward0 object at 0x7f7965c7c6d0>


In [47]:
loss.backward()
print(w.grad)
print(b.grad)

tensor([[0.0011, 0.0690, 0.0342],
        [0.0011, 0.0690, 0.0342],
        [0.0011, 0.0690, 0.0342],
        [0.0011, 0.0690, 0.0342],
        [0.0011, 0.0690, 0.0342]])
tensor([0.0011, 0.0690, 0.0342])


In [50]:
# Compute gradients by hand
# dL/dw = (a-y)x
with torch.no_grad():
    a = torch.nn.functional.sigmoid(z)
    grad = torch.outer(x, a - y) / y.shape[0]
    print(grad)

tensor([[0.0011, 0.0690, 0.0342],
        [0.0011, 0.0690, 0.0342],
        [0.0011, 0.0690, 0.0342],
        [0.0011, 0.0690, 0.0342],
        [0.0011, 0.0690, 0.0342]])


In [7]:
z = torch.matmul(x, w) + b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w) + b
print(z.requires_grad)

True
False


In [10]:
z = torch.matmul(x, w) + b
print(z.requires_grad)
z_d = z.detach()
print(z_d.requires_grad)

True
False


In [64]:
inp = torch.eye(4, 5, requires_grad=True)
out = (inp + 1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call:\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"Second call:\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"Third call:\n{inp.grad}")

First call:
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])
Second call:
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])
Third call:
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])
