# Text Generation with LSTMs in PyTorch


## Step 1: Setting Up the Environment and Configuration


In [None]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
# Configuration
TEXT_DATA = "hello world this is a simple text for our lstm model" # Example text data
SEQUENCE_LENGTH = 5  # Number of characters the LSTM looks at to predict the next
HIDDEN_SIZE = 128    # Number of features in the LSTM's hidden state
NUM_LAYERS_LSTM = 1  # Number of stacked LSTM layers
NUM_EPOCHS = 200     # Training iterations
LEARNING_RATE = 0.001
GENERATION_START_TEXT = "hello" # The initial text to kick off generation
GENERATION_LENGTH = 3 # Number of characters to generate

## Step 2: Preparing Text Data


In [None]:
# Create a sorted list of unique characters to ensure consistent mapping
unique_chars = sorted(list(set(TEXT_DATA)))
vocab_size = len(unique_chars) # Size of our character vocabulary

# Create mappings between characters and their numerical indices
char_to_idx = {ch: idx for idx, ch in enumerate(unique_chars)}
idx_to_char = {idx: ch for idx, ch in enumerate(unique_chars)}

# Prepare sequences for training: input sequences and their corresponding next-character targets
sequences_idx = []
targets_idx = []
for i in range(len(TEXT_DATA) - SEQUENCE_LENGTH):
    sequences_idx.append([char_to_idx[c] for c in TEXT_DATA[i:i+SEQUENCE_LENGTH]])
    targets_idx.append(char_to_idx[TEXT_DATA[i+SEQUENCE_LENGTH]])

# Convert sequences to one-hot encoded tensors for the LSTM input
# Shape: (number_of_sequences, sequence_length, vocabulary_size)
input_sequences_one_hot = torch.zeros(len(sequences_idx), SEQUENCE_LENGTH, vocab_size, dtype=torch.float32)
for i, seq in enumerate(sequences_idx):
    for j, char_idx in enumerate(seq):
        input_sequences_one_hot[i, j, char_idx] = 1.0 # Set the appropriate character index to 1.0

# Convert target indices to a tensor
target_tensor = torch.tensor(targets_idx, dtype=torch.long)

## Step 3: Building the LSTM Model


In [None]:
class TextGeneratorLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super(TextGeneratorLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # The core LSTM layer.
        # `input_size` (vocab_size): Dimension of each input character (one-hot vector).
        # `hidden_size`: Size of the hidden state.
        # `num_layers`: Number of stacked LSTM layers.
        # `batch_first=True`: Ensures input/output tensors are (batch, seq, features).
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)

        # Fully connected layer to map LSTM's output hidden state to the vocabulary size,
        # giving us probabilities for each possible next character.
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hc=None):
        # Initialize hidden and cell states with zeros if not provided (for first sequence in a batch)
        if hc is None:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            hc_init = (h0, c0)
        else:
            hc_init = hc

        # Forward pass through the LSTM layer.
        # `out` contains the hidden states for each time step in the sequence from the last layer.
        # `hidden_state` contains the final (h_n, c_n) for the entire batch.
        out, hidden_state = self.lstm(x, hc_init)

        # We take the hidden state output of the *last time step* (`out[:, -1, :]`)
        # to predict the next character, as it summarizes the entire input sequence.
        out = self.fc(out[:, -1, :])

        return out, hidden_state # Return output logits and final hidden state for sequential generation

## Step 4: Training the LSTM Model


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = TextGeneratorLSTM(input_size=vocab_size,
                          hidden_size=HIDDEN_SIZE,
                          output_size=vocab_size, # Output size is vocabulary size for character prediction
                          num_layers=NUM_LAYERS_LSTM)
model.to(device)

# Loss function for multi-class classification (predicting the next character)
criterion = nn.CrossEntropyLoss()
# Optimizer to update model weights
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Starting training for {NUM_EPOCHS} epochs...")
for epoch in range(NUM_EPOCHS):
    model.train() # Set model to training mode (enables dropout, etc., if defined)

    # Move data batch to the selected device
    inputs_one_hot_batch = input_sequences_one_hot.to(device)
    targets_batch = target_tensor.to(device)

    optimizer.zero_grad() # Clear gradients from the previous iteration

    # Forward pass: get predictions. For this basic example, we treat each sequence independently.
    output_logits, _ = model(inputs_one_hot_batch)

    loss = criterion(output_logits, targets_batch) # Calculate the loss
    loss.backward() # Backpropagate to compute gradients
    optimizer.step() # Update model weights

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {loss.item():.4f}')
print("Training complete.")

Using device: cpu
Starting training for 200 epochs...
Epoch [10/200], Loss: 2.7648
Epoch [20/200], Loss: 2.6198
Epoch [30/200], Loss: 2.5377
Epoch [40/200], Loss: 2.4807
Epoch [50/200], Loss: 2.3800
Epoch [60/200], Loss: 2.1800
Epoch [70/200], Loss: 1.8761
Epoch [80/200], Loss: 1.4604
Epoch [90/200], Loss: 1.0242
Epoch [100/200], Loss: 0.6256
Epoch [110/200], Loss: 0.3596
Epoch [120/200], Loss: 0.2154
Epoch [130/200], Loss: 0.1342
Epoch [140/200], Loss: 0.0885
Epoch [150/200], Loss: 0.0628
Epoch [160/200], Loss: 0.0475
Epoch [170/200], Loss: 0.0375
Epoch [180/200], Loss: 0.0305
Epoch [190/200], Loss: 0.0254
Epoch [200/200], Loss: 0.0215
Training complete.


## Step 5: Generating Text


In [None]:
def generate_text(model, start_text, length, char_to_idx, idx_to_char, sequence_length, vocab_size, device):
    model.eval() # Set model to evaluation mode (disables dropout, batch norm, etc.)

    # Prepare initial sequence for generation
    current_indices = []
    # Only use characters from start_text that are in our known vocabulary
    for c in start_text:
        if c in char_to_idx:
            current_indices.append(char_to_idx[c])
        else:
            print(f"Warning: Character '{c}' in start_text not found in vocabulary. Skipping.")

    # Handle cases where start_text might be empty or contain no known characters
    if not current_indices:
        print(f"Warning: No valid characters in start_text '{start_text}'. Using a default seed.")
        # Fallback: use the first character from vocabulary as a seed
        first_char_idx = 0
        current_indices = [first_char_idx]
        generated_text = idx_to_char[first_char_idx]
    else:
        generated_text = start_text

    # Hidden state to be carried over for stateful generation
    # Initialized to None for the first prediction, then updated.
    hidden_cell_state = None

    with torch.no_grad(): # Crucial for inference: disables gradient calculation to save memory and speed
        for _ in range(length):
            # Take the *last `sequence_length`* characters as input for the model.
            # This simulates a sliding window approach for generation.
            input_for_model_indices = current_indices[-sequence_length:]

            # Create a one-hot encoded tensor for the current input characters
            current_input_one_hot = torch.zeros(1, len(input_for_model_indices), vocab_size, dtype=torch.float32).to(device)
            for i, char_idx in enumerate(input_for_model_indices):
                current_input_one_hot[0, i, char_idx] = 1.0

            # Pass the current input and the *previous* hidden state through the model
            # This is key for stateful generation, allowing the model to remember context.
            output_logits, hidden_cell_state = model(current_input_one_hot, hidden_cell_state)

            # Get the predicted character index (greedy approach: pick the one with highest probability)
            _, predicted_idx = torch.max(output_logits, dim=1)
            predicted_char_idx = predicted_idx.item()
            predicted_char = idx_to_char[predicted_char_idx]

            generated_text += predicted_char # Append the new character
            current_indices.append(predicted_char_idx) # Add its index to the history for the next iteration

    return generated_text

print("\nStarting text generation...")
generated_output = generate_text(model=model,
                                 start_text=GENERATION_START_TEXT,
                                 length=GENERATION_LENGTH,
                                 char_to_idx=char_to_idx,
                                 idx_to_char=idx_to_char,
                                 sequence_length=SEQUENCE_LENGTH,
                                 vocab_size=vocab_size,
                                 device=device)

print(f"Seed Text: '{GENERATION_START_TEXT}'")
print(f"Generated Text: '{generated_output}'")


Starting text generation...
Seed Text: 'hello'
Generated Text: 'hello hi'
