# Char RNN: Training a recurrent neural network for character - level language modeling

In this notebook, we will train a recurrent neural network to perform character-level language modeling on the classic Shakespeare's corpus.

In [None]:
import sys

sys.path.append("..")
import os
import numpy as np
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
import jax.random as jr
from functools import partial
from tqdm import tqdm
import matplotlib.pyplot as plt

# Using PyTorch's DataLoader for convenience (optional)
import torch
from torch.utils.data import Dataset, DataLoader

from rnn_jax.cells.gated import LongShortTermMemoryCell, LongShortTermMemory
from rnn_jax.layers import RNNEncoder, DeepRNN

print("Imports OK")

## Data Preprocessing

Now, we read and preprocess the corpus, which is stored in a single `.txt` file. 
We index the characters with unique integers, from 0 to $n_{vocab}$

In [None]:
# Read and preprocess the book (Moby Dick)
file_path = "shakespeare.txt"
assert os.path.exists(file_path), f"File not found: {file_path}"
with open(file_path, "r", encoding="utf-8") as f:
    text = f.read()
# normalize whitespace and remove repeated newlines to avoid token explosion
text = text.replace("\r\n", "\n").replace("\n", " ")
# wrap with markers for clarity
text = f"<{text}>"
print("Loaded text length:", len(text))
# Build character-level vocabulary
chars = sorted(list(set(text)))
char_to_index = {c: i for i, c in enumerate(chars)}
index_to_char = {i: c for c, i in char_to_index.items()}
vocab_size = len(chars)
print("Vocab size:", vocab_size)

Feeding the whole text to the neural network will be too slow to be effective in training the network.
For this reason, we split the text in chunks of 128 characters. We use a sliding window with a stride of 32 to mantain some redundancy. 

The `BookDataset` class is an utility dataset, created using torch's `Dataset` class.

In [None]:
# Create sliding-window sequences for char-level modeling
seq_len = 128
step = 32
encoded = np.array([char_to_index[c] for c in text], dtype=np.int32)
inputs = []
targets = []
for i in range(0, len(encoded) - seq_len - 1, step):
    inputs.append(encoded[i : i + seq_len])
    targets.append(encoded[i + 1 : i + seq_len + 1])
inputs = np.array(inputs, dtype=np.int32)
targets = np.array(targets, dtype=np.int32)
print(f"Built {len(inputs):,} sequences")
# downsample for smoke-testing if dataset is too large
max_samples = -1
if max_samples > 0 and len(inputs) > max_samples:
    idxs = np.random.choice(len(inputs), size=max_samples, replace=False)
    inputs = inputs[idxs]
    targets = targets[idxs]
    print(f"Downsampled dataset to {len(inputs)} samples")
print("Dataset shape (inputs, targets):", inputs.shape, targets.shape)
print(("").join([index_to_char[i] for i in inputs[0]]))


# Torch dataset wrapper (simple)
class BookDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.from_numpy(X)
        self.Y = torch.from_numpy(Y)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx].long(), self.Y[idx].long()


batch_size = 128
dataset = BookDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print("Prepared dataloader with batch size", batch_size)

In this example, we can showcase how to easily define new function that take in input cells in the library.
For example, CharRNN is basically an embedding layer, followed by a (deep) RNN cell.

In [None]:
class CharRNN(eqx.Module):
    embed: eqx.nn.Embedding
    encoder: RNNEncoder
    decoder: eqx.nn.Linear

    def __init__(self, vocab_size, emb_dim, hidden_dim, *, key):
        k0, k1, k2, k3 = jr.split(key, 4)
        self.embed = eqx.nn.Embedding(vocab_size, emb_dim, key=k0)
        cell = LongShortTermMemoryCell(emb_dim, hidden_dim, key=k1,)

        self.encoder = RNNEncoder(cell=cell, key=k2)
        self.decoder = eqx.nn.Linear(hidden_dim, vocab_size, key=k3)

    def __call__(self, x):
        # x: (seq_len,) or (seq_len, emb_dim) per-sample
        x_emb = eqx.filter_vmap(self.embed)(x)
        hidden_all = self.encoder(x_emb)
        logits = jax.vmap(self.decoder)(hidden_all)
        return logits


Now we create a very simple model, with a single hidden LSTM layer and 256 neurons. 

In [None]:
# Create model instance (small to allow quick tests)
vocab_size = len(chars)
emb_dim = 256
hidden_dim = 512
key = jr.PRNGKey(0)
model = CharRNN(vocab_size=vocab_size, emb_dim=emb_dim, hidden_dim=hidden_dim, key=key)
print("Model created: vocab", vocab_size, "emb", emb_dim, "hidden", hidden_dim)

In [None]:
optimizer = optax.chain(optax.clip_by_global_norm(1.), optax.rmsprop(1e-3))
# apply a schedule

opt_state = optimizer.init(eqx.filter(model, eqx.is_inexact_array))

@eqx.filter_value_and_grad(has_aux=True)
def forward_and_loss(model, X, Y):
    logits = eqx.filter_vmap(model)(X)  # (batch, seq_len, vocab)
    loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, Y))
    return loss, logits


@eqx.filter_jit
def update(model, opt_state, X, Y):
    (loss, _), grads = forward_and_loss(model, X, Y)
    updates, opt_state = optimizer.update(
        grads, opt_state, params=eqx.filter(model, eqx.is_inexact_array)
    )
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

def train_model(model, opt_state, dataloader, n_epochs=1):
    for epoch in range(n_epochs):
        print(f"Epoch {epoch + 1}/{n_epochs}")
        bar = tqdm(dataloader)
        b_losses = []
        for Xb, Yb in bar:
            Xa = jnp.array(Xb.numpy()) # Convert Torch tensor to JAX array
            Ya = jnp.array(Yb.numpy())
            model, opt_state, loss = update(model, opt_state, Xa, Ya)
            b_losses.append(float(loss))
            bar.set_description(f"batch loss: {loss:.4f}")
    return model, opt_state, b_losses  

In [None]:
small_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
model, opt_state, b_losses = train_model(model, opt_state, small_dataloader, n_epochs=1)

In [None]:
fig, ax = plt.subplots()
ax.plot(b_losses, marker="o", alpha=0.7)
ax.set_xlabel("Batch")
ax.set_ylabel("Loss")
ax.set_title("Training Loss over Epochs")
plt.show()


In [None]:
def generate_text(model, seed, max_len=100, temperature=1.0):
    """Generate text given a seed string."""
    generated = seed
    idxs = [char_to_index[c] for c in seed]
    state = jnp.zeros((model.encoder.cell.hdim,)), jnp.zeros((model.encoder.cell.hdim,))
    # Warm-up with the seed
    print("---------")
    for c in seed[:-1]:
        x = char_to_index[c]


        x_emb = model.embed(x)
        state, h = model.encoder.cell(x_emb[0], state)
    # Generate new characters
    for _ in range(max_len):
        x = idxs[-1]
        x_emb = model.embed(x)
        state, h = model.encoder.cell(x_emb, state)
        logits = model.decoder(h)
        probs = jax.nn.softmax(logits / temperature)
        next_idx = jax.random.choice(
            jr.PRNGKey(np.random.randint(0, 1_000_000)), vocab_size, p=probs
        )
        idxs.append(int(next_idx))
        generated += index_to_char[int(next_idx)]
    return generated

In [None]:
for temp in [0.1, 0.75, 1.0, 1.5]:
    print(f"Generating text with temperature = {temp}")
    text = generate_text(model, seed=" ", max_len=1000, temperature=temp)
    print(text)
    print("=" * 30)

For low temperature, the model seems to imitate well the original Shakespeare's corpus style. Increasing the temperature too much leads to increasingly worse, non-sense samples. 