In [1]:
import torch
a = torch.ones(5)
a.requires_grad = True

In [2]:
b = 2 * a
b.retain_grad() # to calculate the gradient of non-lead node
c = b.mean()

In [3]:
c.backward()
print(a.grad, b.grad)

tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000]) tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])


In [4]:
# Redo the experiment but with a hook that multiplies b's grad by 2
a = torch.ones(5)
a.requires_grad = True

b = 2 * a
b.retain_grad()

b.register_hook(lambda x: print(x))

b.mean().backward()

print(a.grad, b.grad)

tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000]) tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])


In [34]:
import torch
from torch import nn
import torch.nn.functional as F

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.fc = nn.Linear(160, 5)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = x.flatten(start_dim=1)
        return self.fc(x)

In [33]:
a = torch.ones(5)
a.requires_grad = True

b = 2 * a
b.retain_grad()

b.register_hook(lambda x: x * x)

b.mean().backward()

print(a.grad, b.grad)

tensor([0.0800, 0.0800, 0.0800, 0.0800, 0.0800]) tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])


In [36]:
def hook_fn(module, grad_input, grad_output):
    print(module)
    print('---------------Input grad--------------------')

    for grad in grad_input:
        try:
            print(grad.shape)
        except AttributeError:
            print('None found for Gradient')

    print('---------------Output grad--------------------')
    for grad in grad_output:
        try:
            print(grad.shape)
        except AttributeError:
            print('None found for Gradient')

In [41]:
model = MyNet()
model.conv.register_backward_hook(hook_fn)
model.fc.register_backward_hook(hook_fn)

sample_input = torch.randn((1, 3, 8, 8))
sample_output = model(sample_input)

(1 - sample_output.mean()).backward()

Linear(in_features=160, out_features=5, bias=True)
---------------Input grad--------------------
torch.Size([5])
torch.Size([1, 160])
torch.Size([160, 5])
---------------Output grad--------------------
torch.Size([1, 5])
Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))
---------------Input grad--------------------
None found for Gradient
torch.Size([10, 3, 2, 2])
torch.Size([10])
---------------Output grad--------------------
torch.Size([1, 10, 4, 4])


In [57]:
import torch
from torch import nn
import torch.nn.functional as F

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.fc = nn.Linear(160, 5)

    def forward(self, x):
        x = F.relu(self.conv(x))
        # No gradient shall be backprop conv less than 0
        x.register_hook(lambda grad: torch.clamp(grad, min=0))
        # Print whether there is any negative grad
        x.register_hook(lambda grad: print(f'Gradients less than 0: {bool((grad < 0).any())gras.shap}'))
        x = x.flatten(start_dim=1)
        return self.fc(x)

In [59]:
model = MyNet()

for name, param in model.named_parameters():
    # if the param is from a linear and is a bias
    if 'fc' in name and 'bias' in name:
        param.register_hook(lambda grad: torch.zeros(grad.shape))

pred = model(torch.randn((1, 3, 8, 8)))
(1 - pred).mean().backward()

print(f'The biases are: {model.fc.bias.grad}')

Gradients less than 0: torch.Size([1, 10, 4, 4])
The biases are: tensor([0., 0., 0., 0., 0.])


In [66]:
visualization = {}

def hook_fn(m, i, o):
    visualization[m] = o

model = MyNet()

for name, layer in model._modules.items():
    # layer.register_forward_hook(hook_fn)
    print(f'{name}\t{layer}')
    layer.register_forward_hook(hook_fn)
    
input = torch.randn((1, 3, 8, 8))

pred = model(input)

conv	Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))
fc	Linear(in_features=160, out_features=5, bias=True)


In [1]:
import torch
from torch import nn
import torch.nn.functional as F

class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 10, 2, stride=2)
        self.fc = nn.Linear(160, 5)
        self.seq = nn.Sequential(nn.Linear(5, 3), nn.Linear(3, 2))

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return self.seq(x)

In [3]:
visualization = {}

def hook_fn(m, i, o):
    visualization[m] = o

def get_all_layers(model):
    for name, layer in model._modules.items():
        # If it is a sequential, don't register a hook on it but recursively register hook on all it's module children
        if isinstance(layer, nn.Sequential):
            get_all_layers(layer)
        else:
            # it's a non sequential. Register a hook
            layer.register_forward_hook(hook_fn)

model = MyNet()
get_all_layers(model)
pred = model(torch.randn((1, 3, 8, 8)))
for m, o in visualization.items():
    print(f'{m}\t{o.shape}')

Conv2d(3, 10, kernel_size=(2, 2), stride=(2, 2))	torch.Size([1, 10, 4, 4])
Linear(in_features=160, out_features=5, bias=True)	torch.Size([1, 5])
Linear(in_features=5, out_features=3, bias=True)	torch.Size([1, 3])
Linear(in_features=3, out_features=2, bias=True)	torch.Size([1, 2])
