<a href="https://colab.research.google.com/github/banno-0720/Deep-Learning-Projects/blob/main/Belief_State_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Belief State Transformer Paper Replication

Based on Research paper "THE BELIEF STATE TRANSFORMER
" which was published by Microsoft on February 20, 2025.

[Click on this](https://arxiv.org/pdf/2410.23506) for original research paper

And [Click on this](https://youtu.be/aqhbRtB2Fyg?si=ABz33R6ZfdWue-mi) for the video on the topic, for better understanding of Belief State Transformers

## Why replicate a machine learning research paper?

A machine learning research paper is often a presentation of months of work and experiments done by some of the best machine learning teams in the world condensed into a few pages of text.

And if these experiments lead to better results in an area related to the problem you're working on, it'd be nice to check them out.

Also, replicating the work of others is a fantastic way to practice your skills.

<img src="https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/08-george-hotz-quote.png" width=600 alt="george hotz quote saying to get better at being a machine learning engineer, download a paper, implement it and keep going until you have skills"/>

*George Hotz is founder of [comma.ai](https://comma.ai/), a self-driving car company and livestreams machine learning coding on [Twitch](https://www.twitch.tv/georgehotz) and those videos get posted in full to [YouTube](https://www.youtube.com/c/georgehotzarchive). I pulled this quote from one of his livestreams. The "٭" is to note that machine learning engineering often involves the extra step(s) of preprocessing data and making your models available for others to use (deployment).*

 ## What is a Belief State?

For any probability distribution over a set of sequences $P(x_{1:T})$, and for any partial sequence $s = x_{1:t}$, we define a vector $v_s$ to be a **belief state** for $s$ if there exists a randomized function $g$ such that

$$
g(v_s) \sim P(x_{t+1:T} \mid x_{1:t}).
$$

In other words, sampling $g(v_s)$ yields a sample from the conditional distribution $P(x_{t+1:T} \mid x_{1:t})$.


By definition, a belief state captures all available information relevant for predicting the future tokens.
Once the belief state is learned, there is no additional useful information to be gained—everything
necessary for future predictions is already encoded within it.



## Why use Belief State Transformer?

In the research paper, we see that they check its performance with
1. Star graph problem, in which it comes out to be the clear winner and with much higher accuracy than next-token, data-augmentation, FIM and teacherless models
2. TinyStories Dataset, it divides the text into three parts: prefix, suffix, and missing middle part and make the model predict the middle missing part. It gains very high accuracy

Due to its self evaluation feature, it does need high computational power but  it has almost perfect accuracy.

# 0. Getting Setup

In [1]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
# Set up device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


# 1. Get Data

In [4]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading xx

In [5]:
from datasets import load_dataset

In [7]:
# Load the dataset
dataset = load_dataset("roneneldan/TinyStories")
print(dataset)

README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})


# Creating Datasets and Dataloaders

In [9]:
def transform_story(story, min_prefix_tokens=5, min_suffix_tokens=5):
    """
    Transforms a story into prefix, missing middle, and suffix parts.
    """
    tokens = story.split()  # Simple tokenization; consider a more robust tokenizer if needed.
    n = len(tokens)
    if n < (min_prefix_tokens + min_suffix_tokens + 1):
        return None
    prefix_end = random.randint(min_prefix_tokens, n - min_suffix_tokens - 1)
    missing_end = random.randint(prefix_end + 1, n - min_suffix_tokens)
    prefix = " ".join(tokens[:prefix_end])
    missing = " ".join(tokens[prefix_end:missing_end])
    suffix = " ".join(tokens[missing_end:])
    return prefix, missing, suffix

In [13]:
from transformers import AutoTokenizer

In [14]:
class TinyStoriesDataset(Dataset):
    def __init__(self, hf_dataset, split="train", tokenizer=AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")):
        self.data = hf_dataset[split]
        self.tokenizer = tokenizer  # Optionally pass a tokenizer (e.g., from HuggingFace) to convert text to ids.

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Assume each example in the dataset is a dictionary with a 'story' field.
        story = self.data[idx]["story"]
        transformed = transform_story(story)
        if transformed is None:
            # If transformation fails, return the full story as fallback
            transformed = (story, "", "")
        prefix, missing, suffix = transformed

        # For simplicity, here we concatenate the parts with special tokens.
        # You could also process each part separately.
        full_input = f"[PREFIX] {prefix} [MISSING] {missing} [SUFFIX] {suffix}"
        # Tokenize: here we simply split into tokens. Replace with your tokenizer if available.
        tokens = full_input.split()
        # Convert tokens to indices: for demonstration, we map each unique token to an index.
        # In practice, use a proper tokenizer/vocabulary.
        token_to_idx = {token: idx for idx, token in enumerate(set(tokens))}
        indices = [token_to_idx[token] for token in tokens]

        # Convert to tensor
        x = torch.tensor(indices, dtype=torch.long)
        return x

tokenizer_config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [15]:
train_dataset = TinyStoriesDataset(dataset, split="train")
val_dataset = TinyStoriesDataset(dataset, split="validation")

In [16]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=lambda x: nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=0))
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=lambda x: nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=0))

In [17]:
print("Train dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))

Train dataset size: 2119719
Validation dataset size: 21990


## 3. Belief State Transformer Overview

The Belief State Transformer is designed to capture a compact representation (or *belief state*) of a partial sequence. Given a sequence $ x_{1:T} $ and a partial sequence $ s = x_{1:t} $, the model learns a vector $ v_s $ such that a randomized function $ g(v_s) $ can sample from the conditional distribution $ P(x_{t+1:T} \mid x_{1:t}) $.

The architecture generally consists of:
- **Forward Encoder:** Processes the prefix $ x_{1:t} $ to produce forward states.
- **Backward Encoder:** Processes the sequence in reverse to capture future context.
- **Text Head:** Combines the forward and backward states to predict tokens in the missing (middle) section.
- **Efficient Prefix-Suffix Loss Computation:** Computes loss over all valid prefix-suffix pairs using a specialized loss function that handles multiple pairs efficiently.

This design allows the model to perform the "fill-in-the-middle" task effectively.


# 4. Equations

For any probability distribution over a set of sequences $ P(x_{1:T}) $, and for any partial sequence $ s = x_{1:t} $, we define a vector $ v_s $ to be a **belief state** for $ s $ if there exists a randomized function $ g $ such that:

$$
g(v_s) \sim P(x_{t+1:T} \mid x_{1:t}).
$$

The efficient computation of the prefix-suffix loss is given by:

1. Compute the forward state $ F $ from the prefix and the backward state $ B $ from the reversed sequence.
2. For each valid prefix-suffix pair, extract the corresponding states.
3. Concatenate the states and pass through the text head to obtain logits.
4. Reshape and compute the cross-entropy loss over all pairs.

The pseudocode for the loss function is provided in the research paper.


# 5. Creating the Transformer Encoder

In [21]:
class SimpleTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, dropout=0.1):
        super(SimpleTransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        # Transformer expects (seq_len, batch, embed_dim)
        embedded = embedded.transpose(0, 1)
        encoded = self.transformer_encoder(embedded)  # (seq_len, batch, embed_dim)
        encoded = encoded.transpose(0, 1)
        return self.dropout(encoded)

In [24]:
# Example instantiation:
vocab_size = 1000  # Adjust based on your tokenizer/vocab
embed_dim = 512
num_heads = 8
hidden_dim = 2048
num_layers = 4

In [23]:
transformer_encoder = SimpleTransformerEncoder(vocab_size, embed_dim, num_heads, hidden_dim, num_layers).to(device)