# Only Maths You Will Ever Need for Deep Learning

## Section 4.2: Derivatives of Vector Element-wise Binary Operators

Neural networks often use element-wise operations on vectors, such as addition, subtraction, multiplication, and division. In these operations, each output element depends only on the corresponding elements of the input vectors. This leads to a major simplification when computing derivatives (Jacobians): the Jacobian matrices are diagonal.

Let's see how this works for common element-wise operations between two vectors $w$ and $x$.

In [1]:
import numpy as np

# Example vectors
w = np.array([2.0, 3.0, 4.0])
x = np.array([5.0, 6.0, 7.0])

# Element-wise operations
y_add = w + x
print("Addition (w + x):", y_add)

y_sub = w - x
print("Subtraction (w - x):", y_sub)

y_mul = w * x
print("Element-wise Multiplication (w * x):", y_mul)

y_div = w / x
print("Element-wise Division (w / x):", y_div)

Addition (w + x): [ 7.  9. 11.]
Subtraction (w - x): [-3. -3. -3.]
Element-wise Multiplication (w * x): [10. 18. 28.]
Element-wise Division (w / x): [0.4        0.5        0.57142857]


### Jacobians for Element-wise Operations

For $y = w + x$:
- $\frac{\partial y}{\partial w} = I$ (identity matrix)
- $\frac{\partial y}{\partial x} = I$

For $y = w - x$:
- $\frac{\partial y}{\partial w} = I$
- $\frac{\partial y}{\partial x} = -I$

For $y = w \odot x$ (element-wise multiplication):
- $\frac{\partial y}{\partial w} = \mathrm{diag}(x)$
- $\frac{\partial y}{\partial x} = \mathrm{diag}(w)$

For $y = w / x$ (element-wise division):
- $\frac{\partial y}{\partial w} = \mathrm{diag}(1/x)$
- $\frac{\partial y}{\partial x} = \mathrm{diag}(-w/x^2)$

Let's compute these Jacobians in code for our example vectors.

In [2]:
# Jacobian for y = w + x (identity matrix)
J_add_w = np.eye(len(w))
J_add_x = np.eye(len(x))
print("Jacobian of y = w + x with respect to w:\n", J_add_w)
print("Jacobian of y = w + x with respect to x:\n", J_add_x)

# Jacobian for y = w - x
J_sub_w = np.eye(len(w))
J_sub_x = -np.eye(len(x))
print("Jacobian of y = w - x with respect to w:\n", J_sub_w)
print("Jacobian of y = w - x with respect to x:\n", J_sub_x)

# Jacobian for y = w * x (element-wise multiplication)
J_mul_w = np.diag(x)
J_mul_x = np.diag(w)
print("Jacobian of y = w * x with respect to w:\n", J_mul_w)
print("Jacobian of y = w * x with respect to x:\n", J_mul_x)

# Jacobian for y = w / x (element-wise division)
J_div_w = np.diag(1 / x)
J_div_x = np.diag(-w / x**2)
print("Jacobian of y = w / x with respect to w:\n", J_div_w)
print("Jacobian of y = w / x with respect to x:\n", J_div_x)

Jacobian of y = w + x with respect to w:
 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
Jacobian of y = w + x with respect to x:
 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
Jacobian of y = w - x with respect to w:
 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
Jacobian of y = w - x with respect to x:
 [[-1. -0. -0.]
 [-0. -1. -0.]
 [-0. -0. -1.]]
Jacobian of y = w * x with respect to w:
 [[5. 0. 0.]
 [0. 6. 0.]
 [0. 0. 7.]]
Jacobian of y = w * x with respect to x:
 [[2. 0. 0.]
 [0. 3. 0.]
 [0. 0. 4.]]
Jacobian of y = w / x with respect to w:
 [[0.2        0.         0.        ]
 [0.         0.16666667 0.        ]
 [0.         0.         0.14285714]]
Jacobian of y = w / x with respect to x:
 [[-0.08        0.          0.        ]
 [ 0.         -0.08333333  0.        ]
 [ 0.          0.         -0.08163265]]


### Key Takeaway

For element-wise binary operations between vectors, the Jacobians with respect to each input are diagonal matrices. Each diagonal entry is simply the scalar derivative of the output element with respect to the corresponding input element. This makes backpropagation through such layers very efficient.

## Section 4.3: Derivatives Involving Scalar Expansion

In neural networks, we often perform operations between a vector and a scalar, such as adding a scalar bias to every element of a vector or multiplying a vector by a scalar. These operations can be viewed as element-wise operations where the scalar is broadcasted to match the vector's size.

Let's explore the derivatives for these cases:

### Example 1: Addition of a Scalar to a Vector

Let $y = x + z$, where $x$ is a vector and $z$ is a scalar. This is equivalent to $y_i = x_i + z$ for each $i$.

- **Derivative with respect to $x$:**
  - $\frac{\partial y}{\partial x} = I$ (identity matrix)
- **Derivative with respect to $z$:**
  - $\frac{\partial y}{\partial z} = \mathbf{1}$ (a column vector of ones)

In [3]:
import numpy as np

x = np.array([1.0, 2.0, 3.0])
z = 5.0

y_add = x + z
print("y = x + z:", y_add)

# Jacobian with respect to x (identity matrix)
J_add_x = np.eye(len(x))
print("Jacobian of y = x + z with respect to x:\n", J_add_x)

# Derivative with respect to scalar z (vector of ones)
J_add_z = np.ones((len(x), 1))
print("Jacobian of y = x + z with respect to z:\n", J_add_z)

y = x + z: [6. 7. 8.]
Jacobian of y = x + z with respect to x:
 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
Jacobian of y = x + z with respect to z:
 [[1.]
 [1.]
 [1.]]


### Example 2: Multiplication of a Vector by a Scalar

Let $y = xz$, where $x$ is a vector and $z$ is a scalar. This is equivalent to $y_i = x_i \cdot z$ for each $i$.

- **Derivative with respect to $x$:**
  - $\frac{\partial y}{\partial x} = zI$ (scalar $z$ times the identity matrix)
- **Derivative with respect to $z$:**
  - $\frac{\partial y}{\partial z} = x$ (the vector $x$ as a column vector)

In [4]:
# y = x * z

y_mul = x * z
print("y = x * z:", y_mul)

# Jacobian with respect to x (z times identity matrix)
J_mul_x = z * np.eye(len(x))
print("Jacobian of y = x * z with respect to x:\n", J_mul_x)

# Derivative with respect to scalar z (vector x)
J_mul_z = x.reshape(-1, 1)
print("Jacobian of y = x * z with respect to z:\n", J_mul_z)

y = x * z: [ 5. 10. 15.]
Jacobian of y = x * z with respect to x:
 [[5. 0. 0.]
 [0. 5. 0.]
 [0. 0. 5.]]
Jacobian of y = x * z with respect to z:
 [[1.]
 [2.]
 [3.]]


### Key Takeaway

When a vector and a scalar are combined via addition or multiplication, the Jacobian with respect to the vector is diagonal, while the derivative with respect to the scalar is a column vector. This reflects how a change in the scalar affects all elements of the output vector equally (for addition) or proportionally (for multiplication).

## Section 4.4: Vector Sum Reduction

In machine learning, we often reduce a vector to a scalar by summing its elements, such as when computing the total loss. This section explains how to compute the derivative of such a sum with respect to the input vector.

### Example 1: Derivative of the Sum of a Vector

Let $y = \sum x_i$, where $x$ is a vector. The derivative of $y$ with respect to $x$ is a row vector of ones:

$$
\nabla y = \frac{\partial y}{\partial x} = [1, 1, \ldots, 1]
$$

Intuition: Changing any $x_j$ by $\Delta$ changes $y$ by $\Delta$.

In [5]:
import numpy as np

x = np.array([2.0, 3.0, 4.0])
y = np.sum(x)
print("y = sum(x):", y)

grad_y_x = np.ones_like(x)
print("Gradient of y = sum(x) with respect to x:", grad_y_x)

y = sum(x): 9.0
Gradient of y = sum(x) with respect to x: [1. 1. 1.]


### Example 2: Derivative of the Sum of a Vector-Scalar Product

Let $y = \sum (x_i z)$, where $x$ is a vector and $z$ is a scalar. This is equivalent to $y = z \cdot \sum x_i$.

- **Derivative with respect to $x$:**
  - $\nabla y = [z, z, \ldots, z]$
- **Derivative with respect to $z$:**
  - $\frac{\partial y}{\partial z} = \sum x_i$

In [6]:
z = 5.0
y2 = np.sum(x * z)
print("y = sum(x * z):", y2)

grad_y2_x = np.ones_like(x) * z
print("Gradient of y = sum(x * z) with respect to x:", grad_y2_x)

grad_y2_z = np.sum(x)
print("Derivative of y = sum(x * z) with respect to z:", grad_y2_z)

y = sum(x * z): 45.0
Gradient of y = sum(x * z) with respect to x: [5. 5. 5.]
Derivative of y = sum(x * z) with respect to z: 9.0


### Key Takeaway

When you sum the elements of a vector, the gradient with respect to the vector is a row vector of ones. For a sum of a vector times a scalar, the gradient with respect to the vector is a row vector of the scalar, and with respect to the scalar is the sum of the vector. This is fundamental for backpropagation in neural networks.

## Section 4.5: The Chain Rules

The chain rule is the backbone of how derivatives are computed for complex, nested functions—exactly what neural networks are! It allows us to efficiently compute how a final output (like a loss) changes with respect to every parameter in a deeply nested function by breaking the computation into simpler steps.

### 4.5.1 Single-variable Chain Rule

Suppose $y = f(g(x))$. The chain rule states:

$$
\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}
$$
where $u = g(x)$ and $y = f(u)$.

**Example:**
Let $y = \sin(x^2)$. Set $u = x^2$, so $y = \sin(u)$.
- $du/dx = 2x$
- $dy/du = \cos(u)$
- $dy/dx = \cos(x^2) \cdot 2x$

In [7]:
import numpy as np

def chain_rule_example(x):
    u = x**2
    y = np.sin(u)
    dy_du = np.cos(u)
    du_dx = 2 * x
    dy_dx = dy_du * du_dx
    return y, dy_dx

x = 1.5
result, grad = chain_rule_example(x)
print(f"y = sin(x^2) at x={x}: {result}")
print(f"dy/dx at x={x}: {grad}")

y = sin(x^2) at x=1.5: 0.7780731968879212
dy/dx at x=1.5: -1.8845208681682175


### 4.5.2 Single-variable Total-Derivative Chain Rule

If an intermediate variable depends on $x$ both directly and indirectly, sum all paths:

Let $u_1(x) = x^2$, $y = u_2(x, u_1) = x + u_1$.

$$
\frac{dy}{dx} = \frac{\partial u_2}{\partial x} + \frac{\partial u_2}{\partial u_1} \cdot \frac{du_1}{dx}
$$

For $y = x + x^2$:
- $\partial u_2/\partial x = 1$
- $\partial u_2/\partial u_1 = 1$
- $du_1/dx = 2x$
- $dy/dx = 1 + 2x$

In [8]:
def total_derivative_example(x):
    u1 = x**2
    y = x + u1
    dy_dx = 1 + 2 * x
    return y, dy_dx

x = 2.0
result, grad = total_derivative_example(x)
print(f"y = x + x^2 at x={x}: {result}")
print(f"dy/dx at x={x}: {grad}")

y = x + x^2 at x=2.0: 6.0
dy/dx at x=2.0: 5.0


### 4.5.3 Vector Chain Rule

For vector functions, the chain rule generalizes to Jacobians:

If $u = g(x)$ and $y = f(u)$, then
$$
\frac{\partial y}{\partial x} = \frac{\partial y}{\partial u} \cdot \frac{\partial u}{\partial x}
$$
where $\frac{\partial y}{\partial u}$ and $\frac{\partial u}{\partial x}$ are Jacobian matrices, and the multiplication is matrix multiplication.

**Example:**
Let $g(x) = [x^2, \sin(x)]^T$, $f(u) = [u_1 + u_2, u_1 u_2]^T$.

- $u_1 = x^2$, $u_2 = \sin(x)$
- $y_1 = u_1 + u_2$, $y_2 = u_1 u_2$

Compute the Jacobians and apply the chain rule.

In [9]:
def vector_chain_rule_example(x):
    # Intermediate vector
    u1 = x**2
    u2 = np.sin(x)
    u = np.array([u1, u2])
    # Output vector
    y1 = u1 + u2
    y2 = u1 * u2
    y = np.array([y1, y2])
    # Jacobian of f with respect to u (2x2)
    df_du = np.array([
        [1, 1],        # dy1/du1, dy1/du2
        [u2, u1]       # dy2/du1, dy2/du2
    ])
    # Jacobian of u with respect to x (2x1)
    du_dx = np.array([[2*x], [np.cos(x)]])
    # Chain rule: Jacobian of y with respect to x (2x1)
    dy_dx = df_du @ du_dx
    return y, dy_dx

x = 1.0
result, grad = vector_chain_rule_example(x)
print(f"y = f(g(x)) at x={x}: {result}")
print(f"dy/dx at x={x}:\n{grad}")

y = f(g(x)) at x=1.0: [1.84147098 0.84147098]
dy/dx at x=1.0:
[[2.54030231]
 [2.22324428]]


### Key Takeaway

The chain rule, in its various forms, is the fundamental mechanism for calculating derivatives of complex, nested functions. In neural networks, the vector chain rule allows us to efficiently compute gradients for backpropagation by chaining together local derivatives (Jacobians) from each layer.

## Section 5: The Gradient of Neuron Activation

Let's apply the chain rule to a single neuron. We'll compute how the neuron's activation changes with respect to its weights and bias, which is fundamental for backpropagation.

A typical neuron computes:
1. **Affine function:** $z = w \cdot x + b$
2. **Activation function:** $a = A(z)$, e.g., $A(z) = \max(0, z)$ (ReLU)

We want to find $\frac{\partial a}{\partial w}$ and $\frac{\partial a}{\partial b}$.

### Step 1: Gradients of the Affine Part $z = w \cdot x + b$

- $\frac{\partial z}{\partial w} = x^T$
- $\frac{\partial z}{\partial b} = 1$

**Intuition:**
- Changing $w_j$ changes $z$ by $x_j$.
- Changing $b$ changes $z$ by $1$.

In [10]:
import numpy as np

# Example input
x = np.array([1.0, 2.0, 3.0])
w = np.array([0.5, -1.0, 2.0])
b = 0.1

# Affine part
z = np.dot(w, x) + b
print(f"z = w·x + b: {z}")

# Gradients
dz_dw = x
print("∂z/∂w:", dz_dw)

dz_db = 1.0
print("∂z/∂b:", dz_db)

z = w·x + b: 4.6
∂z/∂w: [1. 2. 3.]
∂z/∂b: 1.0


### Step 2: Gradient of the Activation Function (ReLU)

Let $a = \max(0, z)$. The derivative is:
- $\frac{da}{dz} = 0$ if $z \leq 0$
- $\frac{da}{dz} = 1$ if $z > 0$

This is the ReLU derivative.

In [11]:
def relu(z):
    return np.maximum(0, z)

def relu_grad(z):
    return 1.0 if z > 0 else 0.0

# Activation
a = relu(z)
print(f"a = ReLU(z): {a}")

# Derivative of activation with respect to z
da_dz = relu_grad(z)
print(f"∂a/∂z at z={z}: {da_dz}")

a = ReLU(z): 4.6
∂a/∂z at z=4.6: 1.0


### Step 3: Combine Using the Chain Rule

By the chain rule:
- $\frac{\partial a}{\partial w} = \frac{\partial a}{\partial z} \cdot \frac{\partial z}{\partial w}$
- $\frac{\partial a}{\partial b} = \frac{\partial a}{\partial z} \cdot \frac{\partial z}{\partial b}$

So:
- $\frac{\partial a}{\partial w} = 0^T$ if $z \leq 0$, $x^T$ if $z > 0$
- $\frac{\partial a}{\partial b} = 0$ if $z \leq 0$, $1$ if $z > 0$

In [12]:
# Gradients of activation with respect to w and b
if da_dz == 0:
    da_dw = np.zeros_like(w)
    da_db = 0.0
else:
    da_dw = x
    da_db = 1.0

print("∂a/∂w:", da_dw)
print("∂a/∂b:", da_db)

∂a/∂w: [1. 2. 3.]
∂a/∂b: 1.0


### Key Takeaway

For a neuron with $a = \max(0, w \cdot x + b)$:
- $\frac{\partial a}{\partial w} = x^T$ if $w \cdot x + b > 0$, $0^T$ otherwise
- $\frac{\partial a}{\partial b} = 1$ if $w \cdot x + b > 0$, $0$ otherwise

These local gradients are essential for backpropagation, telling us how to update the neuron's weights and bias.

## Section 6: The Gradient of the Neural Network Loss Function

This section brings together all the matrix calculus concepts to show how a neural network learns by minimizing a loss function. We'll derive the gradients of the Mean Squared Error (MSE) loss with respect to a neuron's weights $w$ and bias $b$ using the chain rule and the neuron's activation gradients from Section 5.

### Loss Function Setup

Suppose we have $N$ training examples. For each input $x_i$ and target $y_i$:
- The neuron computes $z_i = w \cdot x_i + b$
- The activation is $a_i = \max(0, z_i)$ (ReLU)
- The loss for one instance is $L_i = (y_i - a_i)^2$
- The total loss is $C(w, b, X, y) = \frac{1}{N} \sum_i (y_i - a_i)^2$

We want $\frac{\partial C}{\partial w}$ and $\frac{\partial C}{\partial b}$.

In [13]:
import numpy as np

# Example data: 3 training instances, 3 features
X = np.array([
    [1.0, 2.0, 3.0],
    [4.0, 5.0, 6.0],
    [7.0, 8.0, 9.0]
])
y = np.array([10.0, 20.0, 30.0])
w = np.array([0.5, -1.0, 2.0])
b = 0.1
N = X.shape[0]

def relu(z):
    return np.maximum(0, z)

def relu_grad(z):
    return (z > 0).astype(float)

# Forward pass
z = X @ w + b  # shape (N,)
a = relu(z)    # shape (N,)
loss = np.mean((y - a) ** 2)
print(f"Loss: {loss}")

Loss: 138.97666666666666


### Gradient with Respect to Weights $w$

For each instance $i$:
- $e_i = (w \cdot x_i + b) - y_i$
- If $z_i > 0$ (neuron active):
  - $\frac{\partial L_i}{\partial w} = 2 e_i x_i^T$
- If $z_i \leq 0$ (neuron inactive):
  - $\frac{\partial L_i}{\partial w} = 0$

Averaging over all $N$ instances:
$$
\frac{\partial C}{\partial w} = \frac{2}{N} \sum_{i: z_i > 0} e_i x_i^T
$$

In [14]:
# Compute gradients for all instances
active = z > 0  # Boolean mask for active neurons
errors = (z - y)

grad_w = np.zeros_like(w)
for i in range(N):
    if active[i]:
        grad_w += 2 * errors[i] * X[i]
grad_w /= N
print("∂C/∂w:", grad_w)

∂C/∂w: [-109.2 -131.  -152.8]


### Gradient with Respect to Bias $b$

For each instance $i$:
- If $z_i > 0$:
  - $\frac{\partial L_i}{\partial b} = 2 e_i$
- If $z_i \leq 0$:
  - $\frac{\partial L_i}{\partial b} = 0$

Averaging over all $N$ instances:
$$
\frac{\partial C}{\partial b} = \frac{2}{N} \sum_{i: z_i > 0} e_i
$$

In [15]:
grad_b = 0.0
for i in range(N):
    if active[i]:
        grad_b += 2 * errors[i]
grad_b /= N
print("∂C/∂b:", grad_b)

∂C/∂b: -21.8


### Key Takeaway

The gradients $\frac{\partial C}{\partial w}$ and $\frac{\partial C}{\partial b}$ tell us how to update the neuron's weights and bias to reduce the loss. These are the fundamental building blocks for training neural networks using gradient descent.

## Section 7: End-to-End Backpropagation Algorithm for Regression Using Gradient Descent

Now that we've derived the gradients for a single neuron, let's put it all together and implement the full backpropagation algorithm for a simple regression task using a single-layer neural network (one neuron) and gradient descent.

We'll use Mean Squared Error (MSE) as the loss function and update the weights and bias using the gradients computed in the previous section.

### Algorithm Steps

1. **Initialize parameters:** Randomly initialize weights $w$ and bias $b$.
2. **Forward pass:** For each training example, compute the neuron's output $a_i = \max(0, w \cdot x_i + b)$.
3. **Compute loss:** Calculate the MSE loss $C(w, b, X, y) = \frac{1}{N} \sum_i (y_i - a_i)^2$.
4. **Backward pass (compute gradients):**
    - Compute $\frac{\partial C}{\partial w}$ and $\frac{\partial C}{\partial b}$ as derived previously.
5. **Parameter update:** Update $w$ and $b$ using gradient descent:
    - $w \leftarrow w - \eta \frac{\partial C}{\partial w}$
    - $b \leftarrow b - \eta \frac{\partial C}{\partial b}$
6. **Repeat:** Iterate steps 2-5 for a fixed number of epochs or until convergence.

In [16]:
import numpy as np

# Example data: 3 training instances, 3 features
X = np.array([
    [1.0, 2.0, 3.0],
    [4.0, 5.0, 6.0],
    [7.0, 8.0, 9.0]
])
y = np.array([10.0, 20.0, 30.0])

# Hyperparameters
learning_rate = 0.01
num_epochs = 100

# Initialize weights and bias
np.random.seed(42)
w = np.random.randn(X.shape[1])
b = 0.0

# ReLU and its gradient
def relu(z):
    return np.maximum(0, z)

def relu_grad(z):
    return (z > 0).astype(float)

N = X.shape[0]

for epoch in range(num_epochs):
    # Forward pass
    z = X @ w + b  # shape (N,)
    a = relu(z)    # shape (N,)
    loss = np.mean((y - a) ** 2)

    # Gradients
    errors = z - y
    active = z > 0
    grad_w = np.zeros_like(w)
    grad_b = 0.0
    for i in range(N):
        if active[i]:
            grad_w += 2 * errors[i] * X[i]
            grad_b += 2 * errors[i]
    grad_w /= N
    grad_b /= N

    # Parameter update
    w -= learning_rate * grad_w
    b -= learning_rate * grad_b

    if (epoch + 1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}: Loss = {loss:.4f}")

print("Final weights:", w)
print("Final bias:", b)

Epoch 1: Loss = 252.0780
Epoch 10: Loss = 45.8938
Epoch 20: Loss = 7.5255
Epoch 30: Loss = 1.7175
Epoch 40: Loss = 0.7380
Epoch 50: Loss = 0.4949
Epoch 60: Loss = 0.3805
Epoch 70: Loss = 0.3010
Epoch 80: Loss = 0.2394
Epoch 90: Loss = 0.1907
Epoch 100: Loss = 0.1518
Final weights: [0.55084591 0.6847681  2.23962159]
Final bias: 0.7689006443276262


### Key Takeaway

This is the complete end-to-end backpropagation and gradient descent algorithm for a single neuron (regression with ReLU activation). The same principles extend to deeper networks, where gradients are propagated backward layer by layer using the chain rule.

## Bridge: From Shallow to Deep Networks – Why New Techniques Are Needed

So far in this notebook, we've explored the fundamentals of matrix calculus, gradients, and backpropagation for shallow neural networks. We've seen how to compute gradients for a single neuron and how to train it using gradient descent. But as we move to deeper networks (many layers, many neurons per layer), new challenges arise:

- Gradients can vanish or explode as they propagate through many layers, making training unstable or ineffective.
- The choice of activation function and weight initialization becomes critical for stable learning.
- Techniques like Batch Normalization are needed to keep the signal and gradients flowing well through deep architectures.

The next sections will address these challenges, starting with the vanishing/exploding gradient problem and how Batch Normalization helps.

## The Vanishing/Exploding Gradients Problems

As neural networks become deeper, a critical challenge emerges: gradients can either vanish (become extremely small) or explode (become extremely large) as they are propagated backward through many layers during training. This makes it very difficult for the lower layers to learn effectively.

- **Vanishing Gradients:** If the weights and activation derivatives are small, gradients shrink exponentially as they move backward, causing early layers to learn very slowly or not at all.
- **Exploding Gradients:** If weights or derivatives are large, gradients can grow exponentially, making training unstable and causing divergence.

**Why does this happen?**
- Each layer's gradient is a product of many terms (weights and activation derivatives). If these are consistently <1, gradients vanish; if >1, they explode.
- This is especially problematic with saturating activation functions (like sigmoid/tanh) and poor weight initialization.

**Key insight:** For stable training, the variance of activations and gradients should remain roughly constant across layers. Proper initialization and activation function choice are crucial.

## Interactive Visualization of Vanishing/Exploding Gradients and Activation Functions

This interactive dashboard helps you build intuition for two of the most important mathematical phenomena in deep learning: the vanishing/exploding gradient problem and the behavior of different activation functions, including modern choices like ELU and SELU.

## What the Dashboard Shows

**1. Gradient Magnitude Across Layers (Top Left Plot):**
- Shows how the magnitude of a gradient changes as it is propagated backward through multiple layers of a deep neural network.
- The x-axis is the layer number, and the y-axis (log scale) is the gradient magnitude, computed as $|w|^L$ where $|w|$ is the absolute value of the weight and $L$ is the number of layers.
- **Vanishing gradients** occur when $|w| < 1$, causing the gradient to shrink exponentially with depth. **Exploding gradients** occur when $|w| > 1$, causing the gradient to grow exponentially.
- Use the **Weight (|w|)** and **Layers** sliders to see how these parameters affect gradient flow.

**2. Activation Function Plot (Top Right):**
- Shows the shape of the selected activation function (ReLU, Leaky ReLU, ELU, SELU, Sigmoid, or Tanh).
- ELU (Exponential Linear Unit) and SELU (Scaled ELU) are modern activations that help maintain nonzero gradients for negative inputs and promote self-normalization.
- The x-axis is the input $z$, and the y-axis is the output of the activation function.
- Nonsaturating activations (like ReLU, Leaky ReLU, ELU, SELU) help maintain healthy gradients, while saturating activations (like Sigmoid and Tanh) can cause gradients to vanish for large $|z|$.
- Use the **Activation Function** dropdown and **Leaky α** slider (for Leaky ReLU/ELU) to explore different behaviors.

**3. Activation Derivative Plot (Bottom Right):**
- Shows the derivative of the selected activation function, which directly affects how gradients are backpropagated.
- For ReLU, the derivative is 1 for $z > 0$ and 0 for $z < 0$. For Leaky ReLU, the derivative is always nonzero, helping prevent "dead" neurons.
- For ELU and SELU, the derivative is smooth and nonzero for $z < 0$, supporting stable learning and self-normalization.
- For Sigmoid and Tanh, the derivative approaches zero for large $|z|$, illustrating why these activations can cause vanishing gradients in deep networks.

## How to Use the Dashboard
- Adjust the **Weight (|w|)** slider to see how the magnitude of weights affects gradient flow through layers.
- Change the **Layers** slider to simulate deeper or shallower networks.
- Modify the **Leaky α** slider to see how the slope for negative values in Leaky ReLU and ELU impacts the activation and its derivative.
- Switch between activation functions (including ELU and SELU) to compare their shapes and derivatives.

## Why This Matters
- **Vanishing/exploding gradients** are a core challenge in training deep neural networks. If gradients vanish, early layers learn very slowly; if they explode, training becomes unstable.
- **Activation function choice** is critical for stable training. Modern deep learning relies on nonsaturating activations (ReLU, Leaky ReLU, ELU, SELU) and careful initialization to keep gradients healthy.
- ELU and SELU are especially useful in deep architectures for their self-normalizing properties and ability to maintain nonzero gradients for all input values.

This dashboard provides a hands-on way to experiment with these concepts and see their effects in real time, deepening your understanding of the mathematical foundations of deep learning.

In [None]:
# Interactive dashboard for Vanishing/Exploding Gradients and Activation Functions
import numpy as np
import plotly.graph_objects as go
from ipywidgets import interact, interactive, fixed, widgets
from IPython.display import display, HTML

# Helper functions
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
def sigmoid_grad(x):
    s = sigmoid(x)
    return s * (1 - s)
def tanh_grad(x):
    return 1 - np.tanh(x) ** 2
def relu(x):
    return np.maximum(0, x)
def relu_grad(x):
    return (x > 0).astype(float)
def leaky_relu(x, alpha=0.2):
    return np.where(x > 0, x, alpha * x)
def leaky_relu_grad(x, alpha=0.2):
    return np.where(x > 0, 1, alpha)
def elu(x, alpha=1.0):
    return np.where(x >= 0, x, alpha * (np.exp(x) - 1))
def elu_grad(x, alpha=1.0):
    return np.where(x >= 0, 1, alpha * np.exp(x))
def selu(x):
    # Standard SELU parameters
    alpha = 1.6732632423543772
    scale = 1.0507009873554805
    return scale * elu(x, alpha)
def selu_grad(x):
    alpha = 1.6732632423543772
    scale = 1.0507009873554805
    return scale * elu_grad(x, alpha)

# Create the initial data
x = np.linspace(-5, 5, 400)
layers = np.arange(1, 21)

# Function to create and display the dashboard
def create_dashboard(weight=0.8, n_layers=5, alpha=0.2, activation='ReLU'):
    # Set up figure with 3 subplots
    fig = go.Figure()
    
    # 1. Gradient magnitude plot
    grad_mag = weight ** layers
    grad_fig = go.Figure()
    grad_fig.add_trace(go.Scatter(
        x=layers, 
        y=grad_mag, 
        mode='lines+markers', 
        name='Gradient Magnitude',
        line=dict(color='royalblue', width=3),
        marker=dict(size=8)
    ))
    grad_fig.update_layout(
        title='Gradient Magnitude Across Layers',
        xaxis=dict(title='Layer'),
        yaxis=dict(
            title='Gradient Magnitude', 
            type='log',
            range=[-5, np.log10(max(10, weight ** 20))],
        ),
        height=400,
        margin=dict(l=50, r=50, b=50, t=80),
    )
    
    # 2. Activation function plot
    if activation == 'ReLU':
        act_y = relu(x)
    elif activation == 'Leaky ReLU':
        act_y = leaky_relu(x, alpha)
    elif activation == 'Sigmoid':
        act_y = sigmoid(x)
    elif activation == 'Tanh':
        act_y = np.tanh(x)
    elif activation == 'ELU':
        act_y = elu(x, alpha)
    elif activation == 'SELU':
        act_y = selu(x)
    
    act_fig = go.Figure()
    act_fig.add_trace(go.Scatter(
        x=x, 
        y=act_y, 
        mode='lines', 
        name='Activation',
        line=dict(color='forestgreen', width=3)
    ))
    act_fig.update_layout(
        title=f'{activation} Activation Function',
        xaxis=dict(title='Input z'),
        yaxis=dict(
            title='Activation', 
            range=[-1.5, 5]
        ),
        height=400,
        margin=dict(l=50, r=50, b=50, t=80),
    )
    
    # 3. Activation derivative plot
    if activation == 'ReLU':
        grad_y = relu_grad(x)
    elif activation == 'Leaky ReLU':
        grad_y = leaky_relu_grad(x, alpha)
    elif activation == 'Sigmoid':
        grad_y = sigmoid_grad(x)
    elif activation == 'Tanh':
        grad_y = tanh_grad(x)
    elif activation == 'ELU':
        grad_y = elu_grad(x, alpha)
    elif activation == 'SELU':
        grad_y = selu_grad(x)
    
    deriv_fig = go.Figure()
    deriv_fig.add_trace(go.Scatter(
        x=x, 
        y=grad_y, 
        mode='lines', 
        name='Derivative',
        line=dict(color='crimson', width=3)
    ))
    deriv_fig.update_layout(
        title=f'{activation} Derivative',
        xaxis=dict(title='Input z'),
        yaxis=dict(
            title='Derivative', 
            range=[-0.1, 1.1]
        ),
        height=400,
        margin=dict(l=50, r=50, b=50, t=80),
    )
    
    return grad_fig, act_fig, deriv_fig

# Create interactive widgets
weight_slider = widgets.FloatSlider(
    value=0.8,
    min=0.1,
    max=2.0,
    step=0.01,
    description='Weight (|w|):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='500px')
)

layers_slider = widgets.IntSlider(
    value=5,
    min=1,
    max=20,
    step=1,
    description='Layers:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='500px')
)

alpha_slider = widgets.FloatSlider(
    value=0.2,
    min=0.01,
    max=0.5,
    step=0.01,
    description='Leaky/ELU α:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='500px')
)

activation_dropdown = widgets.Dropdown(
    options=['ReLU', 'Leaky ReLU', 'ELU', 'SELU', 'Sigmoid', 'Tanh'],
    value='ReLU',
    description='Activation:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='500px')
)

# Function to update the dashboard
def update_dashboard(weight, n_layers, alpha, activation):
    grad_fig, act_fig, deriv_fig = create_dashboard(weight, n_layers, alpha, activation)
    display(HTML("<div style='display: flex; flex-direction: column; gap: 20px;'>"))
    display(grad_fig)
    display(act_fig)
    display(deriv_fig)
    display(HTML("</div>"))

# Set up the interactive plot
interact(update_dashboard,
         weight=weight_slider,
         n_layers=layers_slider,
         alpha=alpha_slider,
         activation=activation_dropdown)

interactive(children=(FloatSlider(value=0.8, description='Weight (|w|):', layout=Layout(width='500px'), max=2.…

<function __main__.update_dashboard(weight, n_layers, alpha, activation)>

In [18]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create a subplot with 2 side-by-side cells
fig = make_subplots(
    rows=1, cols=2, 
    subplot_titles=(
        "<b>Recurrent Neuron - Rolled Representation</b>", 
        "<b>Recurrent Neuron - Unrolled Through Time</b>"
    ),
    horizontal_spacing=0.05
)

# --- Left side - rolled representation ---
# Draw the neuron as a circle
fig.add_shape(
    type="circle",
    xref="x", yref="y",
    x0=0.3, y0=0.3,
    x1=0.7, y1=0.7,
    fillcolor="rgba(173, 216, 230, 0.7)",  # lightblue with transparency
    line_color="black",
    line_width=2,
    row=1, col=1
)

# Add text for the neuron
fig.add_annotation(
    x=0.5, y=0.5,
    text="y(t)",
    showarrow=False,
    font=dict(size=14, color="black"),
    row=1, col=1
)

# Create a curved self-connection loop using a custom SVG path
# First segment - up
fig.add_shape(
    type="path",
    path="M 0.7 0.5 C 0.8 0.5 0.8 0.9 0.5 0.9 C 0.2 0.9 0.2 0.5 0.3 0.5",
    line=dict(color="green", width=3),
    row=1, col=1
)

# Add arrow head to the loop
fig.add_annotation(
    x=0.3, y=0.5,
    ax=0.35, ay=0.5,
    text="",
    showarrow=True,
    arrowhead=2,
    arrowsize=1.5,
    arrowcolor="green",
    arrowwidth=2,
    row=1, col=1
)

# Add text for the recurrent connection
fig.add_annotation(
    x=0.5, y=0.85,
    text="y(t-1)",
    font=dict(color="green", size=12),
    showarrow=False,
    row=1, col=1
)

# Draw the input arrow
fig.add_annotation(
    x=0.35, y=0.35,
    ax=0.1, ay=0.25,
    text="",
    showarrow=True,
    arrowhead=2,
    arrowsize=1.5,
    arrowcolor="blue",
    arrowwidth=2,
    row=1, col=1
)

# Add text for the input
fig.add_annotation(
    x=0.15, y=0.2,
    text="x(t)",
    font=dict(color="blue", size=12),
    showarrow=False,
    row=1, col=1
)

# Draw the output arrow
fig.add_annotation(
    x=0.9, y=0.25,
    ax=0.65, ay=0.35,
    text="",
    showarrow=True,
    arrowhead=2,
    arrowsize=1.5,
    arrowcolor="red",
    arrowwidth=2,
    row=1, col=1
)

# Add text for the output
fig.add_annotation(
    x=0.9, y=0.2,
    text="output",
    font=dict(color="red", size=12),
    showarrow=False,
    row=1, col=1
)

# --- Right side - unrolled representation ---
# Draw neurons at different time steps
t_positions = [0.2, 0.4, 0.6, 0.8]
t_labels = ['t-3', 't-2', 't-1', 't']

for i, (pos, label) in enumerate(zip(t_positions, t_labels)):
    # Draw neuron circle with hover info
    fig.add_shape(
        type="circle",
        xref="x2", yref="y2",
        x0=pos-0.05, y0=0.45,
        x1=pos+0.05, y1=0.55,
        fillcolor="rgba(173, 216, 230, 0.7)",  # lightblue with transparency
        line_color="black",
        line_width=2,
        row=1, col=2
    )
    
    # Add neuron label
    fig.add_annotation(
        x=pos, y=0.5,
        text=f"y({label})",
        showarrow=False,
        font=dict(size=11),
        row=1, col=2
    )
    
    # Draw input arrow
    fig.add_annotation(
        x=pos, y=0.4,
        ax=pos, ay=0.3,
        text="",
        showarrow=True,
        arrowhead=2,
        arrowsize=1.5,
        arrowcolor="blue",
        arrowwidth=2,
        row=1, col=2
    )
    
    # Add input label
    fig.add_annotation(
        x=pos, y=0.25,
        text=f"x({label})",
        showarrow=False,
        font=dict(size=10, color="blue"),
        row=1, col=2
    )
    
    # Draw output arrow
    fig.add_annotation(
        x=pos, y=0.7,
        ax=pos, ay=0.6,
        text="",
        showarrow=True,
        arrowhead=2,
        arrowsize=1.5,
        arrowcolor="red",
        arrowwidth=2,
        row=1, col=2
    )
    
    # Add output label
    fig.add_annotation(
        x=pos, y=0.75,
        text=f"y({label})",
        showarrow=False,
        font=dict(size=10, color="red"),
        row=1, col=2
    )
    
    # Draw connections between neurons
    if i > 0:
        fig.add_annotation(
            x=pos, y=0.5,
            ax=t_positions[i-1], ay=0.5,
            text="",
            showarrow=True,
            arrowhead=2,
            arrowsize=1.5,
            arrowcolor="green",
            arrowwidth=2,
            row=1, col=2
        )

# Add text for unrolling in time
fig.add_annotation(
    x=0.5, y=0.1,
    text="<b>Unrolling through time</b>",
    showarrow=False,
    font=dict(size=12, color="black"),
    row=1, col=2
)

# Add legend using invisible scatter traces
fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode='lines',
        line=dict(color='blue', width=2),
        name='Input Connection'
    ),
    row=1, col=1
)
fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode='lines',
        line=dict(color='red', width=2),
        name='Output Connection'
    ),
    row=1, col=1
)
fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode='lines',
        line=dict(color='green', width=2),
        name='Recurrent Connection'
    ),
    row=1, col=1
)

# Update layout
fig.update_xaxes(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False, row=1, col=1)
fig.update_yaxes(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False, row=1, col=1)
fig.update_xaxes(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False, row=1, col=2)
fig.update_yaxes(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False, row=1, col=2)

fig.update_layout(
    height=500,
    width=900,
    margin=dict(l=20, r=20, t=60, b=20),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.1,
        xanchor="center",
        x=0.5
    ),
    hoverlabel=dict(
        bgcolor="white",
        font_size=12
    ),
    title=dict(
        text="<b>Recurrent Neural Network Structure</b>",
        x=0.5
    )
)

fig.show()

In [19]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create a 2x2 subplot for the four RNN configurations
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=(
        '<b>Sequence-to-Sequence</b>', 
        '<b>Sequence-to-Vector</b>', 
        '<b>Vector-to-Sequence</b>', 
        '<b>Encoder-Decoder</b>'
    ),
    vertical_spacing=0.1,
    horizontal_spacing=0.1
)

# Helper function to draw RNN configurations
def draw_rnn_config(fig, row, col, input_seq=True, output_seq=True, has_encoder=False, has_decoder=False):
    # Set up coordinates
    x_spacing = 0.15
    y_input = 0.2
    y_rnn = 0.5
    y_output = 0.8
    start_x = 0.2
    
    # Draw RNN boxes
    if has_encoder and has_decoder:
        # Encoder part
        for i in range(3):
            x_pos = start_x + i * x_spacing
            fig.add_shape(
                type="rect",
                xref=f"x{(row-1)*2+col}", yref=f"y{(row-1)*2+col}",
                x0=x_pos-0.05, y0=y_rnn-0.05, 
                x1=x_pos+0.05, y1=y_rnn+0.05,
                fillcolor="rgba(173, 216, 230, 0.7)",  # Light blue with transparency
                line=dict(color="black", width=1.5),
                row=row, col=col
            )
            
            # Recurrent connections
            if i > 0:
                prev_x_pos = start_x + (i-1) * x_spacing
                fig.add_annotation(
                    x=x_pos-0.05, y=y_rnn,
                    ax=prev_x_pos+0.05, ay=y_rnn,
                    text="",
                    showarrow=True,
                    arrowhead=2,
                    arrowsize=1,
                    arrowcolor="green",
                    arrowwidth=1.5,
                    row=row, col=col
                )
        
        # Decoder part
        for i in range(3, 6):
            x_pos = start_x + i * x_spacing
            fig.add_shape(
                type="rect",
                xref=f"x{(row-1)*2+col}", yref=f"y{(row-1)*2+col}",
                x0=x_pos-0.05, y0=y_rnn-0.05, 
                x1=x_pos+0.05, y1=y_rnn+0.05,
                fillcolor="rgba(255, 250, 205, 0.7)",  # Light yellow with transparency
                line=dict(color="black", width=1.5),
                row=row, col=col
            )
            
            # Recurrent connections
            if i > 3:
                prev_x_pos = start_x + (i-1) * x_spacing
                fig.add_annotation(
                    x=x_pos-0.05, y=y_rnn,
                    ax=prev_x_pos+0.05, ay=y_rnn,
                    text="",
                    showarrow=True,
                    arrowhead=2,
                    arrowsize=1,
                    arrowcolor="green",
                    arrowwidth=1.5,
                    row=row, col=col
                )
                
        # Connection from encoder to decoder
        encoder_end_x = start_x + 2*x_spacing
        decoder_start_x = start_x + 3*x_spacing
        fig.add_annotation(
            x=decoder_start_x-0.05, y=y_rnn,
            ax=encoder_end_x+0.05, ay=y_rnn,
            text="",
            showarrow=True,
            arrowhead=2,
            arrowsize=1.5,
            arrowcolor="purple",
            arrowwidth=2,
            row=row, col=col
        )
        fig.add_annotation(
            x=start_x + 2.5*x_spacing, y=y_rnn+0.08,
            text="context vector",
            font=dict(size=9, color="purple"),
            showarrow=False,
            row=row, col=col
        )
    else:    
        # Standard RNN boxes
        for i in range(steps):
            x_pos = start_x + i * x_spacing
            fig.add_shape(
                type="rect",
                xref=f"x{(row-1)*2+col}", yref=f"y{(row-1)*2+col}",
                x0=x_pos-0.05, y0=y_rnn-0.05, 
                x1=x_pos+0.05, y1=y_rnn+0.05,
                fillcolor="rgba(173, 216, 230, 0.7)",  # Light blue with transparency
                line=dict(color="black", width=1.5),
                row=row, col=col
            )
            
            # Recurrent connections
            if i > 0:
                prev_x_pos = start_x + (i-1) * x_spacing
                fig.add_annotation(
                    x=x_pos-0.05, y=y_rnn,
                    ax=prev_x_pos+0.05, ay=y_rnn,
                    text="",
                    showarrow=True,
                    arrowhead=2,
                    arrowsize=1,
                    arrowcolor="green",
                    arrowwidth=1.5,
                    row=row, col=col
                )
    
    # Draw inputs
    if input_seq:
        for i in range(steps if not has_decoder else 3):
            x_pos = start_x + i * x_spacing
            
            # Input arrows
            fig.add_annotation(
                x=x_pos, y=y_rnn-0.05,
                ax=x_pos, ay=y_input+0.01,
                text="",
                showarrow=True,
                arrowhead=2,
                arrowsize=1.5,
                arrowcolor="blue",
                arrowwidth=2,
                row=row, col=col
            )
            fig.add_annotation(
                x=x_pos, y=y_input-0.05,
                text=f'X(t{i})',
                font=dict(size=9, color="blue"),
                showarrow=False,
                row=row, col=col
            )
    else:
        # Just a single input for vector-to-sequence
        fig.add_shape(
            type="rect",
            xref=f"x{(row-1)*2+col}", yref=f"y{(row-1)*2+col}",
            x0=start_x-0.05, y0=y_input-0.05, 
            x1=start_x+0.05, y1=y_input+0.05,
            fillcolor="rgba(144, 238, 144, 0.7)",  # Light green with transparency
            line=dict(color="black", width=1.5),
            row=row, col=col
        )
        fig.add_annotation(
            x=start_x, y=y_input,
            text='X',
            font=dict(size=9),
            showarrow=False,
            row=row, col=col
        )
        
        # Connect to all RNN cells or just the first for encoder-decoder
        if has_decoder:
            # Connect to first encoder cell only
            fig.add_annotation(
                x=start_x, y=y_rnn-0.05,
                ax=start_x, ay=y_input+0.05,
                text="",
                showarrow=True,
                arrowhead=2,
                arrowsize=1.5,
                arrowcolor="blue",
                arrowwidth=1.5,
                row=row, col=col
            )
        else:
            # Connect to all RNN cells with curved arrows
            for i in range(steps):
                x_pos = start_x + i * x_spacing
                
                # Create curved arrow paths
                x_mid = (start_x + x_pos)/2
                y_mid = y_input + 0.1 + i * 0.02  # Progressively higher curves
                
                # Add curved trace
                fig.add_trace(
                    go.Scatter(
                        x=[start_x, x_mid, x_pos],
                        y=[y_input, y_mid, y_rnn-0.05],
                        mode="lines",
                        line=dict(color="blue", width=1.5),
                        showlegend=False,
                    ),
                    row=row, col=col
                )
                
                # Add arrowhead
                fig.add_annotation(
                    x=x_pos, y=y_rnn-0.05,
                    ax=x_pos-(x_pos-x_mid)/10, ay=y_rnn-0.05-(y_rnn-0.05-y_mid)/10,
                    text="",
                    showarrow=True,
                    arrowhead=2,
                    arrowsize=1,
                    arrowcolor="blue",
                    arrowwidth=1.5,
                    row=row, col=col
                )
    
    # Draw outputs
    if output_seq:
        output_range = range(steps) if not has_encoder else range(3, 6)
        for i in output_range:
            x_pos = start_x + i * x_spacing
            
            # Output arrows
            fig.add_annotation(
                x=x_pos, y=y_output,
                ax=x_pos, ay=y_rnn+0.05,
                text="",
                showarrow=True,
                arrowhead=2,
                arrowsize=1.5,
                arrowcolor="red",
                arrowwidth=2,
                row=row, col=col
            )
            fig.add_annotation(
                x=x_pos, y=y_output+0.05,
                text=f'Y(t{i-3 if has_encoder else i})',
                font=dict(size=9, color="red"),
                showarrow=False,
                row=row, col=col
            )
    else:
        # Just a single output for sequence-to-vector
        x_pos = start_x + (steps-1) * x_spacing
        fig.add_shape(
            type="rect",
            xref=f"x{(row-1)*2+col}", yref=f"y{(row-1)*2+col}",
            x0=x_pos-0.05, y0=y_output-0.05, 
            x1=x_pos+0.05, y1=y_output+0.05,
            fillcolor="rgba(250, 128, 114, 0.7)",  # Salmon with transparency
            line=dict(color="black", width=1.5),
            row=row, col=col
        )
        fig.add_annotation(
            x=x_pos, y=y_output,
            text='Y',
            font=dict(size=9),
            showarrow=False,
            row=row, col=col
        )
        
        # Connect last RNN cell to output
        fig.add_annotation(
            x=x_pos, y=y_output-0.05,
            ax=x_pos, ay=y_rnn+0.05,
            text="",
            showarrow=True,
            arrowhead=2,
            arrowsize=1.5,
            arrowcolor="red",
            arrowwidth=1.5,
            row=row, col=col
        )

# Draw the four different configurations
draw_rnn_config(fig, 1, 1, input_seq=True, output_seq=True)
draw_rnn_config(fig, 1, 2, input_seq=True, output_seq=False)
draw_rnn_config(fig, 2, 1, input_seq=False, output_seq=True)
draw_rnn_config(fig, 2, 2, input_seq=True, output_seq=True, has_encoder=True, has_decoder=True)

# Add legend using invisible scatter traces
fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode='lines',
        line=dict(color='blue', width=2),
        name='Input Connection'
    )
)
fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode='lines',
        line=dict(color='red', width=2),
        name='Output Connection'
    )
)
fig.add_trace(
    go.Scatter(
        x=[None], y=[None],
        mode='lines',
        line=dict(color='green', width=2),
        name='Recurrent Connection'
    )
)
if any([r == 2 and c == 2 for r in range(1, 3) for c in range(1, 3)]):
    fig.add_trace(
        go.Scatter(
            x=[None], y=[None],
            mode='lines',
            line=dict(color='purple', width=2),
            name='Context Vector'
        )
    )

# Update layout for all subplots
for i in range(1, 3):
    for j in range(1, 3):
        fig.update_xaxes(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False, row=i, col=j)
        fig.update_yaxes(range=[0, 1], showgrid=False, zeroline=False, showticklabels=False, row=i, col=j)

fig.update_layout(
    height=800,
    width=900,
    showlegend=True,
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.1,
        xanchor="center",
        x=0.5
    ),
    margin=dict(l=20, r=20, t=60, b=50),
    title=dict(
        text="<b>RNN Sequence Processing Configurations</b>",
        x=0.5
    )
)

fig.show()

NameError: name 'steps' is not defined

In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Define a simple RNN class
class SimpleRNN:
    def __init__(self, input_size, hidden_size):
        # Initialize weights and biases with controlled values for better visualization
        np.random.seed(42)  # For reproducibility
        self.Wx = np.random.randn(input_size, hidden_size) * 0.01  # Input weights
        self.Wy = np.random.randn(hidden_size, hidden_size) * 0.01  # Recurrent weights
        self.b = np.zeros((1, hidden_size))  # Bias
        
        self.hidden_size = hidden_size
        
    def forward(self, x_sequence):
        # x_sequence shape: (seq_length, input_size)
        seq_length = len(x_sequence)
        # Initialize states and outputs
        h = np.zeros((seq_length + 1, self.hidden_size))
        
        # Forward pass through time
        for t in range(seq_length):
            # Current input
            x_t = x_sequence[t:t+1]
            # Previous hidden state
            h_prev = h[t:t+1]
            
            # Calculate current hidden state (which is also the output for SimpleRNN)
            h[t+1] = np.tanh(np.dot(x_t, self.Wx) + np.dot(h_prev, self.Wy) + self.b)
        
        return h[1:]  # Return all states except initial state h[-1]

# Example: Generate a sequence with specific pattern for better visualization
seq_length = 15  # Longer sequence for better visualization
input_size = 2
hidden_size = 3

# Create input sequence with a pattern
x_sequence = np.zeros((seq_length, input_size))
# Add a sine wave pattern to first dimension
x_sequence[:, 0] = np.sin(np.linspace(0, 4*np.pi, seq_length))
# Add a cosine wave pattern to second dimension (out of phase)
x_sequence[:, 1] = np.cos(np.linspace(0, 3*np.pi, seq_length))

# Initialize RNN and run forward pass
rnn = SimpleRNN(input_size, hidden_size)
outputs = rnn.forward(x_sequence)

# Visualize the outputs with Plotly - improved version
fig = make_subplots(rows=2, cols=1, 
                   subplot_titles=('<b>RNN Hidden States Over Time</b>', 
                                  '<b>Input Sequence</b>'),
                   vertical_spacing=0.15,
                   row_heights=[0.7, 0.3])

# Plot hidden states
for i in range(hidden_size):
    fig.add_trace(
        go.Scatter(
            x=list(range(seq_length)),
            y=outputs[:, i],
            mode='lines+markers',
            name=f'Neuron {i+1}',
            line=dict(width=2),
            marker=dict(size=6),
            hovertemplate='Time Step: %{x}<br>Value: %{y:.4f}<extra></extra>'
        ),
        row=1, col=1
    )

# Plot input sequence
for i in range(input_size):
    fig.add_trace(
        go.Scatter(
            x=list(range(seq_length)),
            y=x_sequence[:, i],
            mode='lines+markers',
            name=f'Input {i+1}',
            line=dict(dash='dash', width=2),
            marker=dict(size=5),
            hovertemplate='Time Step: %{x}<br>Value: %{y:.4f}<extra></extra>'
        ),
        row=2, col=1
    )

# Update layout
fig.update_layout(
    height=700,
    width=900,
    title=dict(
        text='<b>RNN Hidden States Visualization</b>',
        x=0.5,
        font=dict(size=18)
    ),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.2,
        xanchor="center",
        x=0.5
    ),
    margin=dict(l=60, r=30, t=100, b=100),
    hovermode='closest'
)

# Update axes
fig.update_xaxes(title_text='Time Step', row=2, col=1)
fig.update_yaxes(title_text='Hidden State Values', row=1, col=1)
fig.update_yaxes(title_text='Input Values', row=2, col=1)

fig.show()

# Visualize how recurrent connections carry information forward using Plotly - improved version
# Choose one dimension to visualize
dim_to_visualize = 0

# Create a new sequence with a spike in the first time step
spike_sequence = np.zeros((seq_length, input_size))
spike_sequence[0, dim_to_visualize] = 3.0  # Strong initial signal

# Run the RNN
spike_outputs = rnn.forward(spike_sequence)

# Create subplots
fig = make_subplots(
    rows=2, cols=1, 
    subplot_titles=("<b>Input with Spike at Time 0</b>", 
                   "<b>RNN Response to Spike Input (Memory Effect)</b>"),
    vertical_spacing=0.15,
    row_heights=[0.3, 0.7]
)

# Add spike input plot - improved visual
fig.add_trace(
    go.Bar(
        x=[0],
        y=[spike_sequence[0, dim_to_visualize]],
        name='Input Spike',
        marker=dict(color='rgba(255, 0, 0, 0.7)'),
        width=0.3,
        hovertemplate='Spike value: %{y:.2f}<extra></extra>'
    ),
    row=1, col=1
)

# Add zero values for other time steps for clear visualization
fig.add_trace(
    go.Bar(
        x=list(range(1, seq_length)),
        y=np.zeros(seq_length-1),
        name='No Input',
        marker=dict(color='rgba(200, 200, 200, 0.4)'),
        hoverinfo='skip',
        showlegend=False
    ),
    row=1, col=1
)

# Add neuron output traces with better styling
colors = ['rgb(31, 119, 180)', 'rgb(255, 127, 14)', 'rgb(44, 160, 44)']
for i in range(hidden_size):
    fig.add_trace(
        go.Scatter(
            x=list(range(seq_length)),
            y=spike_outputs[:, i],
            mode='lines+markers',
            name=f'Neuron {i+1}',
            line=dict(color=colors[i], width=2),
            marker=dict(size=6, color=colors[i]),
            hovertemplate='Time: %{x}<br>Activation: %{y:.4f}<extra></extra>'
        ),
        row=2, col=1
    )

# Update layout
fig.update_layout(
    height=700,
    width=900,
    title=dict(
        text='<b>Memory Effect in RNN</b>',
        x=0.5,
        font=dict(size=18)
    ),
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.15,
        xanchor="center",
        x=0.5
    ),
    margin=dict(l=60, r=30, t=100, b=100),
    hovermode='closest'
)

# Update axes
fig.update_xaxes(title_text="Time Step", range=[-0.5, seq_length-0.5], row=2, col=1)
fig.update_xaxes(title_text="Time Step", range=[-0.5, seq_length-0.5], row=1, col=1)
fig.update_yaxes(title_text="Input Value", row=1, col=1)
fig.update_yaxes(title_text="Hidden State Values", row=2, col=1)

# Add annotations to highlight the decay phenomenon
max_idx = np.argmax(spike_outputs[:, 0])
max_val = spike_outputs[max_idx, 0]
last_val = spike_outputs[-1, 0]

fig.add_annotation(
    x=max_idx,
    y=max_val,
    text="Initial Response",
    showarrow=True,
    arrowhead=2,
    ax=20,
    ay=-40
)

fig.add_annotation(
    x=seq_length-2,
    y=last_val,
    text="Fading Memory",
    showarrow=True,
    arrowhead=2,
    ax=-30,
    ay=-40
)

fig.show()

In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Define a simple RNN with BPTT
class SimpleRNNWithBPTT:
    def __init__(self, input_size, hidden_size, output_size, learning_rate=0.01):
        np.random.seed(42)  # For reproducible results
        # Initialize weights and biases
        self.Wx = np.random.randn(input_size, hidden_size) * 0.01
        self.Wy = np.random.randn(hidden_size, hidden_size) * 0.01
        self.Wo = np.random.randn(hidden_size, output_size) * 0.01
        self.bh = np.zeros((1, hidden_size))
        self.bo = np.zeros((1, output_size))
        
        self.learning_rate = learning_rate
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.output_size = output_size
        
    def forward(self, x_sequence):
        # x_sequence shape: (seq_length, input_size)
        seq_length = len(x_sequence)
        # Initialize states and outputs
        h = np.zeros((seq_length + 1, self.hidden_size))  # +1 for initial state h[-1]
        y = np.zeros((seq_length, self.output_size))
        
        # Forward pass through time
        for t in range(seq_length):
            # Current input
            x_t = x_sequence[t:t+1]
            # Previous hidden state
            h_prev = h[t:t+1]
            
            # Calculate current hidden state
            h[t+1] = np.tanh(np.dot(x_t, self.Wx) + np.dot(h_prev, self.Wy) + self.bh)
            # Calculate output
            y[t] = np.dot(h[t+1], self.Wo) + self.bo
        
        return h[1:], y  # Return hidden states and outputs
    
    def backprop(self, x, targets, h, outputs):
        # Initialize gradients
        dWx = np.zeros_like(self.Wx)
        dWy = np.zeros_like(self.Wy)
        dWo = np.zeros_like(self.Wo)
        dbh = np.zeros_like(self.bh)
        dbo = np.zeros_like(self.bo)
        
        # Initial gradient of hidden state
        dh_next = np.zeros((1, self.hidden_size))
        
        seq_length = len(x)
        
        # Backpropagation through time
        for t in reversed(range(seq_length)):
            # Gradient from output (MSE loss)
            dy = outputs[t] - targets[t:t+1]
            
            # Gradient of output weights
            h_t_reshaped = h[t].reshape(1, -1)
            dWo += np.dot(h_t_reshaped.T, dy)
            dbo += dy
            
            # Gradient into hidden layer
            dh = np.dot(dy, self.Wo.T) + dh_next
            
            # Gradient through tanh
            dh_raw = (1 - h[t] ** 2) * dh
            
            # Gradient of hidden weights, biases, and recurrent weights
            dbh += dh_raw
            dWx += np.dot(x[t].reshape(-1, 1), dh_raw)
            
            # We need the previous hidden state
            prev_h = np.zeros((1, self.hidden_size)) if t == 0 else h[t-1:t]
            dWy += np.dot(prev_h.T, dh_raw)
            
            # Gradient for next iteration
            dh_next = np.dot(dh_raw, self.Wy.T)
        
        # Clip gradients to prevent exploding gradients
        for grad in [dWx, dWy, dWo, dbh, dbo]:
            np.clip(grad, -5, 5, out=grad)
        
        # Update weights
        self.Wx -= self.learning_rate * dWx
        self.Wy -= self.learning_rate * dWy
        self.Wo -= self.learning_rate * dWo
        self.bh -= self.learning_rate * dbh
        self.bo -= self.learning_rate * dbo

# Create a simple sequence prediction task
# We'll create a sine wave and have the RNN predict the next value
def generate_sine_wave(seq_length, frequency=0.1):
    x = np.linspace(0, 10, seq_length)
    sine_wave = np.sin(x * frequency)
    return sine_wave.reshape(-1, 1)

# Generate data
seq_length = 100
sine_wave = generate_sine_wave(seq_length + 1)  # +1 to allow for targets

# Create training data and targets
x_train = sine_wave[:-1]  # All but the last point
y_train = sine_wave[1:]   # All but the first point (shifted by 1)

# Initialize RNN with smaller learning rate for smoother convergence
rnn = SimpleRNNWithBPTT(input_size=1, hidden_size=16, output_size=1, learning_rate=0.005)

# Training loop with logging for visualization
epochs = 300
loss_history = []
predictions_over_time = []

# Store initial predictions
_, initial_preds = rnn.forward(x_train)
predictions_over_time.append(initial_preds.copy())

# Training epochs
for epoch in range(epochs):
    # Forward pass
    h, y_pred = rnn.forward(x_train)
    
    # Calculate loss (MSE)
    loss = np.mean((y_pred - y_train) ** 2)
    loss_history.append(loss)
    
    # Save predictions periodically
    if epoch % 50 == 0 and epoch > 0:
        predictions_over_time.append(y_pred.copy())
    
    # Backpropagation
    rnn.backprop(x_train, y_train, h, y_pred)
    
    # Print progress
    if epoch % 30 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

# Get final predictions
_, final_predictions = rnn.forward(x_train)
predictions_over_time.append(final_predictions)

# Visualize training loss history with improved Plotly visualization
fig_loss = go.Figure()
fig_loss.add_trace(
    go.Scatter(
        x=list(range(epochs)),
        y=loss_history,
        mode='lines',
        name='Training Loss',
        line=dict(
            color='rgb(31, 119, 180)',
            width=2
        ),
        hovertemplate='Epoch: %{x}<br>Loss: %{y:.6f}<extra></extra>'
    )
)

# Add markers at specific loss values
marker_epochs = [0, 50, 100, 200, epochs-1]
marker_losses = [loss_history[e] for e in marker_epochs]

fig_loss.add_trace(
    go.Scatter(
        x=marker_epochs,
        y=marker_losses,
        mode='markers',
        marker=dict(
            color='red',
            size=8,
            line=dict(
                color='darkred',
                width=1.5
            )
        ),
        name='Highlighted Points',
        hovertemplate='Epoch: %{x}<br>Loss: %{y:.6f}<extra></extra>'
    )
)

fig_loss.update_layout(
    title=dict(
        text='<b>RNN Training Loss Over Time</b>',
        x=0.5,
        font=dict(size=18)
    ),
    xaxis_title='Epoch',
    yaxis_title='MSE Loss',
    yaxis=dict(
        type='log',
        showgrid=True,
        gridwidth=0.5,
        gridcolor='lightgray'
    ),
    width=900,
    height=500,
    margin=dict(l=60, r=30, t=80, b=60),
    hovermode='closest',
    legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="right",
        x=0.99,
        bgcolor="rgba(255,255,255,0.8)"
    ),
    shapes=[
        # Annotation line for convergence region
        dict(
            type="line",
            yref="y",
            y0=loss_history[-1]*1.2,
            y1=loss_history[-1]*1.2,
            xref="x",
            x0=epochs*0.6,
            x1=epochs-1,
            line=dict(
                color="green",
                width=2,
                dash="dot",
            ),
        ),
    ],
    annotations=[
        dict(
            x=epochs*0.8,
            y=loss_history[-1]*1.3,
            text="Convergence Region",
            showarrow=False,
            font=dict(color="green")
        )
    ]
)

# Show the loss curve
fig_loss.show()

# Visualize predictions with interactive Plotly
# Create a subplot with 2 rows
fig_pred = make_subplots(
    rows=2, 
    cols=1,
    subplot_titles=(
        '<b>RNN Sine Wave Prediction (Final Result)</b>', 
        '<b>Training Progress Comparison</b>'
    ),
    vertical_spacing=0.15,
    row_heights=[0.6, 0.4]
)

# Plot the final predictions vs true values
fig_pred.add_trace(
    go.Scatter(
        x=list(range(seq_length)),
        y=y_train.flatten(),
        mode='lines',
        name='True Values',
        line=dict(color='blue', width=2.5),
        hovertemplate='Time Step: %{x}<br>True Value: %{y:.4f}<extra></extra>'
    ),
    row=1, col=1
)
fig_pred.add_trace(
    go.Scatter(
        x=list(range(seq_length)),
        y=final_predictions.flatten(),
        mode='lines',
        name='Predicted Values',
        line=dict(color='red', width=2, dash='dot'),
        hovertemplate='Time Step: %{x}<br>Prediction: %{y:.4f}<extra></extra>'
    ),
    row=1, col=1
)

# Calculate and plot error
error = np.abs(y_train.flatten() - final_predictions.flatten())
fig_pred.add_trace(
    go.Bar(
        x=list(range(seq_length)),
        y=error,
        name='Absolute Error',
        marker_color='rgba(255, 165, 0, 0.5)',
        hovertemplate='Time Step: %{x}<br>Error: %{y:.4f}<extra></extra>'
    ),
    row=2, col=1
)

# Show predictions at different stages of training
stages = ['Initial', 'After 50 epochs', 'After 100 epochs', 
          'After 150 epochs', 'After 200 epochs', 'After 250 epochs', 'Final']
colors = ['rgba(128,128,128,0.7)', 'rgba(250,128,114,0.7)', 'rgba(255,165,0,0.7)', 
          'rgba(107,142,35,0.7)', 'rgba(65,105,225,0.7)', 'rgba(186,85,211,0.7)', 'rgba(255,0,0,0.7)']

# Select subset of time steps for clarity
time_steps = np.arange(0, 100, 3)

# Add button for each epoch's predictions
buttons = []
for i, preds in enumerate(predictions_over_time):
    opacity = 0.4 if i < len(predictions_over_time)-1 else 0.8
    epoch_num = 0 if i == 0 else (50*i if i < len(predictions_over_time)-1 else epochs)
    
    trace = go.Scatter(
        x=time_steps,
        y=preds.flatten()[time_steps],
        mode='lines',
        name=f'{stages[i]} (Epoch {epoch_num})',
        line=dict(color=colors[i], width=2, dash='dot' if i < len(predictions_over_time)-1 else 'solid'),
        visible=(i == len(predictions_over_time)-1),  # Only show final prediction by default
        hovertemplate=f'{stages[i]}<br>Time: %{{x}}<br>Value: %{{y:.4f}}<extra></extra>'
    )
    fig_pred.add_trace(trace, row=1, col=1)
    
    # Create button for this epoch's visibility
    visible_list = [False] * len(predictions_over_time)
    visible_list[i] = True  # Make this prediction visible
    buttons.append(dict(
        method="update",
        label=stages[i],
        args=[{"visible": [True, True, True] + visible_list}],  # Keep true/pred/error traces visible
    ))

# Add slider for prediction over time
sliders = [dict(
    active=len(predictions_over_time)-1,
    currentvalue={"prefix": "Showing: "},
    pad={"t": 50},
    steps=[dict(
        method="update",
        label=stages[i],
        args=[{"visible": [True, True, True] + [i==j for j in range(len(predictions_over_time))]}],
    ) for i in range(len(predictions_over_time))]
)]

# Configure the layout
fig_pred.update_layout(
    height=800,
    width=900,
    title=dict(
        text='<b>RNN Sine Wave Prediction</b>',
        x=0.5,
        font=dict(size=18)
    ),
    sliders=sliders,
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.2,
        xanchor="center",
        x=0.5
    ),
    margin=dict(l=60, r=30, t=80, b=100),
    hovermode='closest'
)

# Update axes
fig_pred.update_xaxes(title_text="Time Step", row=1, col=1)
fig_pred.update_xaxes(title_text="Time Step", row=2, col=1)
fig_pred.update_yaxes(title_text="Value", row=1, col=1)
fig_pred.update_yaxes(title_text="Absolute Error", row=2, col=1)

# Show the figure
fig_pred.show()

Epoch 0, Loss: 0.2760
Epoch 30, Loss: 0.0489
Epoch 60, Loss: 0.0180
Epoch 90, Loss: 0.0002
Epoch 120, Loss: 0.0001
Epoch 150, Loss: 0.0001
Epoch 180, Loss: 0.0000
Epoch 210, Loss: 0.0000
Epoch 240, Loss: 0.0000
Epoch 270, Loss: 0.0000
Epoch 210, Loss: 0.0000
Epoch 240, Loss: 0.0000
Epoch 270, Loss: 0.0000


## Limitations of Basic RNNs: Vanishing and Exploding Gradients

While the basic RNN structure we've explored can theoretically process sequences of arbitrary length, in practice they struggle to learn long-range dependencies. This is due to two major problems:

### 1. Vanishing Gradients

During BPTT, gradients are multiplied by the recurrent weight matrix Wy many times as they flow backward through time. If these weights have values less than 1, the gradients shrink exponentially with each time step, becoming effectively zero for steps far in the past. This means:

- The network cannot learn long-term dependencies
- Only recent inputs influence predictions
- Training becomes extremely slow for earlier time steps

### 2. Exploding Gradients

Conversely, if recurrent weights have values greater than 1, gradients grow exponentially during backpropagation, leading to:

- Numerical instability
- Extremely large weight updates
- Training divergence

### Solutions

To address these problems, several improved RNN architectures have been developed:

1. **Long Short-Term Memory (LSTM)** - Introduces gates to control information flow and maintain gradients
2. **Gated Recurrent Unit (GRU)** - A simplified version of LSTM with fewer parameters
3. **Gradient clipping** - Prevents exploding gradients by limiting their magnitude
4. **Skip connections** - Creates shortcuts for gradient flow across multiple time steps

These advanced architectures have largely replaced simple RNNs in practical applications that require learning long-range dependencies.

## Summary: Key Concepts in Recurrent Neural Networks

1. **Structure of RNNs**:
   - Recurrent neurons have connections pointing backward, forming loops
   - These loops create a form of memory allowing the network to consider past inputs
   - Mathematically represented as: $y(t) = \phi(W_x^T x(t) + W_y^T y(t-1) + b)$

2. **Unrolling through time**:
   - An RNN can be visualized by unrolling it through time
   - Each time step becomes a layer in a very deep network that shares weights

3. **Sequence Processing Configurations**:
   - Sequence-to-Sequence: Maps input sequences to output sequences
   - Sequence-to-Vector: Summarizes a sequence into a single vector
   - Vector-to-Sequence: Generates a sequence from a static input
   - Encoder-Decoder: Combines the above for complex tasks like translation

4. **Training via BPTT**:
   - Forward pass processes the sequence step by step
   - Cost function evaluates selected outputs
   - Gradients flow backward through time
   - Shared weights result in summing gradients across time steps

5. **Limitations**:
   - Basic RNNs struggle with long sequences due to vanishing/exploding gradients
   - Advanced architectures like LSTM and GRU address these limitations

Understanding these concepts provides the foundation for working with sequential data in deep learning, from natural language processing to time series forecasting and beyond.

# Long Short-Term Memory (LSTM) Networks: Overcoming RNN Limitations

LSTMs were introduced to address the vanishing and exploding gradient problems that plague basic RNNs, enabling the learning of long-range dependencies in sequential data. They achieve this by introducing a memory cell and a set of gates that control the flow of information.

## Motivation: Why LSTMs?
- **Vanishing gradients** in RNNs make it hard to learn dependencies over long sequences.
- LSTMs use a memory cell and gating mechanisms to allow gradients to flow unchanged over many time steps, preserving information and enabling learning of long-term dependencies.

## LSTM Architecture and Gating Mechanisms
An LSTM cell consists of:
- **Cell state ($c_t$):** The memory of the network, running through the cell with only minor linear interactions.
- **Hidden state ($h_t$):** The output at each time step.
- **Gates:** Neural network layers that control the flow of information:
    - **Forget gate ($f_t$):** Decides what information to discard from the cell state.
    - **Input gate ($i_t$):** Decides what new information to add to the cell state.
    - **Output gate ($o_t$):** Decides what part of the cell state to output.

### LSTM Equations
Given input $x_t$, previous hidden state $h_{t-1}$, and previous cell state $c_{t-1}$:

$$
\begin{align*}
f_t &= \sigma(W_f x_t + U_f h_{t-1} + b_f) &\text{(forget gate)} \\
i_t &= \sigma(W_i x_t + U_i h_{t-1} + b_i) &\text{(input gate)} \\
\tilde{c}_t &= \tanh(W_c x_t + U_c h_{t-1} + b_c) &\text{(candidate cell state)} \\
c_t &= f_t \odot c_{t-1} + i_t \odot \tilde{c}_t &\text{(new cell state)} \\
o_t &= \sigma(W_o x_t + U_o h_{t-1} + b_o) &\text{(output gate)} \\
h_t &= o_t \odot \tanh(c_t) &\text{(new hidden state)}
\end{align*}
$$

Where $\sigma$ is the sigmoid function, $\tanh$ is the hyperbolic tangent, and $\odot$ denotes element-wise multiplication.

### LSTM Cell Diagram
Below is a schematic of the LSTM cell, showing the flow of information and the gating mechanisms:

![LSTM Cell Diagram](https://colah.github.io/posts/2015-08-Understanding-LSTMs/img/LSTM3-chain.png)

*Source: Chris Olah's blog*

## LSTM vs. RNN: Key Differences
- **Memory cell:** LSTMs have an explicit memory cell, RNNs do not.
- **Gating:** LSTMs use gates to control information flow, RNNs use a single activation.
- **Gradient flow:** LSTMs preserve gradients over long sequences, RNNs do not.

## LSTM Implementation and Visualization
Let's implement a simple LSTM from scratch (no deep learning frameworks) and visualize its hidden and cell states over time.

In [None]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

class SimpleLSTM:
    def __init__(self, input_size, hidden_size):
        np.random.seed(42)
        self.input_size = input_size
        self.hidden_size = hidden_size
        # Weights for gates and cell
        self.Wf = np.random.randn(input_size, hidden_size) * 0.1
        self.Uf = np.random.randn(hidden_size, hidden_size) * 0.1
        self.bf = np.zeros((1, hidden_size))
        
        self.Wi = np.random.randn(input_size, hidden_size) * 0.1
        self.Ui = np.random.randn(hidden_size, hidden_size) * 0.1
        self.bi = np.zeros((1, hidden_size))
        
        self.Wc = np.random.randn(input_size, hidden_size) * 0.1
        self.Uc = np.random.randn(hidden_size, hidden_size) * 0.1
        self.bc = np.zeros((1, hidden_size))
        
        self.Wo = np.random.randn(input_size, hidden_size) * 0.1
        self.Uo = np.random.randn(hidden_size, hidden_size) * 0.1
        self.bo = np.zeros((1, hidden_size))
    
    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))
    
    def forward(self, x_sequence):
        seq_length = len(x_sequence)
        h = np.zeros((seq_length + 1, self.hidden_size))
        c = np.zeros((seq_length + 1, self.hidden_size))
        f_s, i_s, o_s = [], [], []
        for t in range(seq_length):
            x_t = x_sequence[t:t+1]
            h_prev = h[t:t+1]
            c_prev = c[t:t+1]
            f_t = self.sigmoid(x_t @ self.Wf + h_prev @ self.Uf + self.bf)
            i_t = self.sigmoid(x_t @ self.Wi + h_prev @ self.Ui + self.bi)
            c_hat_t = np.tanh(x_t @ self.Wc + h_prev @ self.Uc + self.bc)
            c[t+1] = f_t * c_prev + i_t * c_hat_t
            o_t = self.sigmoid(x_t @ self.Wo + h_prev @ self.Uo + self.bo)
            h[t+1] = o_t * np.tanh(c[t+1])
            f_s.append(f_t.squeeze())
            i_s.append(i_t.squeeze())
            o_s.append(o_t.squeeze())
        return h[1:], c[1:], np.array(f_s), np.array(i_s), np.array(o_s)

# Example sequence
seq_length = 15
input_size = 2
hidden_size = 3
x_sequence = np.zeros((seq_length, input_size))
x_sequence[:, 0] = np.sin(np.linspace(0, 4*np.pi, seq_length))
x_sequence[:, 1] = np.cos(np.linspace(0, 3*np.pi, seq_length))
lstm = SimpleLSTM(input_size, hidden_size)
h, c, f_s, i_s, o_s = lstm.forward(x_sequence)

# Visualization
fig = make_subplots(rows=3, cols=1, subplot_titles=(
    '<b>LSTM Hidden States Over Time</b>',
    '<b>LSTM Cell States Over Time</b>',
    '<b>LSTM Gate Activations (Forget, Input, Output)</b>'),
    vertical_spacing=0.12, row_heights=[0.33, 0.33, 0.34])

# Hidden states
for i in range(hidden_size):
    fig.add_trace(go.Scatter(x=list(range(seq_length)), y=h[:, i], mode='lines+markers',
                             name=f'hidden {i+1}', legendgroup='h', showlegend=(i==0)), row=1, col=1)
# Cell states
for i in range(hidden_size):
    fig.add_trace(go.Scatter(x=list(range(seq_length)), y=c[:, i], mode='lines+markers',
                             name=f'cell {i+1}', legendgroup='c', showlegend=(i==0)), row=2, col=1)
# Gates (average across hidden units)
fig.add_trace(go.Scatter(x=list(range(seq_length)), y=f_s.mean(axis=1), mode='lines+markers',
                         name='Forget Gate (avg)', line=dict(color='blue')), row=3, col=1)
fig.add_trace(go.Scatter(x=list(range(seq_length)), y=i_s.mean(axis=1), mode='lines+markers',
                         name='Input Gate (avg)', line=dict(color='green')), row=3, col=1)
fig.add_trace(go.Scatter(x=list(range(seq_length)), y=o_s.mean(axis=1), mode='lines+markers',
                         name='Output Gate (avg)', line=dict(color='red')), row=3, col=1)

fig.update_layout(height=900, width=900, title='<b>LSTM Dynamics Visualization</b>',
                  legend=dict(orientation='h', yanchor='bottom', y=-0.15, xanchor='center', x=0.5),
                  margin=dict(l=60, r=30, t=100, b=100), hovermode='closest')
fig.update_xaxes(title_text='Time Step', row=1, col=1)
fig.update_xaxes(title_text='Time Step', row=2, col=1)
fig.update_xaxes(title_text='Time Step', row=3, col=1)
fig.update_yaxes(title_text='Hidden State', row=1, col=1)
fig.update_yaxes(title_text='Cell State', row=2, col=1)
fig.update_yaxes(title_text='Gate Activation', row=3, col=1)
fig.show()


# Deep Dive: Interactive Visualization of Vanishing/Exploding Gradients and Activation Functions

(See interactive dashboard below for details.)

# Transformers: The Basics
Transformers are a foundational architecture in modern deep learning, especially for natural language processing. Unlike RNNs, Transformers process all tokens in a sequence simultaneously, using self-attention to model relationships between tokens. However, this parallelism means they lack an inherent sense of order, making positional encoding essential.

## Self-Attention and Permutation Invariance
The self-attention mechanism allows each token to attend to every other token in the sequence. However, without positional information, the model is permutation-invariant: shuffling the input tokens does not change the attention scores. This is problematic for language, where word order matters.

## Positional Encoding: The Traditional Approach
To address the lack of order, the original Transformer paper introduced positional encodings—vectors added to token embeddings to inject position information. These encodings are typically sinusoidal functions of the position index, allowing the model to distinguish between different positions in the sequence.

## Traditional Positional Encoding: A Deeper Dive

In the original Transformer paper "Attention Is All You Need," the authors faced a critical challenge: how to inject information about token positions into a model that processes all tokens simultaneously. Their solution was elegantly simple yet effective.

### Fixed Sinusoidal Positional Encoding

Unlike token embeddings which are learned during training, positional encodings were **fixed** using carefully designed sinusoidal functions:

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

Where:
- $pos$ is the position in the sequence (0, 1, 2, ...)
- $i$ is the dimension index (paired across even/odd dimensions)
- $d_{model}$ is the embedding dimension

### Why This Approach Works

1. **Unique positional signature**: Each position gets a unique vector through the $pos$ term

2. **Multi-scale representation**: Different dimensions oscillate at different frequencies:
   - Early dimensions change rapidly with position (high frequency)
   - Later dimensions change slowly (low frequency)
   - This creates a multi-resolution representation of position

3. **Extrapolation**: The sinusoidal pattern allows the model to potentially handle sequences longer than those seen during training

4. **Zero additional parameters**: Unlike learned position embeddings, this approach requires no additional trainable weights

5. **Relative position encoding**: For any fixed offset $k$, $PE_{(pos+k)}$ can be expressed as a linear function of $PE_{(pos)}$, making it easier for the model to learn to attend by relative positions

### How It's Applied

The positional encoding is simply added to the token embedding before the first layer of the Transformer:

$$Input\_Representation = Token\_Embedding + Positional\_Encoding$$

This element-wise addition combines semantic information (from token embeddings) with positional information (from positional encodings) into a single representation that preserves the dimensionality of the original embeddings.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def generate_traditional_positional_encoding(max_seq_length=100, d_model=64):
    """Generate traditional sinusoidal positional encodings from the Transformer paper"""
    position = torch.arange(max_seq_length).unsqueeze(1).float()
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
    
    pos_encoding = torch.zeros(max_seq_length, d_model)
    pos_encoding[:, 0::2] = torch.sin(position * div_term)
    pos_encoding[:, 1::2] = torch.cos(position * div_term)
    
    return pos_encoding

# Compare positional encoding at different positions
def compare_pe_at_positions(pe_matrix, positions=[0, 1, 10, 50], d_model=64):
    fig = go.Figure()
    
    for pos in positions:
        if pos < pe_matrix.shape[0]:
            fig.add_trace(go.Scatter(
                x=list(range(d_model)),
                y=pe_matrix[pos].numpy(),
                mode='lines',
                name=f'Position {pos}'
            ))
    
    fig.update_layout(
        title='Positional Encoding Vectors at Different Positions',
        xaxis_title='Embedding Dimension',
        yaxis_title='Encoding Value',
        legend_title='Position',
        width=800,
        height=500
    )
    return fig

# Generate the positional encodings
pe_matrix = generate_traditional_positional_encoding(max_seq_length=100, d_model=64)

# Compare specific positions to see how they differ
compare_pe_at_positions(pe_matrix)

# Rotary Position Embedding (RoPE): A Deeper Dive
RoPE is an elegant method for encoding positional information in Transformers. Instead of adding position vectors, RoPE rotates the query and key vectors in a position-dependent way, enabling the model to capture both absolute and relative positional information.

## From Traditional Positional Encoding to RoPE

While traditional positional encodings were a breakthrough, Rotary Position Embedding (RoPE) represents an evolution that addresses some of their limitations.

### Key Differences and Improvements

1. **Integration Method**:
   - **Traditional**: Adds positional vectors to token embeddings before attention
   - **RoPE**: Applies rotations directly to queries and keys during attention computation

2. **Relative Position Modeling**:
   - **Traditional**: Captures relative position indirectly and imperfectly
   - **RoPE**: Explicitly encodes relative positions through dot product properties

3. **Extrapolation to Longer Sequences**:
   - **Traditional**: Limited extrapolation ability
   - **RoPE**: Better generalization to sequences longer than training data

4. **Mathematical Properties**:
   - **Traditional**: Addition can change vector norms and affect attention score distributions
   - **RoPE**: Rotation preserves vector norms and maintains stable attention distributions

5. **Implementation**:
   - **Traditional**: Applied once at embedding layer
   - **RoPE**: Applied at each attention layer, directly to Q and K matrices

Both approaches use sinusoidal functions at their core, but RoPE applies them as rotation matrices rather than additive vectors. This seemingly small change leads to significantly improved properties for attention-based language models, especially when handling long sequences and relative positional relationships.

## The Core Mechanism: Rotation in 2D Space
RoPE operates by pairing the dimensions of the query and key vectors and applying a 2D rotation to each pair. For a d-dimensional vector, there are d/2 pairs. The rotation angle for each pair depends on the position and the pair index, introducing multiple frequencies into the encoding.

In [None]:
import numpy as np

def rotate_pair(x_j, x_j1, theta):
    """Apply 2D rotation to a pair of values."""
    x_j_new = x_j * np.cos(theta) - x_j1 * np.sin(theta)
    x_j1_new = x_j * np.sin(theta) + x_j1 * np.cos(theta)
    return x_j_new, x_j1_new

# Example: rotate (1, 0) by 45 degrees
x_j, x_j1 = 1.0, 0.0
theta = np.pi / 4
rotate_pair(x_j, x_j1, theta)

(np.float64(0.7071067811865476), np.float64(0.7071067811865475))

## Multi-Frequency Encoding: Local and Global Context
Each pair of dimensions is rotated by a different frequency, determined by θ_i = 1 / (base^(2i/d)). Early pairs (small i) rotate quickly, capturing local relationships. Later pairs (large i) rotate slowly, capturing global context. This multi-scale approach is more powerful than using a single frequency.

In [None]:
import plotly.graph_objs as go
import numpy as np

positions = np.arange(0, 50)
d = 8  # example dimensionality
base = 10000

traces = []
for i in range(d // 2):
    theta_i = 1 / (base ** (2 * i / d))
    angles = positions * theta_i
    x = np.cos(angles)
    y = np.sin(angles)
    traces.append(go.Scatter(x=x, y=y, mode='lines', name=f'Pair {i}'))

layout = go.Layout(title='RoPE: Rotation Trajectories for Different Dimension Pairs',
                   xaxis=dict(title='cos(mθ_i)'),
                   yaxis=dict(title='sin(mθ_i)'))
fig = go.Figure(data=traces, layout=layout)
fig.show()

## Relative Position from Absolute Rotations
A key property of RoPE is that, after rotation, the dot product between a query at position m and a key at position n depends only on their relative distance (m - n). This means the attention mechanism becomes sensitive to relative positions, which is crucial for modeling language context.

In [None]:
def rope_dot_product(q, k, m, n, base=10000):
    d = len(q)
    result = 0.0
    for i in range(0, d, 2):
        theta_i = 1 / (base ** (2 * (i//2) / d))
        angle_m = m * theta_i
        angle_n = n * theta_i
        # Rotate q and k
        q_rot = np.array([q[i] * np.cos(angle_m) - q[i+1] * np.sin(angle_m),
                          q[i] * np.sin(angle_m) + q[i+1] * np.cos(angle_m)])
        k_rot = np.array([k[i] * np.cos(angle_n) - k[i+1] * np.sin(angle_n),
                          k[i] * np.sin(angle_n) + k[i+1] * np.cos(angle_n)])
        result += np.dot(q_rot, k_rot)
    return result

# Example: dot product for different relative positions
q = np.random.randn(8)
k = np.random.randn(8)
rel_positions = np.arange(-10, 11)
dots = [rope_dot_product(q, k, m=0, n=rp) for rp in rel_positions]

import plotly.express as px
fig = px.line(x=rel_positions, y=dots, labels={'x':'Relative Position (n)', 'y':'Dot Product'},
              title='RoPE Dot Product vs. Relative Position')
fig.show()

## Understanding RoPE Dot Product vs. Relative Position

This graph illustrates one of the most remarkable properties of Rotary Position Embedding: how the dot product between query and key vectors varies based on their relative position, not their absolute positions.

**What this visualization shows:**

- **Dot product magnitude varies with relative distance**: The y-axis shows the dot product value between a query at position 0 and keys at various relative positions (x-axis). This demonstrates how attention scores in RoPE naturally depend on how far apart tokens are.

- **Symmetry around zero**: Notice how the dot product pattern is roughly symmetric around relative position zero. This means that the attention mechanism treats tokens at equal distances before and after the current position similarly, while still being able to distinguish direction.

- **Periodic patterns**: The dot product shows a complex periodic pattern resulting from the multiple rotation frequencies used in different dimension pairs. This rich signal helps the model learn nuanced relationships between tokens based on their spacing.

- **Graceful decay**: As the relative distance increases in either direction, the dot product typically decreases, aligning with the intuition that nearby tokens are often more relevant than distant ones.

This property is crucial for language modeling because it allows the attention mechanism to focus on contextually relevant tokens based on their relative positions, regardless of where they appear in the absolute sequence.

# The "Wow" Factors of RoPE
- **Relative from Absolute**: Encodes absolute positions via rotation, but attention scores depend on relative positions.
- **Multi-Scale Positional Information**: Different rotation frequencies capture both local and global context.
- **Excellent Extrapolation**: Generalizes well to longer sequences due to the periodic nature of rotations.
- **No Extra Learnable Parameters**: RoPE uses fixed transformations, keeping the model efficient.
- **Deep Integration**: RoPE is applied directly to Q and K vectors, making it a core part of the attention mechanism.

# Transformer Implementation with PyTorch

In this section, we'll implement a complete Transformer model with both traditional sinusoidal positional encoding and Rotary Position Embedding (RoPE). We'll use this to demonstrate how these positional encodings work in practice for next word prediction tasks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Traditional Positional Encoding

First, let's implement the original sinusoidal positional encoding from the "Attention Is All You Need" paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x has shape [batch_size, seq_len, embedding_dim]
        # Add positional encoding to embeddings
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize the traditional positional encoding
def visualize_positional_encoding(model_dim=64, seq_length=100):
    pos_enc = PositionalEncoding(model_dim)
    dummy_input = torch.zeros(1, seq_length, model_dim)
    encoded = pos_enc(dummy_input)[0]  # Remove batch dimension
    
    # Create a heatmap with plotly
    fig = px.imshow(encoded.detach().numpy(),
                    labels=dict(x="Dimension", y="Position", color="Value"),
                    title="Traditional Sinusoidal Positional Encoding",
                    color_continuous_scale="RdBu_r",
                    zmin=-1, zmax=1)
    fig.update_layout(width=800, height=500)
    return fig

# Display the visualization
visualize_positional_encoding()

## Understanding Traditional Positional Encoding Visualization

The heatmap above visualizes the traditional sinusoidal positional encoding used in the original Transformer architecture. This encoding is crucial because Transformers process tokens in parallel, with no inherent sense of their order in the sequence.

**Key observations:**

- **Vertical patterns**: Each column represents a dimension in the embedding space. Note how some dimensions change rapidly (high-frequency components) while others change slowly (low-frequency components).
  
- **Position uniqueness**: Each row represents a unique position encoding vector, ensuring that each position in a sequence gets a distinctive representation.

- **Sinusoidal pattern**: The encoding uses sine and cosine functions with different frequencies, creating the wave-like patterns you see. This approach allows the model to generalize to sequence lengths it hasn't seen during training.

- **Color variations**: Blue represents negative values, red represents positive values, with white near zero. These variations help the model distinguish between different positions.

This encoding is added directly to token embeddings before they're processed by the attention mechanism. However, this approach has limitations in capturing relative positions effectively, which is addressed by RoPE.

## Rotary Position Embedding (RoPE)

Now, let's implement RoPE, which applies rotation to query and key vectors in self-attention rather than adding positional vectors to the embeddings.

In [None]:
class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding implementation."""
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        # Create position indices
        position = torch.arange(seq_len, device=self.inv_freq.device).float()
        # Compute angles for each position and frequency
        angles = position.unsqueeze(1) * self.inv_freq.unsqueeze(0)
        # Return sines and cosines for easier rotation
        return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

    def _apply_rotary_pos_emb(self, x, cos, sin):
        # x shape: [batch_size, seq_len, dim]
        # Reshape x for easier rotation
        x_reshape = x.reshape(*x.shape[:-1], -1, 2)
        x1, x2 = x_reshape.unbind(-1)  # Split into pairs
        
        # Apply rotation using the rotation matrix [cos -sin; sin cos]
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
        
        # Stack and reshape back
        return torch.stack([out1, out2], dim=-1).flatten(-2)

    def apply_rotary_embedding(self, q, k, seq_len):
        # Get rotation angles
        cos_sin = self.forward(seq_len)
        cos, sin = cos_sin[..., 0], cos_sin[..., 1]
        
        # Reshape for broadcasting
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        
        # Apply rotations to q and k
        q_rotated = self._apply_rotary_pos_emb(q, cos, sin)
        k_rotated = self._apply_rotary_pos_emb(k, cos, sin)
        
        return q_rotated, k_rotated

# Function to visualize rotation trajectories in RoPE
def visualize_rope_trajectories(dim=64, seq_length=50, base=10000):
    # Compute frequencies
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(seq_length)
    
    # Create subplots for different frequency pairs
    pairs_to_plot = [0, dim//8, dim//4, 3*dim//8]  # Plot 4 different frequency pairs
    fig = make_subplots(rows=2, cols=2, 
                       subplot_titles=[f"Dimension Pair {i}" for i in pairs_to_plot])
    
    for i, pair_idx in enumerate(pairs_to_plot):
        # Calculate row and column for subplot
        row = i // 2 + 1
        col = i % 2 + 1
        
        # Get frequency for this pair
        freq = freqs[pair_idx].item()
        
        # Compute rotation angles
        angles = positions.float() * freq
        
        # Compute x, y coordinates on the circle
        x = torch.cos(angles)
        y = torch.sin(angles)
        
        # Create a color gradient based on position
        colors = positions.numpy()
        
        # Add scatter trace to subplot
        fig.add_trace(
            go.Scatter(x=x, y=y, mode='lines+markers',
                      marker=dict(color=colors, colorscale='Viridis', size=8),
                      line=dict(color='rgba(150,150,150,0.3)'),
                      text=[f"Pos {p}" for p in positions],
                      hoverinfo='text',
                      name=f"Pair {pair_idx}"),
            row=row, col=col
        )
        
        # Add annotations for specific positions
        special_pos = [0, 10, 20, 30, 40]
        for pos in special_pos:
            if pos < seq_length:
                fig.add_annotation(
                    x=x[pos].item(), y=y[pos].item(),
                    text=str(pos),
                    showarrow=True, arrowhead=2,
                    arrowsize=1, arrowwidth=1, ax=0, ay=-20,
                    row=row, col=col
                )
    
    # Update layout
    fig.update_layout(
        title="RoPE: Rotation Trajectories for Different Dimension Pairs",
        height=700, width=900,
        showlegend=False,
    )
    
    # Update axis properties for all subplots
    for i in range(1, 3):
        for j in range(1, 3):
            fig.update_xaxes(title_text="cos(mθ)", range=[-1.1, 1.1], row=i, col=j)
            fig.update_yaxes(title_text="sin(mθ)", range=[-1.1, 1.1], row=i, col=j)
    
    return fig

# Display the visualization
visualize_rope_trajectories()

## Understanding RoPE Rotation Trajectories

This visualization demonstrates how Rotary Position Embedding (RoPE) operates by rotating vector pairs in different dimensions at various frequencies. Unlike traditional positional encoding that adds position vectors to embeddings, RoPE directly rotates query and key vectors in the attention mechanism.

**Key insights from the visualization:**

- **Multiple rotation frequencies**: Each subplot shows how a different dimension pair rotates at its own frequency. Early pairs (Dimension Pair 0) rotate quickly with position changes, while later pairs (Dimension Pair 24, 32, 48) rotate more slowly.

- **Position mapping to angles**: Each colored dot represents a position in the sequence, mapped to a specific angle on the unit circle. The color gradient helps track how positions (0, 10, 20, etc.) move around the circle.

- **Multi-scale representation**: The varying rotation speeds create a multi-scale representation of position. Fast-rotating pairs capture local, fine-grained positional relationships, while slow-rotating pairs capture global context.

- **Circular nature**: The circular trajectories ensure that the encoding is bounded and well-behaved numerically, avoiding the exploding or vanishing issues that can occur with some other position encodings.

By using these rotations instead of additive encodings, RoPE elegantly integrates positional information into the attention mechanism while preserving the ability to model relative positions effectively.

## Multi-Head Attention Implementation

Now let's implement the multi-head attention mechanism with both standard positional encoding and RoPE variants.

In [None]:
import numpy as np

def rotate_pair(x_j, x_j1, theta):
    """Apply 2D rotation to a pair of values."""
    x_j_new = x_j * np.cos(theta) - x_j1 * np.sin(theta)
    x_j1_new = x_j * np.sin(theta) + x_j1 * np.cos(theta)
    return x_j_new, x_j1_new

# Example: rotate (1, 0) by 45 degrees
x_j, x_j1 = 1.0, 0.0
theta = np.pi / 4
rotate_pair(x_j, x_j1, theta)

## Multi-Frequency Encoding: Local and Global Context
Each pair of dimensions is rotated by a different frequency, determined by θ_i = 1 / (base^(2i/d)). Early pairs (small i) rotate quickly, capturing local relationships. Later pairs (large i) rotate slowly, capturing global context. This multi-scale approach is more powerful than using a single frequency.

In [None]:
import plotly.graph_objs as go
import numpy as np

positions = np.arange(0, 50)
d = 8  # example dimensionality
base = 10000

traces = []
for i in range(d // 2):
    theta_i = 1 / (base ** (2 * i / d))
    angles = positions * theta_i
    x = np.cos(angles)
    y = np.sin(angles)
    traces.append(go.Scatter(x=x, y=y, mode='lines', name=f'Pair {i}'))

layout = go.Layout(title='RoPE: Rotation Trajectories for Different Dimension Pairs',
                   xaxis=dict(title='cos(mθ_i)'),
                   yaxis=dict(title='sin(mθ_i)'))
fig = go.Figure(data=traces, layout=layout)
fig.show()

## Relative Position from Absolute Rotations
A key property of RoPE is that, after rotation, the dot product between a query at position m and a key at position n depends only on their relative distance (m - n). This means the attention mechanism becomes sensitive to relative positions, which is crucial for modeling language context.

In [None]:
def rope_dot_product(q, k, m, n, base=10000):
    d = len(q)
    result = 0.0
    for i in range(0, d, 2):
        theta_i = 1 / (base ** (2 * (i//2) / d))
        angle_m = m * theta_i
        angle_n = n * theta_i
        # Rotate q and k
        q_rot = np.array([q[i] * np.cos(angle_m) - q[i+1] * np.sin(angle_m),
                          q[i] * np.sin(angle_m) + q[i+1] * np.cos(angle_m)])
        k_rot = np.array([k[i] * np.cos(angle_n) - k[i+1] * np.sin(angle_n),
                          k[i] * np.sin(angle_n) + k[i+1] * np.cos(angle_n)])
        result += np.dot(q_rot, k_rot)
    return result

# Example: dot product for different relative positions
q = np.random.randn(8)
k = np.random.randn(8)
rel_positions = np.arange(-10, 11)
dots = [rope_dot_product(q, k, m=0, n=rp) for rp in rel_positions]

import plotly.express as px
fig = px.line(x=rel_positions, y=dots, labels={'x':'Relative Position (n)', 'y':'Dot Product'},
              title='RoPE Dot Product vs. Relative Position')
fig.show()

## Understanding RoPE Dot Product vs. Relative Position

This graph illustrates one of the most remarkable properties of Rotary Position Embedding: how the dot product between query and key vectors varies based on their relative position, not their absolute positions.

**What this visualization shows:**

- **Dot product magnitude varies with relative distance**: The y-axis shows the dot product value between a query at position 0 and keys at various relative positions (x-axis). This demonstrates how attention scores in RoPE naturally depend on how far apart tokens are.

- **Symmetry around zero**: Notice how the dot product pattern is roughly symmetric around relative position zero. This means that the attention mechanism treats tokens at equal distances before and after the current position similarly, while still being able to distinguish direction.

- **Periodic patterns**: The dot product shows a complex periodic pattern resulting from the multiple rotation frequencies used in different dimension pairs. This rich signal helps the model learn nuanced relationships between tokens based on their spacing.

- **Graceful decay**: As the relative distance increases in either direction, the dot product typically decreases, aligning with the intuition that nearby tokens are often more relevant than distant ones.

This property is crucial for language modeling because it allows the attention mechanism to focus on contextually relevant tokens based on their relative positions, regardless of where they appear in the absolute sequence.

# The "Wow" Factors of RoPE
- **Relative from Absolute**: Encodes absolute positions via rotation, but attention scores depend on relative positions.
- **Multi-Scale Positional Information**: Different rotation frequencies capture both local and global context.
- **Excellent Extrapolation**: Generalizes well to longer sequences due to the periodic nature of rotations.
- **No Extra Learnable Parameters**: RoPE uses fixed transformations, keeping the model efficient.
- **Deep Integration**: RoPE is applied directly to Q and K vectors, making it a core part of the attention mechanism.

# Transformer Implementation with PyTorch

In this section, we'll implement a complete Transformer model with both traditional sinusoidal positional encoding and Rotary Position Embedding (RoPE). We'll use this to demonstrate how these positional encodings work in practice for next word prediction tasks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Traditional Positional Encoding

First, let's implement the original sinusoidal positional encoding from the "Attention Is All You Need" paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x has shape [batch_size, seq_len, embedding_dim]
        # Add positional encoding to embeddings
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize the traditional positional encoding
def visualize_positional_encoding(model_dim=64, seq_length=100):
    pos_enc = PositionalEncoding(model_dim)
    dummy_input = torch.zeros(1, seq_length, model_dim)
    encoded = pos_enc(dummy_input)[0]  # Remove batch dimension
    
    # Create a heatmap with plotly
    fig = px.imshow(encoded.detach().numpy(),
                    labels=dict(x="Dimension", y="Position", color="Value"),
                    title="Traditional Sinusoidal Positional Encoding",
                    color_continuous_scale="RdBu_r",
                    zmin=-1, zmax=1)
    fig.update_layout(width=800, height=500)
    return fig

# Display the visualization
visualize_positional_encoding()

## Understanding Traditional Positional Encoding Visualization

The heatmap above visualizes the traditional sinusoidal positional encoding used in the original Transformer architecture. This encoding is crucial because Transformers process tokens in parallel, with no inherent sense of their order in the sequence.

**Key observations:**

- **Vertical patterns**: Each column represents a dimension in the embedding space. Note how some dimensions change rapidly (high-frequency components) while others change slowly (low-frequency components).
  
- **Position uniqueness**: Each row represents a unique position encoding vector, ensuring that each position in a sequence gets a distinctive representation.

- **Sinusoidal pattern**: The encoding uses sine and cosine functions with different frequencies, creating the wave-like patterns you see. This approach allows the model to generalize to sequence lengths it hasn't seen during training.

- **Color variations**: Blue represents negative values, red represents positive values, with white near zero. These variations help the model distinguish between different positions.

This encoding is added directly to token embeddings before they're processed by the attention mechanism. However, this approach has limitations in capturing relative positions effectively, which is addressed by RoPE.

## Rotary Position Embedding (RoPE)

Now, let's implement RoPE, which applies rotation to query and key vectors in self-attention rather than adding positional vectors to the embeddings.

In [None]:
class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding implementation."""
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        # Create position indices
        position = torch.arange(seq_len, device=self.inv_freq.device).float()
        # Compute angles for each position and frequency
        angles = position.unsqueeze(1) * self.inv_freq.unsqueeze(0)
        # Return sines and cosines for easier rotation
        return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

    def _apply_rotary_pos_emb(self, x, cos, sin):
        # x shape: [batch_size, seq_len, dim]
        # Reshape x for easier rotation
        x_reshape = x.reshape(*x.shape[:-1], -1, 2)
        x1, x2 = x_reshape.unbind(-1)  # Split into pairs
        
        # Apply rotation using the rotation matrix [cos -sin; sin cos]
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
        
        # Stack and reshape back
        return torch.stack([out1, out2], dim=-1).flatten(-2)

    def apply_rotary_embedding(self, q, k, seq_len):
        # Get rotation angles
        cos_sin = self.forward(seq_len)
        cos, sin = cos_sin[..., 0], cos_sin[..., 1]
        
        # Reshape for broadcasting
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        
        # Apply rotations to q and k
        q_rotated = self._apply_rotary_pos_emb(q, cos, sin)
        k_rotated = self._apply_rotary_pos_emb(k, cos, sin)
        
        return q_rotated, k_rotated

# Function to visualize rotation trajectories in RoPE
def visualize_rope_trajectories(dim=64, seq_length=50, base=10000):
    # Compute frequencies
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(seq_length)
    
    # Create subplots for different frequency pairs
    pairs_to_plot = [0, dim//8, dim//4, 3*dim//8]  # Plot 4 different frequency pairs
    fig = make_subplots(rows=2, cols=2, 
                       subplot_titles=[f"Dimension Pair {i}" for i in pairs_to_plot])
    
    for i, pair_idx in enumerate(pairs_to_plot):
        # Calculate row and column for subplot
        row = i // 2 + 1
        col = i % 2 + 1
        
        # Get frequency for this pair
        freq = freqs[pair_idx].item()
        
        # Compute rotation angles
        angles = positions.float() * freq
        
        # Compute x, y coordinates on the circle
        x = torch.cos(angles)
        y = torch.sin(angles)
        
        # Create a color gradient based on position
        colors = positions.numpy()
        
        # Add scatter trace to subplot
        fig.add_trace(
            go.Scatter(x=x, y=y, mode='lines+markers',
                      marker=dict(color=colors, colorscale='Viridis', size=8),
                      line=dict(color='rgba(150,150,150,0.3)'),
                      text=[f"Pos {p}" for p in positions],
                      hoverinfo='text',
                      name=f"Pair {pair_idx}"),
            row=row, col=col
        )
        
        # Add annotations for specific positions
        special_pos = [0, 10, 20, 30, 40]
        for pos in special_pos:
            if pos < seq_length:
                fig.add_annotation(
                    x=x[pos].item(), y=y[pos].item(),
                    text=str(pos),
                    showarrow=True, arrowhead=2,
                    arrowsize=1, arrowwidth=1, ax=0, ay=-20,
                    row=row, col=col
                )
    
    # Update layout
    fig.update_layout(
        title="RoPE: Rotation Trajectories for Different Dimension Pairs",
        height=700, width=900,
        showlegend=False,
    )
    
    # Update axis properties for all subplots
    for i in range(1, 3):
        for j in range(1, 3):
            fig.update_xaxes(title_text="cos(mθ)", range=[-1.1, 1.1], row=i, col=j)
            fig.update_yaxes(title_text="sin(mθ)", range=[-1.1, 1.1], row=i, col=j)
    
    return fig

# Display the visualization
visualize_rope_trajectories()

## Understanding RoPE Rotation Trajectories

This visualization demonstrates how Rotary Position Embedding (RoPE) operates by rotating vector pairs in different dimensions at various frequencies. Unlike traditional positional encoding that adds position vectors to embeddings, RoPE directly rotates query and key vectors in the attention mechanism.

**Key insights from the visualization:**

- **Multiple rotation frequencies**: Each subplot shows how a different dimension pair rotates at its own frequency. Early pairs (Dimension Pair 0) rotate quickly with position changes, while later pairs (Dimension Pair 24, 32, 48) rotate more slowly.

- **Position mapping to angles**: Each colored dot represents a position in the sequence, mapped to a specific angle on the unit circle. The color gradient helps track how positions (0, 10, 20, etc.) move around the circle.

- **Multi-scale representation**: The varying rotation speeds create a multi-scale representation of position. Fast-rotating pairs capture local, fine-grained positional relationships, while slow-rotating pairs capture global context.

- **Circular nature**: The circular trajectories ensure that the encoding is bounded and well-behaved numerically, avoiding the exploding or vanishing issues that can occur with some other position encodings.

By using these rotations instead of additive encodings, RoPE elegantly integrates positional information into the attention mechanism while preserving the ability to model relative positions effectively.

## The Core Mechanism: Rotation in 2D Space
RoPE operates by pairing the dimensions of the query and key vectors and applying a 2D rotation to each pair. For a d-dimensional vector, there are d/2 pairs. The rotation angle for each pair depends on the position and the pair index, introducing multiple frequencies into the encoding.

In [None]:
import numpy as np

def rotate_pair(x_j, x_j1, theta):
    """Apply 2D rotation to a pair of values."""
    x_j_new = x_j * np.cos(theta) - x_j1 * np.sin(theta)
    x_j1_new = x_j * np.sin(theta) + x_j1 * np.cos(theta)
    return x_j_new, x_j1_new

# Example: rotate (1, 0) by 45 degrees
x_j, x_j1 = 1.0, 0.0
theta = np.pi / 4
rotate_pair(x_j, x_j1, theta)

## Multi-Frequency Encoding: Local and Global Context
Each pair of dimensions is rotated by a different frequency, determined by θ_i = 1 / (base^(2i/d)). Early pairs (small i) rotate quickly, capturing local relationships. Later pairs (large i) rotate slowly, capturing global context. This multi-scale approach is more powerful than using a single frequency.

In [None]:
import plotly.graph_objs as go
import numpy as np

positions = np.arange(0, 50)
d = 8  # example dimensionality
base = 10000

traces = []
for i in range(d // 2):
    theta_i = 1 / (base ** (2 * i / d))
    angles = positions * theta_i
    x = np.cos(angles)
    y = np.sin(angles)
    traces.append(go.Scatter(x=x, y=y, mode='lines', name=f'Pair {i}'))

layout = go.Layout(title='RoPE: Rotation Trajectories for Different Dimension Pairs',
                   xaxis=dict(title='cos(mθ_i)'),
                   yaxis=dict(title='sin(mθ_i)'))
fig = go.Figure(data=traces, layout=layout)
fig.show()

## Relative Position from Absolute Rotations
A key property of RoPE is that, after rotation, the dot product between a query at position m and a key at position n depends only on their relative distance (m - n). This means the attention mechanism becomes sensitive to relative positions, which is crucial for modeling language context.

In [None]:
def rope_dot_product(q, k, m, n, base=10000):
    d = len(q)
    result = 0.0
    for i in range(0, d, 2):
        theta_i = 1 / (base ** (2 * (i//2) / d))
        angle_m = m * theta_i
        angle_n = n * theta_i
        # Rotate q and k
        q_rot = np.array([q[i] * np.cos(angle_m) - q[i+1] * np.sin(angle_m),
                          q[i] * np.sin(angle_m) + q[i+1] * np.cos(angle_m)])
        k_rot = np.array([k[i] * np.cos(angle_n) - k[i+1] * np.sin(angle_n),
                          k[i] * np.sin(angle_n) + k[i+1] * np.cos(angle_n)])
        result += np.dot(q_rot, k_rot)
    return result

# Example: dot product for different relative positions
q = np.random.randn(8)
k = np.random.randn(8)
rel_positions = np.arange(-10, 11)
dots = [rope_dot_product(q, k, m=0, n=rp) for rp in rel_positions]

import plotly.express as px
fig = px.line(x=rel_positions, y=dots, labels={'x':'Relative Position (n)', 'y':'Dot Product'},
              title='RoPE Dot Product vs. Relative Position')
fig.show()

## Understanding RoPE Dot Product vs. Relative Position

This graph illustrates one of the most remarkable properties of Rotary Position Embedding: how the dot product between query and key vectors varies based on their relative position, not their absolute positions.

**What this visualization shows:**

- **Dot product magnitude varies with relative distance**: The y-axis shows the dot product value between a query at position 0 and keys at various relative positions (x-axis). This demonstrates how attention scores in RoPE naturally depend on how far apart tokens are.

- **Symmetry around zero**: Notice how the dot product pattern is roughly symmetric around relative position zero. This means that the attention mechanism treats tokens at equal distances before and after the current position similarly, while still being able to distinguish direction.

- **Periodic patterns**: The dot product shows a complex periodic pattern resulting from the multiple rotation frequencies used in different dimension pairs. This rich signal helps the model learn nuanced relationships between tokens based on their spacing.

- **Graceful decay**: As the relative distance increases in either direction, the dot product typically decreases, aligning with the intuition that nearby tokens are often more relevant than distant ones.

This property is crucial for language modeling because it allows the attention mechanism to focus on contextually relevant tokens based on their relative positions, regardless of where they appear in the absolute sequence.

# The "Wow" Factors of RoPE
- **Relative from Absolute**: Encodes absolute positions via rotation, but attention scores depend on relative positions.
- **Multi-Scale Positional Information**: Different rotation frequencies capture both local and global context.
- **Excellent Extrapolation**: Generalizes well to longer sequences due to the periodic nature of rotations.
- **No Extra Learnable Parameters**: RoPE uses fixed transformations, keeping the model efficient.
- **Deep Integration**: RoPE is applied directly to Q and K vectors, making it a core part of the attention mechanism.

# Transformer Implementation with PyTorch

In this section, we'll implement a complete Transformer model with both traditional sinusoidal positional encoding and Rotary Position Embedding (RoPE). We'll use this to demonstrate how these positional encodings work in practice for next word prediction tasks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Traditional Positional Encoding

First, let's implement the original sinusoidal positional encoding from the "Attention Is All You Need" paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x has shape [batch_size, seq_len, embedding_dim]
        # Add positional encoding to embeddings
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize the traditional positional encoding
def visualize_positional_encoding(model_dim=64, seq_length=100):
    pos_enc = PositionalEncoding(model_dim)
    dummy_input = torch.zeros(1, seq_length, model_dim)
    encoded = pos_enc(dummy_input)[0]  # Remove batch dimension
    
    # Create a heatmap with plotly
    fig = px.imshow(encoded.detach().numpy(),
                    labels=dict(x="Dimension", y="Position", color="Value"),
                    title="Traditional Sinusoidal Positional Encoding",
                    color_continuous_scale="RdBu_r",
                    zmin=-1, zmax=1)
    fig.update_layout(width=800, height=500)
    return fig

# Display the visualization
visualize_positional_encoding()

## Understanding Traditional Positional Encoding Visualization

The heatmap above visualizes the traditional sinusoidal positional encoding used in the original Transformer architecture. This encoding is crucial because Transformers process tokens in parallel, with no inherent sense of their order in the sequence.

**Key observations:**

- **Vertical patterns**: Each column represents a dimension in the embedding space. Note how some dimensions change rapidly (high-frequency components) while others change slowly (low-frequency components).
  
- **Position uniqueness**: Each row represents a unique position encoding vector, ensuring that each position in a sequence gets a distinctive representation.

- **Sinusoidal pattern**: The encoding uses sine and cosine functions with different frequencies, creating the wave-like patterns you see. This approach allows the model to generalize to sequence lengths it hasn't seen during training.

- **Color variations**: Blue represents negative values, red represents positive values, with white near zero. These variations help the model distinguish between different positions.

This encoding is added directly to token embeddings before they're processed by the attention mechanism. However, this approach has limitations in capturing relative positions effectively, which is addressed by RoPE.

## Rotary Position Embedding (RoPE)

Now, let's implement RoPE, which applies rotation to query and key vectors in self-attention rather than adding positional vectors to the embeddings.

In [None]:
class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding implementation."""
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        # Create position indices
        position = torch.arange(seq_len, device=self.inv_freq.device).float()
        # Compute angles for each position and frequency
        angles = position.unsqueeze(1) * self.inv_freq.unsqueeze(0)
        # Return sines and cosines for easier rotation
        return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

    def _apply_rotary_pos_emb(self, x, cos, sin):
        # x shape: [batch_size, seq_len, dim]
        # Reshape x for easier rotation
        x_reshape = x.reshape(*x.shape[:-1], -1, 2)
        x1, x2 = x_reshape.unbind(-1)  # Split into pairs
        
        # Apply rotation using the rotation matrix [cos -sin; sin cos]
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
        
        # Stack and reshape back
        return torch.stack([out1, out2], dim=-1).flatten(-2)

    def apply_rotary_embedding(self, q, k, seq_len):
        # Get rotation angles
        cos_sin = self.forward(seq_len)
        cos, sin = cos_sin[..., 0], cos_sin[..., 1]
        
        # Reshape for broadcasting
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        
        # Apply rotations to q and k
        q_rotated = self._apply_rotary_pos_emb(q, cos, sin)
        k_rotated = self._apply_rotary_pos_emb(k, cos, sin)
        
        return q_rotated, k_rotated

# Function to visualize rotation trajectories in RoPE
def visualize_rope_trajectories(dim=64, seq_length=50, base=10000):
    # Compute frequencies
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(seq_length)
    
    # Create subplots for different frequency pairs
    pairs_to_plot = [0, dim//8, dim//4, 3*dim//8]  # Plot 4 different frequency pairs
    fig = make_subplots(rows=2, cols=2, 
                       subplot_titles=[f"Dimension Pair {i}" for i in pairs_to_plot])
    
    for i, pair_idx in enumerate(pairs_to_plot):
        # Calculate row and column for subplot
        row = i // 2 + 1
        col = i % 2 + 1
        
        # Get frequency for this pair
        freq = freqs[pair_idx].item()
        
        # Compute rotation angles
        angles = positions.float() * freq
        
        # Compute x, y coordinates on the circle
        x = torch.cos(angles)
        y = torch.sin(angles)
        
        # Create a color gradient based on position
        colors = positions.numpy()
        
        # Add scatter trace to subplot
        fig.add_trace(
            go.Scatter(x=x, y=y, mode='lines+markers',
                      marker=dict(color=colors, colorscale='Viridis', size=8),
                      line=dict(color='rgba(150,150,150,0.3)'),
                      text=[f"Pos {p}" for p in positions],
                      hoverinfo='text',
                      name=f"Pair {pair_idx}"),
            row=row, col=col
        )
        
        # Add annotations for specific positions
        special_pos = [0, 10, 20, 30, 40]
        for pos in special_pos:
            if pos < seq_length:
                fig.add_annotation(
                    x=x[pos].item(), y=y[pos].item(),
                    text=str(pos),
                    showarrow=True, arrowhead=2,
                    arrowsize=1, arrowwidth=1, ax=0, ay=-20,
                    row=row, col=col
                )
    
    # Update layout
    fig.update_layout(
        title="RoPE: Rotation Trajectories for Different Dimension Pairs",
        height=700, width=900,
        showlegend=False,
    )
    
    # Update axis properties for all subplots
    for i in range(1, 3):
        for j in range(1, 3):
            fig.update_xaxes(title_text="cos(mθ)", range=[-1.1, 1.1], row=i, col=j)
            fig.update_yaxes(title_text="sin(mθ)", range=[-1.1, 1.1], row=i, col=j)
    
    return fig

# Display the visualization
visualize_rope_trajectories()

## Understanding RoPE Rotation Trajectories

This visualization demonstrates how Rotary Position Embedding (RoPE) operates by rotating vector pairs in different dimensions at various frequencies. Unlike traditional positional encoding that adds position vectors to embeddings, RoPE directly rotates query and key vectors in the attention mechanism.

**Key insights from the visualization:**

- **Multiple rotation frequencies**: Each subplot shows how a different dimension pair rotates at its own frequency. Early pairs (Dimension Pair 0) rotate quickly with position changes, while later pairs (Dimension Pair 24, 32, 48) rotate more slowly.

- **Position mapping to angles**: Each colored dot represents a position in the sequence, mapped to a specific angle on the unit circle. The color gradient helps track how positions (0, 10, 20, etc.) move around the circle.

- **Multi-scale representation**: The varying rotation speeds create a multi-scale representation of position. Fast-rotating pairs capture local, fine-grained positional relationships, while slow-rotating pairs capture global context.

- **Circular nature**: The circular trajectories ensure that the encoding is bounded and well-behaved numerically, avoiding the exploding or vanishing issues that can occur with some other position encodings.

By using these rotations instead of additive encodings, RoPE elegantly integrates positional information into the attention mechanism while preserving the ability to model relative positions effectively.

## The Core Mechanism: Rotation in 2D Space
RoPE operates by pairing the dimensions of the query and key vectors and applying a 2D rotation to each pair. For a d-dimensional vector, there are d/2 pairs. The rotation angle for each pair depends on the position and the pair index, introducing multiple frequencies into the encoding.

In [None]:
import numpy as np

def rotate_pair(x_j, x_j1, theta):
    """Apply 2D rotation to a pair of values."""
    x_j_new = x_j * np.cos(theta) - x_j1 * np.sin(theta)
    x_j1_new = x_j * np.sin(theta) + x_j1 * np.cos(theta)
    return x_j_new, x_j1_new

# Example: rotate (1, 0) by 45 degrees
x_j, x_j1 = 1.0, 0.0
theta = np.pi / 4
rotate_pair(x_j, x_j1, theta)

## Multi-Frequency Encoding: Local and Global Context
Each pair of dimensions is rotated by a different frequency, determined by θ_i = 1 / (base^(2i/d)). Early pairs (small i) rotate quickly, capturing local relationships. Later pairs (large i) rotate slowly, capturing global context. This multi-scale approach is more powerful than using a single frequency.

In [None]:
import plotly.graph_objs as go
import numpy as np

positions = np.arange(0, 50)
d = 8  # example dimensionality
base = 10000

traces = []
for i in range(d // 2):
    theta_i = 1 / (base ** (2 * i / d))
    angles = positions * theta_i
    x = np.cos(angles)
    y = np.sin(angles)
    traces.append(go.Scatter(x=x, y=y, mode='lines', name=f'Pair {i}'))

layout = go.Layout(title='RoPE: Rotation Trajectories for Different Dimension Pairs',
                   xaxis=dict(title='cos(mθ_i)'),
                   yaxis=dict(title='sin(mθ_i)'))
fig = go.Figure(data=traces, layout=layout)
fig.show()

## Relative Position from Absolute Rotations
A key property of RoPE is that, after rotation, the dot product between a query at position m and a key at position n depends only on their relative distance (m - n). This means the attention mechanism becomes sensitive to relative positions, which is crucial for modeling language context.

In [None]:
def rope_dot_product(q, k, m, n, base=10000):
    d = len(q)
    result = 0.0
    for i in range(0, d, 2):
        theta_i = 1 / (base ** (2 * (i//2) / d))
        angle_m = m * theta_i
        angle_n = n * theta_i
        # Rotate q and k
        q_rot = np.array([q[i] * np.cos(angle_m) - q[i+1] * np.sin(angle_m),
                          q[i] * np.sin(angle_m) + q[i+1] * np.cos(angle_m)])
        k_rot = np.array([k[i] * np.cos(angle_n) - k[i+1] * np.sin(angle_n),
                          k[i] * np.sin(angle_n) + k[i+1] * np.cos(angle_n)])
        result += np.dot(q_rot, k_rot)
    return result

# Example: dot product for different relative positions
q = np.random.randn(8)
k = np.random.randn(8)
rel_positions = np.arange(-10, 11)
dots = [rope_dot_product(q, k, m=0, n=rp) for rp in rel_positions]

import plotly.express as px
fig = px.line(x=rel_positions, y=dots, labels={'x':'Relative Position (n)', 'y':'Dot Product'},
              title='RoPE Dot Product vs. Relative Position')
fig.show()

## Understanding RoPE Dot Product vs. Relative Position

This graph illustrates one of the most remarkable properties of Rotary Position Embedding: how the dot product between query and key vectors varies based on their relative position, not their absolute positions.

**What this visualization shows:**

- **Dot product magnitude varies with relative distance**: The y-axis shows the dot product value between a query at position 0 and keys at various relative positions (x-axis). This demonstrates how attention scores in RoPE naturally depend on how far apart tokens are.

- **Symmetry around zero**: Notice how the dot product pattern is roughly symmetric around relative position zero. This means that the attention mechanism treats tokens at equal distances before and after the current position similarly, while still being able to distinguish direction.

- **Periodic patterns**: The dot product shows a complex periodic pattern resulting from the multiple rotation frequencies used in different dimension pairs. This rich signal helps the model learn nuanced relationships between tokens based on their spacing.

- **Graceful decay**: As the relative distance increases in either direction, the dot product typically decreases, aligning with the intuition that nearby tokens are often more relevant than distant ones.

This property is crucial for language modeling because it allows the attention mechanism to focus on contextually relevant tokens based on their relative positions, regardless of where they appear in the absolute sequence.

# The "Wow" Factors of RoPE
- **Relative from Absolute**: Encodes absolute positions via rotation, but attention scores depend on relative positions.
- **Multi-Scale Positional Information**: Different rotation frequencies capture both local and global context.
- **Excellent Extrapolation**: Generalizes well to longer sequences due to the periodic nature of rotations.
- **No Extra Learnable Parameters**: RoPE uses fixed transformations, keeping the model efficient.
- **Deep Integration**: RoPE is applied directly to Q and K vectors, making it a core part of the attention mechanism.

# Transformer Implementation with PyTorch

In this section, we'll implement a complete Transformer model with both traditional sinusoidal positional encoding and Rotary Position Embedding (RoPE). We'll use this to demonstrate how these positional encodings work in practice for next word prediction tasks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Traditional Positional Encoding

First, let's implement the original sinusoidal positional encoding from the "Attention Is All You Need" paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x has shape [batch_size, seq_len, embedding_dim]
        # Add positional encoding to embeddings
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize the traditional positional encoding
def visualize_positional_encoding(model_dim=64, seq_length=100):
    pos_enc = PositionalEncoding(model_dim)
    dummy_input = torch.zeros(1, seq_length, model_dim)
    encoded = pos_enc(dummy_input)[0]  # Remove batch dimension
    
    # Create a heatmap with plotly
    fig = px.imshow(encoded.detach().numpy(),
                    labels=dict(x="Dimension", y="Position", color="Value"),
                    title="Traditional Sinusoidal Positional Encoding",
                    color_continuous_scale="RdBu_r",
                    zmin=-1, zmax=1)
    fig.update_layout(width=800, height=500)
    return fig

# Display the visualization
visualize_positional_encoding()

## Understanding Traditional Positional Encoding Visualization

The heatmap above visualizes the traditional sinusoidal positional encoding used in the original Transformer architecture. This encoding is crucial because Transformers process tokens in parallel, with no inherent sense of their order in the sequence.

**Key observations:**

- **Vertical patterns**: Each column represents a dimension in the embedding space. Note how some dimensions change rapidly (high-frequency components) while others change slowly (low-frequency components).
  
- **Position uniqueness**: Each row represents a unique position encoding vector, ensuring that each position in a sequence gets a distinctive representation.

- **Sinusoidal pattern**: The encoding uses sine and cosine functions with different frequencies, creating the wave-like patterns you see. This approach allows the model to generalize to sequence lengths it hasn't seen during training.

- **Color variations**: Blue represents negative values, red represents positive values, with white near zero. These variations help the model distinguish between different positions.

This encoding is added directly to token embeddings before they're processed by the attention mechanism. However, this approach has limitations in capturing relative positions effectively, which is addressed by RoPE.

## Rotary Position Embedding (RoPE)

Now, let's implement RoPE, which applies rotation to query and key vectors in self-attention rather than adding positional vectors to the embeddings.

In [None]:
class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding implementation."""
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        # Create position indices
        position = torch.arange(seq_len, device=self.inv_freq.device).float()
        # Compute angles for each position and frequency
        angles = position.unsqueeze(1) * self.inv_freq.unsqueeze(0)
        # Return sines and cosines for easier rotation
        return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

    def _apply_rotary_pos_emb(self, x, cos, sin):
        # x shape: [batch_size, seq_len, dim]
        # Reshape x for easier rotation
        x_reshape = x.reshape(*x.shape[:-1], -1, 2)
        x1, x2 = x_reshape.unbind(-1)  # Split into pairs
        
        # Apply rotation using the rotation matrix [cos -sin; sin cos]
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
        
        # Stack and reshape back
        return torch.stack([out1, out2], dim=-1).flatten(-2)

    def apply_rotary_embedding(self, q, k, seq_len):
        # Get rotation angles
        cos_sin = self.forward(seq_len)
        cos, sin = cos_sin[..., 0], cos_sin[..., 1]
        
        # Reshape for broadcasting
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        
        # Apply rotations to q and k
        q_rotated = self._apply_rotary_pos_emb(q, cos, sin)
        k_rotated = self._apply_rotary_pos_emb(k, cos, sin)
        
        return q_rotated, k_rotated

# Function to visualize rotation trajectories in RoPE
def visualize_rope_trajectories(dim=64, seq_length=50, base=10000):
    # Compute frequencies
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(seq_length)
    
    # Create subplots for different frequency pairs
    pairs_to_plot = [0, dim//8, dim//4, 3*dim//8]  # Plot 4 different frequency pairs
    fig = make_subplots(rows=2, cols=2, 
                       subplot_titles=[f"Dimension Pair {i}" for i in pairs_to_plot])
    
    for i, pair_idx in enumerate(pairs_to_plot):
        # Calculate row and column for subplot
        row = i // 2 + 1
        col = i % 2 + 1
        
        # Get frequency for this pair
        freq = freqs[pair_idx].item()
        
        # Compute rotation angles
        angles = positions.float() * freq
        
        # Compute x, y coordinates on the circle
        x = torch.cos(angles)
        y = torch.sin(angles)
        
        # Create a color gradient based on position
        colors = positions.numpy()
        
        # Add scatter trace to subplot
        fig.add_trace(
            go.Scatter(x=x, y=y, mode='lines+markers',
                      marker=dict(color=colors, colorscale='Viridis', size=8),
                      line=dict(color='rgba(150,150,150,0.3)'),
                      text=[f"Pos {p}" for p in positions],
                      hoverinfo='text',
                      name=f"Pair {pair_idx}"),
            row=row, col=col
        )
        
        # Add annotations for specific positions
        special_pos = [0, 10, 20, 30, 40]
        for pos in special_pos:
            if pos < seq_length:
                fig.add_annotation(
                    x=x[pos].item(), y=y[pos].item(),
                    text=str(pos),
                    showarrow=True, arrowhead=2,
                    arrowsize=1, arrowwidth=1, ax=0, ay=-20,
                    row=row, col=col
                )
    
    # Update layout
    fig.update_layout(
        title="RoPE: Rotation Trajectories for Different Dimension Pairs",
        height=700, width=900,
        showlegend=False,
    )
    
    # Update axis properties for all subplots
    for i in range(1, 3):
        for j in range(1, 3):
            fig.update_xaxes(title_text="cos(mθ)", range=[-1.1, 1.1], row=i, col=j)
            fig.update_yaxes(title_text="sin(mθ)", range=[-1.1, 1.1], row=i, col=j)
    
    return fig

# Display the visualization
visualize_rope_trajectories()

## Understanding RoPE Rotation Trajectories

This visualization demonstrates how Rotary Position Embedding (RoPE) operates by rotating vector pairs in different dimensions at various frequencies. Unlike traditional positional encoding that adds position vectors to embeddings, RoPE directly rotates query and key vectors in the attention mechanism.

**Key insights from the visualization:**

- **Multiple rotation frequencies**: Each subplot shows how a different dimension pair rotates at its own frequency. Early pairs (Dimension Pair 0) rotate quickly with position changes, while later pairs (Dimension Pair 24, 32, 48) rotate more slowly.

- **Position mapping to angles**: Each colored dot represents a position in the sequence, mapped to a specific angle on the unit circle. The color gradient helps track how positions (0, 10, 20, etc.) move around the circle.

- **Multi-scale representation**: The varying rotation speeds create a multi-scale representation of position. Fast-rotating pairs capture local, fine-grained positional relationships, while slow-rotating pairs capture global context.

- **Circular nature**: The circular trajectories ensure that the encoding is bounded and well-behaved numerically, avoiding the exploding or vanishing issues that can occur with some other position encodings.

By using these rotations instead of additive encodings, RoPE elegantly integrates positional information into the attention mechanism while preserving the ability to model relative positions effectively.

## The Core Mechanism: Rotation in 2D Space
RoPE operates by pairing the dimensions of the query and key vectors and applying a 2D rotation to each pair. For a d-dimensional vector, there are d/2 pairs. The rotation angle for each pair depends on the position and the pair index, introducing multiple frequencies into the encoding.

In [None]:
import numpy as np

def rotate_pair(x_j, x_j1, theta):
    """Apply 2D rotation to a pair of values."""
    x_j_new = x_j * np.cos(theta) - x_j1 * np.sin(theta)
    x_j1_new = x_j * np.sin(theta) + x_j1 * np.cos(theta)
    return x_j_new, x_j1_new

# Example: rotate (1, 0) by 45 degrees
x_j, x_j1 = 1.0, 0.0
theta = np.pi / 4
rotate_pair(x_j, x_j1, theta)

## Multi-Frequency Encoding: Local and Global Context
Each pair of dimensions is rotated by a different frequency, determined by θ_i = 1 / (base^(2i/d)). Early pairs (small i) rotate quickly, capturing local relationships. Later pairs (large i) rotate slowly, capturing global context. This multi-scale approach is more powerful than using a single frequency.

In [None]:
import plotly.graph_objs as go
import numpy as np

positions = np.arange(0, 50)
d = 8  # example dimensionality
base = 10000

traces = []
for i in range(d // 2):
    theta_i = 1 / (base ** (2 * i / d))
    angles = positions * theta_i
    x = np.cos(angles)
    y = np.sin(angles)
    traces.append(go.Scatter(x=x, y=y, mode='lines', name=f'Pair {i}'))

layout = go.Layout(title='RoPE: Rotation Trajectories for Different Dimension Pairs',
                   xaxis=dict(title='cos(mθ_i)'),
                   yaxis=dict(title='sin(mθ_i)'))
fig = go.Figure(data=traces, layout=layout)
fig.show()

## Relative Position from Absolute Rotations
A key property of RoPE is that, after rotation, the dot product between a query at position m and a key at position n depends only on their relative distance (m - n). This means the attention mechanism becomes sensitive to relative positions, which is crucial for modeling language context.

In [None]:
def rope_dot_product(q, k, m, n, base=10000):
    d = len(q)
    result = 0.0
    for i in range(0, d, 2):
        theta_i = 1 / (base ** (2 * (i//2) / d))
        angle_m = m * theta_i
        angle_n = n * theta_i
        # Rotate q and k
        q_rot = np.array([q[i] * np.cos(angle_m) - q[i+1] * np.sin(angle_m),
                          q[i] * np.sin(angle_m) + q[i+1] * np.cos(angle_m)])
        k_rot = np.array([k[i] * np.cos(angle_n) - k[i+1] * np.sin(angle_n),
                          k[i] * np.sin(angle_n) + k[i+1] * np.cos(angle_n)])
        result += np.dot(q_rot, k_rot)
    return result

# Example: dot product for different relative positions
q = np.random.randn(8)
k = np.random.randn(8)
rel_positions = np.arange(-10, 11)
dots = [rope_dot_product(q, k, m=0, n=rp) for rp in rel_positions]

import plotly.express as px
fig = px.line(x=rel_positions, y=dots, labels={'x':'Relative Position (n)', 'y':'Dot Product'},
              title='RoPE Dot Product vs. Relative Position')
fig.show()

## Understanding RoPE Dot Product vs. Relative Position

This graph illustrates one of the most remarkable properties of Rotary Position Embedding: how the dot product between query and key vectors varies based on their relative position, not their absolute positions.

**What this visualization shows:**

- **Dot product magnitude varies with relative distance**: The y-axis shows the dot product value between a query at position 0 and keys at various relative positions (x-axis). This demonstrates how attention scores in RoPE naturally depend on how far apart tokens are.

- **Symmetry around zero**: Notice how the dot product pattern is roughly symmetric around relative position zero. This means that the attention mechanism treats tokens at equal distances before and after the current position similarly, while still being able to distinguish direction.

- **Periodic patterns**: The dot product shows a complex periodic pattern resulting from the multiple rotation frequencies used in different dimension pairs. This rich signal helps the model learn nuanced relationships between tokens based on their spacing.

- **Graceful decay**: As the relative distance increases in either direction, the dot product typically decreases, aligning with the intuition that nearby tokens are often more relevant than distant ones.

This property is crucial for language modeling because it allows the attention mechanism to focus on contextually relevant tokens based on their relative positions, regardless of where they appear in the absolute sequence.

# The "Wow" Factors of RoPE
- **Relative from Absolute**: Encodes absolute positions via rotation, but attention scores depend on relative positions.
- **Multi-Scale Positional Information**: Different rotation frequencies capture both local and global context.
- **Excellent Extrapolation**: Generalizes well to longer sequences due to the periodic nature of rotations.
- **No Extra Learnable Parameters**: RoPE uses fixed transformations, keeping the model efficient.
- **Deep Integration**: RoPE is applied directly to Q and K vectors, making it a core part of the attention mechanism.

# Transformer Implementation with PyTorch

In this section, we'll implement a complete Transformer model with both traditional sinusoidal positional encoding and Rotary Position Embedding (RoPE). We'll use this to demonstrate how these positional encodings work in practice for next word prediction tasks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Traditional Positional Encoding

First, let's implement the original sinusoidal positional encoding from the "Attention Is All You Need" paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x has shape [batch_size, seq_len, embedding_dim]
        # Add positional encoding to embeddings
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize the traditional positional encoding
def visualize_positional_encoding(model_dim=64, seq_length=100):
    pos_enc = PositionalEncoding(model_dim)
    dummy_input = torch.zeros(1, seq_length, model_dim)
    encoded = pos_enc(dummy_input)[0]  # Remove batch dimension
    
    # Create a heatmap with plotly
    fig = px.imshow(encoded.detach().numpy(),
                    labels=dict(x="Dimension", y="Position", color="Value"),
                    title="Traditional Sinusoidal Positional Encoding",
                    color_continuous_scale="RdBu_r",
                    zmin=-1, zmax=1)
    fig.update_layout(width=800, height=500)
    return fig

# Display the visualization
visualize_positional_encoding()

## Understanding Traditional Positional Encoding Visualization

The heatmap above visualizes the traditional sinusoidal positional encoding used in the original Transformer architecture. This encoding is crucial because Transformers process tokens in parallel, with no inherent sense of their order in the sequence.

**Key observations:**

- **Vertical patterns**: Each column represents a dimension in the embedding space. Note how some dimensions change rapidly (high-frequency components) while others change slowly (low-frequency components).
  
- **Position uniqueness**: Each row represents a unique position encoding vector, ensuring that each position in a sequence gets a distinctive representation.

- **Sinusoidal pattern**: The encoding uses sine and cosine functions with different frequencies, creating the wave-like patterns you see. This approach allows the model to generalize to sequence lengths it hasn't seen during training.

- **Color variations**: Blue represents negative values, red represents positive values, with white near zero. These variations help the model distinguish between different positions.

This encoding is added directly to token embeddings before they're processed by the attention mechanism. However, this approach has limitations in capturing relative positions effectively, which is addressed by RoPE.

## Rotary Position Embedding (RoPE)

Now, let's implement RoPE, which applies rotation to query and key vectors in self-attention rather than adding positional vectors to the embeddings.

In [None]:
class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding implementation."""
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        # Create position indices
        position = torch.arange(seq_len, device=self.inv_freq.device).float()
        # Compute angles for each position and frequency
        angles = position.unsqueeze(1) * self.inv_freq.unsqueeze(0)
        # Return sines and cosines for easier rotation
        return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

    def _apply_rotary_pos_emb(self, x, cos, sin):
        # x shape: [batch_size, seq_len, dim]
        # Reshape x for easier rotation
        x_reshape = x.reshape(*x.shape[:-1], -1, 2)
        x1, x2 = x_reshape.unbind(-1)  # Split into pairs
        
        # Apply rotation using the rotation matrix [cos -sin; sin cos]
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
        
        # Stack and reshape back
        return torch.stack([out1, out2], dim=-1).flatten(-2)

    def apply_rotary_embedding(self, q, k, seq_len):
        # Get rotation angles
        cos_sin = self.forward(seq_len)
        cos, sin = cos_sin[..., 0], cos_sin[..., 1]
        
        # Reshape for broadcasting
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        
        # Apply rotations to q and k
        q_rotated = self._apply_rotary_pos_emb(q, cos, sin)
        k_rotated = self._apply_rotary_pos_emb(k, cos, sin)
        
        return q_rotated, k_rotated

# Function to visualize rotation trajectories in RoPE
def visualize_rope_trajectories(dim=64, seq_length=50, base=10000):
    # Compute frequencies
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(seq_length)
    
    # Create subplots for different frequency pairs
    pairs_to_plot = [0, dim//8, dim//4, 3*dim//8]  # Plot 4 different frequency pairs
    fig = make_subplots(rows=2, cols=2, 
                       subplot_titles=[f"Dimension Pair {i}" for i in pairs_to_plot])
    
    for i, pair_idx in enumerate(pairs_to_plot):
        # Calculate row and column for subplot
        row = i // 2 + 1
        col = i % 2 + 1
        
        # Get frequency for this pair
        freq = freqs[pair_idx].item()
        
        # Compute rotation angles
        angles = positions.float() * freq
        
        # Compute x, y coordinates on the circle
        x = torch.cos(angles)
        y = torch.sin(angles)
        
        # Create a color gradient based on position
        colors = positions.numpy()
        
        # Add scatter trace to subplot
        fig.add_trace(
            go.Scatter(x=x, y=y, mode='lines+markers',
                      marker=dict(color=colors, colorscale='Viridis', size=8),
                      line=dict(color='rgba(150,150,150,0.3)'),
                      text=[f"Pos {p}" for p in positions],
                      hoverinfo='text',
                      name=f"Pair {pair_idx}"),
            row=row, col=col
        )
        
        # Add annotations for specific positions
        special_pos = [0, 10, 20, 30, 40]
        for pos in special_pos:
            if pos < seq_length:
                fig.add_annotation(
                    x=x[pos].item(), y=y[pos].item(),
                    text=str(pos),
                    showarrow=True, arrowhead=2,
                    arrowsize=1, arrowwidth=1, ax=0, ay=-20,
                    row=row, col=col
                )
    
    # Update layout
    fig.update_layout(
        title="RoPE: Rotation Trajectories for Different Dimension Pairs",
        height=700, width=900,
        showlegend=False,
    )
    
    # Update axis properties for all subplots
    for i in range(1, 3):
        for j in range(1, 3):
            fig.update_xaxes(title_text="cos(mθ)", range=[-1.1, 1.1], row=i, col=j)
            fig.update_yaxes(title_text="sin(mθ)", range=[-1.1, 1.1], row=i, col=j)
    
    return fig

# Display the visualization
visualize_rope_trajectories()

## Understanding RoPE Rotation Trajectories

This visualization demonstrates how Rotary Position Embedding (RoPE) operates by rotating vector pairs in different dimensions at various frequencies. Unlike traditional positional encoding that adds position vectors to embeddings, RoPE directly rotates query and key vectors in the attention mechanism.

**Key insights from the visualization:**

- **Multiple rotation frequencies**: Each subplot shows how a different dimension pair rotates at its own frequency. Early pairs (Dimension Pair 0) rotate quickly with position changes, while later pairs (Dimension Pair 24, 32, 48) rotate more slowly.

- **Position mapping to angles**: Each colored dot represents a position in the sequence, mapped to a specific angle on the unit circle. The color gradient helps track how positions (0, 10, 20, etc.) move around the circle.

- **Multi-scale representation**: The varying rotation speeds create a multi-scale representation of position. Fast-rotating pairs capture local, fine-grained positional relationships, while slow-rotating pairs capture global context.

- **Circular nature**: The circular trajectories ensure that the encoding is bounded and well-behaved numerically, avoiding the exploding or vanishing issues that can occur with some other position encodings.

By using these rotations instead of additive encodings, RoPE elegantly integrates positional information into the attention mechanism while preserving the ability to model relative positions effectively.

## The Core Mechanism: Rotation in 2D Space
RoPE operates by pairing the dimensions of the query and key vectors and applying a 2D rotation to each pair. For a d-dimensional vector, there are d/2 pairs. The rotation angle for each pair depends on the position and the pair index, introducing multiple frequencies into the encoding.

In [None]:
import numpy as np

def rotate_pair(x_j, x_j1, theta):
    """Apply 2D rotation to a pair of values."""
    x_j_new = x_j * np.cos(theta) - x_j1 * np.sin(theta)
    x_j1_new = x_j * np.sin(theta) + x_j1 * np.cos(theta)
    return x_j_new, x_j1_new

# Example: rotate (1, 0) by 45 degrees
x_j, x_j1 = 1.0, 0.0
theta = np.pi / 4
rotate_pair(x_j, x_j1, theta)

## Multi-Frequency Encoding: Local and Global Context
Each pair of dimensions is rotated by a different frequency, determined by θ_i = 1 / (base^(2i/d)). Early pairs (small i) rotate quickly, capturing local relationships. Later pairs (large i) rotate slowly, capturing global context. This multi-scale approach is more powerful than using a single frequency.

In [None]:
import plotly.graph_objs as go
import numpy as np

positions = np.arange(0, 50)
d = 8  # example dimensionality
base = 10000

traces = []
for i in range(d // 2):
    theta_i = 1 / (base ** (2 * i / d))
    angles = positions * theta_i
    x = np.cos(angles)
    y = np.sin(angles)
    traces.append(go.Scatter(x=x, y=y, mode='lines', name=f'Pair {i}'))

layout = go.Layout(title='RoPE: Rotation Trajectories for Different Dimension Pairs',
                   xaxis=dict(title='cos(mθ_i)'),
                   yaxis=dict(title='sin(mθ_i)'))
fig = go.Figure(data=traces, layout=layout)
fig.show()

## Relative Position from Absolute Rotations
A key property of RoPE is that, after rotation, the dot product between a query at position m and a key at position n depends only on their relative distance (m - n). This means the attention mechanism becomes sensitive to relative positions, which is crucial for modeling language context.

In [None]:
def rope_dot_product(q, k, m, n, base=10000):
    d = len(q)
    result = 0.0
    for i in range(0, d, 2):
        theta_i = 1 / (base ** (2 * (i//2) / d))
        angle_m = m * theta_i
        angle_n = n * theta_i
        # Rotate q and k
        q_rot = np.array([q[i] * np.cos(angle_m) - q[i+1] * np.sin(angle_m),
                          q[i] * np.sin(angle_m) + q[i+1] * np.cos(angle_m)])
        k_rot = np.array([k[i] * np.cos(angle_n) - k[i+1] * np.sin(angle_n),
                          k[i] * np.sin(angle_n) + k[i+1] * np.cos(angle_n)])
        result += np.dot(q_rot, k_rot)
    return result

# Example: dot product for different relative positions
q = np.random.randn(8)
k = np.random.randn(8)
rel_positions = np.arange(-10, 11)
dots = [rope_dot_product(q, k, m=0, n=rp) for rp in rel_positions]

import plotly.express as px
fig = px.line(x=rel_positions, y=dots, labels={'x':'Relative Position (n)', 'y':'Dot Product'},
              title='RoPE Dot Product vs. Relative Position')
fig.show()

## Understanding RoPE Dot Product vs. Relative Position

This graph illustrates one of the most remarkable properties of Rotary Position Embedding: how the dot product between query and key vectors varies based on their relative position, not their absolute positions.

**What this visualization shows:**

- **Dot product magnitude varies with relative distance**: The y-axis shows the dot product value between a query at position 0 and keys at various relative positions (x-axis). This demonstrates how attention scores in RoPE naturally depend on how far apart tokens are.

- **Symmetry around zero**: Notice how the dot product pattern is roughly symmetric around relative position zero. This means that the attention mechanism treats tokens at equal distances before and after the current position similarly, while still being able to distinguish direction.

- **Periodic patterns**: The dot product shows a complex periodic pattern resulting from the multiple rotation frequencies used in different dimension pairs. This rich signal helps the model learn nuanced relationships between tokens based on their spacing.

- **Graceful decay**: As the relative distance increases in either direction, the dot product typically decreases, aligning with the intuition that nearby tokens are often more relevant than distant ones.

This property is crucial for language modeling because it allows the attention mechanism to focus on contextually relevant tokens based on their relative positions, regardless of where they appear in the absolute sequence.

# The "Wow" Factors of RoPE
- **Relative from Absolute**: Encodes absolute positions via rotation, but attention scores depend on relative positions.
- **Multi-Scale Positional Information**: Different rotation frequencies capture both local and global context.
- **Excellent Extrapolation**: Generalizes well to longer sequences due to the periodic nature of rotations.
- **No Extra Learnable Parameters**: RoPE uses fixed transformations, keeping the model efficient.
- **Deep Integration**: RoPE is applied directly to Q and K vectors, making it a core part of the attention mechanism.

# Transformer Implementation with PyTorch

In this section, we'll implement a complete Transformer model with both traditional sinusoidal positional encoding and Rotary Position Embedding (RoPE). We'll use this to demonstrate how these positional encodings work in practice for next word prediction tasks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Traditional Positional Encoding

First, let's implement the original sinusoidal positional encoding from the "Attention Is All You Need" paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x has shape [batch_size, seq_len, embedding_dim]
        # Add positional encoding to embeddings
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize the traditional positional encoding
def visualize_positional_encoding(model_dim=64, seq_length=100):
    pos_enc = PositionalEncoding(model_dim)
    dummy_input = torch.zeros(1, seq_length, model_dim)
    encoded = pos_enc(dummy_input)[0]  # Remove batch dimension
    
    # Create a heatmap with plotly
    fig = px.imshow(encoded.detach().numpy(),
                    labels=dict(x="Dimension", y="Position", color="Value"),
                    title="Traditional Sinusoidal Positional Encoding",
                    color_continuous_scale="RdBu_r",
                    zmin=-1, zmax=1)
    fig.update_layout(width=800, height=500)
    return fig

# Display the visualization
visualize_positional_encoding()

## Understanding Traditional Positional Encoding Visualization

The heatmap above visualizes the traditional sinusoidal positional encoding used in the original Transformer architecture. This encoding is crucial because Transformers process tokens in parallel, with no inherent sense of their order in the sequence.

**Key observations:**

- **Vertical patterns**: Each column represents a dimension in the embedding space. Note how some dimensions change rapidly (high-frequency components) while others change slowly (low-frequency components).
  
- **Position uniqueness**: Each row represents a unique position encoding vector, ensuring that each position in a sequence gets a distinctive representation.

- **Sinusoidal pattern**: The encoding uses sine and cosine functions with different frequencies, creating the wave-like patterns you see. This approach allows the model to generalize to sequence lengths it hasn't seen during training.

- **Color variations**: Blue represents negative values, red represents positive values, with white near zero. These variations help the model distinguish between different positions.

This encoding is added directly to token embeddings before they're processed by the attention mechanism. However, this approach has limitations in capturing relative positions effectively, which is addressed by RoPE.

## Rotary Position Embedding (RoPE)

Now, let's implement RoPE, which applies rotation to query and key vectors in self-attention rather than adding positional vectors to the embeddings.

In [None]:
class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding implementation."""
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        # Create position indices
        position = torch.arange(seq_len, device=self.inv_freq.device).float()
        # Compute angles for each position and frequency
        angles = position.unsqueeze(1) * self.inv_freq.unsqueeze(0)
        # Return sines and cosines for easier rotation
        return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

    def _apply_rotary_pos_emb(self, x, cos, sin):
        # x shape: [batch_size, seq_len, dim]
        # Reshape x for easier rotation
        x_reshape = x.reshape(*x.shape[:-1], -1, 2)
        x1, x2 = x_reshape.unbind(-1)  # Split into pairs
        
        # Apply rotation using the rotation matrix [cos -sin; sin cos]
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
        
        # Stack and reshape back
        return torch.stack([out1, out2], dim=-1).flatten(-2)

    def apply_rotary_embedding(self, q, k, seq_len):
        # Get rotation angles
        cos_sin = self.forward(seq_len)
        cos, sin = cos_sin[..., 0], cos_sin[..., 1]
        
        # Reshape for broadcasting
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        sin = sin.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, dim/2]
        
        # Apply rotations to q and k
        q_rotated = self._apply_rotary_pos_emb(q, cos, sin)
        k_rotated = self._apply_rotary_pos_emb(k, cos, sin)
        
        return q_rotated, k_rotated

# Function to visualize rotation trajectories in RoPE
def visualize_rope_trajectories(dim=64, seq_length=50, base=10000):
    # Compute frequencies
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(seq_length)
    
    # Create subplots for different frequency pairs
    pairs_to_plot = [0, dim//8, dim//4, 3*dim//8]  # Plot 4 different frequency pairs
    fig = make_subplots(rows=2, cols=2, 
                       subplot_titles=[f"Dimension Pair {i}" for i in pairs_to_plot])
    
    for i, pair_idx in enumerate(pairs_to_plot):
        # Calculate row and column for subplot
        row = i // 2 + 1
        col = i % 2 + 1
        
        # Get frequency for this pair
        freq = freqs[pair_idx].item()
        
        # Compute rotation angles
        angles = positions.float() * freq
        
        # Compute x, y coordinates on the circle
        x = torch.cos(angles)
        y = torch.sin(angles)
        
        # Create a color gradient based on position
        colors = positions.numpy()
        
        # Add scatter trace to subplot
        fig.add_trace(
            go.Scatter(x=x, y=y, mode='lines+markers',
                      marker=dict(color=colors, colorscale='Viridis', size=8),
                      line=dict(color='rgba(150,150,150,0.3)'),
                      text=[f"Pos {p}" for p in positions],
                      hoverinfo='text',
                      name=f"Pair {pair_idx}"),
            row=row, col=col
        )
        
        # Add annotations for specific positions
        special_pos = [0, 10, 20, 30, 40]
        for pos in special_pos:
            if pos < seq_length:
                fig.add_annotation(
                    x=x[pos].item(), y=y[pos].item(),
                    text=str(pos),
                    showarrow=True, arrowhead=2,
                    arrowsize=1, arrowwidth=1, ax=0, ay=-20,
                    row=row, col=col
                )
    
    # Update layout
    fig.update_layout(
        title="RoPE: Rotation Trajectories for Different Dimension Pairs",
        height=700, width=900,
        showlegend=False,
    )
    
    # Update axis properties for all subplots
    for i in range(1, 3):
        for j in range(1, 3):
            fig.update_xaxes(title_text="cos(mθ)", range=[-1.1, 1.1], row=i, col=j)
            fig.update_yaxes(title_text="sin(mθ)", range=[-1.1, 1.1], row=i, col=j)
    
    return fig

# Display the visualization
visualize_rope_trajectories()

## Understanding RoPE Rotation Trajectories

This visualization demonstrates how Rotary Position Embedding (RoPE) operates by rotating vector pairs in different dimensions at various frequencies. Unlike traditional positional encoding that adds position vectors to embeddings, RoPE directly rotates query and key vectors in the attention mechanism.

**Key insights from the visualization:**

- **Multiple rotation frequencies**: Each subplot shows how a different dimension pair rotates at its own frequency. Early pairs (Dimension Pair 0) rotate quickly with position changes, while later pairs (Dimension Pair 24, 32, 48) rotate more slowly.

- **Position mapping to angles**: Each colored dot represents a position in the sequence, mapped to a specific angle on the unit circle. The color gradient helps track how positions (0, 10, 20, etc.) move around the circle.

- **Multi-scale representation**: The varying rotation speeds create a multi-scale representation of position. Fast-rotating pairs capture local, fine-grained positional relationships, while slow-rotating pairs capture global context.

- **Circular nature**: The circular trajectories ensure that the encoding is bounded and well-behaved numerically, avoiding the exploding or vanishing issues that can occur with some other position encodings.

By using these rotations instead of additive encodings, RoPE elegantly integrates positional information into the attention mechanism while preserving the ability to model relative positions effectively.

## The Core Mechanism: Rotation in 2D Space
RoPE operates by pairing the dimensions of the query and key vectors and applying a 2D rotation to each pair. For a d-dimensional vector, there are d/2 pairs. The rotation angle for each pair depends on the position and the pair index, introducing multiple frequencies into the encoding.

In [None]:
import numpy as np

def rotate_pair(x_j, x_j1, theta):
    """Apply 2D rotation to a pair of values."""
    x_j_new = x_j * np.cos(theta) - x_j1 * np.sin(theta)
    x_j1_new = x_j * np.sin(theta) + x_j1 * np.cos(theta)
    return x_j_new, x_j1_new

# Example: rotate (1, 0) by 45 degrees
x_j, x_j1 = 1.0, 0.0
theta = np.pi / 4
rotate_pair(x_j, x_j1, theta)

## Multi-Frequency Encoding: Local and Global Context
Each pair of dimensions is rotated by a different frequency, determined by θ_i = 1 / (base^(2i/d)). Early pairs (small i) rotate quickly, capturing local relationships. Later pairs (large i) rotate slowly, capturing global context. This multi-scale approach is more powerful than using a single frequency.

In [None]:
import plotly.graph_objs as go
import numpy as np

positions = np.arange(0, 50)
d = 8  # example dimensionality
base = 10000

traces = []
for i in range(d // 2):
    theta_i = 1 / (base ** (2 * i / d))
    angles = positions * theta_i
    x = np.cos(angles)
    y = np.sin(angles)
    traces.append(go.Scatter(x=x, y=y, mode='lines', name=f'Pair {i}'))

layout = go.Layout(title='RoPE: Rotation Trajectories for Different Dimension Pairs',
                   xaxis=dict(title='cos(mθ_i)'),
                   yaxis=dict(title='sin(mθ_i)'))
fig = go.Figure(data=traces, layout=layout)
fig.show()

## Relative Position from Absolute Rotations
A key property of RoPE is that, after rotation, the dot product between a query at position m and a key at position n depends only on their relative distance (m - n). This means the attention mechanism becomes sensitive to relative positions, which is crucial for modeling language context.

In [None]:
def rope_dot_product(q, k, m, n, base=10000):
    d = len(q)
    result = 0.0
    for i in range(0, d, 2):
        theta_i = 1 / (base ** (2 * (i//2) / d))
        angle_m = m * theta_i
        angle_n = n * theta_i
        # Rotate q and k
        q_rot = np.array([q[i] * np.cos(angle_m) - q[i+1] * np.sin(angle_m),
                          q[i] * np.sin(angle_m) + q[i+1] * np.cos(angle_m)])
        k_rot = np.array([k[i] * np.cos(angle_n) - k[i+1] * np.sin(angle_n),
                          k[i] * np.sin(angle_n) + k[i+1] * np.cos(angle_n)])
        result += np.dot(q_rot, k_rot)
    return result

# Example: dot product for different relative positions
q = np.random.randn(8)
k = np.random.randn(8)
rel_positions = np.arange(-10, 11)
dots = [rope_dot_product(q, k, m=0, n=rp) for rp in rel_positions]

import plotly.express as px
fig = px.line(x=rel_positions, y=dots, labels={'x':'Relative Position (n)', 'y':'Dot Product'},
              title='RoPE Dot Product vs. Relative Position')
fig.show()

## Understanding RoPE Dot Product vs. Relative Position

This graph illustrates one of the most remarkable properties of Rotary Position Embedding: how the dot product between query and key vectors varies based on their relative position, not their absolute positions.

**What this visualization shows:**

- **Dot product magnitude varies with relative distance**: The y-axis shows the dot product value between a query at position 0 and keys at various relative positions (x-axis). This demonstrates how attention scores in RoPE naturally depend on how far apart tokens are.

- **Symmetry around zero**: Notice how the dot product pattern is roughly symmetric around relative position zero. This means that the attention mechanism treats tokens at equal distances before and after the current position similarly, while still being able to distinguish direction.

- **Periodic patterns**: The dot product shows a complex periodic pattern resulting from the multiple rotation frequencies used in different dimension pairs. This rich signal helps the model learn nuanced relationships between tokens based on their spacing.

- **Graceful decay**: As the relative distance increases in either direction, the dot product typically decreases, aligning with the intuition that nearby tokens are often more relevant than distant ones.

This property is crucial for language modeling because it allows the attention mechanism to focus on contextually relevant tokens based on their relative positions, regardless of where they appear in the absolute sequence.

# The "Wow" Factors of RoPE
- **Relative from Absolute**: Encodes absolute positions via rotation, but attention scores depend on relative positions.
- **Multi-Scale Positional Information**: Different rotation frequencies capture both local and global context.
- **Excellent Extrapolation**: Generalizes well to longer sequences due to the periodic nature of rotations.
- **No Extra Learnable Parameters**: RoPE uses fixed transformations, keeping the model efficient.
- **Deep Integration**: RoPE is applied directly to Q and K vectors, making it a core part of the attention mechanism.

# Transformer Implementation with PyTorch

In this section, we'll implement a complete Transformer model with both traditional sinusoidal positional encoding and Rotary Position Embedding (RoPE). We'll use this to demonstrate how these positional encodings work in practice for next word prediction tasks.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Traditional Positional Encoding

First, let's implement the original sinusoidal positional encoding from the "Attention Is All You Need" paper.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term)
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Add batch dimension
        pe = pe.unsqueeze(0)
        
        # Register as buffer (not a parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x has shape [batch_size, seq_len, embedding_dim]
        # Add positional encoding to embeddings
        x = x + self.pe[:, :x.size(1), :]
        return x

# Visualize the traditional positional encoding
def visualize_positional_encoding(model_dim=64, seq_length=100):
    pos_enc = PositionalEncoding(model_dim)
    dummy_input = torch.zeros(1, seq_length, model_dim)
    encoded = pos_enc(dummy_input)[0]  # Remove batch dimension
    
    # Create a heatmap with plotly
    fig = px.imshow(encoded.detach().numpy(),
                    labels=dict(x="Dimension", y="Position", color="Value"),
                    title="Traditional Sinusoidal Positional Encoding",
                    color_continuous_scale="RdBu_r",
                    zmin=-1, zmax=1)
    fig.update_layout(width=800, height=500)
    return fig

# Display the visualization
visualize_positional_encoding()

## Understanding Traditional Positional Encoding Visualization

The heatmap above visualizes the traditional sinusoidal positional encoding used in the original Transformer architecture. This encoding is crucial because Transformers process tokens in parallel, with no inherent sense of their order in the sequence.

**Key observations:**

- **Vertical patterns**: Each column represents a dimension in the embedding space. Note how some dimensions change rapidly (high-frequency components) while others change slowly (low-frequency components).
  
- **Position uniqueness**: Each row represents a unique position encoding vector, ensuring that each position in a sequence gets a distinctive representation.

- **Sinusoidal pattern**: The encoding uses sine and cosine functions with different frequencies, creating the wave-like patterns you see. This approach allows the model to generalize to sequence lengths it hasn't seen during training.

- **Color variations**: Blue represents negative values, red represents positive values, with white near zero. These variations help the model distinguish between different positions.

This encoding is added directly to token embeddings before they're processed by the attention mechanism. However, this approach has limitations in capturing relative positions effectively, which is addressed by RoPE.

## Rotary Position Embedding (RoPE)

Now, let's implement RoPE, which applies rotation to query and key vectors in self-attention rather than adding positional vectors to the embeddings.

In [None]:
class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding implementation."""
    def __init__(self, dim, base=10000):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, seq_len):
        # Create position indices
        position = torch.arange(seq_len, device=self.inv_freq.device).float()
        # Compute angles for each position and frequency
        angles = position.unsqueeze(1) * self.inv_freq.unsqueeze(0)
        # Return sines and cosines for easier rotation
        return torch.stack([torch.cos(angles), torch.sin(angles)], dim=-1)

    def _apply_rotary_pos_emb(self, x, cos, sin):
        # x shape: [batch_size, seq_len, dim]
        # Reshape x for easier rotation
        x_reshape = x.reshape(*x.shape[:-1], -1, 2)
        x1, x2 = x_reshape.unbind(-1)  # Split into pairs
        
        # Apply rotation using the rotation matrix [cos -sin; sin cos]
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
        
        # Stack and reshape back
        return torch.stack([out1, out2], dim=-1).flatten(-2)

    def apply_rotary_embedding(self, q, k, seq_len):
        # Get rotation angles
        cos_sin = self