I'll show you the complete mathematical formulation for the batch normalization backward pass. Let me create a clear step-by-step derivation:

## Batch Normalization Backward Pass - Mathematical Formulation

### Forward Pass Recap

Given input $\mathbf{x} \in \mathbb{R}^{N \times D}$, the forward pass computes:

$$\begin{aligned}
\mu_j &= \frac{1}{N} \sum_{i=1}^N x_{i,j} && \text{(mean)} \\
\sigma^2_j &= \frac{1}{N} \sum_{i=1}^N (x_{i,j} - \mu_j)^2 && \text{(variance)} \\
\hat{x}_{i,j} &= \frac{x_{i,j} - \mu_j}{\sqrt{\sigma^2_j + \epsilon}} && \text{(normalization)} \\
y_{i,j} &= \gamma_j \hat{x}_{i,j} + \beta_j && \text{(scale and shift)}
\end{aligned}$$

### Backward Pass Derivation

Given upstream gradient $\frac{\partial \mathcal{L}}{\partial y_{i,j}}$ (denoted as $dy_{i,j}$ or `dout`), compute:

#### **Step 1: Gradient w.r.t. β (shift parameter)**

$$\frac{\partial \mathcal{L}}{\partial \beta_j} = \sum_{i=1}^N \frac{\partial \mathcal{L}}{\partial y_{i,j}} = \sum_{i=1}^N dy_{i,j}$$

**Code:** `dbeta = np.sum(dout, axis=0)`

---

#### **Step 2: Gradient w.r.t. γ (scale parameter)**

$$\frac{\partial \mathcal{L}}{\partial \gamma_j} = \sum_{i=1}^N \frac{\partial \mathcal{L}}{\partial y_{i,j}} \cdot \frac{\partial y_{i,j}}{\partial \gamma_j} = \sum_{i=1}^N dy_{i,j} \cdot \hat{x}_{i,j}$$

**Code:** `dgamma = np.sum(dout * x_norm, axis=0)`

---

#### **Step 3: Gradient w.r.t. normalized x**

$$\frac{\partial \mathcal{L}}{\partial \hat{x}_{i,j}} = \frac{\partial \mathcal{L}}{\partial y_{i,j}} \cdot \frac{\partial y_{i,j}}{\partial \hat{x}_{i,j}} = dy_{i,j} \cdot \gamma_j$$

**Code:** `dx_norm = dout * gamma`

---

#### **Step 4: Gradient w.r.t. variance**

Since $\hat{x}_{i,j} = (x_{i,j} - \mu_j) \cdot (\sigma^2_j + \epsilon)^{-1/2}$:

$$\frac{\partial \hat{x}_{i,j}}{\partial \sigma^2_j} = (x_{i,j} - \mu_j) \cdot \left(-\frac{1}{2}\right)(\sigma^2_j + \epsilon)^{-3/2}$$

Therefore:

$$\frac{\partial \mathcal{L}}{\partial \sigma^2_j} = \sum_{i=1}^N \frac{\partial \mathcal{L}}{\partial \hat{x}_{i,j}} \cdot (x_{i,j} - \mu_j) \cdot \left(-\frac{1}{2}\right)(\sigma^2_j + \epsilon)^{-3/2}$$

**Code:** `dvar = np.sum(dx_norm * x_centered, axis=0) * -0.5 * (sample_var + eps) ** (-1.5)`

---

#### **Step 5: Gradient w.r.t. mean**

The mean $\mu_j$ affects the loss through TWO paths:
1. Direct path through $\hat{x}_{i,j}$
2. Indirect path through $\sigma^2_j$

$$\frac{\partial \mathcal{L}}{\partial \mu_j} = \underbrace{\sum_{i=1}^N \frac{\partial \mathcal{L}}{\partial \hat{x}_{i,j}} \cdot \frac{\partial \hat{x}_{i,j}}{\partial \mu_j}}_{\text{direct}} + \underbrace{\frac{\partial \mathcal{L}}{\partial \sigma^2_j} \cdot \frac{\partial \sigma^2_j}{\partial \mu_j}}_{\text{through variance}}$$

Where:
- $\frac{\partial \hat{x}_{i,j}}{\partial \mu_j} = -(\sigma^2_j + \epsilon)^{-1/2}$
- $\frac{\partial \sigma^2_j}{\partial \mu_j} = -\frac{2}{N}\sum_{i=1}^N (x_{i,j} - \mu_j)$

Thus:

$$\frac{\partial \mathcal{L}}{\partial \mu_j} = \sum_{i=1}^N \frac{\partial \mathcal{L}}{\partial \hat{x}_{i,j}} \cdot \left(-\frac{1}{\sqrt{\sigma^2_j + \epsilon}}\right) + \frac{\partial \mathcal{L}}{\partial \sigma^2_j} \cdot \left(-\frac{2}{N}\sum_{i=1}^N (x_{i,j} - \mu_j)\right)$$

**Code:** `dmu = np.sum(dx_norm * -1.0 / std, axis=0) + dvar * np.sum(-2.0 * x_centered, axis=0) / N`

---

#### **Step 6: Gradient w.r.t. input x**

The input $x_{i,j}$ affects the loss through THREE paths:
1. Direct path through $\hat{x}_{i,j}$
2. Path through $\sigma^2_j$
3. Path through $\mu_j$

$$\frac{\partial \mathcal{L}}{\partial x_{i,j}} = \underbrace{\frac{\partial \mathcal{L}}{\partial \hat{x}_{i,j}} \cdot \frac{\partial \hat{x}_{i,j}}{\partial x_{i,j}}}_{\text{direct}} + \underbrace{\frac{\partial \mathcal{L}}{\partial \sigma^2_j} \cdot \frac{\partial \sigma^2_j}{\partial x_{i,j}}}_{\text{through variance}} + \underbrace{\frac{\partial \mathcal{L}}{\partial \mu_j} \cdot \frac{\partial \mu_j}{\partial x_{i,j}}}_{\text{through mean}}$$

Where:
- $\frac{\partial \hat{x}_{i,j}}{\partial x_{i,j}} = (\sigma^2_j + \epsilon)^{-1/2}$
- $\frac{\partial \sigma^2_j}{\partial x_{i,j}} = \frac{2}{N}(x_{i,j} - \mu_j)$
- $\frac{\partial \mu_j}{\partial x_{i,j}} = \frac{1}{N}$

Therefore:

$$\boxed{\frac{\partial \mathcal{L}}{\partial x_{i,j}} = \frac{\partial \mathcal{L}}{\partial \hat{x}_{i,j}} \cdot \frac{1}{\sqrt{\sigma^2_j + \epsilon}} + \frac{\partial \mathcal{L}}{\partial \sigma^2_j} \cdot \frac{2(x_{i,j} - \mu_j)}{N} + \frac{\partial \mathcal{L}}{\partial \mu_j} \cdot \frac{1}{N}}$$

**Code:** `dx = dx_norm / std + dvar * 2.0 * x_centered / N + dmu / N`

---

### Summary: Complete Backward Pass Algorithm

```
Input: dout (upstream gradient), cache (x, μ, σ², ε, γ, β)

1. Compute intermediate values:
   std = √(σ² + ε)
   x_centered = x - μ
   x_norm = x_centered / std

2. Compute gradients:
   dbeta = Σᵢ dout[i,j]
   dgamma = Σᵢ dout[i,j] · x_norm[i,j]
   dx_norm = dout · γ
   dvar = Σᵢ dx_norm[i,j] · x_centered[i,j] · (-½)(σ² + ε)^(-3/2)
   dmu = Σᵢ dx_norm[i,j] · (-1/std) + dvar · Σᵢ(-2·x_centered[i,j])/N
   dx = dx_norm/std + dvar·2·x_centered/N + dmu/N

Output: dx, dgamma, dbeta
```

This formulation correctly handles the complex dependencies in the batch normalization computational graph!