In [1]:
import numpy as np

In [2]:
class BatchNormalization:
    def __init__(self, momentum=0.9, epsilon=1e-5):
        self.momentum = momentum
        self.epsilon = epsilon
        self.running_mean = None
        self.running_var = None
        self.gamma = 1.0  # Scale parameter
        self.beta = 0.0   # Shift parameter
        
    def forward(self, x, training=True):
        if training:
            mean = np.mean(x, axis=0, keepdims=True)
            var = np.var(x, axis=0, keepdims=True)
            
            # Update running statistics
            if self.running_mean is None:
                self.running_mean = mean
                self.running_var = var
            
            else:
                self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mean
                self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
        
        else:
            mean = self.running_mean
            var = self.running_var
        
        # Normalize
        x_norm = (x - mean) / np.sqrt(var + self.epsilon)
        
        # Scale and shift
        out = self.gamma * x_norm + self.beta
        return out

In [3]:
np.random.seed(42)
x = np.random.randn(10, 5)

bn = BatchNormalization()
x_normalized = bn.forward(x, training=True)