In [None]:
import torch
import torch.nn as nn

In [None]:
torch.set_printoptions(sci_mode=False)

In [None]:
torch.manual_seed(123)
batch_example = torch.randn(2, 5)
layers = nn.Sequential(
    nn.Linear(5, 6),
    nn.ReLU()
)
out = layers(batch_example)
out

In [None]:
mean = out.mean(dim=-1, keepdim=True)
var = out.var(dim=-1, keepdim=True)
print(f"mean={mean}")
print(f"var={var}")
var.shape

In [None]:
out_norm = (out - mean) / torch.sqrt(var)
mean = out_norm.mean(dim=-1, keepdim=True)
var = out_norm.var(dim=-1, keepdim=True)
print(f"mean={mean}")
print(f"var={var}")

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
        self.eps = 1e-5

    def forward(self, x):
        # x is (B, T, emb_dim)
        mean = x.mean(dim=-1, keepdim=True)
        # should be correction=0 in PT 2.0
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * x_norm + self.shift

In [None]:
ln = LayerNorm(emb_dim=5)
out = ln(batch_example)
out.shape

In [None]:
mean = out_norm.mean(dim=-1, keepdim=True)
var = out_norm.var(dim=-1, keepdim=True)
print(f"mean={mean}")
print(f"var={var}")

In [None]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))))

In [None]:
a = torch.linspace(-1, 1, 10); a

In [None]:
import matplotlib.pyplot as plt
gelu, relu = GELU(), nn.ReLU()

x = torch.linspace(-3, 3, 100)
y_gelu, y_relu = gelu(x), relu(x)
plt.figure(figsize=(8, 3))

for i, (y, label) in enumerate(zip([y_gelu, y_relu], ['GELU', 'RELU']), 1):
    plt.subplot(1, 2, i)
    plt.plot(x, y)
    plt.title(f"{label} activation")
    plt.xlabel("x")
    plt.ylabel(f"{label}(x)")
    plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
for p in layers.parameters():
    print(p.numel())