In [1]:
%matplotlib inline

In [8]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

In [None]:
print(torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

In [13]:
x = torch.ones(5)
y = torch.zeros(3)
w = torch.randn(5,3,requires_grad=True)
b = torch.randn(3,requires_grad=True)
z = torch.matmul(x,w) + b
# loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
loss = F.binary_cross_entropy_with_logits(z,y)

In [14]:
print(f"Gradient function for z= {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

Gradient function for z= <AddBackward0 object at 0x7f68321c9410>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7f68321c9290>


In [15]:
loss.backward()
print(w.grad)
print(b.grad)

tensor([[0.2322, 0.0427, 0.2269],
        [0.2322, 0.0427, 0.2269],
        [0.2322, 0.0427, 0.2269],
        [0.2322, 0.0427, 0.2269],
        [0.2322, 0.0427, 0.2269]])
tensor([0.2322, 0.0427, 0.2269])


In [16]:
z = torch.matmul(x, w) + b
print(z.requires_grad)

True


In [17]:
with torch.no_grad():
    z = torch.matmul(x,w) + b
print(z.requires_grad)

False


In [18]:
z = torch.matmul(x,w) + b
z_det = z.detach()
print(z_det.requires_grad)

False


In [19]:
inp = torch.eye(5, requires_grad=True)
out = (inp+1).pow(2)
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(inp), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.],
        [4., 4., 4., 4., 8.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.],
        [2., 2., 2., 2., 4.]])
