# Batch Normalization: Spline Partition Perspective

## Overview

This notebook demonstrates key insights from the paper "Batch Normalization Explained" which shows that:

1. **Spline Partitions**: ReLU networks partition input space into linear regions where the network behaves as an affine transformation
2. **BN Alignment**: Batch Normalization moves partition boundaries to align with training data
3. **Margin Effect**: BN increases decision boundary margins, improving generalization

## Key Concepts

### What are Spline Partitions?

Modern networks use piecewise linear activations (ReLU, Leaky-ReLU). The composition of piecewise linear functions creates a **Continuous Piecewise Affine (CPA) spline** that partitions input space into:

- **Linear regions**: Convex regions where network acts as simple linear transformation
- **Partition boundaries**: Hyperplanes defined by where neuron pre-activations equal zero
- **Folding process**: Multiple layers fold and refold these boundaries creating complex partitions

### How Does BN Help?

BN doesn't just normalize activations - it **smartly repositions partition boundaries**:

- **Without BN**: Boundaries spread uniformly across input space (wasted capacity)
- **With BN**: Boundaries concentrate densely around training data (efficient use of capacity)
- **Result**: Better generalization through larger effective margins

## Reference

- Paper: [[Batch Normalization Explained.pdf]] in ML Notes vault
- Deepkit utilities: `spline_partitions.py`, `boundary_analysis.py`

## Setup & Imports

Import necessary libraries and configure plotting.

In [None]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_circles, make_moons, make_classification
from IPython.display import clear_output
import sys
sys.path.append('../src')

from deepkit.spline_partitions import (
    visualize_partitions_2d,
    compute_partition_entropy
)
from deepkit.boundary_analysis import (
    compute_decision_margin,
    compute_boundary_alignment,
    compute_gradient_norm_statistics
)

plt.style.use('seaborn-v0_8-darkgrid')
print('Imports successful!')

## Dataset: 2D Synthetic Classification

We use 2D datasets to visualize spline partitions and decision boundaries clearly.

In [None]:
# Generate 2D dataset with clear non-linear decision boundary
X, y = make_circles(
    n_samples=500,
    noise=0.1,
    factor=0.5,
    random_state=42
)

X = jnp.array(X)
y = jnp.array(y)

# Visualize dataset
fig, ax = plt.subplots(figsize=(8, 8))
ax.scatter(X[:, 0], X[:, 1], c=y, cmap='RdYlBu', alpha=0.6, edgecolors='black')
ax.set_title('2D Synthetic Dataset: Concentric Circles')
ax.set_xlabel('Feature 1')
ax.set_ylabel('Feature 2')
plt.tight_layout()
plt.show()

print(f'Dataset shape: {X.shape}')
print(f'Classes: {jnp.unique(y)}')

## Model Definitions

Define two simple MLP architectures:
1. **Without BN**: Standard ReLU network
2. **With BN**: Same architecture with BatchNorm after each linear layer

In [None]:
class MLPWithoutBN(nnx.Module):
    """Simple MLP without Batch Normalization"""
    def __init__(self, rngs: nnx.Rngs):
        self.hidden1 = nnx.Linear(2, 64, rngs=rngs)
        self.hidden2 = nnx.Linear(64, 64, rngs=rngs)
        self.output = nnx.Linear(64, 2, rngs=rngs)
    
    def __call__(self, x):
        x = self.hidden1(x)
        x = nnx.relu(x)
        x = self.hidden2(x)
        x = nnx.relu(x)
        return self.output(x)


class MLPWithBN(nnx.Module):
    """Same MLP with Batch Normalization"""
    def __init__(self, rngs: nnx.Rngs):
        self.hidden1 = nnx.Linear(2, 64, rngs=rngs)
        self.bn1 = nnx.BatchNorm(64, rngs=rngs)
        self.hidden2 = nnx.Linear(64, 64, rngs=rngs)
        self.bn2 = nnx.BatchNorm(64, rngs=rngs)
        self.output = nnx.Linear(64, 2, rngs=rngs)
    
    def __call__(self, x):
        x = self.hidden1(x)
        x = self.bn1(x)
        x = nnx.relu(x)
        x = self.hidden2(x)
        x = self.bn2(x)
        x = nnx.relu(x)
        return self.output(x)

# Initialize models
key = jax.random.key(1337)
model_no_bn = MLPWithoutBN(rngs=nnx.Rngs(key))
model_with_bn = MLPWithBN(rngs=nnx.Rngs(key))

print('Models initialized successfully!')
print(f'Without BN parameters: {sum(x.size for x in jax.tree_util.tree_leaves(model_no_bn) if isinstance(x, jnp.ndarray)):,}')
print(f'With BN parameters: {sum(x.size for x in jax.tree_util.tree_leaves(model_with_bn) if isinstance(x, jnp.ndarray)):,}')

## Visualization: Untrained Networks

Before training, let's visualize the random decision boundaries.

In [None]:
# Set both models to eval mode for stable visualization
model_no_bn.eval()
model_with_bn.eval()

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

visualize_partitions_2d(
    model_no_bn, X, y,
    ax=axes[0],
    title='Without Batch Normalization (Untrained)'
)

visualize_partitions_2d(
    model_with_bn, X, y,
    ax=axes[1],
    title='With Batch Normalization (Untrained)'
)

plt.tight_layout()
plt.show()

print('Notice: Both untrained models show random decision boundaries')

## Training Setup

Configure training with identical hyperparameters for both models.

In [None]:
# Training hyperparameters
lr = 0.01
momentum = 0.9
num_epochs = 50
batch_size = 64

# Optimizers
opt_no_bn = nnx.Optimizer(model_no_bn, optax.sgd(learning_rate=lr, momentum=momentum), wrt=nnx.Param)
opt_with_bn = nnx.Optimizer(model_with_bn, optax.sgd(learning_rate=lr, momentum=momentum), wrt=nnx.Param)

# Loss function
def loss_fn(model, batch, labels):
    logits = model(batch)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss

# Accuracy function
@nnx.jit
def accuracy(model, batch, labels):
    logits = model(batch)
    preds = jnp.argmax(logits, axis=-1)
    return jnp.mean(preds == labels)

# Training step
@nnx.jit
def train_step(model, optimizer, batch, labels):
    loss, grads = nnx.value_and_grad(loss_fn)(model, batch, labels)
    optimizer.update(model, grads)
    return loss

print('Training configured!')
print(f'Learning rate: {lr}')
print(f'Epochs: {num_epochs}')
print(f'Batch size: {batch_size}')

## Training Loop

Train both models simultaneously and track their progress.

In [None]:
# Training history
history = {
    'no_bn': {'loss': [], 'acc': []},
    'with_bn': {'loss': [], 'acc': []}
}

# Convert to batches
num_batches = len(X) // batch_size

try:
    for epoch in range(num_epochs):
        model_no_bn.train()
        model_with_bn.train()
        
        epoch_loss_no_bn = 0.0
        epoch_loss_with_bn = 0.0
        epoch_acc_no_bn = 0.0
        epoch_acc_with_bn = 0.0
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            batch_X = X[start_idx:end_idx]
            batch_y = y[start_idx:end_idx]
            
            # Train both models
            loss_no_bn = train_step(model_no_bn, opt_no_bn, batch_X, batch_y)
            loss_with_bn = train_step(model_with_bn, opt_with_bn, batch_X, batch_y)
            
            # Compute accuracy
            acc_no_bn = accuracy(model_no_bn, batch_X, batch_y)
            acc_with_bn = accuracy(model_with_bn, batch_X, batch_y)
            
            epoch_loss_no_bn += float(loss_no_bn)
            epoch_loss_with_bn += float(loss_with_bn)
            epoch_acc_no_bn += float(acc_no_bn)
            epoch_acc_with_bn += float(acc_with_bn)
        
        # Average metrics
        history['no_bn']['loss'].append(epoch_loss_no_bn / num_batches)
        history['with_bn']['loss'].append(epoch_loss_with_bn / num_batches)
        history['no_bn']['acc'].append(epoch_acc_no_bn / num_batches)
        history['with_bn']['acc'].append(epoch_acc_with_bn / num_batches)
        
        # Visualize progress
        if epoch % 10 == 0 or epoch == num_epochs - 1:
            clear_output(wait=True)
            
            fig, axes = plt.subplots(1, 2, figsize=(14, 4))
            
            # Loss curves
            axes[0].plot(history['no_bn']['loss'], label='No BN', alpha=0.7)
            axes[0].plot(history['with_bn']['loss'], label='With BN', alpha=0.7)
            axes[0].set_title('Training Loss')
            axes[0].set_xlabel('Epoch')
            axes[0].set_ylabel('Loss')
            axes[0].legend()
            
            # Accuracy curves
            axes[1].plot(history['no_bn']['acc'], label='No BN', alpha=0.7)
            axes[1].plot(history['with_bn']['acc'], label='With BN', alpha=0.7)
            axes[1].set_title('Training Accuracy')
            axes[1].set_xlabel('Epoch')
            axes[1].set_ylabel('Accuracy')
            axes[1].legend()
            
            plt.tight_layout()
            plt.show()
            
            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f"  No BN   - Loss: {history['no_bn']['loss'][-1]:.4f}, Acc: {history['no_bn']['acc'][-1]:.4f}")
            print(f"  With BN - Loss: {history['with_bn']['loss'][-1]:.4f}, Acc: {history['with_bn']['acc'][-1]:.4f}")

except KeyboardInterrupt:
    print('\nTraining interrupted!')

print('\nTraining complete!')

## Visualization: Trained Networks

Now let's see how the decision boundaries have changed after training.

In [None]:
# Set models to eval mode for visualization
model_no_bn.eval()
model_with_bn.eval()

fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# Untrained vs trained for no BN
key_untrained = jax.random.key(1337)
model_untrained = MLPWithoutBN(rngs=nnx.Rngs(key_untrained))
model_untrained.eval()

visualize_partitions_2d(
    model_untrained, X, y,
    ax=axes[0, 0],
    title='No BN: Before Training'
)

visualize_partitions_2d(
    model_no_bn, X, y,
    ax=axes[0, 1],
    title='No BN: After Training'
)

# Untrained vs trained for with BN
model_bn_untrained = MLPWithBN(rngs=nnx.Rngs(key_untrained))
model_bn_untrained.eval()

visualize_partitions_2d(
    model_bn_untrained, X, y,
    ax=axes[1, 0],
    title='With BN: Before Training'
)

visualize_partitions_2d(
    model_with_bn, X, y,
    ax=axes[1, 1],
    title='With BN: After Training'
)

plt.tight_layout()
plt.show()

print('Observe how BN affects the decision boundary geometry!')

## Analysis: Boundary Alignment

Quantify how decision boundaries align with training data.

In [None]:
# Compute boundary alignment metrics
mean_dist_no_bn, density_no_bn = compute_boundary_alignment(model_no_bn, X, y)
mean_dist_with_bn, density_with_bn = compute_boundary_alignment(model_with_bn, X, y)

print('=== Boundary Alignment Analysis ===')
print(f'\nWithout Batch Normalization:')
print(f'  Mean distance to boundary: {mean_dist_no_bn:.4f}')
print(f'  Boundary density: {density_no_bn:.4f}')

print(f'\nWith Batch Normalization:')
print(f'  Mean distance to boundary: {mean_dist_with_bn:.4f}')
print(f'  Boundary density: {density_with_bn:.4f}')

print(f'\nKey Insights:')
if mean_dist_with_bn < mean_dist_no_bn:
    improvement = (mean_dist_no_bn - mean_dist_with_bn) / mean_dist_no_bn * 100
    print(f'  ✓ BN reduces mean distance by {improvement:.1f}%')
    print(f'  → Boundaries are closer to data (better alignment)')
else:
    print(f'  ✗ BN increased distance (unexpected)')

if density_with_bn > density_no_bn:
    print(f'  ✓ BN increases boundary density')
    print(f'  → More decision boundaries near data points')

## Analysis: Decision Margins

Measure the margin - distance from decision boundary to nearest training point.

In [None]:
# Compute decision margins
margin_no_bn = compute_decision_margin(model_no_bn, X, y, num_steps=30)
margin_with_bn = compute_decision_margin(model_with_bn, X, y, num_steps=30)

print('=== Decision Margin Analysis ===')
print(f'\nWithout Batch Normalization:')
print(f'  Decision margin: {margin_no_bn:.4f}')

print(f'\nWith Batch Normalization:')
print(f'  Decision margin: {margin_with_bn:.4f}')

print(f'\nComparison:')
if margin_with_bn > margin_no_bn:
    improvement = (margin_with_bn / margin_no_bn - 1) * 100
    print(f'  ✓ BN increases margin by {improvement:.1f}%')
    print(f'  → Larger margins typically indicate better generalization')
elif margin_with_bn < margin_no_bn:
    decrease = (margin_no_bn / margin_with_bn - 1) * 100
    print(f'  ✗ BN decreases margin by {decrease:.1f}%')
    print(f'  → This may indicate overfitting or other issues')
else:
    print(f'  = Margins are similar')

# Visualize margins
fig, ax = plt.subplots(figsize=(8, 6))
models = ['No BN', 'With BN']
margins = [margin_no_bn, margin_with_bn]
colors = ['red', 'blue']

bars = ax.bar(models, margins, color=colors, alpha=0.7)
ax.set_ylabel('Decision Margin')
ax.set_title('Decision Margin Comparison')

# Add value labels on bars
for bar, margin in zip(bars, margins):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
            f'{margin:.4f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## Analysis: Gradient Norms

Examine gradient statistics as a proxy for loss landscape smoothness.

In [None]:
# Compute gradient statistics
grad_stats_no_bn = compute_gradient_norm_statistics(model_no_bn, X, y)
grad_stats_with_bn = compute_gradient_norm_statistics(model_with_bn, X, y)

print('=== Gradient Norm Analysis ===')
print(f'\nWithout Batch Normalization:')
print(f'  Mean gradient norm: {grad_stats_no_bn["mean_grad_norm"]:.4f}')

print(f'\nWith Batch Normalization:')
print(f'  Mean gradient norm: {grad_stats_with_bn["mean_grad_norm"]:.4f}')

print(f'\nInterpretation:')
if grad_stats_with_bn['mean_grad_norm'] < grad_stats_no_bn['mean_grad_norm']:
    reduction = (1 - grad_stats_with_bn['mean_grad_norm'] / grad_stats_no_bn['mean_grad_norm']) * 100
    print(f'  ✓ BN reduces gradient norms by {reduction:.1f}%')
    print(f'  → Smoother loss landscape (easier optimization)')
else:
    increase = (grad_stats_with_bn['mean_grad_norm'] / grad_stats_no_bn['mean_grad_norm'] - 1) * 100
    print(f'  BN increases gradient norms by {increase:.1f}%')

## Summary & Key Findings

### What We Observed

1. **Spline Partitions**: ReLU networks create complex piecewise-linear decision boundaries

2. **BN Alignment**: Batch Normalization affects where partition boundaries form
   - Boundaries should be more densely packed around training data
   - This represents more efficient use of network capacity

3. **Decision Margins**: BN affects the distance from decision boundaries to data
   - Larger margins generally correlate with better generalization
   - The "jitter" effect mentioned in the paper

4. **Optimization**: Gradient norms differ with/without BN
   - Reflects changes in loss landscape geometry
   - Smoother landscape = easier optimization

### Connection to Paper

The paper's key insight: **Batch Normalization is not just about normalization** - it's about geometrically repositioning the spline partition boundaries to be more useful for the actual data distribution.

- **Without BN**: Boundaries spread uniformly (wasted capacity in empty regions)
- **With BN**: Boundaries concentrate around data (efficient capacity utilization)

This geometric effect, combined with the "jitter" from batch statistics, creates larger effective margins and better generalization.

### Further Experiments

To extend this work:
- Try different 2D datasets (moons, spirals)
- Vary network depth and width
- Compare with other normalization techniques (Layer Norm, Group Norm)
- Analyze how effects scale with dimensionality

### References

- Paper: Batch Normalization Explained
- Code: `deepkit.spline_partitions`, `deepkit.boundary_analysis`
- Related: Santurkar et al "How Does Batch Normalization Help Optimization?"