In [1]:
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class GELU(nn.Module):    
    def __init__(self, ):
        super().__init__()
    
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(
            torch.sqrt(torch.tensor(2.0/torch.pi)) *
            (x + 0.44715 * torch.pow(x, 3))
            ))

### Shortcut Connections

In [5]:
class ExampleDeepNeuralNetowrk(nn.Module):
    
    def __init__(self, layer_sizes, use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4]), GELU()),
            nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5]), GELU()),
        ])
        
    def forward(self, x):
        for layer in self.layers:
            # compute the output of the current layer
            layer_output = layer(x)
            # check if shortcut can be applied
            if self.use_shortcut:
                x = x + layer_output
            else:
                x = layer_output
        return x

In [15]:
layer_sizes = [3, 3, 3, 3, 3, 1]
sample_input = torch.tensor([[1., 0., -1.]])
torch.manual_seed(123)
model_without_shortcut = ExampleDeepNeuralNetowrk(layer_sizes, use_shortcut=False)

In [19]:
def print_gradients(model, x):
    # Forward pass
    output = model(x)
    target = torch.tensor([[0.]])

    # Calculate loss based on how close the target
    # and output are
    loss = nn.MSELoss()
    loss = loss(output, target)
    
    # Backward pass to calculate the gradients
    loss.backward()

    for name, param in model.named_parameters():
        if 'weight' in name:
            # Print the mean absolute gradient of the weights
            print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")

In [20]:
# Gradients are vanishing when we move from latter layers to earlier ones
print_gradients(model_without_shortcut, sample_input)

layers.0.0.weight has gradient mean of 0.00040942657506093383
layers.1.0.weight has gradient mean of 0.000246359093580395
layers.2.0.weight has gradient mean of 0.0014689492527395487
layers.3.0.weight has gradient mean of 0.0027742418460547924
layers.4.0.weight has gradient mean of 0.010052990168333054


In [21]:
torch.manual_seed(123)
model_with_shortcut = ExampleDeepNeuralNetowrk(layer_sizes, use_shortcut=True)
print_gradients(model_with_shortcut, sample_input)

layers.0.0.weight has gradient mean of 0.9476073384284973
layers.1.0.weight has gradient mean of 0.8769886493682861
layers.2.0.weight has gradient mean of 1.3534194231033325
layers.3.0.weight has gradient mean of 1.0655364990234375
layers.4.0.weight has gradient mean of 2.9612925052642822


  return F.mse_loss(input, target, reduction=self.reduction)
