In [13]:
import torch
import torch.nn as nn
from torch.func import functional_call, grad
import torch.nn.functional as F
from torchinfo import summary


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))


net = Model()
net.conv2.requires_grad_(False)

summary(net, (1, 1, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
Model                                    [1, 20, 20, 20]           --
├─Conv2d: 1-1                            [1, 20, 24, 24]           520
├─Conv2d: 1-2                            [1, 20, 20, 20]           (10,020)
Total params: 10,540
Trainable params: 520
Non-trainable params: 10,020
Total mult-adds (M): 4.31
Input size (MB): 0.00
Forward/backward pass size (MB): 0.16
Params size (MB): 0.04
Estimated Total Size (MB): 0.20

In [16]:
x = torch.randn(1, 1, 28, 28).cuda()
t = torch.randn(1, 20, 20, 20).cuda()


def compute_loss(params, x, t):
    y = functional_call(net, params, x)
    return nn.functional.mse_loss(y, t)


params = {
    k: v
    for k, v in dict(net.named_parameters()).items()
    if v.requires_grad
}

grad_weights = grad(compute_loss)(params, x, t)
print(grad_weights.keys())

dict_keys(['conv1.weight', 'conv1.bias'])
