In [1]:
import torch
import torch.nn as nn

In [2]:
class DeepNN(nn.Module):
    def __init__(self, layer_sizes: list[int], use_shortcut: bool):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(layer_sizes[i], layer_sizes[i+1]),
                nn.GELU()
            ) for i in range(len(layer_sizes) - 1)
        ])
    
    def forward(self, x):
        for layer in self.layers:
            x = x + layer(x) if self.use_shortcut else layer(x)
        return x

In [3]:
layer_sizes = [3, 3, 3, 3, 3, 1]
sample_input = torch.tensor([[1., 0., -1.]])
torch.manual_seed(123)
model_without_res = DeepNN(layer_sizes, use_shortcut=False)
model_with_res = DeepNN(layer_sizes, use_shortcut=True)

In [4]:
def print_gradients(model, x):
    output = model(x)
    target = torch.tensor([[0.]])  # dummy target
    loss = nn.MSELoss()
    loss = loss(output, target)
    loss.backward()
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")

In [5]:
print_gradients(model_without_res, sample_input)
# deep neural nets suffer from vanishing gradients as can be seen below

layers.0.0.weight has gradient mean of 0.00020174124801997095
layers.1.0.weight has gradient mean of 0.00012011772923870012
layers.2.0.weight has gradient mean of 0.0007152438047342002
layers.3.0.weight has gradient mean of 0.0013988513965159655
layers.4.0.weight has gradient mean of 0.005049603525549173


In [6]:
print_gradients(model_with_res, sample_input)
# adding residual connections helps to mitigate vanishing gradients

layers.0.0.weight has gradient mean of 0.6178960800170898
layers.1.0.weight has gradient mean of 0.15985536575317383
layers.2.0.weight has gradient mean of 0.3972354233264923
layers.3.0.weight has gradient mean of 0.44717708230018616
layers.4.0.weight has gradient mean of 1.3972887992858887


  return F.mse_loss(input, target, reduction=self.reduction)
