# Tutorial: Building a GRU Language Model with PyTorch
 
**Target Audience:** MSc Computer Science Students  
**Topic:** Natural Language Processing, Recurrent Neural Networks  
 
## 1. Introduction
 
In this tutorial, we will build a generative Language Model (LM) using a Gated Recurrent Unit (GRU). The goal of a language model is to predict the next token in a sequence given the previous context. Formally, given a sequence of tokens $x_1, x_2, ..., x_t$, the model attempts to learn the probability distribution:
 
$P(x_{t+1} | x_1, ..., x_t) $
 
We will use a **Character-Level** model for this demonstration. This means the model processes text character-by-character (including spaces and punctuation). While less semantic than word-level models, character-level models have a smaller vocabulary and are excellent for demonstrating the mechanics of RNNs without massive datasets.

## 2. Theoretical Background: The GRU

Standard Recurrent Neural Networks (RNNs) suffer from the **vanishing gradient problem**. During backpropagation through time (BPTT), gradients can shrink exponentially as they propagate backward, making it difficult for the model to learn long-range dependencies.
 
The *Gated Recurrent Unit (GRU)*, introduced by Cho et al. (2014), solves this using gating mechanisms that regulate the flow of information. It is mathematically similar to the LSTM but more computationally efficient.

## Mathematical Formulation

At time step $t$, given input $x_t$ and previous hidden state $h_{t-1}$, the GRU computes:

### 1.  *Reset Gate ($r_t$):* Determines how much of the past information to forget.
### $r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{t-1} + b_{hr})$

### 2.  *Update Gate ($z_t$):* Determines how much of the past state to carry over to the new state.
### $z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{t-1} + b_{hz})$

### 3.  *Candidate Hidden State ($\tilde{n}_t$):* Computes new candidate information.
### $\tilde{n}_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{t-1} + b_{hn})) $

### 4.  *Final Hidden State ($h_t$):* Linear interpolation between the previous state and the candidate state.
### $ h_t = (1 - z_t) \odot \tilde{n}_t + z_t \odot h_{t-1} $

Where $\sigma$ is the sigmoid function, $\odot$ is the Hadamard (element-wise) product, and $W$ and $b$ are learnable weights and biases.


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import time

In [2]:
# Set seed for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# --- CRITICAL IMPORTS FOR TPU/XLA ---
# Note: torch_xla is only needed for Google Cloud TPU environments
# This notebook will work fine without it on CPU/CUDA/MPS
XLA_AVAILABLE = False
xm = None

# Uncomment the following lines if running on TPU:
# try:
#     import torch_xla.core.xla_model as xm
#     XLA_AVAILABLE = True
# except ImportError:
#     XLA_AVAILABLE = False
#     print("WARNING: torch_xla not found. Running on CPU/CUDA fallback.")
# --- END XLA IMPORTS ---

# Set device for PyTorch operations
if XLA_AVAILABLE:
    # Use xm.xla_device() to get the primary TPU core device
    DEVICE = xm.xla_device()
    N_DEVICES = 1 # Force single device count
    print(f"Using Single XLA Device: {DEVICE}")
elif torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.cuda.manual_seed_all(SEED)
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')

print(f'Using device: {DEVICE}')


# Uncomment the following lines if running on Google Colab
# from google.colab import drive
# drive.mount('/content/drive')

Using device: mps


## 3. The Dataset

### For this tutorial, we will use a small dataset of Aesop's Fables. In a real-world scenario, you might use Project Gutenberg texts or the WikiText dataset.
### We simulate loading a file by defining a raw string below.

In [3]:
RAW_TEXT = """
Title: The Wolf and the Lamb.
Wolf, meeting with a Lamb astray from the fold, resolved not to lay violent hands on him, but to find some plea to justify to the Lamb the Wolf's right to eat him. 
He thus addressed him: "Sirrah, last year you grossly insulted me." 
"Indeed," bleated the Lamb in a mournful tone of voice, "I was not then born." 
Then said the Wolf, "You feed in my pasture." 
"No, good sir," replied the Lamb, "I have not yet tasted grass." 
Again said the Wolf, "You drink of my well." 
"No," exclaimed the Lamb, "I never yet drank water, for as yet my mother's milk is both food and drink to me.
" Upon which the Wolf seized him and ate him up, saying, "Well! I won't remain supperless, even though you refute every one of my imputations.
" The tyrant will always find a pretext for his tyranny.

Title: The Bat and the Weasels.
A Bat who fell upon the ground and was caught by a Weasel pleaded to be spared his life. 
The Weasel refused, saying that he was by nature the enemy of all birds. 
The Bat assured him that he was not a bird, but a mouse, and thus was set free. 
Shortly afterwards the Bat again fell to the ground and was caught by another Weasel, whom he likewise entreated not to eat him. 
The Weasel said that he had a special hostility to mice. 
The Bat assured him that he was not a mouse, but a bird, and thus escaped. 
It is wise to turn circumstances to good account.

Title: The Ass and the Grasshopper.
An Ass having heard some Grasshoppers chirping, was highly enchanted; and, desiring to possess the like charms of melody, demanded what sort of food they lived on to give them such beautiful voices. 
They replied, "The dew." The Ass resolved that he would live only upon dew, and in a short time died of hunger.

Title: The Lion and the Mouse.
A Lion was awakened from sleep by a Mouse running over his face. 
Rising up angrily, he caught him and was about to kill him, when the Mouse piteously entreated, saying: "If you would only spare my life, I would be sure to repay your kindness." 
The Lion laughed and let him go. 
It happened shortly after this that the Lion was caught by some hunters, who bound him by strong ropes to the ground. 
The Mouse, recognizing his roar, came and gnawed the rope with his teeth, and set him free, exclaiming: "You ridiculed the idea of my ever being able to help you, not expecting to receive from me any repayment of your favor; now you know that it is possible for even a Mouse to con benefits on a Lion."
"""

## Preprocessing: Clean slightly and analyze stats

In [4]:
text = RAW_TEXT.strip()
print(f"Dataset length: {len(text)} characters")

Dataset length: 2487 characters


## 4. Preprocessing
 
### Neural Networks operate on numbers (tensors), not strings. We need to pipeline our data:
#### 1.  **Vocabulary Construction:** Find the set of unique characters.
#### 2.  **Indexing:** Create mappings `char -> int` and `int -> char`.
#### 3.  **Sequence Generation:** Create sliding windows of text.

In [5]:
# 1. Vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Unique characters: {vocab_size}")
print(f"Vocabulary: {''.join(chars)}")

# 2. Mappings
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

# 3. Encode the entire text
encoded_text = [char_to_idx[ch] for ch in text]
encoded_tensor = torch.tensor(encoded_text, dtype=torch.long)

Unique characters: 48
Vocabulary: 
 !"',.:;ABGHILMNRSTUWYabcdefghijklmnoprstuvwxyz


## Custom Dataset Class
 
#### We will create sequences of length `seq_len`. 
* **Input ($x$):** Characters at indices $[i, i+seq\_len-1]$
* **Target ($y$):** Characters at indices $[i+1, i+seq\_len]$
 
#### For example, if text is "HELLO" and `seq_len`=3:
* Input: "HEL"
* Target: "ELL"

In [7]:
class TextDataset(Dataset):
    def __init__(self, text_tensor, seq_len):
        self.text_tensor = text_tensor
        self.seq_len = seq_len
        
    def __len__(self):
        # We can fit len(text) - seq_len sequences
        return len(self.text_tensor) - self.seq_len
    
    def __getitem__(self, idx):
        # Input sequence
        input_seq = self.text_tensor[idx : idx + self.seq_len]
        # Target sequence (shifted by 1)
        target_seq = self.text_tensor[idx + 1 : idx + self.seq_len + 1]
        
        return input_seq, target_seq

# Hyperparameters for Data
SEQ_LEN = 100
BATCH_SIZE = 32

dataset = TextDataset(encoded_tensor, SEQ_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# Verify shapes
x_batch, y_batch = next(iter(dataloader))
print(f"Input batch shape: {x_batch.shape}") # [Batch, Seq_Len]
print(f"Target batch shape: {y_batch.shape}")

Input batch shape: torch.Size([32, 100])
Target batch shape: torch.Size([32, 100])



## 5. Model Architecture

### We will define a `GRULanguageModel` class inheriting from `nn.Module`.

### **Architecture:**
### 1.  **Embedding Layer:** Converts integer indices to dense vectors.
###     * Input: `(Batch, Seq_Len)`
###     * Output: `(Batch, Seq_Len, Embedding_Dim)`
### 2.  **GRU Layer:** Processes the sequence.
###     * Input: `(Batch, Seq_Len, Embedding_Dim)`
###     * Output: `(Batch, Seq_Len, Hidden_Dim)`
### 3.  **Fully Connected (Linear) Layer:** Maps hidden state to vocabulary logits.
###     * Input: `(Batch, Seq_Len, Hidden_Dim)`
###     * Output: `(Batch, Seq_Len, Vocab_Size)`

In [8]:
class GRULanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers=1):
        super(GRULanguageModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # 1. Embedding Layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # 2. GRU Layer
        # batch_first=True expects input shape (batch, seq, feature)
        self.gru = nn.GRU(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        
        # 3. Output Layer
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, x, hidden):
        # x shape: (batch_size, seq_len)
        
        # Embed: (batch_size, seq_len, embedding_dim)
        embeds = self.embedding(x)
        
        # GRU Forward
        # out shape: (batch_size, seq_len, hidden_dim)
        # hidden shape: (num_layers, batch_size, hidden_dim)
        out, hidden = self.gru(embeds, hidden)
        
        # Flatten output for Linear layer
        # reshape to (batch_size * seq_len, hidden_dim)
        out = out.contiguous().view(-1, self.hidden_dim)
        
        # Project to vocabulary size
        # out shape: (batch_size * seq_len, vocab_size)
        out = self.fc(out)
        
        return out, hidden
    
    def init_hidden(self, batch_size):
        # Initialize hidden state with zeros
        weight = next(self.parameters()).data
        return weight.new(self.num_layers, batch_size, self.hidden_dim).zero_()


## 6. Training Loop

### We use **CrossEntropyLoss** which combines `LogSoftmax` and `NLLLoss`.
### Note on dimensions: PyTorch's `CrossEntropyLoss` expects:
### * Input (Logits): $(N, C)$ where $C$ is class count (vocab size).
### * Target: $(N)$ where values are indices $[0, C-1]$.
### This is why we flattened the output in the `forward` method.

In [9]:
# Model Hyperparameters
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 2
LEARNING_RATE = 0.002
EPOCHS = 50

# Instantiate Model, Loss, and Optimizer
model = GRULanguageModel(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
print(model)

# Calculate total trainable parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Bi-directional GRU Model Parameters: {total_params:,}")

GRULanguageModel(
  (embedding): Embedding(48, 128)
  (gru): GRU(128, 256, num_layers=2, batch_first=True)
  (fc): Linear(in_features=256, out_features=48, bias=True)
)
Bi-directional GRU Model Parameters: 709,680


In [10]:
# Training
loss_history = []
model.train()

print(f"Starting training for {EPOCHS} epochs...")
start_time = time.time()

for epoch in range(EPOCHS):
    # Initialize hidden state for the batch
    # Note: In stateful RNN training, we might detach hidden states between batches 
    # to prevent backpropagating through the entire dataset history, 
    # but keep the values to maintain context. 
    # Here, for simplicity in a shuffled loader, we init hidden per batch.
    h = model.init_hidden(BATCH_SIZE)
    
    epoch_loss = 0
    
    for x, y in dataloader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        
        # Detach hidden state to prevent backpropagating through history 
        # beyond the current batch (Truncated BPTT)
        h = h.data 
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        output, h = model(x, h)
        
        # Reshape target to align with output (batch_size * seq_len)
        y = y.view(-1)
        
        # Compute loss
        loss = criterion(output, y)
        
        # Backward pass
        loss.backward()
        
        # Gradient Clipping (prevents exploding gradients, common in RNNs)
        nn.utils.clip_grad_norm_(model.parameters(), 5)
        
        if XLA_AVAILABLE:
            # XLA specific optimization step
            xm.optimizer_step(optimizer)
            xm.mark_step() # Signal end of computation step to XLA
        else:
            optimizer.step()
        
        epoch_loss += loss.item()
        
    avg_loss = epoch_loss / len(dataloader)
    loss_history.append(avg_loss)
    
    if (epoch+1) % 10 == 0:
        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f}")

print(f"Training finished in {time.time()-start_time:.2f}s")

# Save the model
# Use xm.save if XLA is available, otherwise standard torch.save
save_model = xm.save if XLA_AVAILABLE else torch.save
save_model(model.state_dict(), "gru_chara_lm.pth")

"""
# Plotting Loss
plt.figure(figsize=(10,5))
plt.plot(loss_history)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()
"""

Starting training for 50 epochs...
Epoch 10/50 | Loss: 0.0883
Epoch 20/50 | Loss: 0.0802
Epoch 30/50 | Loss: 0.0761
Epoch 40/50 | Loss: 0.0750
Epoch 50/50 | Loss: 0.0736
Training finished in 249.26s


'\n# Plotting Loss\nplt.figure(figsize=(10,5))\nplt.plot(loss_history)\nplt.title("Training Loss")\nplt.xlabel("Epoch")\nplt.ylabel("Loss")\nplt.show()\n'

## 7. Text Generation

### To generate text, we:
### 1.  Feed a "seed" string into the model to build up the hidden state.
### 2.  Predict the next character probability distribution.
### 3.  **Sample** from this distribution (don't just take the `argmax`, or the text will be repetitive).

### Temperature Sampling
### We use a hyperparameter $T$ (temperature) to control randomness.
### $$ P_i = \frac{\exp(z_i / T)}{\sum \exp(z_j / T)} $$
### **High T (>1.0):** Flattens distribution (more random/creative).
### **Low T (<1.0):** Sharpens distribution (more confident/conservative).

In [13]:
def generate_text(model, start_str="The", len_generated=200, temperature=0.8):
    model.eval()
    
    # Initialize hidden state (batch size 1)
    hidden = model.init_hidden(1)
    
    input_seq = torch.tensor([char_to_idx[ch] for ch in start_str]).unsqueeze(0).to(DEVICE)
    
    generated_text = start_str
    
    with torch.no_grad():
        # 1. Build up hidden state with start_str
        # We run the whole seed sequence through. 
        # We only care about the hidden state output after the last char.
        out, hidden = model(input_seq, hidden)
        
        # The input for the next step is the last character of the seed
        last_char_idx = input_seq[:, -1]
        
        for i in range(len_generated):
            # Forward pass with single character
            # Reshape input to (1, 1) -> (Batch, Seq)
            current_input = last_char_idx.unsqueeze(1)
            
            # Embed and GRU
            # Note: We manually perform the forward logic here because our 
            # model.forward flattens output, which is slightly annoying for generation loop.
            # Let's use the model components directly for clarity.
            
            emb = model.embedding(current_input)
            out, hidden = model.gru(emb, hidden)
            
            # Output shape: (1, 1, hidden) -> (1, hidden)
            out = out.squeeze(1)
            logits = model.fc(out)
            
            # Apply temperature
            probs = torch.softmax(logits / temperature, dim=1).cpu().numpy()[0]
            
            # Sample from distribution
            next_char_idx = np.random.choice(vocab_size, p=probs)
            
            # Append result
            generated_text += idx_to_char[next_char_idx]
            
            # Update input for next step
            last_char_idx = torch.tensor([next_char_idx]).to(DEVICE)
            
    return generated_text

# Model Hyperparameters
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 2

# Instantiate and load the model
model = GRULanguageModel(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS).to(DEVICE)
model.load_state_dict(torch.load("gru_chara_lm.pth"))
model.eval()

# Demo 1: Moderate Temperature
print("--- Generated Text (T=0.8) ---")
print(generate_text(model, start_str="The Lion", len_generated=300, temperature=0.8))

print("\n--- Generated Text (T=0.2, Conservative) ---")
print(generate_text(model, start_str="The Lion", len_generated=300, temperature=0.2))

--- Generated Text (T=0.8) ---
The Lion feed him and ate him up, saying, "Well! I won't remain supperless, even though you refute every one of my imputations.
" The tyrant will always find a pretext for his tyranny.

Title: The Bat and the Weasels.
A Bat who fell upon the ground and was caught by another Weasel, whom he likewise entreate

--- Generated Text (T=0.2, Conservative) ---
The Lion and the Mouse.
A Lion was awakened from sleep by a Mouse running over his face. 
Rising up angrily, he caught him and was about to kill him, when the Mouse piteously entreated, saying: "If you would only spare my life, I would be sure to repay your kindness." 
The Lion laughed and let him go. 
It h
