In [None]:
import numpy as np

# ---- BatchNorm "live" step-by-step demo ----
# We simulate a Dense layer's pre-activations z for a batch of 4 samples and 3 units.
z = np.array([
    [1.0,  2.0, 3.0],
    [2.0,  0.0, 4.0],
    [0.0, -1.0, 2.0],
    [1.0,  1.0, 5.0],
], dtype=np.float32)

epsilon = 1e-5  # numerical stability, typical Keras default around 1e-3..1e-5
gamma   = np.array([1.0, 0.5, 2.0], dtype=np.float32)  # learnable scale per feature
beta    = np.array([0.0, 1.0, -1.0], dtype=np.float32) # learnable shift per feature

print("Pre-activation z (shape batch=4, features=3):\n", z, "\n")

# 1) Batch statistics (per feature/column)
mu_B  = z.mean(axis=0)                   # batch mean
var_B = z.var(axis=0)                    # batch variance (population, ddof=0)

print("Batch mean (mu_B):", mu_B)
print("Batch variance (var_B):", var_B, "\n")

# 2) Normalize
z_hat = (z - mu_B) / np.sqrt(var_B + epsilon)

# 3) Affine transform with learnable parameters
y = gamma * z_hat + beta

print("Normalized z (z_hat):\n", np.round(z_hat, 4), "\n")
print("Mean(z_hat) per feature (≈0):", np.round(z_hat.mean(axis=0), 6))
print("Var(z_hat)  per feature (≈1):", np.round(z_hat.var(axis=0), 6), "\n")

print("gamma:", gamma, "beta:", beta)
print("BatchNorm output y = gamma*z_hat + beta:\n", np.round(y, 4), "\n")

# 4) (Optional) What happens if the previous layer shifts/scales z?
z_shifted = 2.0 * z + 10.0  # strong scale and shift
mu_B_s  = z_shifted.mean(axis=0)
var_B_s = z_shifted.var(axis=0)
z_hat_s = (z_shifted - mu_B_s) / np.sqrt(var_B_s + epsilon)
y_s     = gamma * z_hat_s + beta

print("After strong shift/scale (z' = 2*z + 10):")
print("mu_B':", mu_B_s, "var_B':", var_B_s)
print("Normalized z' (z_hat'):\n", np.round(z_hat_s, 4))
print("Mean(z_hat') per feature (≈0):", np.round(z_hat_s.mean(axis=0), 6))
print("Var(z_hat')  per feature (≈1):", np.round(z_hat_s.var(axis=0), 6))
print("Output y' with same gamma,beta:\n", np.round(y_s, 4), "\n")

# 5) (Optional) One-step running stats update (as in inference time usage)
momentum = 0.9
running_mean = np.zeros_like(mu_B)
running_var  = np.ones_like(var_B)

running_mean = momentum * running_mean + (1.0 - momentum) * mu_B
running_var  = momentum * running_var  + (1.0 - momentum) * var_B

print("One-step running stats update (momentum=0.9):")
print("running_mean:", np.round(running_mean, 6))
print("running_var :", np.round(running_var, 6))
