# Visualize and Compare Activation Functions

We'll train a **simple MLP** on **MNIST** using **Sigmoid**, **Tanh**, and **ReLU**.  
We’ll compare:
- Training speed (epochs to reach ~95% val accuracy)
- Final accuracy
- Gradient flow (via histogram of gradients)

> **Minimal & Fast** – uses Keras, small network, early stopping.

In [1]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import seaborn as sns

# Set seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Load MNIST
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28*28).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28*28).astype("float32") / 255.0

# One-hot encode labels
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

## Define MLP Model Factory

In [2]:
def create_mlp(activation):
    model = keras.Sequential([
        layers.Dense(128, activation=activation, input_shape=(784,)),
        layers.Dense(64, activation=activation),
        layers.Dense(10, activation='softmax')
    ])
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

## Train Models with Early Stopping
We'll track:
- `history` → loss/accuracy
- `grad_norms` → average gradient norm per epoch (to detect vanishing/exploding)

In [3]:
activations = ['sigmoid', 'tanh', 'relu']
histories = {}
grad_norms = {}

# Custom callback to record gradient norms
class GradientLogger(keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.norms = []

    def on_epoch_end(self, epoch, logs=None):
        # Compute gradients using GradientTape
        with tf.GradientTape() as tape:
            # Forward pass on a batch (use first batch_size samples)
            batch_x = self.model.input[0][:self.model.optimizer.batch_size or 32]
            batch_y = self.model.input[1][:self.model.optimizer.batch_size or 32]
            predictions = self.model(batch_x, training=True)
            loss = self.model.compiled_loss(batch_y, predictions)

        grads = tape.gradient(loss, self.model.trainable_variables)
        grad_norms = [tf.norm(g).numpy() for g in grads if g is not None]
        if grad_norms:
            self.norms.append(np.mean(grad_norms))
        else:
            self.norms.append(0.0)

# Re-define activations & histories (clear previous)
activations = ['sigmoid', 'tanh', 'relu']
histories = {}
grad_norms = {}

for act in activations:
    print(f"\nTraining with {act.upper()}...")
    model = create_mlp(act)

    grad_logger = GradientLogger()

    early_stop = keras.callbacks.EarlyStopping(
        monitor='val_accuracy', patience=3, restore_best_weights=True, min_delta=0.001
    )

    history = model.fit(
        x_train, y_train,
        epochs=50,
        batch_size=128,
        validation_split=0.2,
        verbose=1,
        callbacks=[early_stop, grad_logger]
    )

    histories[act] = history
    grad_norms[act] = grad_logger.norms


Training with SIGMOID...


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/50
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.6339 - loss: 1.4447

AttributeError: The layer sequential has never been called and thus has no defined input.

## Plot Training Speed & Accuracy

In [None]:
plt.figure(figsize=(15, 5))

# Accuracy over epochs
plt.subplot(1, 3, 1)
for act in activations:
    val_acc = histories[act].history['val_accuracy']
    plt.plot(val_acc, label=f'{act} (best: {max(val_acc):.3f})')
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

# Training loss
plt.subplot(1, 3, 2)
for act in activations:
    loss = histories[act].history['loss']
    plt.plot(loss, label=act)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Gradient norm (log scale)
plt.subplot(1, 3, 3)
for act in activations:
    norms = grad_norms[act]
    plt.plot(norms, label=act)
plt.yscale('log')
plt.title('Avg Gradient Norm (log scale)')
plt.xlabel('Epoch')
plt.ylabel('Gradient Norm')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Final Results Table

In [None]:
results = []
for act in activations:
    h = histories[act]
    best_val_acc = max(h.history['val_accuracy'])
    epochs_trained = len(h.history['val_accuracy'])
    final_grad_norm = grad_norms[act][-1] if grad_norms[act] else None
    results.append({
        'Activation': act,
        'Best Val Acc': f'{best_val_acc:.4f}',
        'Epochs': epochs_trained,
        'Final Grad Norm': f'{final_grad_norm:.2e}' if final_grad_norm else 'N/A'
    })

import pandas as pd
df = pd.DataFrame(results)
df