## Activation functions: vanishing gradient problem

### Introduction
Deep learning models are powerful, but they sometimes encounter issues during training. One such issue is the *vanishing gradient problem*, where the gradients of earlier layers become extremely small, making it difficult for the network to learn effectively. This means the earlier layers of a deep network learn very slowly or stop learning completely.

As you explore Deep Learning, you will come across terms such as *sigmoid*, *tanh*, and *ReLU*. These are what we call *activation functions* that determine the value of the output in each layer of a neural network. For example, predicting a 1 or 0 based on the presence of heart disease.

Let's spend some time understanding what the vanishing gradient problem is, why it happens in deep neural networks, learn how it affects training, and explore different solutions and how they help. We will also explain how *batch normalisation* and *residual connections* work, and why these techniques can help. You will then have a clearer understanding of this fundamental issue and be equipped to fix it if it occurs.

One of the learning outcomes is for you to understand why deep networks struggle to learn, how to detect the problem, and what techniques you can use to fix it.

### What are gradients?
Before discussing the vanishing gradient problem, we need to understand what gradients are.

- Neural networks learn using gradients, which indicate how much to adjust each weight to reduce the error (or loss).
- This is done using *backpropagation*, a process that calculates the gradient at each layer of the network.
- The gradient is a measure of how a small change in a weight affects the loss function.

The key idea is that if the gradients become too small, the updates to the weights also become small, meaning the model struggles to learn effectively.

#### Why do gradients vanish?

When training deep neural networks, we use a method called backpropagation to adjust the weights in the network. This involves calculating how much each weight contributes to the final error, known as the *gradient*. But in deep networks, these gradients can become so small that they effectively disappear, this is called the *vanishing gradient* problem. This occurs for several reasons:

- *Activation saturation*: Functions like *sigmoid* and *tanh* squash input values into a narrow range. When the input is very large or very small, these functions become almost flat, which means their gradient (slope) is close to zero. This makes it hard for the network to learn.

- *Repeated multiplication*: During backpropagation, each layer multiplies the gradient from the layer after it. If each of these numbers is less than one, their repeated multiplication quickly shrinks the gradient to a tiny value, often close to zero.

- *Poor weight initialisation*: If the weights are badly chosen at the start (too large or too small), activations may land straight in those flat areas of the activation function, causing the gradients to vanish early in training.

When gradients vanish, the earlier layers of the network receive little or no useful information about how to improve. As a result, they stop learning and the deeper the network, the more serious the problem.

To address this, modern networks often use alternative activation functions (such as *ReLU*) and improved weight initialisation techniques to help keep gradients flowing.

### Install Python libraries

In [None]:
!pip install tensorflow torch numpy matplotlib

### Plotting Activation functions and their derivatives
Let's explore why certain activations (like sigmoid) are more prone to vanishing gradients, while others (like ReLU) avoid it.

Before we start, a *derivative* represents the slope or steepness of an activation function, it tells us how much the output of the function changes when the input changes slightly. During *backpropagation*, these derivatives are used to calculate how much each *weight* in the network should be adjusted.

If the derivative is very small (close to zero), it means that, as we said, changes in the input barely affect the output so the weight update is tiny.

When this happens across many layers, the effect compounds, and the *gradient* becomes so small that earlier layers stop learning effectively. Now, different *activation functions* behave differently.



### Sigmoid
Sigmoid outputs values between 0 and 1, but it flattens out quickly for large positive or negative inputs. In those flat regions, the derivative becomes very small, which is why sigmoid is prone to vanishing gradients:



In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define the sigmoid activation function
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Define the derivative of the sigmoid function
def sigmoid_derivative(x):
    return sigmoid(x) * (1 - sigmoid(x))

# Generate 200 evenly spaced input values between -5 and 5
x = np.linspace(-5, 5, 200)

# Compute the sigmoid values and their derivatives
y_sigmoid = sigmoid(x)
y_derivative = sigmoid_derivative(x)

# Create a figure with two side-by-side subplots
plt.figure(figsize=(8, 4))

# Plot the sigmoid function
plt.subplot(1, 2, 1)
plt.plot(x, y_sigmoid, label='Sigmoid', color='blue', linewidth=2)
plt.title('Sigmoid function')                  
plt.xlabel('Input value (x)')                  
plt.ylabel('Output')                           
plt.legend()                                   

# Plot the derivative of the sigmoid function
plt.subplot(1, 2, 2)
plt.plot(x, y_derivative, label='Sigmoid Derivative', color='red', linestyle='--', linewidth=2)
plt.title('Sigmoid Derivative')                
plt.xlabel('Input value (x)')                  
plt.ylabel('Gradient')                         
plt.legend()                                   

# Adjust layout to prevent overlap
plt.tight_layout()

# Show the figure
plt.show()


The *sigmoid* function squashes any input into a value between 0 and 1. It has an *S-shape* — very flat at the far left and right, and steepest in the middle around 0. The *derivative* of the sigmoid tells us how much the output changes for a small change in input. This is used during learning to update weights. In the plot, you can see that the derivative is:
  - Highest at the centre (around input = 0)
  - Very small (close to zero) at the edges.

This matters because during training, if most of the inputs to sigmoid fall into those flat edge areas, the derivative is tiny and that leads to *vanishing gradients*. This means the network struggles to update weights in those regions, especially in early layers of deep networks.

### TanH
*Tanh* is similar in shape to *sigmoid*, but instead of squashing values into a range between 0 and 1, it outputs values between -1 and 1. This makes it *zero-centred*, which can help with optimisation, since it balances positive and negative values more naturally during training. 

However, like sigmoid, *tanh* still becomes flat at its extremes, when the input is very positive or very negative, the output levels off and the *derivative* becomes very small. As a result, the gradients passed back through these regions can still vanish. 

That said, because *tanh* has a steeper slope around zero and covers a wider range, it tends to suffer *slightly less* from the vanishing gradient problem compared to *sigmoid*. Nonetheless, it's still not ideal for very deep networks:

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define the tanh activation function
def tanh(x):
    return np.tanh(x)

# Define the derivative of the tanh function
def tanh_derivative(x):
    return 1 - np.tanh(x)**2

# Generate 200 evenly spaced values from -5 to 5
x = np.linspace(-5, 5, 200)

# Calculate the tanh function and its derivative for each x
y_tanh = tanh(x)
y_deriv = tanh_derivative(x)

# Create a figure with two side-by-side plots
plt.figure(figsize=(8, 4))

# Plot the tanh function
plt.subplot(1, 2, 1)
plt.plot(x, y_tanh, label='Tanh', color='blue', linewidth=2)

plt.title('Tanh function')                  

plt.xlabel('Input value (x)')               
plt.ylabel('Output')                        

plt.legend()                                

# Plot the derivative of tanh
plt.subplot(1, 2, 2)
plt.plot(x, y_deriv, label='Tanh Derivative', color='red', linestyle='--', linewidth=2)

plt.title('Tanh Derivative')                
plt.xlabel('Input value (x)')               
plt.ylabel('Gradient')                      

plt.legend()                                

# Adjust layout to prevent overlap
plt.tight_layout()

# Show the figure
plt.show()


In the left plot, the solid curve shows the *tanh* activation function. It has an *S-shape*, similar to *sigmoid*, but instead of going from 0 to 1, it goes from -1 to 1. That means its output is *zero-centred*, which helps the network handle both positive and negative values more symmetrically.

The dashed curve (right plot) shows the *derivative* of *tanh*. This tells us how much the function output changes when the input changes slightly, and it's this gradient that’s used during training to adjust the weights. From the plots, you can see:

- The *tanh* function is steepest in the middle (around input = 0). This is where the gradient is largest, which means learning happens more effectively here.
- As you move away from zero in either direction (large positive or negative inputs), the *tanh* curve flattens out, approaching -1 or 1.
- In these flat outer regions, the *derivative* drops sharply toward zero, the dashed line gets very close to the horizontal axis.

This means if your network’s activations fall into these extreme zones, the gradients become tiny. *Tanh* has this issue, but it tends to be slightly better than *sigmoid* because:
- It’s centred around zero (which helps with weight updates)
- It has a steeper slope near the origin (so the central gradient is stronger)

But like *sigmoid*, it's not ideal for very deep networks, which is why *ReLU* and its variants are often preferred.

### ReLU
*ReLU* (Rectified Linear Unit) is one of the most widely used activation functions in deep learning. It’s very simple: for any input less than zero, the output is zero; for any input greater than zero, the output increases linearly. So the function looks like a flat line for negative values and a diagonal line for positives. This simplicity is actually its strength. The *derivative* of ReLU is:

- 0 when the input is negative (because the function is flat there),
- 1 when the input is positive (since the function just increases directly with input).

This has two big advantages. There’s no *flattening* in the positive range unlike *sigmoid* and *tanh*, the gradient doesn’t shrink as the input grows. That means gradients stay large enough during backpropagation, so weights in early layers can continue updating effectively.

However, there's a small drawback too. For negative inputs, the derivative is zero, so if a neuron's input is always negative, it may "die" (i.e. stop learning). Despite this, *ReLU* is still highly effective and forms the default choice in most modern deep networks. If you have heard of variants of ReLU, like *Leaky ReLU*, these essentially help address the “dying neuron” issue:


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define the ReLU activation function: outputs 0 for negatives, x for positives
def relu(x):
    return np.maximum(0, x)

# Define the derivative of ReLU: 0 for x ≤ 0, 1 for x > 0
def relu_derivative(x):
    return np.where(x > 0, 1, 0)

# Generate 200 evenly spaced input values from -5 to 5
x = np.linspace(-5, 5, 200)

# Calculate the ReLU values and their derivatives for these inputs
y_relu = relu(x)
y_deriv = relu_derivative(x)

# Create a figure with two subplots side by side
plt.figure(figsize=(8, 4))

# First subplot: plot the ReLU function
plt.subplot(1, 2, 1)
plt.plot(x, y_relu, label='ReLU', color='blue', linewidth=2)

plt.title('ReLU function')                   

plt.xlabel('Input value (x)')               
plt.ylabel('Output')                        

plt.legend()                                

# Second subplot: plot the derivative of the ReLU function
plt.subplot(1, 2, 2)

plt.plot(x, y_deriv, label='ReLU Derivative', color='red', linestyle='--', linewidth=2)

plt.title('ReLU Derivative')                
plt.xlabel('Input value (x)')               
plt.ylabel('Gradient')                      

plt.legend()                                

# Automatically adjust layout to prevent overlap
plt.tight_layout()

plt.show()


In the plot above, you’ll see:

The solid line shows the *ReLU* function: flat at 0 for negative inputs, and increasing linearly for positives. The dashed line shows the *derivative*:
  - It's 0 for all negative inputs: no change, hence no gradient.
  - It's 1 for all positive inputs: meaning the function is learning at full strength.

This is a big contrast to *sigmoid* and *tanh*, where the derivative becomes very small at both ends. With *ReLU*, there's no risk of the gradient vanishing on the positive side so information flows better through the network, and training is usually faster and more stable.

### MNIST dataset
We will demonstrate how a deep network using *sigmoid* can suffer from vanishing gradients, while a *ReLU-based* network trains more effectively. The *MNIST dataset*, as we know, contains images of handwritten digits (`0` to `9`). Each image is 28×28 pixels, and the task is to correctly identify the digit in the image. We will construct and train each model and then perform a final comparison, let's load and preprocess the data:

### Load the data

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.datasets import mnist

# Load the MNIST dataset (handwritten digits 0–9)
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

#### Resampling
We create a smaller and random subset from our full training and test sets so that subsequent experiments run faster. We draw 5,000 unique indices from the rows of `X_train` and 1,000 from `X_test` (without replacement, so no sample is picked twice). Finally, we use those index arrays to create corresponding feature matrices (`X_train_sample`, `X_test_sample`) and their matching label vectors (`Y_train`, `Y_test`). The result is a smaller, randomly selected train-test split that mirrors the distribution of the full dataset but trains and evaluates more quickly:

In [None]:
# Take random subsets of the data for demonstration
np.random.seed(7)

train_sample_size = 5000
test_sample_size  = 1000

train_idxs = np.random.choice(X_train.shape[0], size=train_sample_size, replace=False)
test_idxs  = np.random.choice(X_test.shape[0],  size=test_sample_size,  replace=False)

X_train = X_train[train_idxs]
Y_train = Y_train[train_idxs]

X_test  = X_test[test_idxs]
Y_test  = Y_test[test_idxs]


### Preprocessing
We normalise the pixel data by way of preprocessing the data:

In [None]:
# Normalise pixel values to [0,1]
X_train = X_train.astype('float32') / 255.0
X_test  = X_test.astype('float32')  / 255.0

We will set an equal number of epochs for each model:

In [None]:
# Define a common number of epochs for each model (you can increase or decrease this if you wish)
num_epochs = 10

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Flatten, Dense
from tensorflow.keras.optimizers import SGD

# Build the neural network 
model_sigmoid = Sequential([
    # Define the input layer to accept 28×28 grayscale images
    Input(shape=(28, 28)),          
    # Flatten the 2D image into a 1D vector of length 784
    Flatten(),                      
    # First hidden layer with 128 neurons and sigmoid activation
    Dense(128, activation='sigmoid'),
    # Second hidden layer, same size and activation to add depth
    Dense(128, activation='sigmoid'),
    # Third hidden layer to capture additional non-linear patterns
    Dense(128, activation='sigmoid'),
    # Fourth hidden layer for even richer feature extraction
    Dense(128, activation='sigmoid'),
    # Output layer with 10 neurons (one per digit class), using softmax
    # so the outputs sum to 1 and represent a probability distribution
    Dense(10, activation='softmax')
])

# Compile the model:
# SGD optimiser with learning rate 0.01 for weight updates
# sparse_categorical_crossentropy loss since labels are integer-encoded
# track accuracy to monitor how many digits are classified correctly
model_sigmoid.compile(
    optimizer=SGD(learning_rate=0.01),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model on the 5 000-sample subset:
# X_train, Y_train: training data and labels
# validation_data: uses X_test, Y_test to check generalisation after each epoch
# epochs=num_epochs: number of full passes through the training data
# verbose=1: show a progress bar with loss and accuracy
history_sigmoid = model_sigmoid.fit(
    X_train, Y_train,
    validation_data=(X_test, Y_test),
    epochs=num_epochs,
    verbose=1
)


### Evaluate
After ten epochs your model is stuck at around 10–11 % accuracy with a loss of roughly 2.3, which is exactly what you’d expect from random guessing on ten classes (–log(1/10) ≈ 2.3). In practice, stacking four sigmoid-activated layers causes the activations to saturate and gradients to vanish, so the network barely updates its weights from their random initial values.

### Training accuracy with Sigmoid

In [None]:
import matplotlib.pyplot as plt

# Plot the training accuracy over epochs for the sigmoid-based model
plt.plot(history_sigmoid.history['accuracy'], label='Train Acc (Sigmoid)')

# Plot the validation accuracy over epochs for the same model
plt.plot(history_sigmoid.history['val_accuracy'], label='Val Acc (Sigmoid)')

# Label the x-axis to show training epochs
plt.xlabel('Epochs')

# Label the y-axis to show accuracy
plt.ylabel('Accuracy')

# Set the plot title
plt.title('Network with Sigmoid')

# Display the legend to differentiate between training and validation curves
plt.legend()

# Show the plot
plt.show()


Notice that the accuracy increases slowly. This is a sign that gradients may be very small in earlier layers, slowing learning.

### Comparing with ReLU activation
Now, let’s build a similar network using *ReLU* instead of sigmoid for comparison:

In [None]:
# Define a neural network model using ReLU activations
model_relu = Sequential([
    Flatten(input_shape=(28, 28)),                 # Flatten 28x28 images into 784-dimensional vectors
    Dense(128, activation='relu'),                 # First hidden layer with ReLU activation
    Dense(128, activation='relu'),                 # Second hidden layer
    Dense(128, activation='relu'),                 # Third hidden layer
    Dense(128, activation='relu'),                 # Fourth hidden layer
    Dense(10, activation='softmax')                # Output layer (10 classes, softmax for classification)
])

# Compile the model using:
# Stochastic Gradient Descent (SGD) optimiser
# Sparse categorical crossentropy loss
# Accuracy as the evaluation metric
model_relu.compile(
    optimizer=SGD(learning_rate=0.01),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model on the training data and validate on the test set
history_relu = model_relu.fit(
    X_train, Y_train,
    validation_data=(X_test, Y_test),
    epochs=num_epochs,   # Number of epochs
    verbose=1             # Show training progress
)


### Evaluate
Switching to ReLU activations immediately lets the network learn meaningful features: training accuracy jumps from about 16 % in the first epoch to over 90 % by epoch 8, and training loss steadily falls from around 2.26 down to 0.29. 

On the validation side you see a rapid improvement too, val accuracy climbs from roughly 51 % at the start to a peak of about 89.3 % in epoch 9, with val loss falling to around 0.33. However, by epoch 10 the validation loss creeps back up (to 0.38) and val accuracy dips to 88.6 %, signalling the first signs of overfitting. 

In practice you’d capture the best generalisation by stopping training around epoch 8 or 9 (where val loss is lowest).


### Training accuracy with ReLU

In [None]:
import matplotlib.pyplot as plt

# Plot the training accuracy over epochs for the ReLU-based model
plt.plot(history_relu.history['accuracy'], label='Train Acc (ReLU)')

# Plot the validation accuracy over epochs for the same model
plt.plot(history_relu.history['val_accuracy'], label='Val Acc (ReLU)')

plt.title('Network with ReLU')

plt.xlabel('Epochs')
plt.ylabel('Accuracy')

plt.legend()

# Render the plot
plt.show()


As we can see it does much better! Training accuracy increases quite quickly per epoch.

### Comparing Sigmoid vs ReLU
We produce a validation accuracy plot to compare the two trained models. You will see that as training progresses over multiple epochs, the plot tracks how well each model performs on unseen validation data.

The ReLU-based model generally shows a clear advantage. It tends to improve more quickly in the early stages of training and reaches a higher overall accuracy. This is because ReLU avoids the issue of *vanishing gradients* we mentioned, which often affects sigmoid functions in deeper networks. With ReLU, gradients remain strong for positive inputs, allowing the network to learn more effectively across all layers.

In contrast, the sigmoid-based model usually improves more slowly and may plateau at a lower accuracy. This is due to the saturation of the sigmoid function at high or low input values, which causes its gradients to shrink. As a result, learning becomes inefficient in deeper parts of the network, and performance is limited.

Overall, the plot highlights a key reason why ReLU has become the standard activation function in deep learning. It enables faster, more stable training and often leads to better generalisation on tasks like image classification:

In [None]:
# Plot validation accuracy for both models across training epochs
plt.plot(history_sigmoid.history['val_accuracy'], label='Val Acc (Sigmoid)')  # Sigmoid-based model
plt.plot(history_relu.history['val_accuracy'], label='Val Acc (ReLU)')        # ReLU-based model

# Label the x-axis (training epochs)
plt.xlabel('Epochs')

# Label the y-axis (validation accuracy)
plt.ylabel('Validation Accuracy')

# Add a title to describe the comparison
plt.title('Sigmoid vs ReLU on MNIST')

# Add a legend to identify each curve
plt.legend()

# Display the plot
plt.show()


The key take away is that from the plot, we typically see faster improvement in ReLU, indicating it suffers *less* from the vanishing gradient problem. So this can be your go-to option for many tasks, but always experiment first!

### Batch normalisation

*Batch Normalisation* (often shortened to *BatchNorm*) is a widely used technique in deep learning that helps speed up and stabilise the training of neural networks. It works by normalising the activations (i.e. the outputs) of each layer so that they have a more consistent distribution typically with a mean close to 0 and a standard deviation close to 1 within each mini-batch during training. It helps in several ways:

- *Stable distributions*:  
   Without batch normalisation, the output of one layer might vary significantly during training, especially as the layers deeper in the network adapt their weights. This forces the next layer to constantly readjust to changing inputs. BatchNorm fixes this by ensuring that each layer sees data with a more stable mean and variance, which makes learning smoother.

- *Reduces internal covariate shift*:  
   As training progresses, the distribution of activations within the network can shift, causing instability and slower learning. This is known as internal *covariate shift*. BatchNorm helps to reduce this problem by keeping the activation distributions more consistent across training iterations.
   >
   > *Covariate shift* happens when the type of data a model sees *changes* either during training or between training and testing. For example, suppose you train a model to recognise animals in colour photos. Later, you give it black-and-white photos, this means the inputs have changed. That change is a kind of *covariate shift*, i.e. the input data has shifted, even though the task (recognising animals), is the same. In deep learning, internal covariate shift means this kind of shift happens *inside* the network between layers. As the model trains, the outputs from one layer can change a lot, which makes it harder for the next layer to learn properly.
   >
- *Maintains larger gradients*:  
   When activations are kept within a moderate, non-saturated range, the *gradients* used in backpropagation are less likely to vanish. This is especially useful for deep networks where earlier layers often struggle to learn. BatchNorm helps maintain useful gradient sizes, allowing the network to train more effectively from end to end.

There are other benefits, which include faster convergence - networks often train much faster with batch normalisation, sometimes allowing for higher learning rates. Also, a regularisation effect - BatchNorm can reduce the need for techniques like dropout, as it introduces some noise during training (due to per-batch statistics), which has a slight regularising effect. Lastly, we get improved performance. In many cases, models with batch normalisation not only train faster but also achieve better final accuracy:

### Model
This next code sets up a controlled experiment to compare how inserting Batch Normalisation affects a simple feed-forward network on MNIST. First, it one-hot encodes the integer labels into 10-dimensional vectors so that both models use categorical cross-entropy.

The `no BN` model simply flattens each 28×28 image, applies two sigmoid-activated dense layers (128 then 64 units), and ends with a softmax output for the ten digit classes. The `with BN` model has the same overall shape, but each dense layer is split into a linear transform followed immediately by Batch Normalisation, then the sigmoid activation. Both models use vanilla SGD at a learning rate of 0.1 and are trained for three epochs on the full training set.

Finally, the code builds two smaller sub-models that take the same input images and tap the activations of the first hidden layer (post-sigmoid for the no-BN model, post-BatchNorm+sigmoid for the BN model). 

When running a batch of 256 test images through these sub-models, you end up with two activation matrices (`act_no_bn` and `act_with_bn`) that we can inspecte to see how Batch Normalisation changes the distribution of neuron activations before the next layer:


In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization, Activation, Input

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD

# Convert labels to one-hot encoded vectors
Y_train_cat = to_categorical(Y_train, 10)

# Build a model WITHOUT Batch Normalisation
input_1 = Input(shape=(28, 28))                        # Input layer for 28x28 images
x1 = Flatten()(input_1)                                # Flatten image to 784-d vector
x1 = Dense(128, activation='sigmoid')(x1)              # First hidden layer with sigmoid
x1 = Dense(64, activation='sigmoid')(x1)               # Second hidden layer
output_1 = Dense(10, activation='softmax')(x1)         # Output layer with softmax for 10 classes
model_no_bn = Model(input_1, output_1)                 # Create the model

# Build a model WITH Batch Normalisation
input_2 = Input(shape=(28, 28))                        # Input layer
x2 = Flatten()(input_2)                                # Flatten image
x2 = Dense(128)(x2)                                    # First hidden layer (no activation yet)
x2 = BatchNormalization()(x2)                          # BatchNorm to stabilise activations
x2 = Activation('sigmoid')(x2)                         # Apply sigmoid after normalisation
x2 = Dense(64)(x2)                                     # Second hidden layer
x2 = BatchNormalization()(x2)                          # BatchNorm again
x2 = Activation('sigmoid')(x2)                         # Sigmoid activation
output_2 = Dense(10, activation='softmax')(x2)         # Output layer
model_with_bn = Model(input_2, output_2)               # Create the model

# Create separate optimisers for both models
opt_no_bn = SGD(learning_rate=0.1)
opt_with_bn = SGD(learning_rate=0.1)

# Compile both models using categorical cross-entropy loss and accuracy as a metric
model_no_bn.compile(optimizer=opt_no_bn, loss='categorical_crossentropy', metrics=['accuracy'])
model_with_bn.compile(optimizer=opt_with_bn, loss='categorical_crossentropy', metrics=['accuracy'])

# Train each model briefly for 3 epochs on the training set
model_no_bn.fit(X_train, Y_train_cat, epochs=3, batch_size=128, verbose=0)
model_with_bn.fit(X_train, Y_train_cat, epochs=3, batch_size=128, verbose=0)

# Select a small batch of test images for visualising activations
sample = X_test[:256]

# Create models that output the first hidden layer's activations
layer_no_bn = Model(inputs=model_no_bn.input, outputs=model_no_bn.layers[2].output)
layer_with_bn = Model(inputs=model_with_bn.input, outputs=model_with_bn.layers[4].output)

# Get activation values from the selected sample
act_no_bn = layer_no_bn(sample).numpy()
act_with_bn = layer_with_bn(sample).numpy()


### Visualise the effects of BatchNormalisation

In [None]:
import matplotlib.pyplot as plt

# Create a figure with two side-by-side plots
plt.figure(figsize=(10, 4))

# Plot for activations WITHOUT BatchNorm
plt.subplot(1, 2, 1)  # First subplot (left side)
plt.hist(act_no_bn.flatten(), bins=50, color='skyblue', edgecolor='black')  # Plot histogram of raw activations

plt.title('Activations WITHOUT BatchNorm')  

plt.xlabel('Activation Value')               
plt.ylabel('Frequency')                      

# Plot for activations WITH BatchNorm
plt.subplot(1, 2, 2)  # Second subplot (right side)
plt.hist(act_with_bn.flatten(), bins=50, color='lightgreen', edgecolor='black')  # Histogram after BatchNorm

plt.title('Activations WITH BatchNorm')     

plt.xlabel('Activation Value')              
plt.ylabel('Frequency')                     

# Adjust layout to avoid overlapping elements
plt.tight_layout()

plt.show()


Without BatchNorm, the activation values, especially when using functions like *sigmoid*, often get pushed towards the extremes: very close to 0 or very close to 1. This is because as signals pass through multiple layers, the outputs can grow or shrink unpredictably. When the activations end up in these extreme zones, the *sigmoid* function flattens out, and its gradient becomes very small. This is called *saturation*, and it slows down learning, particularly in the early layers of the network.

With BatchNorm, the activation values are automatically scaled and shifted so they stay in a more moderate, central range usually centred around 0 with a consistent spread. This keeps the values away from the saturated parts of the activation function, allowing the gradients to remain large enough for effective learning. The result is smoother gradient flow, better weight updates, and faster, more stable training.

### Residual connections (Skip Connections)

*Residual connections* introduced in *ResNet* architectures allow the network to "skip over" layers by adding the input of a block directly to its output. Rather than learning the full transformation, the network learns the *residual*, the difference between input and output.
>
> In *ResNet* (Residual Network) architectures, the key innovation is the introduction of *residual connections*, or *skip connections* that we are discussing now. As we said, these allow the network to bypass one or more layers by adding the input directly to the output of a block, helping to solve the degradation problem where adding more layers leads to worse performance. This design enables the successful training of very deep networks (e.g. 50, 101, or even 152 layers), overcoming issues like vanishing gradients that previously limited most convolutional networks to 20–30 layers.  
>
> ResNet essentially made training deep neural networks possible  and practical, laying the foundation for many modern architectures in computer vision and beyond. Its residual connections have become a core building block in deep learning, enabling both depth and stability in complex models.

Let's visualise the activations to see how skip connections preserve information across layers:

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten, Add, Activation
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD


# Build a plain deep model (no skip connections)
inp_plain = Input(shape=(28, 28))                      # Input for 28x28 images
x = Flatten()(inp_plain)                               # Flatten to 784
x = Dense(128, activation='relu')(x)                   # First dense layer
x = Dense(128, activation='relu')(x)                   # Second dense layer
x = Dense(128, activation='relu')(x)                   # Third dense layer
out_plain = Dense(10, activation='softmax')(x)         # Output layer (softmax for 10 classes)
model_plain = Model(inputs=inp_plain, outputs=out_plain)

# Build a residual model (with skip connection)
inp_res = Input(shape=(28, 28))                        # Input layer
x = Flatten()(inp_res)                                 # Flatten image

h1 = Dense(128, activation='relu')(x)                  # First hidden layer
h2 = Dense(128, activation='relu')(h1)                 # Second hidden layer
h3 = Dense(128)(h2)                                    # Third hidden layer (no activation yet)

res = Add()([h1, h3])                                  # Add skip connection: h1 + h3
res = Activation('relu')(res)                          # Apply activation after addition

out_res = Dense(10, activation='softmax')(res)         # Output layer
model_res = Model(inputs=inp_res, outputs=out_res)

# Compile both models with separate optimisers
model_plain.compile(optimizer=SGD(learning_rate=0.1),
                    loss='categorical_crossentropy',
                    metrics=['accuracy'])

model_res.compile(optimizer=SGD(learning_rate=0.1),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

# Train both models briefly for 3 epochs
model_plain.fit(X_train, Y_train_cat, epochs=3, batch_size=128, verbose=0)
model_res.fit(X_train, Y_train_cat, epochs=3, batch_size=128, verbose=0)

# Take a small sample from the test set for activation visualisation
sample = X_test[:256]

# Create models to extract activations from the final hidden layer
intermediate_plain = Model(inputs=model_plain.input, outputs=model_plain.layers[-2].output)
intermediate_res = Model(inputs=model_res.input, outputs=model_res.layers[-2].output)

# Get the activations for the sample
act_plain = intermediate_plain(sample).numpy()
act_res = intermediate_res(sample).numpy()


And plotting the results gives:

In [None]:
import matplotlib.pyplot as plt

# Create a figure with two subplots side by side
plt.figure(figsize=(12, 6))

# Plot histogram of activations from the plain (non-residual) model
plt.subplot(1, 2, 1)

plt.hist(act_plain.flatten(), bins=50, color='orange', edgecolor='black')

plt.title('Plain model Activations')   
         
plt.xlabel('Activation Value')                  
plt.ylabel('Frequency')                         

# Plot histogram of activations from the residual model
plt.subplot(1, 2, 2)

plt.hist(act_res.flatten(), bins=50, color='green', edgecolor='black')

plt.title('Residual model Activations')         

plt.xlabel('Activation Value')                  
plt.ylabel('Frequency')                         

# Automatically adjust spacing between plots
plt.tight_layout()

plt.show()


Looking at the plot we see that in the *plain model*, activations may become more compressed or irregular as the network deepens, especially without techniques like BatchNorm or ReLU in skip-free layers.

In the *residual model*, skip connections help maintain more dynamic activations, preserving variation, keeping gradients flowing, and helping the model learn deeper representations without degradation.

### What have we learnt?
We explored some of the core challenges and solutions in training deep neural networks, with a particular focus on activation functions, normalisation techniques, and architectural design strategies. We began by examining the *vanishing gradient* problem - a key issue that arises when gradients shrink as they are propagated backward through multiple layers. This is especially common when using activation functions like *sigmoid* or *tanh*, which saturate at their extremes and produce very small derivatives. As a result, early layers in a network may stop learning entirely, slowing or even stalling the training process.

To address this, we looked at the *ReLU* activation function, which has become the default in many modern architectures. Unlike sigmoid and tanh, ReLU does not saturate in the positive range and has a constant derivative of 1 for positive inputs. This helps preserve gradient strength, enabling more efficient learning and deeper networks to be trained effectively. Through visualisations, we saw how ReLU maintains larger gradients, while sigmoid and tanh activations quickly plateau.

We then turned to *Batch Normalisation*, a technique that standardises activations within each mini-batch during training. This stabilises the distribution of inputs seen by each layer, reducing internal *covariate shift* and allowing for faster and more reliable convergence. BatchNorm also helps keep gradients in a healthy range, mitigating vanishing gradients and making training less sensitive to weight initialisation. When comparing activation distributions with and without BatchNorm, we observed how this technique results in more centred and balanced outputs, which is a crucial factor in stable training.

Finally, we explored *residual connections* (or *skip connections*), a structural innovation that allows inputs to bypass one or more layers and be added directly to a deeper layer’s output. Residual connections make it easier for the network to learn incremental improvements over previous representations, rather than learning everything from scratch. More importantly, they create direct pathways for gradients to flow backward, greatly improving the trainability of very deep networks.

Through our code and visual comparison, we saw how residual connections help maintain activation diversity and prevent the degradation often seen in plain deep networks.