# Tutorial 19: Sequence Modeling with RNNs and Transformers

In this tutorial, we'll build sequence models for tasks like text generation and time series prediction using RNNs, LSTMs, and attention mechanisms.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Build RNN and LSTM models for sequence tasks
- Implement attention mechanisms
- Create a simple Transformer architecture
- Train models for text generation
- Handle variable-length sequences
- Implement teacher forcing and sampling strategies
- Visualize attention weights

## What We'll Build

We'll create:
- Character-level language model with RNN/LSTM
- Attention-based sequence-to-sequence model
- Simple Transformer for sequence modeling
- Text generation pipeline

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple, List, Dict, Optional
import string

# Set random seed
bst.random.seed(42)

print(f"JAX devices: {jax.devices()}")

## 1. Dataset Preparation: Character-Level Text

In [None]:
# Sample text data
sample_text = """
The quick brown fox jumps over the lazy dog.
Neural networks are powerful tools for machine learning.
Deep learning has revolutionized artificial intelligence.
Transformers have become the dominant architecture for NLP.
Attention is all you need for sequence modeling.
Recurrent neural networks process sequences step by step.
LSTMs solve the vanishing gradient problem in RNNs.
BrainState makes it easy to build neural networks with JAX.
"""

class CharacterTokenizer:
    """Simple character-level tokenizer."""
    
    def __init__(self, text: str):
        # Get unique characters
        self.chars = sorted(list(set(text)))
        self.vocab_size = len(self.chars)
        
        # Create mappings
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
    
    def encode(self, text: str) -> np.ndarray:
        """Convert text to indices."""
        return np.array([self.char_to_idx[ch] for ch in text])
    
    def decode(self, indices: np.ndarray) -> str:
        """Convert indices to text."""
        return ''.join([self.idx_to_char[int(idx)] for idx in indices])

# Create tokenizer
tokenizer = CharacterTokenizer(sample_text)
encoded_text = tokenizer.encode(sample_text)

print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Characters: {''.join(tokenizer.chars[:20])}...")
print(f"\nOriginal text length: {len(sample_text)}")
print(f"Encoded length: {len(encoded_text)}")
print(f"\nSample encoding: {encoded_text[:20]}")
print(f"Decoded: {tokenizer.decode(encoded_text[:20])}")

### Create Training Sequences

In [None]:
def create_sequences(data, seq_length):
    """Create input-target pairs for sequence prediction.
    
    Args:
        data: Encoded text data
        seq_length: Length of input sequences
        
    Returns:
        inputs: Array of shape (n_sequences, seq_length)
        targets: Array of shape (n_sequences, seq_length)
    """
    inputs = []
    targets = []
    
    for i in range(len(data) - seq_length):
        inputs.append(data[i:i+seq_length])
        targets.append(data[i+1:i+seq_length+1])
    
    return np.array(inputs), np.array(targets)

# Create sequences
seq_length = 20
X, y = create_sequences(encoded_text, seq_length)

# Split into train/val
split_idx = int(0.8 * len(X))
X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]

print(f"Training sequences: {X_train.shape}")
print(f"Validation sequences: {X_val.shape}")
print(f"\nExample sequence:")
print(f"Input:  {tokenizer.decode(X_train[0])}")
print(f"Target: {tokenizer.decode(y_train[0])}")

## 2. RNN Language Model

In [None]:
class RNNLanguageModel(bst.graph.Node):
    """Simple RNN for character-level language modeling."""
    
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=128):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        
        # Embedding layer
        self.embedding = bst.ParamState(
            bst.random.randn(vocab_size, embedding_dim) * 0.1
        )
        
        # RNN cell
        self.rnn_cell = bst.nn.RNNCell(embedding_dim, hidden_dim)
        
        # Output projection
        self.fc_out = bst.nn.Linear(hidden_dim, vocab_size)
    
    def __call__(self, x, hidden=None):
        """Forward pass.
        
        Args:
            x: Input indices of shape (batch, seq_len)
            hidden: Initial hidden state
            
        Returns:
            logits: Output logits of shape (batch, seq_len, vocab_size)
            hidden: Final hidden state
        """
        batch_size, seq_len = x.shape
        
        # Initialize hidden state if not provided
        if hidden is None:
            hidden = jnp.zeros((batch_size, self.hidden_dim))
        
        # Embed input
        embedded = self.embedding.value[x]  # (batch, seq_len, embedding_dim)
        
        # Process sequence
        outputs = []
        for t in range(seq_len):
            hidden = self.rnn_cell(embedded[:, t, :], hidden)
            outputs.append(hidden)
        
        # Stack outputs
        outputs = jnp.stack(outputs, axis=1)  # (batch, seq_len, hidden_dim)
        
        # Project to vocabulary
        logits = jax.vmap(self.fc_out)(outputs)  # (batch, seq_len, vocab_size)
        
        return logits, hidden

# Create model
rnn_model = RNNLanguageModel(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=64,
    hidden_dim=128
)

# Test forward pass
test_input = jnp.array(X_train[:4])
test_logits, test_hidden = rnn_model(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Logits shape: {test_logits.shape}")
print(f"Hidden shape: {test_hidden.shape}")

# Count parameters
n_params = sum(p.value.size for p in rnn_model.states(bst.ParamState).values())
print(f"\nTotal parameters: {n_params:,}")

## 3. LSTM Language Model

In [None]:
class LSTMLanguageModel(bst.graph.Node):
    """LSTM for character-level language modeling."""
    
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=128):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        
        # Embedding layer
        self.embedding = bst.ParamState(
            bst.random.randn(vocab_size, embedding_dim) * 0.1
        )
        
        # LSTM cell
        self.lstm_cell = bst.nn.LSTMCell(embedding_dim, hidden_dim)
        
        # Output projection
        self.fc_out = bst.nn.Linear(hidden_dim, vocab_size)
    
    def __call__(self, x, state=None):
        """Forward pass.
        
        Args:
            x: Input indices of shape (batch, seq_len)
            state: Initial (hidden, cell) state tuple
            
        Returns:
            logits: Output logits of shape (batch, seq_len, vocab_size)
            state: Final (hidden, cell) state
        """
        batch_size, seq_len = x.shape
        
        # Initialize state if not provided
        if state is None:
            hidden = jnp.zeros((batch_size, self.hidden_dim))
            cell = jnp.zeros((batch_size, self.hidden_dim))
            state = (hidden, cell)
        
        hidden, cell = state
        
        # Embed input
        embedded = self.embedding.value[x]  # (batch, seq_len, embedding_dim)
        
        # Process sequence
        outputs = []
        for t in range(seq_len):
            hidden, cell = self.lstm_cell(embedded[:, t, :], (hidden, cell))
            outputs.append(hidden)
        
        # Stack outputs
        outputs = jnp.stack(outputs, axis=1)  # (batch, seq_len, hidden_dim)
        
        # Project to vocabulary
        logits = jax.vmap(self.fc_out)(outputs)  # (batch, seq_len, vocab_size)
        
        return logits, (hidden, cell)

# Create LSTM model
lstm_model = LSTMLanguageModel(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=64,
    hidden_dim=128
)

# Test forward pass
test_logits, test_state = lstm_model(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Logits shape: {test_logits.shape}")
print(f"Hidden state shape: {test_state[0].shape}")
print(f"Cell state shape: {test_state[1].shape}")

n_params = sum(p.value.size for p in lstm_model.states(bst.ParamState).values())
print(f"\nTotal parameters: {n_params:,}")

## 4. Training Setup

In [None]:
def cross_entropy_loss(logits, targets):
    """Compute cross-entropy loss for sequences.
    
    Args:
        logits: Predicted logits of shape (batch, seq_len, vocab_size)
        targets: True indices of shape (batch, seq_len)
        
    Returns:
        Scalar loss value
    """
    batch_size, seq_len, vocab_size = logits.shape
    
    # Reshape for easier computation
    logits_flat = logits.reshape(-1, vocab_size)
    targets_flat = targets.reshape(-1)
    
    # One-hot encode targets
    one_hot = jax.nn.one_hot(targets_flat, vocab_size)
    
    # Compute log probabilities
    log_probs = jax.nn.log_softmax(logits_flat, axis=-1)
    
    # Compute loss
    loss = -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
    
    return loss

def perplexity(loss):
    """Compute perplexity from loss."""
    return jnp.exp(loss)

def train_step(model, x_batch, y_batch, learning_rate=0.001):
    """Perform one training step."""
    with bst.environ.context(fit=True):
        def loss_fn():
            logits, _ = model(jnp.array(x_batch))
            return cross_entropy_loss(logits, jnp.array(y_batch))
        
        # Compute gradients
        loss, grads = bst.augment.grad(
            loss_fn,
            model.states(bst.ParamState),
            return_value=True
        )()
        
        # Update parameters
        for name, grad in grads.items():
            model.states()[name].value -= learning_rate * grad
        
        return float(loss)

def eval_step(model, x_batch, y_batch):
    """Perform one evaluation step."""
    with bst.environ.context(fit=False):
        logits, _ = model(jnp.array(x_batch))
        loss = cross_entropy_loss(logits, jnp.array(y_batch))
        return float(loss)

# Test training step
batch_size = 32
x_batch = X_train[:batch_size]
y_batch = y_train[:batch_size]

initial_loss = eval_step(lstm_model, x_batch, y_batch)
print(f"Initial loss: {initial_loss:.4f}")
print(f"Initial perplexity: {perplexity(initial_loss):.2f}")

## 5. Train LSTM Model

In [None]:
def train_epoch(model, X, y, batch_size, learning_rate):
    """Train for one epoch."""
    losses = []
    n_samples = len(X)
    
    # Shuffle data
    indices = np.random.permutation(n_samples)
    
    for start_idx in range(0, n_samples, batch_size):
        end_idx = min(start_idx + batch_size, n_samples)
        batch_indices = indices[start_idx:end_idx]
        
        x_batch = X[batch_indices]
        y_batch = y[batch_indices]
        
        loss = train_step(model, x_batch, y_batch, learning_rate)
        losses.append(loss)
    
    return np.mean(losses)

def evaluate(model, X, y, batch_size):
    """Evaluate model."""
    losses = []
    n_samples = len(X)
    
    for start_idx in range(0, n_samples, batch_size):
        end_idx = min(start_idx + batch_size, n_samples)
        x_batch = X[start_idx:end_idx]
        y_batch = y[start_idx:end_idx]
        
        loss = eval_step(model, x_batch, y_batch)
        losses.append(loss)
    
    return np.mean(losses)

# Training configuration
config = {
    'num_epochs': 50,
    'batch_size': 64,
    'learning_rate': 0.003,
}

# Create fresh model
model = LSTMLanguageModel(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=64,
    hidden_dim=128
)

# Training loop
history = {'train_loss': [], 'val_loss': []}

print("Training LSTM Language Model")
print("=" * 60)

for epoch in range(config['num_epochs']):
    # Train
    train_loss = train_epoch(
        model, X_train, y_train,
        config['batch_size'],
        config['learning_rate']
    )
    
    # Validate
    val_loss = evaluate(model, X_val, y_val, config['batch_size'])
    
    # Record history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    # Print progress
    if epoch % 10 == 0 or epoch == config['num_epochs'] - 1:
        print(f"Epoch {epoch:2d}: "
              f"train_loss={train_loss:.4f} (ppl={perplexity(train_loss):.2f}), "
              f"val_loss={val_loss:.4f} (ppl={perplexity(val_loss):.2f})")

print("=" * 60)
print("Training completed!")

### Visualize Training

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

epochs = range(len(history['train_loss']))

# Loss
ax1.plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
ax1.plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Perplexity
train_ppl = [perplexity(loss) for loss in history['train_loss']]
val_ppl = [perplexity(loss) for loss in history['val_loss']]

ax2.plot(epochs, train_ppl, 'b-', label='Train Perplexity', linewidth=2)
ax2.plot(epochs, val_ppl, 'r-', label='Val Perplexity', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Perplexity')
ax2.set_title('Training and Validation Perplexity')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Text Generation

In [None]:
def generate_text(model, start_text, length=100, temperature=1.0):
    """Generate text using the trained model.
    
    Args:
        model: Trained language model
        start_text: Starting string
        length: Number of characters to generate
        temperature: Sampling temperature (higher = more random)
        
    Returns:
        Generated text string
    """
    # Encode start text
    current_seq = tokenizer.encode(start_text)
    generated = list(current_seq)
    
    with bst.environ.context(fit=False):
        state = None
        
        for _ in range(length):
            # Prepare input (last seq_length characters)
            input_seq = current_seq[-seq_length:]
            input_seq = jnp.array(input_seq).reshape(1, -1)
            
            # Pad if necessary
            if input_seq.shape[1] < seq_length:
                pad_length = seq_length - input_seq.shape[1]
                padding = jnp.zeros((1, pad_length), dtype=jnp.int32)
                input_seq = jnp.concatenate([padding, input_seq], axis=1)
            
            # Get predictions
            logits, state = model(input_seq, state)
            
            # Get last time step
            last_logits = logits[0, -1, :] / temperature
            
            # Sample from distribution
            probs = jax.nn.softmax(last_logits)
            next_idx = np.random.choice(len(probs), p=np.array(probs))
            
            # Append to sequence
            generated.append(next_idx)
            current_seq = np.append(current_seq, next_idx)
    
    return tokenizer.decode(np.array(generated))

# Generate samples with different temperatures
start_text = "The "
temperatures = [0.5, 1.0, 1.5]

print("Generated Text Samples:")
print("=" * 70)

for temp in temperatures:
    generated = generate_text(model, start_text, length=150, temperature=temp)
    print(f"\nTemperature = {temp}:")
    print(generated)
    print("-" * 70)

## 7. Attention Mechanism

In [None]:
class ScaledDotProductAttention(bst.graph.Node):
    """Scaled dot-product attention mechanism."""
    
    def __init__(self, dim):
        super().__init__()
        self.scale = jnp.sqrt(dim)
    
    def __call__(self, query, key, value, mask=None):
        """Compute attention.
        
        Args:
            query: Query tensor of shape (batch, seq_len_q, dim)
            key: Key tensor of shape (batch, seq_len_k, dim)
            value: Value tensor of shape (batch, seq_len_v, dim)
            mask: Optional mask
            
        Returns:
            output: Attended values
            attention_weights: Attention weights
        """
        # Compute attention scores
        scores = jnp.matmul(query, key.transpose(0, 2, 1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            scores = jnp.where(mask, scores, -1e9)
        
        # Compute attention weights
        attention_weights = jax.nn.softmax(scores, axis=-1)
        
        # Apply attention to values
        output = jnp.matmul(attention_weights, value)
        
        return output, attention_weights

class AttentionLSTM(bst.graph.Node):
    """LSTM with self-attention."""
    
    def __init__(self, vocab_size, embedding_dim=64, hidden_dim=128):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        
        # Embedding
        self.embedding = bst.ParamState(
            bst.random.randn(vocab_size, embedding_dim) * 0.1
        )
        
        # LSTM
        self.lstm_cell = bst.nn.LSTMCell(embedding_dim, hidden_dim)
        
        # Attention
        self.attention = ScaledDotProductAttention(hidden_dim)
        
        # Output
        self.fc_out = bst.nn.Linear(hidden_dim * 2, vocab_size)
    
    def __call__(self, x):
        """Forward pass with attention."""
        batch_size, seq_len = x.shape
        
        # Embed
        embedded = self.embedding.value[x]
        
        # LSTM processing
        hidden = jnp.zeros((batch_size, self.hidden_dim))
        cell = jnp.zeros((batch_size, self.hidden_dim))
        
        lstm_outputs = []
        for t in range(seq_len):
            hidden, cell = self.lstm_cell(embedded[:, t, :], (hidden, cell))
            lstm_outputs.append(hidden)
        
        lstm_outputs = jnp.stack(lstm_outputs, axis=1)
        
        # Self-attention
        attended, attn_weights = self.attention(
            lstm_outputs, lstm_outputs, lstm_outputs
        )
        
        # Concatenate LSTM output and attended output
        combined = jnp.concatenate([lstm_outputs, attended], axis=-1)
        
        # Output projection
        logits = jax.vmap(self.fc_out)(combined)
        
        return logits, attn_weights

# Create attention model
attn_model = AttentionLSTM(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=64,
    hidden_dim=128
)

# Test
test_logits, test_attn = attn_model(test_input)
print(f"Logits shape: {test_logits.shape}")
print(f"Attention weights shape: {test_attn.shape}")

n_params = sum(p.value.size for p in attn_model.states(bst.ParamState).values())
print(f"\nTotal parameters: {n_params:,}")

### Visualize Attention

In [None]:
# Get attention weights for a sample
sample_input = jnp.array(X_train[0:1])  # Single sequence

with bst.environ.context(fit=False):
    _, attn_weights = attn_model(sample_input)

# Plot attention heatmap
plt.figure(figsize=(10, 8))
plt.imshow(attn_weights[0], cmap='viridis', aspect='auto')
plt.colorbar(label='Attention Weight')
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Self-Attention Weights')

# Add text labels
input_text = tokenizer.decode(X_train[0])
plt.text(0.5, -0.1, f'Input: "{input_text}"', 
         transform=plt.gca().transAxes, ha='center')

plt.tight_layout()
plt.show()

## 8. Simple Transformer Block

In [None]:
class TransformerBlock(bst.graph.Node):
    """Simple Transformer block with self-attention and feed-forward."""
    
    def __init__(self, dim, num_heads=4, ff_dim=256):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        
        # Multi-head attention (simplified)
        self.attention = ScaledDotProductAttention(dim)
        
        # Layer normalization
        self.ln1 = bst.nn.LayerNorm([dim])
        self.ln2 = bst.nn.LayerNorm([dim])
        
        # Feed-forward network
        self.ff1 = bst.nn.Linear(dim, ff_dim)
        self.ff2 = bst.nn.Linear(ff_dim, dim)
    
    def __call__(self, x, mask=None):
        """Forward pass.
        
        Args:
            x: Input of shape (batch, seq_len, dim)
            mask: Optional attention mask
            
        Returns:
            Output of same shape as input
        """
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x, mask)
        x = self.ln1(x + attn_out)
        
        # Feed-forward with residual
        ff_out = self.ff2(jax.nn.relu(self.ff1(x)))
        x = self.ln2(x + ff_out)
        
        return x

class SimpleTransformer(bst.graph.Node):
    """Simple Transformer for sequence modeling."""
    
    def __init__(self, vocab_size, dim=128, num_layers=2, num_heads=4):
        super().__init__()
        self.vocab_size = vocab_size
        self.dim = dim
        
        # Embedding
        self.embedding = bst.ParamState(
            bst.random.randn(vocab_size, dim) * 0.1
        )
        
        # Transformer blocks
        self.blocks = []
        for i in range(num_layers):
            block = TransformerBlock(dim, num_heads, ff_dim=dim*4)
            self.blocks.append(block)
            setattr(self, f'block_{i}', block)
        
        # Output
        self.fc_out = bst.nn.Linear(dim, vocab_size)
    
    def __call__(self, x):
        """Forward pass."""
        # Embed
        x = self.embedding.value[x]
        
        # Transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Output projection
        logits = jax.vmap(self.fc_out)(x)
        
        return logits

# Create Transformer
transformer = SimpleTransformer(
    vocab_size=tokenizer.vocab_size,
    dim=128,
    num_layers=2,
    num_heads=4
)

# Test
test_logits = transformer(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_logits.shape}")

n_params = sum(p.value.size for p in transformer.states(bst.ParamState).values())
print(f"\nTotal parameters: {n_params:,}")

## Summary

In this tutorial, we built sequence models:

1. **Character-Level Tokenization**: Simple text encoding
2. **RNN Language Model**: Basic recurrent architecture
3. **LSTM Language Model**: Improved with LSTM cells
4. **Training Pipeline**: Complete training loop for sequences
5. **Text Generation**: Sampling strategies with temperature
6. **Attention Mechanism**: Scaled dot-product attention
7. **Attention Visualization**: Heatmaps of attention weights
8. **Transformer Block**: Self-attention with feed-forward

## Key Takeaways

- **RNNs process sequences step-by-step**
- **LSTMs handle long-term dependencies better**
- **Attention allows focusing on relevant parts**
- **Transformers use self-attention exclusively**
- **Temperature controls generation randomness**
- **Perplexity measures language model quality**
- Sequence modeling requires careful state management

## Next Steps

In the next tutorial, we'll explore:
- **Brain-Inspired Computing**: Spiking neural networks
- Neurodynamics and biological models
- Plasticity and learning rules
- Event-driven computation