In [182]:
import torch
import torch.nn as nn


B, T, C = 4, 5, 6


class Model:
    def __init__(self, num_layers, shape, activation="relu", use_ln=False, use_res=False):
        self.layers, self.lns = [], []
        self.activation = self.get_activation(activation)
        self.use_ln, self.use_res = use_ln, use_res
        for _ in range(num_layers):
            self.layers.append(nn.Linear(shape[-1], shape[-1]))
            if use_ln:
                self.lns.append(nn.LayerNorm(shape[-1]))

    def get_activation(self, name):
        if name == "relu":
            fn = nn.ReLU()
        elif name == "sigmoid":
            fn = nn.Sigmoid()
        elif name == "tanh":
            fn = nn.Tanh()
        else:
            raise ValueError(f"unknown activation '{name}'")
        return fn
    
    def __call__(self, x):
        for idx in range(len(self.layers)):
            temp = self.layers[idx](x)
            if self.use_ln:
                temp = self.lns[idx](temp)
            temp = self.activation(temp)
            if self.use_res:
                temp = temp + x
            x = temp
        return x
    
    def print_grad(self):
        for idx in range(len(self.layers)):
            grad = self.layers[idx].weight.grad
            print("layer %d -> min: %.12f, mean: %.12f, max: %.12f, abs_mean: %.12f" % (idx + 1, grad.min(), grad.mean(), grad.max(), grad.abs().mean()))

In [183]:
x = torch.randn((B, T, C), requires_grad=True)
model = Model(10, (B, T, C), activation="relu")
y = model(x)
y = y.mean()
y.backward()
model.print_grad()

layer 1 -> min: -0.000008724685, mean: 0.000000016982, max: 0.000007885125, abs_mean: 0.000002013164
layer 2 -> min: -0.000021151631, mean: -0.000001646425, max: 0.000019664803, abs_mean: 0.000004690145
layer 3 -> min: -0.000068766269, mean: -0.000008068234, max: 0.000003547475, abs_mean: 0.000008744186
layer 4 -> min: -0.000058036661, mean: -0.000001599074, max: 0.000032604883, abs_mean: 0.000005611385
layer 5 -> min: 0.000000000000, mean: 0.000018623370, max: 0.000206593948, abs_mean: 0.000018623370
layer 6 -> min: -0.000622075284, mean: -0.000067921806, max: 0.000000000000, abs_mean: 0.000067921806
layer 7 -> min: -0.002337198704, mean: -0.000007956284, max: 0.001580619952, abs_mean: 0.000262820016
layer 8 -> min: -0.002778632101, mean: 0.002014670055, max: 0.016389533877, abs_mean: 0.002334865741
layer 9 -> min: 0.000000000000, mean: 0.002392121358, max: 0.026031041518, abs_mean: 0.002392121358
layer 10 -> min: 0.000000000000, mean: 0.006217778195, max: 0.059228196740, abs_mean: 0.