In [1]:
import torch

x = torch.ones(5) # input tensor
y = torch.zeros(3) # expected output
w = torch.rand(5, 3, requires_grad=True)
b = torch.rand(3, requires_grad=True)
z = torch.matmul(x, w) + b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

Tensors, Functions and Computational graph

In [2]:
print("Gradient function for z = ", z.grad_fn)
print("Gradient Function for loss = ", loss.grad_fn)

Gradient function for z =  <AddBackward0 object at 0x00000211ABDAC400>
Gradient Function for loss =  <BinaryCrossEntropyWithLogitsBackward0 object at 0x00000211ABDAC880>


Computing Gradients

In [3]:
loss.backward()
print(w.grad)
print(b.grad)

tensor([[0.3104, 0.3205, 0.3159],
        [0.3104, 0.3205, 0.3159],
        [0.3104, 0.3205, 0.3159],
        [0.3104, 0.3205, 0.3159],
        [0.3104, 0.3205, 0.3159]])
tensor([0.3104, 0.3205, 0.3159])


Disabling Gradient Tracking

In [4]:
z = torch.matmul(x, w) + b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w) + b
    print(z.requires_grad)

True
False


In [5]:
# also can use
z = torch.matmul(x, w) + b
z_det = z.detach()
print(z_det.requires_grad)

False


Jacobian Products

In [6]:
inp = torch.eye(5, requires_grad=True)
out = (inp + 1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True) # backward() = backward(torch.tensor(1.0))
print("First call \n", inp.grad)
out.backward(torch.ones_like(inp), retain_graph=True)
print("\n Second call \n", inp.grad)
inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print("\n Call after zeroing gradients \n", inp.grad)

First call 
 tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])

 Second call 
 tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.],
        [4., 4., 4., 4., 8.]])

 Call after zeroing gradients 
 tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])
