# How Large Language Models Use Sliding Windows for Next-Word Prediction — With PyTorch

## Introduction

In this post, we’ll dive into how Large Language Models (LLMs) predict the next word in a sequence using the sliding window technique. This simple yet powerful concept forms the basis of training LLMs like GPT-3.

We’ll explain:

- How to tokenize text and convert it into a form that models can learn from.
- How to implement the sliding window approach using PyTorch’s Dataset class.
- How to structure the solution across multiple Python files for clarity and maintainability.

By the end of this post, you’ll have a solid understanding of how LLMs learn to predict text, along with a reusable code structure for further experiments.

## Understanding the Sliding Window Approach and Next-Word Prediction

Next-Word Prediction is the core task that LLMs are trained on. When a model generates text, it predicts what word (or token) comes next in a sentence based on the preceding context. For example, if the model sees the phrase “The cat sat on the,” it should predict that “mat” is likely to follow.

The Sliding Window Approach is a method to create training examples for this task. Instead of feeding the model the entire sentence at once, we divide the text into smaller, overlapping chunks or “windows.” Each window contains a fixed number of tokens that serve as input, and the token immediately following the window is the target for prediction.

Here’s how the sliding window works in practice:

- Window Size: This is the number of tokens the model sees at one time. For instance, if the window size is 4, the model looks at four words.
- Stride: This is how many tokens the window moves forward after each prediction. If the stride is 1, the model moves one token forward; if it’s 2, the window skips one token.
Example: Given the sentence:

```
"The quick brown fox jumps over the lazy dog."
```
With a window size of 4 and a stride of 1, the input-output pairs generated would be:

```
Input: ["The", "quick", "brown", "fox"] → Output: "jumps"
Input: ["quick", "brown", "fox", "jumps"] → Output: "over"
Input: ["brown", "fox", "jumps", "over"] → Output: "the"
```

This technique helps the model learn the relationships between words in a sentence, enabling it to generate coherent text.

See this gif:

![Alt text](token_prediction.gif "Image Title")


## 1. Structuring the Code: Files and Organization
To keep things clean and maintainable, we’ll break the code into three files:

- tokenizer.py: For tokenizing the text and converting it to token IDs.
- dataset.py: Defines the dataset class using the PyTorch Dataset base class.
- main.py: The main script where we prepare the dataset and run the sliding window approach.
Here’s how we’ll organize the project:

```
llm_sliding_window/
│
├── tokenizer.py
├── dataset.py
└── main.py
```

## 2. Tokenizing Text (tokenizer.py)
LLMs don’t work directly with raw text — they convert words into tokens. In tokenizer.py, we’ll define a simple tokenizer that maps each unique word to a token ID.

In [None]:
# tokenizer.py
class SimpleTokenizer:
    def __init__(self, text):
        # Split text into words
        self.words = text.split()
        # Assign each unique word a token ID
        self.tokens = {word: idx for idx, word in enumerate(set(self.words))}
    
    def tokenize(self):
        """Convert words to token IDs"""
        return [self.tokens[word] for word in self.words]
    
    def decode(self, token_ids):
        """Convert token IDs back to words"""
        reverse_tokens = {idx: word for word, idx in self.tokens.items()}
        return [reverse_tokens[token] for token in token_ids]

# Example usage
if __name__ == "__main__":
    text = "Lorem Ipsum is simply dummy text of the printing and typesetting industry."
    tokenizer = SimpleTokenizer(text)
    token_ids = tokenizer.tokenize()
    print("Token IDs:", token_ids)

## 3. Creating a Custom Dataset (dataset.py)
The heart of our solution is creating a custom dataset class that will generate input-output pairs (sliding windows) for training. We’ll inherit from PyTorch’s Dataset class to allow this dataset to be used with PyTorch’s data-loading and training tools.

In [None]:
# dataset.py
import torch
from torch.utils.data import Dataset

class SlidingWindowDataset(Dataset):
    def __init__(self, tokenized_text, window_size, stride=1):
        self.tokenized_text = tokenized_text
        self.window_size = window_size
        self.stride = stride
        self.input_windows, self.output_tokens = self._generate_windows()
    
    def _generate_windows(self):
        """Generate input-output pairs using a sliding window approach."""
        input_windows = []
        output_tokens = []
        for i in range(0, len(self.tokenized_text) - self.window_size, self.stride):
            # Input window of size `window_size`
            input_windows.append(self.tokenized_text[i:i + self.window_size])
            # The next token (the one we want to predict)
            output_tokens.append(self.tokenized_text[i + self.window_size])
        return input_windows, output_tokens
    
    def __len__(self):
        """Return the total number of samples."""
        return len(self.input_windows)
    
    def __getitem__(self, idx):
        """Return a single input-output pair."""
        return torch.tensor(self.input_windows[idx]), torch.tensor(self.output_tokens[idx])

# Example usage
if __name__ == "__main__":
    tokenized_text = [2, 4, 5, 6, 7, 1, 3, 8, 9, 10, 11, 12]  # Example tokenized text
    dataset = SlidingWindowDataset(tokenized_text, window_size=4, stride=2)
    for i in range(len(dataset)):
        input_window, next_token = dataset[i]
        print(f"Input: {input_window}, Next token: {next_token}")

## 4. Bringing It All Together (main.py)
Now that we have the SimpleTokenizer and SlidingWindowDataset ready, let’s use them together in the main.py file to simulate a real-world scenario. We’ll tokenize the text, create sliding windows, and print out the input-output pairs.

In [None]:
# main.py
from tokenizer import SimpleTokenizer
from dataset import SlidingWindowDataset
from torch.utils.data import DataLoader

# Sample text
text = "Lorem Ipsum is simply dummy text of the printing and typesetting industry."

# Step 1: Tokenize the text
tokenizer = SimpleTokenizer(text)
tokenized_text = tokenizer.tokenize()
print("Tokenized text:", tokenized_text)

# Step 2: Create the dataset
window_size = 4
stride = 1
dataset = SlidingWindowDataset(tokenized_text, window_size=window_size, stride=stride)

# Step 3: Load the data using PyTorch's DataLoader (batch_size=1 for simplicity)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# Step 4: Iterate through the dataloader and print the input-output pairs
print("\nSliding Window Input-Output Pairs:")
for i, (input_window, next_token) in enumerate(dataloader):
    decoded_input = tokenizer.decode(input_window.squeeze().tolist())
    decoded_token = tokenizer.decode([next_token.item()])
    print(f"Sample {i+1}: Input: {decoded_input} -> Predict: {decoded_token}")

## Explanation of the Workflow

- tokenizer.py: This file contains the SimpleTokenizer class, which splits the text into words, assigns token IDs to each unique word, and converts the text into a list of token IDs.
- dataset.py: We define the SlidingWindowDataset, which generates sliding window input-output pairs. The dataset inherits from PyTorch’s Dataset class to make it compatible with PyTorch’s training ecosystem. This custom dataset can be easily expanded and reused.
- main.py: This script brings everything together. It tokenizes the input text, creates sliding windows of token IDs, and uses PyTorch’s DataLoader to batch and iterate through the dataset. Each sliding window is shown along with the next token to predict.

## 5. Running the Code

Let’s now run the main.py file to see the sliding window technique in action. After tokenizing the text and generating sliding windows, the output will look like this:

```
Tokenized text: [2, 4, 5, 6, 7, 1, 3, 8, 9, 10, 11, 12]

Sliding Window Input-Output Pairs:
Sample 1: Input: ['Lorem', 'Ipsum', 'is', 'simply'] -> Predict: ['dummy']
Sample 2: Input: ['Ipsum', 'is', 'simply', 'dummy'] -> Predict: ['text']
Sample 3: Input: ['is', 'simply', 'dummy', 'text'] -> Predict: ['of']
Sample 4: Input: ['simply', 'dummy', 'text', 'of'] -> Predict: ['the']
...
```

As you can see, the model is trained to predict the next word based on the current context window.

## Conclusion

In this tutorial, we explored the sliding window technique used to train LLMs for next-word prediction. We built a simple but effective solution using PyTorch and organized our code into three separate files:

- tokenizer.py for tokenization,
- dataset.py for creating the sliding windows, and
- main.py for tying everything together and iterating through the dataset.

By following this approach, you can easily extend the solution to larger datasets, more complex tokenization schemes (like Byte Pair Encoding), and even start training your own language models!

## Next Steps
Expand the tokenizer to handle more complex inputs.
Experiment with different window sizes and strides to see their impact on data sampling.
Use this dataset with a neural network for next-token prediction.

## Generate animated gif

In [19]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as patches

def create_sliding_window_animation(input_text, window_size, stride, output_file='sliding_window.gif'):
    # Set up the figure and axes
    fig, (ax_text, ax_main) = plt.subplots(2, 1, figsize=(12, 6), height_ratios=[1, 3])
    plt.subplots_adjust(hspace=0.3)

    # Process input text
    words = input_text.split()

    # Initialize the main plot
    ax_main.set_xlim(0, len(words))
    ax_main.set_ylim(0, 1)
    ax_main.set_xticks([])
    ax_main.set_yticks([])

    # Add the words to the main plot
    for i, word in enumerate(words):
        ax_main.text(i + 0.5, 0.5, word, ha='center', va='center', fontsize=10)

    # Create the sliding window
    window = patches.Rectangle((0, 0), window_size, 1, fill=False, edgecolor='red', lw=2)
    ax_main.add_patch(window)

    # Set up the text area
    ax_text.axis('off')
    title = ax_text.text(0.5, 0.7, "Sliding Window in LLM Training", ha='center', va='center', fontsize=14, fontweight='bold')
    window_text = ax_text.text(0.5, 0.3, '', ha='center', va='center', fontsize=12)

    def animate(frame):
        i = frame * stride
        window.set_x(i)
        current_words = words[i:i+window_size]
        window_text.set_text(f"Current window: {' '.join(current_words)}")
        return window, window_text

    # Calculate the number of frames based on stride
    num_frames = (len(words) - window_size) // stride + 1

    # Create the animation
    anim = animation.FuncAnimation(fig, animate, frames=num_frames, interval=1000, blit=True)

    # Save the animation as a gif
    anim.save(output_file, writer='pillow', fps=1)
    plt.close(fig)

    print(f"Animation saved as {output_file}")

# Example usage
input_text = "Lorem Ipsum is simply dummy text of the printing and typesetting industry."
create_sliding_window_animation(input_text=input_text, window_size=6, stride=2)

Animation saved as sliding_window.gif


In [22]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as patches
import numpy as np

def create_token_prediction_animation(tokens, words, window_size, output_file='token_prediction.gif'):
    # Set up the figure and axes
    fig, (ax_text, ax_main) = plt.subplots(2, 1, figsize=(12, 6), height_ratios=[1, 3])
    plt.subplots_adjust(hspace=0.3)

    # Initialize the main plot
    ax_main.set_xlim(0, len(tokens) + 1)
    ax_main.set_ylim(0, 2)
    ax_main.set_xticks([])
    ax_main.set_yticks([])

    # Add the tokens and words to the main plot
    for i, (token, word) in enumerate(zip(tokens, words)):
        ax_main.text(i + 0.5, 1.5, str(token), ha='center', va='center', fontsize=10, color='black')  # Token ID
        ax_main.text(i + 0.5, 0.5, word, ha='center', va='center', fontsize=10, color='green')  # Word

    # Create the sliding window for context tokens
    window = patches.Rectangle((0, 0), window_size, 2, fill=False, edgecolor='red', lw=2)
    ax_main.add_patch(window)

    # Set up the text area
    ax_text.axis('off')
    title = ax_text.text(0.5, 0.7, "LLM Predicting the Next Token", ha='center', va='center', fontsize=14, fontweight='bold')
    window_text = ax_text.text(0.5, 0.3, '', ha='center', va='center', fontsize=12)

    def animate(frame):
        i = frame
        current_tokens = tokens[i:i+window_size]
        current_words = words[i:i+window_size]
        next_token = tokens[i + window_size] if i + window_size < len(tokens) else None
        next_word = words[i + window_size] if i + window_size < len(words) else None

        # Update window position
        window.set_x(i)

        # Update the text to display the current context tokens and next predicted token
        context_text = f"Context tokens: {' '.join(map(str, current_tokens))}"
        context_words = f"Context words: {' '.join(current_words)}"
        prediction_text = f"Predicted next token: {next_token} ({next_word})" if next_token is not None else "End of sequence"
        window_text.set_text(f"{context_text}\n{context_words}\n{prediction_text}")

        # Highlight the predicted token and word
        ax_main.text(i + window_size + 0.5, 1.5, str(next_token), ha='center', va='center', fontsize=10, color='blue')
        ax_main.text(i + window_size + 0.5, 0.5, next_word, ha='center', va='center', fontsize=10, color='blue')

        return window, window_text

    # Calculate the number of frames (we predict until the second-to-last token)
    num_frames = len(tokens) - window_size

    # Create the animation
    anim = animation.FuncAnimation(fig, animate, frames=num_frames, interval=1000, blit=True)

    # Save the animation as a gif
    anim.save(output_file, writer='pillow', fps=1)
    plt.close(fig)

    print(f"Animation saved as {output_file}")

# Example usage
dummy_text = "Lorem Ipsum is simply dummy text of the printing and typesetting industry."
words = dummy_text.split()  # List of words
tokens = list(np.random.randint(1000, 1100, len(words)))  # Simulated token IDs for each word

create_token_prediction_animation(tokens=tokens, words=words, window_size=5)


Animation saved as token_prediction.gif
