# Tutorial 3: Intro to Recurrent Neural Networks: Math, Training, and the Copy Task

# Instructor: Dr. Ankur Mali
# University of South Florida (Spring 2025)
### In this tutorial we will build RNNs based on equation and will compare 3 popular frameworks (Jax, TensorFlow and Pytorch)

## Vanilla RNN -- For more in depth explanation refer to your slides

### Forward Pass (Inference) -- Stage 1
Given an input at time \(t\):
\begin{aligned}
\mathbf{x}_t \in \mathbb{R}^{d_{\text{in}}},\quad \mathbf{h}_{t-1} \in \mathbb{R}^{d_{\text{hid}}}
\end{aligned}
we define RNN parameters:
\begin{aligned}
\mathbf{W}_x \in \mathbb{R}^{d_{\text{in}} \times d_{\text{hid}}}, \quad
\mathbf{W}_h \in \mathbb{R}^{d_{\text{hid}} \times d_{\text{hid}}}, \quad
\mathbf{b}_h \in \mathbb{R}^{d_{\text{hid}}}.
\end{aligned}

The hidden state update:
\begin{aligned}
\mathbf{h}_t = \tanh\Bigl(\mathbf{x}_t\,\mathbf{W}_x \;+\;\mathbf{h}_{t-1}\,\mathbf{W}_h \;+\;\mathbf{b}_h\Bigr).
\end{aligned}

Over a sequence  ($\mathbf{x}_1$, $\dots$, $\mathbf{x}_T$), we unroll:
\begin{aligned}
\mathbf{h}_0 = \mathbf{0},\quad
\mathbf{h}_1 = \tanh(\mathbf{x}_1 \mathbf{W}_x + \mathbf{h}_0 \mathbf{W}_h + \mathbf{b}_h),\,\dots,\,
\mathbf{h}_T = \tanh(\mathbf{x}_T \mathbf{W}_x + \mathbf{h}_{T-1} \mathbf{W}_h + \mathbf{b}_h).
\end{aligned}

Optionally, each hidden state  \($\mathbf{h}_t$\) can be projected to the output dimension $d_{\text{in}}$:
\begin{aligned}
\mathbf{\hat{y}}_t = \mathbf{h}_t \mathbf{W}_{\text{out}} + \mathbf{b}_{\text{out}}
\end{aligned}

<!-- $\mathbf{\hat{y}}$_t = $\mathbf{h}_t$,$\mathbf{W}_{\text{out}}$ + $\mathbf{b}_{\text{out}}$. -->


### Remaining Stages
We define a loss (Stage 2) over all time steps, for instance:
\begin{aligned}
\mathbf{L} = \frac{1}{T} \sum_{t=1}^T \left\|\,\mathbf{\hat{y}}_t - \mathbf{y}_t\,\right\|^2,
\end{aligned}
and use Backpropagation Through Time (BPTT) (Stage 3). An optimizer (e.g., Adam) updates parameters (Stage 4):
\begin{aligned}
\theta \,\leftarrow\, \theta \;-\; \eta \,\nabla_\theta \,\mathbf{L}.
\end{aligned}

---

## GRU

### Forward Pass (Inference)
A Gated Recurrent Unit includes reset $\mathbf{r}_t$ and update $\mathbf{z}_t$ gates:

\begin{aligned}
\mathbf{z}_t &= \sigma\!\bigl(\mathbf{x}_t \mathbf{W}_z + \mathbf{h}_{t-1}\,\mathbf{U}_z + \mathbf{b}_z\bigr), \\
\mathbf{r}_t &= \sigma\!\bigl(\mathbf{x}_t \mathbf{W}_r + \mathbf{h}_{t-1}\,\mathbf{U}_r + \mathbf{b}_r\bigr), \\
\tilde{\mathbf{h}}_t &= \tanh\!\bigl(\mathbf{x}_t \mathbf{W}_h + (\mathbf{r}_t \odot \mathbf{h}_{t-1})\,\mathbf{U}_h + \mathbf{b}_h\bigr), \\
\mathbf{h}_t &= (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} \;+\; \mathbf{z}_t \odot \tilde{\mathbf{h}}_t.
\end{aligned}

where $\sigma$ is the sigmoid function, and $\odot$ denotes elementwise multiplication.

### Remaining Stages
As in the vanilla RNN, define a loss $\mathbf{L}$ (e.g. MSE). The same BPTT logic applies, but the derivatives now include the GRU gating operations. Parameters (e.g., $\mathbf{W}_z, \mathbf{U}_z, \ldots$ ) are updated by any gradient-based optimizer.

---

## Optimizer
A typical training loop includes:

1. **Forward pass**: compute model outputs $\mathbf{\hat{y}}_t$.
2. **Loss computation**: $\mathbf{L}(\mathbf{\hat{y}}_t, \mathbf{y}_t)$.
3. **Backward pass**: compute $\nabla_\theta \mathbf{L}$ via BPTT.
4. **Parameter update**:
   \begin{aligned}
   \theta \leftarrow \theta - \eta \;\nabla_\theta \,\mathcal{L}.
   \end{aligned}
   (For example, using Adam, SGD, RMSProp, etc.)

---

## The Copy Task
The **copy task** is a simple sequence-to-sequence challenge:

- **Input**: a sequence of random vectors {$\mathbf{x}_1, \dots, \mathbf{x}_T$}.
- **Target**: the **same** sequence {$\mathbf{x}_1, \dots, \mathbf{x}_T$}.

Thus, the model should learn to produce $\mathbf{\hat{y}}_t \approx \mathbf{x}_t$ at each time step ($t$). It's a straightforward yet revealing test of a model’s capacity to retain and reproduce a sequence—particularly sensitive to the model’s ability to **remember** information over time.  


In [29]:
import torch
import tensorflow as tf
import jax
import jax.numpy as jnp
from jax import random, lax
import time
import numpy as np
from functools import partial

########################################
# Custom RNN Cell (Core Computation)
########################################

# ------- PyTorch Single-Step RNN Cell -------
class RNNCellPyTorch(torch.nn.Module):
    """
    A single-step RNN cell in PyTorch.
    h_t = tanh( x_t * W_x + h_{t-1} * W_h + b )
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        # For a single step:  input: (batch_size, input_size)
        #                    hidden: (batch_size, hidden_size)
        self.W_x = torch.nn.Parameter(torch.randn(input_size, hidden_size) * 0.1)
        self.W_h = torch.nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.1)
        self.b_h = torch.nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x_t, h_prev):
        # x_t: [batch_size, input_size]
        # h_prev: [batch_size, hidden_size]
        h_t = torch.tanh(x_t @ self.W_x + h_prev @ self.W_h + self.b_h)
        return h_t

# ------- Higher-level PyTorch RNN that unrolls over time -------
class RNNPyTorch(torch.nn.Module):
    """
    Unrolls the RNNCell over a full sequence.
    Also includes an output projection from hidden_size -> input_size
    so we can do an MSE loss vs. the original input.
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn_cell = RNNCellPyTorch(input_size, hidden_size)
        # Output projection to match the original input dimension for copy task
        self.W_out = torch.nn.Parameter(torch.randn(hidden_size, input_size) * 0.1)
        self.b_out = torch.nn.Parameter(torch.zeros(input_size))

    def forward(self, X):
        # X: [batch_size, seq_length, input_size]
        batch_size, seq_length, _ = X.shape
        h = torch.zeros(batch_size, self.hidden_size, device=X.device)
        outputs = []
        for t in range(seq_length):
            x_t = X[:, t, :]  # [batch_size, input_size]
            h = self.rnn_cell(x_t, h)  # [batch_size, hidden_size]
            # Project hidden -> input_size
            out_t = h @ self.W_out + self.b_out
            outputs.append(out_t.unsqueeze(1))  # shape [batch_size,1,input_size]
        # Concatenate across time
        return torch.cat(outputs, dim=1)  # [batch_size, seq_length, input_size]

########################################
# TensorFlow Implementation
########################################

# ------- Single-Step RNN Cell -------
class RNNCellTF(tf.keras.layers.Layer):
    """
    A single-step RNN cell in TensorFlow.
    h_t = tanh( x_t * W_x + h_{t-1} * W_h + b )
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.W_x = self.add_weight(
            shape=(input_size, hidden_size), initializer="random_normal", trainable=True
        )
        self.W_h = self.add_weight(
            shape=(hidden_size, hidden_size), initializer="random_normal", trainable=True
        )
        self.b_h = self.add_weight(
            shape=(hidden_size,), initializer="zeros", trainable=True
        )

    def call(self, x_t, h_prev):
        h_t = tf.math.tanh(
            tf.matmul(x_t, self.W_x) + tf.matmul(h_prev, self.W_h) + self.b_h
        )
        return h_t

# ------- Higher-level TF RNN that unrolls over time -------
class RNNTF(tf.keras.layers.Layer):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.rnn_cell = RNNCellTF(input_size, hidden_size)
        # Output projection
        self.W_out = self.add_weight(
            shape=(hidden_size, input_size), initializer="random_normal", trainable=True
        )
        self.b_out = self.add_weight(
            shape=(input_size,), initializer="zeros", trainable=True
        )

    def call(self, X):
        # X: [batch_size, seq_length, input_size]
        batch_size = tf.shape(X)[0]
        seq_length = tf.shape(X)[1]
        h = tf.zeros((batch_size, self.hidden_size), dtype=X.dtype)
        outputs = []
        for t in range(seq_length):
            x_t = X[:, t, :]
            h = self.rnn_cell(x_t, h)
            out_t = tf.matmul(h, self.W_out) + self.b_out
            outputs.append(tf.expand_dims(out_t, axis=1))
        return tf.concat(outputs, axis=1)  # [batch_size, seq_length, input_size]



########################################
# Training / Benchmark
########################################

# -------------- PyTorch Benchmark --------------
def benchmark_pytorch(input_size, hidden_size, X_train, Y_train, epochs=10, lr=0.01):
    model = RNNPyTorch(input_size, hidden_size)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()
    start_time = time.time()

    X_torch = torch.tensor(X_train, dtype=torch.float32)
    Y_torch = torch.tensor(Y_train, dtype=torch.float32)

    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(X_torch)  # [batch_size, seq_length, input_size]
        loss = criterion(output, Y_torch)
        loss.backward()
        optimizer.step()
        #print(f"Epoch {epoch} | Loss torch: {loss.item():.6f}")

    return time.time() - start_time

# -------------- TensorFlow Benchmark --------------
def benchmark_tensorflow(input_size, hidden_size, X_train, Y_train, epochs=10, lr=0.01):
    model = RNNTF(input_size, hidden_size)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    loss_fn = tf.keras.losses.MeanSquaredError()

    X_tf = tf.convert_to_tensor(X_train, dtype=tf.float32)
    Y_tf = tf.convert_to_tensor(Y_train, dtype=tf.float32)

    start_time = time.time()
    for epoch in range(epochs):
        with tf.GradientTape() as tape:
            output = model(X_tf)
            loss = loss_fn(output, Y_tf)
        grads = tape.gradient(loss, model.trainable_variables)
        #print(f"Epoch {epoch} | Loss TF: {loss.numpy():.6f}")
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

    return time.time() - start_time


########################################
# JAX Implementation -- Faster
########################################

def RNNCellJAX(params, x_t, h_prev):
    W_x, W_h, b_h = params
    h_t = jnp.tanh(jnp.dot(x_t, W_x) + jnp.dot(h_prev, W_h) + b_h)
    return h_t

def RNNJAX_unroll(params, X):
    W_x, W_h, b_h, W_out, b_out = params
    batch_size, seq_length, _ = X.shape
    X_t = jnp.swapaxes(X, 0, 1)  # [seq_length, batch_size, input_size]

    def step_fn(h_prev, x_t):
        h_t = jnp.tanh(
            jnp.dot(x_t, W_x)
            + jnp.dot(h_prev, W_h)
            + b_h
        )
        out_t = jnp.dot(h_t, W_out) + b_out
        return h_t, out_t

    h0 = jnp.zeros((batch_size, W_h.shape[0]))
    final_h, outs = lax.scan(step_fn, h0, X_t)
    # outs: [seq_length, batch_size, input_size]
    outs = jnp.swapaxes(outs, 0, 1)  # [batch_size, seq_length, input_size]
    return outs

########################################
# Training / Benchmark
########################################

def init_jax_params(key, input_size, hidden_size):
    k1, k2, k3, k4, k5 = jax.random.split(key, 5)
    W_x = 0.1 * jax.random.normal(k1, (input_size, hidden_size))
    W_h = 0.1 * jax.random.normal(k2, (hidden_size, hidden_size))
    b_h = jnp.zeros((hidden_size,))
    W_out = 0.1 * jax.random.normal(k3, (hidden_size,  input_size))
    b_out = jnp.zeros((input_size,))
    return (W_x, W_h, b_h, W_out, b_out)

def loss_fn(params, x, y):
    pred = RNNJAX_unroll(params, x)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def train_step(params, x, y, lr):
    grads = jax.grad(loss_fn)(params, x, y)
    new_params = []
    for p, g in zip(params, grads):
        new_params.append(p - lr * g)
    return tuple(new_params)


def benchmark_jax(input_size, hidden_size, X_train, Y_train, epochs=10, lr=0.01):
    key = random.PRNGKey(42)
    params = init_jax_params(key, input_size, hidden_size)

    X_jax = jnp.array(X_train)
    Y_jax = jnp.array(Y_train)

    # warm-up to compile
    _ = train_step(params, X_jax, Y_jax, lr)

    start_time = time.time()
    p = params
    for epoch in range(epochs):
        p = train_step(p, X_jax, Y_jax, lr)
        #current_loss = loss_fn(p, X_jax, Y_jax)
        #print(f"Epoch {epoch} | Loss jax: {current_loss:.6f}")

    total_time = time.time() - start_time

    return total_time

############################
# Main Run
############################
def run_benchmark():
    seq_length = 20
    batch_size = 32
    input_size = 10
    hidden_size = 128
    num_epochs = 10

    np.random.seed(42)
    X_train = np.random.rand(1000, seq_length, input_size).astype(np.float32)
    Y_train = X_train.copy()

    # PyTorch
    pytorch_time = benchmark_pytorch(input_size, hidden_size, X_train, Y_train, num_epochs)

    # TensorFlow
    tensorflow_time = benchmark_tensorflow(input_size, hidden_size, X_train, Y_train, num_epochs)

    # JAX
    jax_time = benchmark_jax(input_size, hidden_size, X_train, Y_train, num_epochs)

    print(f"PyTorch Time: {pytorch_time:.4f} s")
    print(f"TensorFlow Time: {tensorflow_time:.4f} s")
    print(f"JAX Time: {jax_time:.4f} s")






## Things to Learn

########################################
# JAX Implementation -- slow version -- This will work (Learn how to speedup things in JAX by comparing two implementation)
########################################
# def RNNCellJAX(params, x_t, h_prev):
#     """
#     Single-step RNN cell in JAX.
#     params = (W_x, W_h, b_h)
#     h_t = tanh( x_t*W_x + h_prev*W_h + b )
#     """
#     W_x, W_h, b_h = params
#     h_t = jnp.tanh(jnp.dot(x_t, W_x) + jnp.dot(h_prev, W_h) + b_h)
#     return h_t

# def RNNJAX_unroll(params, X):
#     """
#     Unroll RNNCellJAX across time.
#     params_main = (W_x, W_h, b_h, W_out, b_out)
#     X: [batch_size, seq_length, input_size]
#     We'll swap so we scan over seq_length dimension.
#     """
#     W_x, W_h, b_h, W_out, b_out = params
#     batch_size, seq_length, _ = X.shape

#     # Swap to [seq_length, batch_size, input_size]
#     X_t = jnp.swapaxes(X, 0, 1)

#     def step_fn(h_prev, x_t):
#         h_t = jnp.tanh(jnp.dot(x_t, W_x) + jnp.dot(h_prev, W_h) + b_h)
#         # Output projection back to input_size
#         out_t = jnp.dot(h_t, W_out) + b_out
#         return h_t, out_t

#     h0 = jnp.zeros((batch_size, W_h.shape[0]))
#     final_h, outs = lax.scan(step_fn, h0, X_t)
#     # outs: [seq_length, batch_size, input_size]
#     # we want [batch_size, seq_length, input_size], so swap axes
#     outs = jnp.swapaxes(outs, 0, 1)
#     return outs


# -------------- JAX Benchmark -- slow version--------------
# def init_jax_params(key, input_size, hidden_size):
#     # W_x: [input_size, hidden_size]
#     # W_h: [hidden_size, hidden_size]
#     # b_h: [hidden_size]
#     # W_out: [hidden_size, input_size]
#     # b_out: [input_size]

#     k1, k2, k3, k4, k5 = jax.random.split(key, 5)
#     W_x = 0.1 * jax.random.normal(k1, (input_size, hidden_size))
#     W_h = 0.1 * jax.random.normal(k2, (hidden_size, hidden_size))
#     b_h = jnp.zeros((hidden_size,))
#     W_out = 0.1 * jax.random.normal(k3, (hidden_size,  input_size))
#     b_out = jnp.zeros((input_size,))
#     return (W_x, W_h, b_h, W_out, b_out)

# def benchmark_jax(input_size, hidden_size, X_train, Y_train, epochs=10, lr=0.01):
#     key = random.PRNGKey(42)
#     params = init_jax_params(key, input_size, hidden_size)

#     def loss_fn(p, x, y):
#         pred = RNNJAX_unroll(p, x)
#         return jnp.mean((pred - y) ** 2)

#     grad_fn = jax.grad(loss_fn)
#     X_jax = jnp.array(X_train)
#     Y_jax = jnp.array(Y_train)

#     start_time = time.time()
#     p = params
#     for epoch in range(epochs):
#         grads = grad_fn(p, X_jax, Y_jax)
#         p = [param - lr*g for param, g in zip(p, grads)]

#     return time.time() - start_time


In [30]:
run_benchmark()

PyTorch Time: 0.2036 s
TensorFlow Time: 1.1847 s
JAX Time: 0.0039 s
