<a href="https://colab.research.google.com/github/howardatri/notes/blob/main/BatchNorm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np

class BatchNorm:
    def __init__(self, num_features, epsilon=1e-5, momentum=0.9):
        self.num_features = num_features
        self.epsilon = epsilon
        self.momentum = momentum

        # Learnable parameters
        self.gamma = np.ones(num_features)
        self.beta = np.zeros(num_features)

        # Running mean and variance for inference
        self.running_mean = np.zeros(num_features)
        self.running_var = np.ones(num_features)

        # Variables for backpropagation
        self.x_norm = None
        self.x_centered = None
        self.std_inv = None
        self.var = None
        self.mean = None
        self.N = 0

    def forward(self, x, training=True):
        """
        Forward pass for batch normalization.

        Args:
            x (np.ndarray): Input data of shape (N, D) where N is batch size and D is number of features.
            training (bool): Whether in training mode.

        Returns:
            np.ndarray: Output after batch normalization.
        """
        if training:
            self.N = x.shape[0]
            self.mean = np.mean(x, axis=0)
            self.var = np.var(x, axis=0)

            # Update running mean and variance
            self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * self.mean
            self.running_var = self.momentum * self.running_var + (1 - self.momentum) * self.var

            self.x_centered = x - self.mean
            self.var += self.epsilon
            self.std_inv = 1 / np.sqrt(self.var)
            self.x_norm = self.x_centered * self.std_inv
            out = self.gamma * self.x_norm + self.beta

        else:
            # Use running mean and variance for inference
            x_norm = (x - self.running_mean) / np.sqrt(self.running_var + self.epsilon)
            out = self.gamma * x_norm + self.beta

        return out

    def backward(self, dout):
        """
        Backward pass for batch normalization.

        Args:
            dout (np.ndarray): Gradient of the loss with respect to the output of batch normalization.

        Returns:
            tuple: Gradients of the loss with respect to input (dx), gamma (dgamma), and beta (dbeta).
        """
        dbeta = np.sum(dout, axis=0)
        dgamma = np.sum(dout * self.x_norm, axis=0)

        dx_norm = dout * self.gamma
        dstd_inv = np.sum(dx_norm * self.x_centered, axis=0)
        dvar = dstd_inv * (-0.5) * (self.var)**(-1.5)

        dmean = np.sum(dx_norm * (-self.std_inv), axis=0) + dvar * (-2 / self.N) * np.sum(self.x_centered, axis=0)

        dx = dx_norm * self.std_inv + dvar * (2 / self.N) * self.x_centered + dmean / self.N

        return dx, dgamma, dbeta

# Example usage:
# Assume x is your input data with shape (batch_size, num_features)
# num_features = 10
# batch_size = 32
# x = np.random.randn(batch_size, num_features)

# bn = BatchNorm(num_features)

# Forward pass in training mode
# out_train = bn.forward(x, training=True)
# print("Output in training mode:\n", out_train)

# Forward pass in inference mode (after training)
# out_inference = bn.forward(x, training=False)
# print("Output in inference mode:\n", out_inference)

# Example backward pass (assuming you have a gradient dout)
# dout = np.random.randn(*out_train.shape)
# dx, dgamma, dbeta = bn.backward(dout)
# print("Gradient with respect to input:\n", dx)
# print("Gradient with respect to gamma:\n", dgamma)
# print("Gradient with respect to beta:\n", dbeta)