In [2]:
import torch
import torch.nn as nn

In [3]:
import sys
sys.path.append('../')
from modular.GPT_architecture.FeedForwardBlock import GELU

In [4]:
class ExampleDeepNeuralNetwork(nn.Module):
    def __init__(self,layer_sizes,use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(layer_sizes[0],layer_sizes[1]),GELU()),
            nn.Sequential(nn.Linear(layer_sizes[1],layer_sizes[2]),GELU()),
            nn.Sequential(nn.Linear(layer_sizes[2],layer_sizes[3]),GELU()),
            nn.Sequential(nn.Linear(layer_sizes[3],layer_sizes[4]),GELU()),
            nn.Sequential(nn.Linear(layer_sizes[4],layer_sizes[5]),GELU()),
        ])
        
    def forward(self,x):
        for layer in self.layers:
            # compute the output of the current layer
            layer_output = layer(x)
            # check if shortcut can be applied
            if self.use_shortcut and x.shape == layer_output.shape:
                x = x+ layer_output
            else:
                x = layer_output
        return x

This code implements a deep neural network with 5 layers, each consisting of a linear layer and a GeLU activation function
* In the forward pass, we iteratively pass the input through the layers and optionally add the shortcut connections.

In [5]:
layer_sizes = [3,3,3,3,3,1]
sample_input = torch.tensor([[1.,0.,-1.]])
torch.manual_seed(123)
model_without_shortcut = ExampleDeepNeuralNetwork(
    layer_sizes,use_shortcut=False
)


In [6]:
def print_gradients(model,x):
    #forward pass
    output = model(x)
    target = torch.tensor([[0.]])
    
    #calculate loss based on how close the target and output are
    loss = nn.MSELoss()
    loss = loss(output,target)
    
    #backward pass to calculate the gradients
    loss.backward()
    for name,param in model.named_parameters():
        if 'weight' in name:
            # print the mean absolute gradient of the weights
            print(f"{name} has gradient mean of {param.grad.abs().mean().item()}")

* we specify a loss function that computes how close the model output and a user-specified target
* Then, when calling loss.backward(), PyTorch computes the loss gradient for each layer in the model
* we can iterate through the weight parameters via model.named_parameters()

In [7]:
print_gradients(model_without_shortcut,sample_input)

layers.0.0.weight has gradient mean of 0.00020173587836325169
layers.1.0.weight has gradient mean of 0.0001201116101583466
layers.2.0.weight has gradient mean of 0.0007152041653171182
layers.3.0.weight has gradient mean of 0.001398873864673078
layers.4.0.weight has gradient mean of 0.005049646366387606


Let's now instantiate a model with skip connections and see how it compares

In [8]:
torch.manual_seed(123)
model_with_shortcut = ExampleDeepNeuralNetwork(
    layer_sizes,use_shortcut=True
)
print_gradients(model_with_shortcut,sample_input)

layers.0.0.weight has gradient mean of 0.2216978669166565
layers.1.0.weight has gradient mean of 0.20694100856781006
layers.2.0.weight has gradient mean of 0.3289698660373688
layers.3.0.weight has gradient mean of 0.2665731906890869
layers.4.0.weight has gradient mean of 1.3258538246154785


As we can see, based on the output, the last layer (layers.4) still has a larger gradient than the other layers.
* However, the gradient value stabilizes as we progress towards the first layer (lyers.0) and doesn't shrink to a vanishly small value