# Recurrent Neural Networks

Recurrent Neural Networks (RNNs) are designed to process sequential data by maintaining hidden states that capture temporal dependencies.

In this tutorial, you will learn:

- 🔄 **RNN Basics** - Understanding recurrence and hidden states
- 🧠 **LSTM** - Long Short-Term Memory for long-term dependencies
- ⚡ **GRU** - Gated Recurrent Unit as efficient alternative
- 📊 **Sequence Processing** - Handling variable-length sequences
- 💡 **Practical Examples** - Sentiment analysis, time series

## Why RNNs?

RNNs excel at:
- 📝 Natural language processing
- 🎵 Speech and audio
- 📈 Time series forecasting
- 🎬 Video analysis

In [None]:
import brainstate as bst
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

## 1. RNN Basics

An RNN processes sequences one element at a time, updating its hidden state:

**h_t = tanh(W_xh @ x_t + W_hh @ h_{t-1} + b)**

### Simple RNN Cell

In [None]:
# RNNCell: Basic recurrent unit
bst.random.seed(42)
rnn_cell = bst.nn.RNNCell(
    input_size=10,
    hidden_size=20
)

print("RNNCell:")
print(rnn_cell)
print(f"\nInput size: {rnn_cell.input_size}")
print(f"Hidden size: {rnn_cell.hidden_size}")

# Initialize hidden state
hidden = jnp.zeros(20)

# Process a single timestep
x_t = bst.random.randn(10)
hidden_new = rnn_cell(x_t, hidden)

print(f"\nInput shape: {x_t.shape}")
print(f"Previous hidden: {hidden.shape}")
print(f"New hidden: {hidden_new.shape}")

### Processing Sequences

Let's process a complete sequence:

In [None]:
class SimpleRNN(bst.graph.Node):
    """Simple RNN for sequence processing."""
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn_cell = bst.nn.RNNCell(input_size, hidden_size)
        self.output_layer = bst.nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size
    
    def __call__(self, sequence):
        """
        Process a sequence.
        
        Args:
            sequence: (seq_length, input_size)
            
        Returns:
            outputs: (seq_length, output_size)
        """
        seq_length = sequence.shape[0]
        
        # Initialize hidden state
        hidden = jnp.zeros(self.hidden_size)
        
        outputs = []
        for t in range(seq_length):
            # Update hidden state
            hidden = self.rnn_cell(sequence[t], hidden)
            
            # Generate output
            output = self.output_layer(hidden)
            outputs.append(output)
        
        return jnp.stack(outputs)

# Create RNN
bst.random.seed(0)
rnn = SimpleRNN(input_size=5, hidden_size=10, output_size=3)

# Test sequence
sequence = bst.random.randn(7, 5)  # 7 timesteps, 5 features
outputs = rnn(sequence)

print(f"Input sequence shape: {sequence.shape}")
print(f"Output sequence shape: {outputs.shape}")
print(f"\nOutputs:\n{outputs}")

### Visualizing RNN Processing

In [None]:
# Generate sine wave sequence
t = jnp.linspace(0, 4 * jnp.pi, 50)
sine_wave = jnp.sin(t)

# Prepare as sequence (add feature dimension)
sequence = sine_wave[:, None]

# Create RNN
bst.random.seed(42)
rnn = SimpleRNN(input_size=1, hidden_size=16, output_size=1)

# Process sequence
outputs = rnn(sequence).flatten()

# Plot
plt.figure(figsize=(12, 4))
plt.plot(t, sine_wave, linewidth=2, label='Input (Sine)', alpha=0.7)
plt.plot(t, outputs, linewidth=2, label='RNN Output', alpha=0.7)
plt.xlabel('Time')
plt.ylabel('Value')
plt.title('RNN Processing Temporal Sequence', fontweight='bold')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

print("RNN learns to process temporal patterns")

## 2. LSTM - Long Short-Term Memory

LSTM addresses the vanishing gradient problem with gating mechanisms:

- **Forget gate**: What to forget from cell state
- **Input gate**: What new information to store
- **Output gate**: What to output from cell state

### LSTM Cell

In [None]:
# LSTMCell: Advanced recurrent unit
bst.random.seed(42)
lstm_cell = bst.nn.LSTMCell(
    input_size=10,
    hidden_size=20
)

print("LSTMCell:")
print(lstm_cell)

# LSTM has two states: hidden (h) and cell (c)
hidden = jnp.zeros(20)
cell = jnp.zeros(20)

# Process one timestep
x_t = bst.random.randn(10)
hidden_new, cell_new = lstm_cell(x_t, (hidden, cell))

print(f"\nInput: {x_t.shape}")
print(f"Hidden state: {hidden_new.shape}")
print(f"Cell state: {cell_new.shape}")
print("\n✅ LSTM maintains both hidden and cell states")
print("✅ Cell state provides long-term memory")

### LSTM Network

In [None]:
class LSTMNet(bst.graph.Node):
    """LSTM network for sequence processing."""
    
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.lstm_cell = bst.nn.LSTMCell(input_size, hidden_size)
        self.output_layer = bst.nn.Linear(hidden_size, output_size)
        self.hidden_size = hidden_size
    
    def __call__(self, sequence):
        seq_length = sequence.shape[0]
        
        # Initialize states
        hidden = jnp.zeros(self.hidden_size)
        cell = jnp.zeros(self.hidden_size)
        
        outputs = []
        for t in range(seq_length):
            # Update LSTM states
            hidden, cell = self.lstm_cell(sequence[t], (hidden, cell))
            
            # Generate output
            output = self.output_layer(hidden)
            outputs.append(output)
        
        return jnp.stack(outputs)

# Create LSTM
bst.random.seed(0)
lstm = LSTMNet(input_size=5, hidden_size=20, output_size=3)

# Test
sequence = bst.random.randn(10, 5)
outputs = lstm(sequence)

print(f"LSTM Network:")
print(lstm)
print(f"\nSequence: {sequence.shape} → Output: {outputs.shape}")

## 3. GRU - Gated Recurrent Unit

GRU simplifies LSTM with fewer gates:

- **Reset gate**: How much past to forget
- **Update gate**: How much to update

### GRU Cell

In [None]:
# GRUCell: Efficient alternative to LSTM
bst.random.seed(42)
gru_cell = bst.nn.GRUCell(
    input_size=10,
    hidden_size=20
)

print("GRUCell:")
print(gru_cell)

# GRU only has hidden state (no separate cell state)
hidden = jnp.zeros(20)
x_t = bst.random.randn(10)
hidden_new = gru_cell(x_t, hidden)

print(f"\nInput: {x_t.shape}")
print(f"Hidden state: {hidden_new.shape}")
print("\n✅ Simpler than LSTM (no cell state)")
print("✅ Faster training, fewer parameters")
print("✅ Often performs similarly to LSTM")

### Comparing RNN, LSTM, and GRU

In [None]:
# Create all three types
bst.random.seed(0)
input_size, hidden_size = 5, 10

rnn_cell = bst.nn.RNNCell(input_size, hidden_size)
lstm_cell = bst.nn.LSTMCell(input_size, hidden_size)
gru_cell = bst.nn.GRUCell(input_size, hidden_size)

# Test on same sequence
sequence = bst.random.randn(20, input_size)

# Process with RNN
h_rnn = jnp.zeros(hidden_size)
rnn_states = []
for x_t in sequence:
    h_rnn = rnn_cell(x_t, h_rnn)
    rnn_states.append(h_rnn)

# Process with LSTM
h_lstm = jnp.zeros(hidden_size)
c_lstm = jnp.zeros(hidden_size)
lstm_states = []
for x_t in sequence:
    h_lstm, c_lstm = lstm_cell(x_t, (h_lstm, c_lstm))
    lstm_states.append(h_lstm)

# Process with GRU
h_gru = jnp.zeros(hidden_size)
gru_states = []
for x_t in sequence:
    h_gru = gru_cell(x_t, h_gru)
    gru_states.append(h_gru)

# Visualize hidden states
rnn_states = jnp.stack(rnn_states)
lstm_states = jnp.stack(lstm_states)
gru_states = jnp.stack(gru_states)

plt.figure(figsize=(15, 4))

plt.subplot(1, 3, 1)
plt.imshow(np.array(rnn_states.T), aspect='auto', cmap='viridis')
plt.colorbar()
plt.title('RNN Hidden States', fontweight='bold')
plt.xlabel('Time')
plt.ylabel('Hidden Unit')

plt.subplot(1, 3, 2)
plt.imshow(np.array(lstm_states.T), aspect='auto', cmap='viridis')
plt.colorbar()
plt.title('LSTM Hidden States', fontweight='bold')
plt.xlabel('Time')
plt.ylabel('Hidden Unit')

plt.subplot(1, 3, 3)
plt.imshow(np.array(gru_states.T), aspect='auto', cmap='viridis')
plt.colorbar()
plt.title('GRU Hidden States', fontweight='bold')
plt.xlabel('Time')
plt.ylabel('Hidden Unit')

plt.tight_layout()
plt.show()

print("Different activation patterns across time")

## 4. Practical Example: Sequence Classification

Let's build a complete sequence classifier:

In [None]:
class SequenceClassifier(bst.graph.Node):
    """Classify sequences using LSTM."""
    
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        
        # LSTM layer
        self.lstm = bst.nn.LSTMCell(input_size, hidden_size)
        
        # Classifier
        self.fc = bst.nn.Linear(hidden_size, num_classes)
        self.hidden_size = hidden_size
    
    def __call__(self, sequence):
        """
        Classify a sequence.
        
        Args:
            sequence: (seq_length, input_size)
            
        Returns:
            logits: (num_classes,)
        """
        # Initialize states
        hidden = jnp.zeros(self.hidden_size)
        cell = jnp.zeros(self.hidden_size)
        
        # Process sequence
        for t in range(sequence.shape[0]):
            hidden, cell = self.lstm(sequence[t], (hidden, cell))
        
        # Use final hidden state for classification
        logits = self.fc(hidden)
        return logits

# Create classifier
bst.random.seed(42)
classifier = SequenceClassifier(
    input_size=8,
    hidden_size=32,
    num_classes=3
)

print("Sequence Classifier:")
print(classifier)

# Test with batch of sequences
sequences = [
    bst.random.randn(10, 8),  # Short sequence
    bst.random.randn(15, 8),  # Medium sequence
    bst.random.randn(20, 8),  # Long sequence
]

print("\nClassifying sequences of different lengths:")
for i, seq in enumerate(sequences):
    logits = classifier(seq)
    pred = jnp.argmax(logits)
    print(f"  Sequence {i+1} (length={seq.shape[0]:2d}): logits={logits}, predicted class={pred}")

## 5. Time Series Prediction

In [None]:
# Generate synthetic time series
def generate_time_series(n_steps=100):
    t = jnp.linspace(0, 10, n_steps)
    # Combination of sine waves
    series = jnp.sin(t) + 0.5 * jnp.sin(3 * t) + 0.1 * bst.random.randn(n_steps)
    return series

# Create sequences for prediction
def create_sequences(data, seq_length=10):
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return jnp.array(X), jnp.array(y)

# Generate data
bst.random.seed(0)
time_series = generate_time_series(200)
X, y = create_sequences(time_series, seq_length=15)

# Add feature dimension
X = X[:, :, None]

print(f"Time series data: {time_series.shape}")
print(f"Sequences (X): {X.shape}")
print(f"Targets (y): {y.shape}")

# Create predictor
class TimeSeriesPredictor(bst.graph.Node):
    def __init__(self, hidden_size=32):
        super().__init__()
        self.gru = bst.nn.GRUCell(input_size=1, hidden_size=hidden_size)
        self.fc = bst.nn.Linear(hidden_size, 1)
        self.hidden_size = hidden_size
    
    def __call__(self, sequence):
        hidden = jnp.zeros(self.hidden_size)
        for t in range(sequence.shape[0]):
            hidden = self.gru(sequence[t], hidden)
        prediction = self.fc(hidden)
        return prediction[0]

# Create and test predictor
bst.random.seed(42)
predictor = TimeSeriesPredictor(hidden_size=64)

# Make predictions on test data
predictions = []
for i in range(min(50, len(X))):
    pred = predictor(X[i])
    predictions.append(pred)

predictions = jnp.array(predictions)
targets = y[:len(predictions)]

# Plot
plt.figure(figsize=(12, 5))
plt.plot(targets, label='True Values', linewidth=2, alpha=0.7)
plt.plot(predictions, label='Predictions (Untrained)', linewidth=2, alpha=0.7)
plt.xlabel('Time Step')
plt.ylabel('Value')
plt.title('Time Series Prediction with GRU', fontweight='bold')
plt.legend()
plt.grid(alpha=0.3)
plt.show()

mse = jnp.mean((predictions - targets) ** 2)
print(f"\nMSE (untrained): {mse:.4f}")
print("💡 With training, GRU can learn to predict future values")

## Summary

In this tutorial, you learned:

✅ **RNN Basics**
  - Recurrence and hidden states
  - Processing sequences step-by-step
  - Building simple RNN networks

✅ **LSTM**
  - Gating mechanisms (forget, input, output)
  - Cell state for long-term memory
  - Handling long-term dependencies

✅ **GRU**
  - Simplified gating (reset, update)
  - Fewer parameters than LSTM
  - Efficient alternative

✅ **Practical Applications**
  - Sequence classification
  - Time series prediction
  - Variable-length sequences

### Quick Comparison

| Model | States | Gates | Best For |
|-------|--------|-------|----------|
| **RNN** | 1 (h) | 0 | Short sequences, simple patterns |
| **LSTM** | 2 (h, c) | 3 | Long sequences, complex dependencies |
| **GRU** | 1 (h) | 2 | Balance of complexity and performance |

### When to Use Each

- 🎯 **Start with GRU** - Good default choice
- 📚 **Use LSTM** - When you need maximum capacity for long-term memory
- ⚡ **Use RNN** - For simple patterns or as baseline

### Best Practices

1. 🔄 **Initialize hidden states to zero**
2. 📊 **Normalize input sequences**
3. 🎯 **Use gradient clipping** to prevent exploding gradients
4. 💾 **Save hidden states** for inference on long sequences
5. 🔍 **Try bidirectional RNNs** for offline sequence processing

### Next Steps

Continue with:
- **Dynamics Systems** - Brain-inspired temporal models
- **Attention Mechanisms** - Beyond RNNs (Transformers)
- **Training** - Optimize RNNs effectively