In [38]:
import torch
from torch import nn
from torch import optim

class TinyNetwork(nn.Module):
    def __init__(self,):
        super().__init__()
        self.layers = nn.Sequential(nn.Linear(2, 5),
                                    nn.ReLU(),
                                    nn.Linear(5, 1))
        
    def forward(self, x):
        return self.layers(x)

In [39]:
model = TinyNetwork()
print(model)

TinyNetwork(
  (layers): Sequential(
    (0): Linear(in_features=2, out_features=5, bias=True)
    (1): ReLU()
    (2): Linear(in_features=5, out_features=1, bias=True)
  )
)


In [40]:
batch_size = 2
x = torch.randn(batch_size, 2, requires_grad=False)
print(f"x {x.requires_grad} grad {x.grad}")
out = model(x)
print(out)
print(out.requires_grad)

x False grad None
tensor([[-0.2375],
        [-0.1713]], grad_fn=<AddmmBackward0>)
True


In [41]:
print(model.layers[0].weight.grad, model.layers[2].weight.grad)

None None


In [42]:
x = torch.randn(batch_size, 2, requires_grad=True)
print(f"x {x.requires_grad} grad {x.grad}")
out = model(x)
print(f"Out {out} grad {out.grad} requires grad {out.requires_grad}")
print(model.layers[0].weight.grad, model.layers[2].weight.grad)

x True grad None
Out tensor([[-0.1862],
        [-0.2644]], grad_fn=<AddmmBackward0>) grad None requires grad True
None None


In [43]:
optimizer = optim.SGD(model.parameters(), lr=0.1)
loss = (torch.ones(2) - out).pow(2).sum()
loss.backward()
print(f"Out {out} grad {out.grad} requires grad {out.requires_grad}")
print(model.layers[0].weight.grad, model.layers[2].weight.grad)

Out tensor([[-0.1862],
        [-0.2644]], grad_fn=<AddmmBackward0>) grad None requires grad True
tensor([[ 0.0000,  0.0000],
        [-0.5011, -0.3012],
        [-0.2097, -0.1261],
        [-0.4627,  3.4254],
        [ 0.2939, -0.9932]]) tensor([[ 0.0000, -6.8149, -5.0638, -3.7467, -3.1071]])


In [44]:
optimizer.step()
print(f"Out {out} grad {out.grad} requires grad {out.requires_grad}")
print(model.layers[0].weight.grad, model.layers[2].weight.grad)
optimizer.zero_grad(set_to_none=True)
print(f"Out {out} grad {out.grad} requires grad {out.requires_grad}")
print(model.layers[0].weight.grad, model.layers[2].weight.grad)

Out tensor([[-0.1862],
        [-0.2644]], grad_fn=<AddmmBackward0>) grad None requires grad True
tensor([[ 0.0000,  0.0000],
        [-0.5011, -0.3012],
        [-0.2097, -0.1261],
        [-0.4627,  3.4254],
        [ 0.2939, -0.9932]]) tensor([[ 0.0000, -6.8149, -5.0638, -3.7467, -3.1071]])
Out tensor([[-0.1862],
        [-0.2644]], grad_fn=<AddmmBackward0>) grad None requires grad True
None None


In [53]:
x = torch.randn(batch_size, 2, requires_grad=True)
print(f"x grad {x.grad}")
out = model(x)
loss = (torch.ones(2) - out).pow(2).sum()
loss.backward(retain_graph=True)
print(f"x with retain graph {x.grad}")


x grad None
x with retain graph tensor([[-1.5050, -4.3178],
        [ 0.0533,  0.6561]])


In [54]:
print(f"Out {out} grad {out.grad} requires grad {out.requires_grad}")
print(model.layers[0].weight.grad, model.layers[2].weight.grad)

Out tensor([[2.7993],
        [1.3782]], grad_fn=<AddmmBackward0>) grad None requires grad True
tensor([[  0.0000,   0.0000],
        [ -8.2707,  -4.4217],
        [ -0.4875,   2.1057],
        [-18.0650, -16.8640],
        [ -9.2060,   1.5357]]) tensor([[ 0.0000, 29.2423,  5.4341, 24.3544,  5.2050]])
