# Activation Functions and Normalization

Activation functions and normalization layers are critical components that enable deep neural networks to learn complex patterns and train stably.

In this tutorial, you will learn:

- 🎯 **Activation Functions** - ReLU, GELU, Sigmoid, Tanh and variants
- 📊 **Batch Normalization** - Stabilizing training with BatchNorm
- 📐 **Layer Normalization** - Alternative normalization for sequences
- 🔲 **Group Normalization** - For small batch sizes
- ⚖️ **When to use each** - Practical guidelines

## Why Are These Important?

**Activation Functions** introduce non-linearity, allowing networks to learn complex patterns  
**Normalization Layers** stabilize training, accelerate convergence, and act as regularizers

In [None]:
import brainstate
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

## 1. Activation Functions

Activation functions determine the output of a neuron given its input.

### ReLU Family

#### Standard ReLU

In [None]:
# ReLU: max(0, x)
relu = brainstate.nn.ReLU()

x = jnp.linspace(-3, 3, 100)
y = relu(x)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(x, y, linewidth=2, label='ReLU')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('ReLU Activation', fontweight='bold')
plt.legend()

# Test with data
test_input = jnp.array([-2, -1, 0, 1, 2])
test_output = relu(test_input)
plt.subplot(1, 2, 2)
plt.stem(test_input, test_output, basefmt=' ')
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('ReLU on Sample Points', fontweight='bold')
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Input:  {test_input}")
print(f"Output: {test_output}")
print("\n✅ Advantages: Simple, fast, no gradient vanishing for positive values")
print("⚠️  Issue: 'Dying ReLU' problem when neurons always output 0")

#### Leaky ReLU

In [None]:
# LeakyReLU: max(alpha * x, x)
leaky_relu = brainstate.nn.LeakyReLU(negative_slope=0.1)

y_leaky = leaky_relu(x)

plt.figure(figsize=(8, 5))
plt.plot(x, y, linewidth=2, label='ReLU', alpha=0.5)
plt.plot(x, y_leaky, linewidth=2, label='Leaky ReLU (α=0.1)')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('ReLU vs Leaky ReLU', fontweight='bold')
plt.legend()
plt.show()

print("✅ Advantage: Allows small gradient for negative values")
print("   Helps prevent dying ReLU problem")

### Modern Activation Functions

#### GELU (Gaussian Error Linear Unit)

In [None]:
# GELU: Smooth, probabilistic activation
gelu = brainstate.nn.GELU()

y_gelu = gelu(x)

plt.figure(figsize=(8, 5))
plt.plot(x, y, linewidth=2, label='ReLU', alpha=0.5)
plt.plot(x, y_gelu, linewidth=2, label='GELU')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('ReLU vs GELU', fontweight='bold')
plt.legend()
plt.show()

print("✅ Used in Transformers (BERT, GPT)")
print("✅ Smooth, differentiable everywhere")
print("✅ Stochastic regularization properties")

### Classic Activations

#### Sigmoid and Tanh

In [None]:
# Sigmoid: 1 / (1 + exp(-x))
sigmoid = brainstate.nn.Sigmoid()

# Tanh: (exp(x) - exp(-x)) / (exp(x) + exp(-x))
tanh = brainstate.nn.Tanh()

y_sigmoid = sigmoid(x)
y_tanh = tanh(x)

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(x, y_sigmoid, linewidth=2, label='Sigmoid', color='blue')
plt.axhline(0.5, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Sigmoid: σ(x) = 1/(1+e⁻ˣ)', fontweight='bold')
plt.legend()
plt.ylim([-0.1, 1.1])

plt.subplot(1, 2, 2)
plt.plot(x, y_tanh, linewidth=2, label='Tanh', color='green')
plt.axhline(0, color='gray', linestyle='--', alpha=0.3)
plt.axvline(0, color='gray', linestyle='--', alpha=0.3)
plt.grid(alpha=0.3)
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Tanh: tanh(x)', fontweight='bold')
plt.legend()
plt.ylim([-1.1, 1.1])

plt.tight_layout()
plt.show()

print("Sigmoid:")
print("  ✅ Output range: (0, 1) - good for probabilities")
print("  ⚠️  Vanishing gradients for large |x|")
print("\nTanh:")
print("  ✅ Output range: (-1, 1) - zero-centered")
print("  ⚠️  Also suffers from vanishing gradients")
print("  📝 Often used in RNN/LSTM gates")

### Comparing All Activations

In [None]:
# Compare multiple activations
activations = {
    'ReLU': brainstate.nn.ReLU(),
    'LeakyReLU': brainstate.nn.LeakyReLU(0.1),
    'GELU': brainstate.nn.GELU(),
    'ELU': brainstate.nn.ELU(),
    'Sigmoid': brainstate.nn.Sigmoid(),
    'Tanh': brainstate.nn.Tanh(),
}

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for idx, (name, activation) in enumerate(activations.items()):
    y_act = activation(x)
    axes[idx].plot(x, y_act, linewidth=2.5, color=f'C{idx}')
    axes[idx].axhline(0, color='gray', linestyle='--', alpha=0.3)
    axes[idx].axvline(0, color='gray', linestyle='--', alpha=0.3)
    axes[idx].grid(alpha=0.3)
    axes[idx].set_xlabel('Input', fontsize=10)
    axes[idx].set_ylabel('Output', fontsize=10)
    axes[idx].set_title(name, fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n📊 Activation Function Guide:\n")
print("ReLU:        Default choice, fast, works well in most cases")
print("LeakyReLU:   When dying ReLU is a problem")
print("GELU:        Transformers, NLP models")
print("ELU:         Smooth variant of ReLU with negative values")
print("Sigmoid:     Output layer for binary classification")
print("Tanh:        RNN/LSTM gates, when zero-centered output needed")

### Softmax - For Classification

In [None]:
# Softmax: converts logits to probabilities
softmax = brainstate.nn.Softmax()

# Example logits from a classifier
logits = jnp.array([2.0, 1.0, 0.5, 3.0, 0.1])
probs = softmax(logits)

print("Logits:       ", logits)
print("Probabilities:", probs)
print(f"Sum of probs:  {jnp.sum(probs):.6f} (should be 1.0)")

# Visualize
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

classes = ['Cat', 'Dog', 'Bird', 'Fish', 'Horse']

axes[0].bar(classes, logits, color='steelblue', alpha=0.7)
axes[0].set_ylabel('Logit Value')
axes[0].set_title('Raw Logits', fontweight='bold')
axes[0].grid(axis='y', alpha=0.3)

axes[1].bar(classes, probs, color='coral', alpha=0.7)
axes[1].set_ylabel('Probability')
axes[1].set_title('After Softmax', fontweight='bold')
axes[1].set_ylim([0, 1])
axes[1].grid(axis='y', alpha=0.3)

# Add values on bars
for i, v in enumerate(probs):
    axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=9)

plt.tight_layout()
plt.show()

print("\n✅ Softmax converts logits to valid probability distribution")
print("   Use for multi-class classification output layer")

## 2. Normalization Layers

Normalization stabilizes training by controlling the distribution of activations.

### Batch Normalization

In [None]:
# BatchNorm: Normalizes across batch dimension
brainstate.random.seed(42)
batch_norm = brainstate.nn.BatchNorm1d(num_features=10)

print("BatchNorm1d:")
print(batch_norm)

# Create batch of data with varying statistics
batch_size = 32
features = 10
x = brainstate.random.randn(batch_size, features) * 5 + 10  # mean≈10, std≈5

print(f"\nBefore BatchNorm:")
print(f"  Mean: {jnp.mean(x, axis=0)[:3]}...")
print(f"  Std:  {jnp.std(x, axis=0)[:3]}...")

# Apply batch norm
y = batch_norm(x)

print(f"\nAfter BatchNorm:")
print(f"  Mean: {jnp.mean(y, axis=0)[:3]}...")
print(f"  Std:  {jnp.std(y, axis=0)[:3]}...")

print("\n✅ Normalizes to ~mean=0, ~std=1 across batch")
print("✅ Learns scale (γ) and shift (β) parameters")

### Visualizing BatchNorm Effect

In [None]:
# Generate data with different distributions
brainstate.random.seed(0)
x1 = brainstate.random.randn(1000) * 10 + 50  # High variance, high mean
x2 = brainstate.random.randn(1000) * 2 - 5     # Low variance, negative mean

# Create batch norm (treating as batch dimension)
bn = brainstate.nn.BatchNorm1d(num_features=1)

# Reshape to (batch, features)
x1_batch = x1[:, None]
x2_batch = x2[:, None]

# Apply normalization
y1 = bn(x1_batch).flatten()
y2 = bn(x2_batch).flatten()

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Distribution 1 - before
axes[0, 0].hist(np.array(x1), bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[0, 0].set_title('Distribution 1 - Before BN', fontweight='bold')
axes[0, 0].set_xlabel('Value')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].axvline(jnp.mean(x1), color='red', linestyle='--', label=f'μ={jnp.mean(x1):.1f}')
axes[0, 0].legend()

# Distribution 1 - after
axes[0, 1].hist(np.array(y1), bins=50, alpha=0.7, color='green', edgecolor='black')
axes[0, 1].set_title('Distribution 1 - After BN', fontweight='bold')
axes[0, 1].set_xlabel('Value')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].axvline(jnp.mean(y1), color='red', linestyle='--', label=f'μ={jnp.mean(y1):.2f}')
axes[0, 1].legend()

# Distribution 2 - before
axes[1, 0].hist(np.array(x2), bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[1, 0].set_title('Distribution 2 - Before BN', fontweight='bold')
axes[1, 0].set_xlabel('Value')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].axvline(jnp.mean(x2), color='red', linestyle='--', label=f'μ={jnp.mean(x2):.1f}')
axes[1, 0].legend()

# Distribution 2 - after
axes[1, 1].hist(np.array(y2), bins=50, alpha=0.7, color='green', edgecolor='black')
axes[1, 1].set_title('Distribution 2 - After BN', fontweight='bold')
axes[1, 1].set_xlabel('Value')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].axvline(jnp.mean(y2), color='red', linestyle='--', label=f'μ={jnp.mean(y2):.2f}')
axes[1, 1].legend()

plt.tight_layout()
plt.show()

print("BatchNorm standardizes different distributions to similar scale")

### Layer Normalization

In [None]:
# LayerNorm: Normalizes across features (not batch)
layer_norm = brainstate.nn.LayerNorm(normalized_shape=(10,))

print("LayerNorm:")
print(layer_norm)

# Single sample
x_single = brainstate.random.randn(10) * 5 + 10

print(f"\nBefore LayerNorm (single sample):")
print(f"  Values: {x_single}")
print(f"  Mean: {jnp.mean(x_single):.3f}")
print(f"  Std:  {jnp.std(x_single):.3f}")

y_single = layer_norm(x_single)

print(f"\nAfter LayerNorm:")
print(f"  Values: {y_single}")
print(f"  Mean: {jnp.mean(y_single):.3f}")
print(f"  Std:  {jnp.std(y_single):.3f}")

print("\n✅ LayerNorm works on single samples")
print("✅ Popular in Transformers and RNNs")
print("✅ Independent of batch size")

### Group Normalization

In [None]:
# GroupNorm: Divides channels into groups and normalizes within groups
group_norm = brainstate.nn.GroupNorm(num_groups=4, num_channels=16)

print("GroupNorm:")
print(group_norm)

# 2D feature map: (batch, height, width, channels)
x_img = brainstate.random.randn(2, 8, 8, 16) * 3 + 2
y_img = group_norm(x_img)

print(f"\nInput shape: {x_img.shape}")
print(f"Output shape: {y_img.shape}")
print(f"Number of groups: {group_norm.num_groups}")
print(f"Channels per group: {16 // 4}")

print("\n✅ Works well with small batch sizes")
print("✅ Alternative to BatchNorm for small batches")
print("✅ Used in computer vision models")

### Comparing Normalization Methods

In [None]:
import pandas as pd

comparison = pd.DataFrame({
    'Method': ['BatchNorm', 'LayerNorm', 'GroupNorm', 'InstanceNorm'],
    'Normalizes Over': ['Batch + Spatial', 'Features', 'Groups', 'Spatial per channel'],
    'Best For': ['CNNs, large batches', 'RNNs, Transformers', 'Small batches', 'Style transfer'],
    'Batch Dependent': ['Yes', 'No', 'No', 'No'],
    'Typical Use': ['Vision', 'NLP', 'Vision (small batch)', 'GANs, Style'],
})

print("\n📊 Normalization Method Comparison:\n")
print(comparison.to_string(index=False))

print("\n\n🎯 Quick Guide:")
print("  • Large batch + CNN → BatchNorm")
print("  • Small batch + CNN → GroupNorm")
print("  • Sequences/RNN/Transformer → LayerNorm")
print("  • Single image inference → LayerNorm or GroupNorm")

## 3. Putting It All Together

Building a complete network with activations and normalization:

In [None]:
class ModernCNN(brainstate.graph.Node):
    """CNN with modern activations and normalization."""
    
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Block 1: Conv + BatchNorm + GELU
        self.conv1 = brainstate.nn.Conv2d(3, 64, kernel_size=(3, 3), padding='SAME')
        self.bn1 = brainstate.nn.BatchNorm2d(64)
        self.act1 = brainstate.nn.GELU()
        self.pool1 = brainstate.nn.MaxPool2d(kernel_size=(2, 2), strides=(2, 2))
        
        # Block 2
        self.conv2 = brainstate.nn.Conv2d(64, 128, kernel_size=(3, 3), padding='SAME')
        self.bn2 = brainstate.nn.BatchNorm2d(128)
        self.act2 = brainstate.nn.GELU()
        self.pool2 = brainstate.nn.MaxPool2d(kernel_size=(2, 2), strides=(2, 2))
        
        # Classifier
        self.flatten = brainstate.nn.Flatten()
        self.fc1 = brainstate.nn.Linear(128 * 8 * 8, 256)
        self.ln = brainstate.nn.LayerNorm((256,))
        self.act3 = brainstate.nn.GELU()
        self.dropout = brainstate.nn.Dropout(p=0.5)
        self.fc2 = brainstate.nn.Linear(256, num_classes)
    
    def __call__(self, x):
        # Block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act1(x)
        x = self.pool1(x)
        
        # Block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        x = self.pool2(x)
        
        # Classifier
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.ln(x)
        x = self.act3(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Create and test
brainstate.random.seed(0)
model = ModernCNN(num_classes=10)

# Forward pass
x = brainstate.random.randn(4, 32, 32, 3)  # 4 images
logits = model(x)

print("Modern CNN with GELU + BatchNorm + LayerNorm:")
print(model)
print(f"\nInput: {x.shape}")
print(f"Output: {logits.shape}")
print(f"\nLogits: {logits[0]}")

## Summary

In this tutorial, you learned:

✅ **Activation Functions**
  - ReLU family (ReLU, Leaky ReLU, ELU)
  - Modern activations (GELU, SiLU)
  - Classic activations (Sigmoid, Tanh)
  - Softmax for classification

✅ **Normalization Layers**
  - BatchNorm for large-batch training
  - LayerNorm for sequences and transformers
  - GroupNorm for small-batch scenarios

✅ **Practical Guidelines**
  - When to use each activation
  - When to use each normalization
  - How to combine them effectively

### Quick Reference Card

| Task | Activation | Normalization |
|------|-----------|---------------|
| **CNN (large batch)** | ReLU/GELU | BatchNorm |
| **CNN (small batch)** | ReLU/GELU | GroupNorm |
| **Transformer/NLP** | GELU | LayerNorm |
| **RNN/LSTM** | Tanh (gates) | LayerNorm |
| **Binary output** | Sigmoid | - |
| **Multi-class output** | Softmax | - |

### Best Practices

1. 🎯 **Use ReLU/GELU** as default activations
2. 📊 **Add normalization** after conv/linear layers
3. ⚡ **Order**: Conv/Linear → Norm → Activation
4. 🔍 **Experiment** with activation functions
5. 📝 **Use appropriate normalization** for your batch size

### Next Steps

Continue with:
- **Recurrent Networks** - Handle sequential data
- **Training** - Optimize with gradient descent
- **Advanced Architectures** - ResNets, Transformers