# Recurrent Neural Networks

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
DEVICE

'mps'

## 1. Sequential Data and RNN

### 1.1 Motivation

Sequential data, such as text, audio, or time series, has an inherent **temporal order**, where each element depends on previous ones. Traditional models like **MLPs** and **CNNs** process inputs independently or in fixed-size local windows, so they struggle to capture these **long-term dependencies** and **contextual relationships**. To address this, **Recurrent Neural Networks (RNNs)** were introduced - models specifically designed to handle sequential data by maintaining a **hidden state** that carries information from past time steps, allowing them to learn temporal patterns and dependencies over time.


### 1.2 RNN Cell mechanism

The RNN model works by storing a hidden state $\mathbf{h}_t \in \mathbb{R}^d$. Its value depends on the hidden state at the previous time step and the input data. That is,
$$
\mathbf{h}_t = f(\mathbf{h}_{t-1}, \mathbf{x}_t).
$$
The output of the model is then determined by the hidden state
$$
\mathbf{y}_t = g(\mathbf{h}_t).
$$
We thus have three matrices
$$
\begin{aligned}
	\mathbf{W}_{hh}:  & \text{ Update the hidden state using the previous hidden state.}\\
	\mathbf{W}_{hx}:  & \text{ Update the hidden state using the input.}\\
	\mathbf{W}_{yh}:  & \text{ Calculate output using the hidden state.} \\
\end{aligned}
$$
Then, depending on the choice of the activation function, we can write (adding the biases)
$$
\begin{aligned}
	\mathbf{h}_t &= \tanh(
		\mathbf{W}_{hh} \mathbf{h}_{t-1} +
		\mathbf{W}_{hx} \mathbf{x}_t +
		\mathbf{b}_h
	), \\
	\mathbf{y}_t &= \sigma(
		\mathbf{W}_{yh} \mathbf{h}_t +
		\mathbf{b}_y
	).
\end{aligned}
$$
However, in the implementation, the linear layer already have the bias in them, so we don't need to explicitly write it out.

### 1.3 Implementation

We separate the implementation of the cell and the model. Given an input and the hidden state, the cell will update that hidden state.

In [22]:
class RNNCell(nn.Module):
	"""
	RNN Cell update rule.
	"""
	def __init__(self, input_dim, hidden_dim):
		"""
		Inputs:
		- input_dim:  dimension of the input
		- hidden_dim: dimention of the hidden state
		"""
		super().__init__()
		self.d_input = input_dim
		self.d_hidden = hidden_dim

		# Initialize the matrices
		self.Whh = nn.Linear(hidden_dim, hidden_dim)
		self.Whx = nn.Linear(input_dim,  hidden_dim)

	def forward(self, x_t, h_prev):
		"""
		Input:
		- x_t: input at the current time step        (batch_size, input_dim)
		- h_prev: hidden state at previous time step (batch_size, hidden_dim)
	
		Returns the updated hidden state (batch_size, hidden_dim).
		"""
		return torch.tanh(self.Whh(h_prev) + self.Whx(x_t))


In [None]:
class RNN(nn.Module):
	"""
	RNN model
	"""
	def __init__(self, input_dim, output_dim, hidden_dim):
		"""
		Inputs:
		- input_dim:  dimension of the input
		- output_dim: dimension of the output
		- hidden_dim: dimention of the hidden state
		"""
		super().__init__()
		self.d_input = input_dim
		self.d_output = output_dim
		self.d_hidden = hidden_dim

		self.RNNCell = RNNCell(input_dim, hidden_dim)
		self.Wyh = nn.Linear(hidden_dim, output_dim)

	def forward(self, x, h_0=None):
		"""
		Inputs:
		- x: input (batch_size, T, input_dim)
		- h_0 (optional): provide the current hidden state (batch_size, hidden_dim)

		Output:
		- y_T: output (batch_size, T, output_dim)
		- h_T: final hidden state (batch_size, hidden_dim)
		"""
		# shapes and device
		batch_size, T, _ = x.shape
		device = x.device

		# initializing h_0
		if h_0 is None:
			h_t = torch.zeros(batch_size, self.d_hidden, device=device)
		else:
			h_t = h_0

		y_out = []
		for t in range(T):
			x_t = x[:, t, :]
			# update using RNNCell
			h_t = self.RNNCell(x_t, h_t)
			y_t = self.Wyh(h_t)
			y_out.append(y_t)

		return torch.stack(y_out, dim=1), h_t

We can do a quick test to check that the shapes works out

In [32]:
# Quick test

# Hyperparameters
batch_size = 4
seq_len = 10
input_dim = 8
hidden_dim = 16
output_dim = 3

# Create random input on your global device
x = torch.randn(batch_size, seq_len, input_dim, device=DEVICE)

# Initialize and move model to the same device
model = RNN(input_dim, output_dim, hidden_dim).to(DEVICE)

# Forward pass
y_out, h_T = model(x)

# Print results
print(f"Input shape:     {x.shape}, device: {x.device}")
print(f"Output shape:    {y_out.shape}, device: {y_out.device}")
print(f"Hidden state shape: {h_T.shape}, device: {h_T.device}")

Input shape:     torch.Size([4, 10, 8]), device: mps:0
Output shape:    torch.Size([4, 10, 3]), device: mps:0
Hidden state shape: torch.Size([4, 16]), device: mps:0


## 2. Gated recurrent networks

RNNs struggle to “remember” information when dependencies in a sequence are separated by many time steps. During training, the gradients that carry learning signals through time can **vanish or explode**, making it difficult for the network to update weights related to long-term dependencies. As a result, standard RNNs tend to forget earlier context as sequences grow longer. **Gated recurrent networks**, such as **LSTMs** and **GRUs**, address this issue by introducing **gates** that regulate how information is stored, forgotten, and passed forward through time. These gates allow the network to preserve important information over long distances, effectively mitigating the vanishing gradient problem and enabling the model to learn long-term temporal patterns.

### 2.1 Long-Short Term Memory (LSTM)

#### 2.1.1 LSTM Cell mechanism

Similar to a RNN Cell, a LSTM Cell also has a hidden state $\mathbf{h}_t \in \mathbb{R}^d$. In addition to that, it also has a cell state $\mathbf{C}_t \in \mathbb{R}^d$, which acts as a kind of long-term memory that runs through the entire sequence with only minor linear interactions. It allows information to flow across many time steps **without being repeatedly multiplied by weights**, which helps prevent the **vanishing gradient problem** that plagues standard RNNs. There are three gates in a LSTM: two for updating the cell state, and one for updating the hidden state. They are
- **forget gate** ($f_t$): how much to *forget* the previous cell state
- **input gate** ($i_t$): how much a new value *input* to the current cell state
- **output gate** ($o_t$): how much the updated cell state *output* to the hidden state

The update rules are as follows
$$
\begin{aligned}
	f_t &= \sigma(
		\mathbf{W}_{fx} \mathbf{x}_t +
		\mathbf{W}_{fh} \mathbf{h}_{t-1} +
		\mathbf{b}_f
	) \\
	i_t &= \sigma(
		\mathbf{W}_{ix} \mathbf{x}_t +
		\mathbf{W}_{ih} \mathbf{h}_{t-1} +
		\mathbf{b}_i
	) \\
	o_t &= \sigma(
		\mathbf{W}_{ox} \mathbf{x}_t +
		\mathbf{W}_{oh} \mathbf{h}_{t-1} +
		\mathbf{b}_o
	) \\
	\tilde{\mathbf{C}}_t &= \tanh(
		\mathbf{W}_{cx} \mathbf{x}_t +
		\mathbf{W}_{ch} \mathbf{h}_{t-1} +
		\mathbf{b}_c
	) \\
	\mathbf{C}_t &= f_t \odot \mathbf{C}_{t-1} + i_t \odot \tilde{\mathbf{C}}_t \\
	\mathbf{h}_t &= o_t \odot \mathbf{C}_t
\end{aligned}
$$


#### 2.1.2 Implementation

In [38]:
class LSTMCell(nn.Module):
	"""
	LSTM Cell update rule.
	"""
	def __init__(self, input_dim, hidden_dim):
		"""
		Inputs:
		- input_dim:  dimension of the input
		- hidden_dim: dimention of the hidden state
		"""
		super().__init__()
		self.d_input = input_dim
		self.d_hidden = hidden_dim
		
		# Initialize the matrices
		self.Wfx = nn.Linear(input_dim, hidden_dim)
		self.Wfh = nn.Linear(hidden_dim, hidden_dim)
		self.Wix = nn.Linear(input_dim, hidden_dim)
		self.Wih = nn.Linear(hidden_dim, hidden_dim)
		self.Wox = nn.Linear(input_dim, hidden_dim)
		self.Woh = nn.Linear(hidden_dim, hidden_dim)
		self.Wcx = nn.Linear(input_dim, hidden_dim)
		self.Wch = nn.Linear(hidden_dim, hidden_dim)
	
	def forward(self, x_t, c_prev, h_prev):
		"""
		Input:
		- x_t: input at the current time step        (batch_size,  input_dim)
		- c_prev: cell state at previous time step   (batch_size, hidden_dim)
		- h_prev: hidden state at previous time step (batch_size, hidden_dim)

		Returns
		- c_t: updated cell state   (batch_size, hidden_dim)
		- h_t: updated hidden state (batch_size, hidden_dim)
		"""
		# calculate the gates
		f = torch.sigmoid(self.Wfx(x_t) + self.Wfh(h_prev))
		i = torch.sigmoid(self.Wix(x_t) + self.Wih(h_prev))
		o = torch.sigmoid(self.Wox(x_t) + self.Woh(h_prev))

		# new (proposed) value
		c_tilde = torch.tanh(self.Wcx(x_t) + self.Wch(h_prev))

		# updated cell value
		c_t = f * c_prev + i * c_tilde

		# updated hidden state
		h_t = o * torch.tanh(c_t)

		return c_t, h_t

In [39]:
class LSTM(nn.Module):
	"""
	LSTM model
	"""
	def __init__(self, input_dim, output_dim, hidden_dim):
		"""
		Inputs:
		- input_dim:  dimension of the input
		- output_dim: dimension of the output
		- hidden_dim: dimention of the hidden state
		"""
		super().__init__()
		self.d_input = input_dim
		self.d_output = output_dim
		self.d_hidden = hidden_dim

		self.LSTMCell = LSTMCell(input_dim, hidden_dim)
		self.Wyh = nn.Linear(hidden_dim, output_dim)

	def forward(self, x, c_0=None, h_0=None):
		"""
		Inputs:
		- x: input (batch_size, T, input_dim)
		- c_0 (optional): provide the current cell state (batch_size, hidden_dim)
		- h_0 (optional): provide the current hidden state (batch_size, hidden_dim)

		Output:
		- y_T: output (batch_size, T, output_dim)
		- c_T: final cell state (batch_size, hidden_dim)
		- h_T: final hidden state (batch_size, hidden_dim)
		"""
		# shapes and device
		batch_size, T, _ = x.shape
		device = x.device

		# initializing the cell and hidden states
		if c_0 is None:
			c_t = torch.zeros(batch_size, self.d_hidden, device=device)
		else:
			c_t = c_0

		if h_0 is None:
			h_t = torch.zeros(batch_size, self.d_hidden, device=device)
		else:
			h_t = h_0

		# update
		y_out = []
		for t in range(T):
			x_t = x[:, t, :]
			# update using LSTMCell
			c_t, h_t = self.LSTMCell(x_t, c_t, h_t)
			y_t = self.Wyh(h_t)
			y_out.append(y_t)

		return torch.stack(y_out, dim=1), c_t, h_t

In [45]:
# Quick test

# Hyperparameters
batch_size = 4
seq_len = 10
input_dim = 8
hidden_dim = 16
output_dim = 3

# Random input on the correct device
x = torch.randn(batch_size, seq_len, input_dim, device=DEVICE)

# Initialize model and move to device
model = LSTM(input_dim, output_dim, hidden_dim).to(DEVICE)

# Forward pass (without providing initial states)
y_out, c_T, h_T = model(x)

# Print shapes and device info
print(f"Input shape:   {x.shape}, device: {x.device}")
print(f"Output shape:  {y_out.shape}, device: {y_out.device}")
print(f"Cell state:    {c_T.shape},    device: {c_T.device}")
print(f"Hidden state:  {h_T.shape},    device: {h_T.device}")

Input shape:   torch.Size([4, 10, 8]), device: mps:0
Output shape:  torch.Size([4, 10, 3]), device: mps:0
Cell state:    torch.Size([4, 16]),    device: mps:0
Hidden state:  torch.Size([4, 16]),    device: mps:0


### 2.2 Gated Recurrent Unit (GRU)

### 2.2.1 GRU Cell mechanism

[WIP]

#### 2.2.2 Implementation

[WIP]

## 3. Backpropagation Through Time

[WIP]

## 4. Testing

[WIP]