In [1]:
import numpy as np

def batch_norm(x, gamma=1, beta=0, eps=1e-5, momentum=0.9, running_mean=None, running_var=None, training=True):
    """
    Simple batch normalization implementation
    
    Parameters:
    x: input data (n_samples, n_features)
    gamma: scale parameter (default: 1)
    beta: shift parameter (default: 0)
    eps: small constant for numerical stability
    momentum: for updating running statistics
    running_mean: previous running mean (None for first call)
    running_var: previous running variance (None for first call)
    training: whether in training mode (affects normalization behavior)
    
    Returns:
    normalized_x: batch-normalized output
    new_running_mean: updated running mean
    new_running_var: updated running variance
    """
    
    if training:
        # Calculate batch statistics
        mean = np.mean(x, axis=0)
        var = np.var(x, axis=0)
        
        # Update running statistics (exponentially weighted average)
        if running_mean is None:
            new_running_mean = mean
            new_running_var = var
        else:
            new_running_mean = momentum * running_mean + (1 - momentum) * mean
            new_running_var = momentum * running_var + (1 - momentum) * var
        
        # Normalize using current batch statistics
        x_norm = (x - mean) / np.sqrt(var + eps)
    else:
        # Use running statistics during inference
        if running_mean is None or running_var is None:
            raise ValueError("Running statistics must be provided for inference mode")
        
        x_norm = (x - running_mean) / np.sqrt(running_var + eps)
        new_running_mean = running_mean
        new_running_var = running_var
    
    # Scale and shift
    normalized_x = gamma * x_norm + beta
    
    return normalized_x, new_running_mean, new_running_var

# Generate some random data (batch of 5 samples with 3 features each)
data = np.random.randn(5, 3) * 2 + 5  # mean=5, std=2

# First training pass (no running stats yet)
normalized_data, running_mean, running_var = batch_norm(data, training=True)

print("Original data mean:", np.mean(data, axis=0))
print("Original data std:", np.std(data, axis=0))
print("Normalized data mean:", np.mean(normalized_data, axis=0))
print("Normalized data std:", np.std(normalized_data, axis=0))

# Subsequent training pass (using running stats from previous call)
new_data = np.random.randn(5, 3) * 2 + 5
normalized_data2, running_mean, running_var = batch_norm(
    new_data, 
    running_mean=running_mean, 
    running_var=running_var,
    training=True
)

# Inference mode (using accumulated running stats)
test_sample = np.random.randn(1, 3)
normalized_test, _, _ = batch_norm(
    test_sample,
    running_mean=running_mean,
    running_var=running_var,
    training=False
)

Original data mean: [4.4921557  3.83512478 5.82073526]
Original data std: [1.46873    0.92218634 1.71888805]
Normalized data mean: [4.44089210e-17 5.77315973e-16 3.16413562e-16]
Normalized data std: [0.99999768 0.99999412 0.99999831]
