E01

When all weights and biases are initialized to zero, the network does not remain completely stuck because the output bias receives a non-zero gradient from the softmaxâ€“cross-entropy loss. As a result, the output bias learns global character frequencies, effectively turning the model into a unigram language model. However, all weight matrices and earlier biases receive zero gradients due to symmetry: the hidden activations are zero and identical for all neurons, so no symmetry breaking occurs. Consequently, only the output bias is trained, the network ignores context, and the final performance is poor.

E02

In [1]:
import torch
import torch.nn.functional as F

torch.manual_seed(42)

# dimensions
nin = 10
nh1 = 64
nh2 = 64
nout = 5
eps = 1e-5

# dummy data
X = torch.randn(512, nin)
Y = torch.randint(0, nout, (512,))

# layer 1
W1 = torch.randn(nin, nh1) * 0.1
b1 = torch.zeros(nh1)
bn1_gamma = torch.ones(nh1)
bn1_beta = torch.zeros(nh1)
bn1_running_mean = torch.zeros(nh1)
bn1_running_var = torch.ones(nh1)

# layer 2
W2 = torch.randn(nh1, nh2) * 0.1
b2 = torch.zeros(nh2)
bn2_gamma = torch.ones(nh2)
bn2_beta = torch.zeros(nh2)
bn2_running_mean = torch.zeros(nh2)
bn2_running_var = torch.ones(nh2)

# output layer (no BN)
W3 = torch.randn(nh2, nout) * 0.1
b3 = torch.zeros(nout)

params = [
    W1, b1, bn1_gamma, bn1_beta,
    W2, b2, bn2_gamma, bn2_beta,
    W3, b3
]
for p in params:
    p.requires_grad = True


In [2]:
for step in range(2000):

    # forward
    h1 = X @ W1 + b1
    mu1 = h1.mean(0)
    var1 = h1.var(0, unbiased=False)

    bn1_running_mean = 0.9 * bn1_running_mean + 0.1 * mu1
    bn1_running_var  = 0.9 * bn1_running_var  + 0.1 * var1

    h1n = (h1 - mu1) / torch.sqrt(var1 + eps)
    h1 = torch.tanh(bn1_gamma * h1n + bn1_beta)

    h2 = h1 @ W2 + b2
    mu2 = h2.mean(0)
    var2 = h2.var(0, unbiased=False)

    bn2_running_mean = 0.9 * bn2_running_mean + 0.1 * mu2
    bn2_running_var  = 0.9 * bn2_running_var  + 0.1 * var2

    h2n = (h2 - mu2) / torch.sqrt(var2 + eps)
    h2 = torch.tanh(bn2_gamma * h2n + bn2_beta)

    logits = h2 @ W3 + b3
    loss = F.cross_entropy(logits, Y)

    # backward
    for p in params:
        p.grad = None
    loss.backward()

    for p in params:
        p.data += -0.1 * p.grad

In [3]:
W1_fused = (bn1_gamma / torch.sqrt(bn1_running_var + eps)) * W1
b1_fused = (
    bn1_gamma / torch.sqrt(bn1_running_var + eps)
) * (b1 - bn1_running_mean) + bn1_beta


In [4]:
W2_fused = (bn2_gamma / torch.sqrt(bn2_running_var + eps)) * W2
b2_fused = (
    bn2_gamma / torch.sqrt(bn2_running_var + eps)
) * (b2 - bn2_running_mean) + bn2_beta


In [5]:
@torch.no_grad()
def forward_with_bn(X):
    h1 = X @ W1 + b1
    h1 = (h1 - bn1_running_mean) / torch.sqrt(bn1_running_var + eps)
    h1 = torch.tanh(bn1_gamma * h1 + bn1_beta)

    h2 = h1 @ W2 + b2
    h2 = (h2 - bn2_running_mean) / torch.sqrt(bn2_running_var + eps)
    h2 = torch.tanh(bn2_gamma * h2 + bn2_beta)

    return h2 @ W3 + b3


In [6]:
@torch.no_grad()
def forward_fused(X):
    h1 = torch.tanh(X @ W1_fused + b1_fused)
    h2 = torch.tanh(h1 @ W2_fused + b2_fused)
    return h2 @ W3 + b3


In [7]:
out1 = forward_with_bn(X)
out2 = forward_fused(X)

print("max absolute difference:", (out1 - out2).abs().max().item())


max absolute difference: 7.62939453125e-06


BatchNorm can be folded into the preceding Linear layer at inference time by absorbing its scale and shift parameters into the weights and biases. After training a 3-layer MLP with BatchNorm, we computed new weights and biases using the running mean and variance together with gamma and beta. Removing BatchNorm and using the fused Linear layers produced identical outputs during inference (up to numerical precision), confirming that BatchNorm can be safely erased at test time.