<a href="https://colab.research.google.com/github/banno-0720/Deep-Learning-Projects/blob/main/Belief_State_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Belief State Transformer Paper Replication

Based on Research paper "THE BELIEF STATE TRANSFORMER
" which was published by Microsoft on February 20, 2025.

[Click on this](https://arxiv.org/pdf/2410.23506) for original research paper

And [Click on this](https://youtu.be/aqhbRtB2Fyg?si=ABz33R6ZfdWue-mi) for the video on the topic, for better understanding of Belief State Transformers

## Why replicate a machine learning research paper?

A machine learning research paper is often a presentation of months of work and experiments done by some of the best machine learning teams in the world condensed into a few pages of text.

And if these experiments lead to better results in an area related to the problem you're working on, it'd be nice to check them out.

Also, replicating the work of others is a fantastic way to practice your skills.

<img src="https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/08-george-hotz-quote.png" width=600 alt="george hotz quote saying to get better at being a machine learning engineer, download a paper, implement it and keep going until you have skills"/>

*George Hotz is founder of [comma.ai](https://comma.ai/), a self-driving car company and livestreams machine learning coding on [Twitch](https://www.twitch.tv/georgehotz) and those videos get posted in full to [YouTube](https://www.youtube.com/c/georgehotzarchive). I pulled this quote from one of his livestreams. The "٭" is to note that machine learning engineering often involves the extra step(s) of preprocessing data and making your models available for others to use (deployment).*

 ## What is a Belief State?

For any probability distribution over a set of sequences $P(x_{1:T})$, and for any partial sequence $s = x_{1:t}$, we define a vector $v_s$ to be a **belief state** for $s$ if there exists a randomized function $g$ such that

$$
g(v_s) \sim P(x_{t+1:T} \mid x_{1:t}).
$$

In other words, sampling $g(v_s)$ yields a sample from the conditional distribution $P(x_{t+1:T} \mid x_{1:t})$.


By definition, a belief state captures all available information relevant for predicting the future tokens.
Once the belief state is learned, there is no additional useful information to be gained—everything
necessary for future predictions is already encoded within it.



## Why use Belief State Transformer?

In the research paper, we see that they check its performance with
1. Star graph problem, in which it comes out to be the clear winner and with much higher accuracy than next-token, data-augmentation, FIM and teacherless models
2. TinyStories Dataset, it divides the text into three parts: prefix, suffix, and missing middle part and make the model predict the middle missing part. It gains very high accuracy

Due to its self evaluation feature, it does need high computational power but  it has almost perfect accuracy.

# 0. Getting Setup

In [92]:
import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [2]:
# Set up device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


# 1. Get Data

In [3]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [4]:
from datasets import load_dataset

In [5]:
# Load the dataset
dataset = load_dataset("mintujupally/ROCStories")
print(dataset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/256 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


train.txt:   0%|          | 0.00/18.1M [00:00<?, ?B/s]

test.txt:   0%|          | 0.00/4.52M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/78528 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/19633 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 78528
    })
    test: Dataset({
        features: ['text'],
        num_rows: 19633
    })
})


# 2. Build a reduceded global vocabulary

Instead of using the full vocabulary (which could be huge), we build a frequency-based vocab

In [6]:
from collections import Counter

In [70]:
def build_vocab(hf_dataset, split="train", max_vocab_size=10000):
    counter = Counter()
    key = "text"
    for example in hf_dataset[split]:
        # The 'text' field is a list of sentences; join them into a single text.
        if isinstance(example[key], list):
            text = " ".join(example[key])
        else:
            text = example[key]
        tokens = text.split()  # Simple tokenization; for production, use a robust tokenizer.
        counter.update(tokens)
    # Add special tokens needed for our task.
    special_tokens = ["[PAD]", "[UNK]", "[PREFIX]", "[MISSING]", "[SUFFIX]"]
    most_common = counter.most_common(max_vocab_size - len(special_tokens))
    vocab = {token: idx for idx, token in enumerate(special_tokens)}
    for token, freq in most_common:
        if token not in vocab:
            vocab[token] = len(vocab)
    return vocab

In [71]:
global_vocab = build_vocab(dataset, split="train", max_vocab_size=10000)
print("Global vocabulary size:", len(global_vocab))

Global vocabulary size: 10000


In [72]:
# Use this reduced vocabulary size in the rest of the model.
vocab_size = len(global_vocab)

# 3. Creating Datasets and Dataloaders

In [73]:
def transform_story(text, min_prefix_tokens=5, min_suffix_tokens=5):
    """
    Transforms a story into prefix, missing middle, and suffix parts.
    """
    tokens = text.split()  # Simple tokenization
    n = len(tokens)
    if n < (min_prefix_tokens + min_suffix_tokens + 1):
        return None
    prefix_end = random.randint(min_prefix_tokens, n - min_suffix_tokens - 1)
    missing_end = random.randint(prefix_end + 1, n - min_suffix_tokens)
    prefix = " ".join(tokens[:prefix_end])
    missing = " ".join(tokens[prefix_end:missing_end])
    suffix = " ".join(tokens[missing_end:])
    return prefix, missing, suffix

In [74]:
class ROCStoriesDataset(Dataset):
    def __init__(self, hf_dataset, split="train", vocab=None, max_length=128):
        self.data = hf_dataset[split]
        self.vocab = vocab  # Use the global reduced vocab for tokenization.
        if "[UNK]" not in self.vocab:
            self.vocab["[UNK]"] = len(self.vocab)
        self.max_length = max_length  # Limit sequence length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Use the "story" key. If it's a list, join it into a single string.
        story_field = self.data[idx]["text"]
        if isinstance(story_field, list):
            text = " ".join(story_field)
        else:
            text = story_field

        transformed = transform_story(text)
        if transformed is None:
            transformed = (text, "", "")
        prefix, missing, suffix = transformed

        # Concatenate parts with special tokens.
        full_input = f"[PREFIX] {prefix} [MISSING] {missing} [SUFFIX] {suffix}"
        tokens = full_input.split()
        # Truncate to max_length tokens.
        tokens = tokens[:self.max_length]
        indices = [self.vocab.get(token, self.vocab["[UNK]"]) for token in tokens]
        x = torch.tensor(indices, dtype=torch.long)
        return x

In [75]:
# Create datasets and dataloaders using ROCStories.
train_dataset = ROCStoriesDataset(dataset, split="train", vocab=global_vocab, max_length=128)
test_dataset = ROCStoriesDataset(dataset, split="test", vocab=global_vocab, max_length=128)

In [110]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True,
                          collate_fn=lambda x: nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=0))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False,
                        collate_fn=lambda x: nn.utils.rnn.pad_sequence(x, batch_first=True, padding_value=0))

In [111]:
print("Train dataset size:", len(train_dataset))
print("Validation dataset size:", len(test_dataset))

Train dataset size: 78528
Validation dataset size: 19633


# 4. Belief State Transformer Overview

The Belief State Transformer is designed to capture a compact representation (or *belief state*) of a partial sequence. Given a sequence $ x_{1:T} $ and a partial sequence $ s = x_{1:t} $, the model learns a vector $ v_s $ such that a randomized function $ g(v_s) $ can sample from the conditional distribution $ P(x_{t+1:T} \mid x_{1:t}) $.

The architecture generally consists of:
- **Forward Encoder:** Processes the prefix $ x_{1:t} $ to produce forward states.
- **Backward Encoder:** Processes the sequence in reverse to capture future context.
- **Text Head:** Combines the forward and backward states to predict tokens in the missing (middle) section.
- **Efficient Prefix-Suffix Loss Computation:** Computes loss over all valid prefix-suffix pairs using a specialized loss function that handles multiple pairs efficiently.

This design allows the model to perform the "fill-in-the-middle" task effectively.


# 5. Equations

For any probability distribution over a set of sequences $ P(x_{1:T}) $, and for any partial sequence $ s = x_{1:t} $, we define a vector $ v_s $ to be a **belief state** for $ s $ if there exists a randomized function $ g $ such that:

$$
g(v_s) \sim P(x_{t+1:T} \mid x_{1:t}).
$$

The efficient computation of the prefix-suffix loss is given by:

1. Compute the forward state $ F $ from the prefix and the backward state $ B $ from the reversed sequence.
2. For each valid prefix-suffix pair, extract the corresponding states.
3. Concatenate the states and pass through the text head to obtain logits.
4. Reshape and compute the cross-entropy loss over all pairs.

The pseudocode for the loss function is provided in the research paper.


# 6. Creating the Transformer Encoder

In [112]:
class SimpleTransformerEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, dropout=0.1):
        super(SimpleTransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        # Transformer expects (seq_len, batch, embed_dim)
        embedded = embedded.transpose(0, 1)
        encoded = self.transformer_encoder(embedded)  # (seq_len, batch, embed_dim)
        encoded = encoded.transpose(0, 1)
        return self.dropout(encoded)

In [113]:
transformer_encoder = SimpleTransformerEncoder(vocab_size, embed_dim=512, num_heads=8, hidden_dim=2048, num_layers=4).to(device)

# 7. Putting It All Together: Belief State Transforme

In [114]:
class BeliefStateTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(BeliefStateTransformer, self).__init__()
        # Forward encoder for the prefix
        self.enc_F = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Linear(embed_dim, embed_dim)
        )
        # Backward encoder for the suffix (applied on reversed sequence)
        self.enc_B = nn.Sequential(
            nn.Embedding(vocab_size, embed_dim),
            nn.Linear(embed_dim, embed_dim)
        )
        # Text head: combines forward and backward states to produce token logits.
        self.text_head = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim),
            nn.LeakyReLU(),
            nn.Linear(embed_dim, vocab_size * 2)
        )

    def forward(self, x):
        # x: (batch, T)
        f = self.enc_F(x)  # (batch, T, embed_dim)
        b = self.enc_B(x)  # (batch, T, embed_dim)
        return f, b

In [115]:
def belief_state_objective(all_f, all_b, text_head, x):
    """
    Efficient computation of the prefix-suffix loss.
    x: (batch, T) token indices.
    all_f: forward states from the encoder.
    all_b: backward states from the encoder.
    text_head: head network to produce logits.
    """
    bs, T = x.shape
    forward_state = all_f  # (batch, T, embed_dim)
    backward_state = all_b.flip(1)  # Reverse the sequence along T
    ft = torch.arange(T, dtype=torch.int32, device=x.device)
    bt = torch.arange(T, dtype=torch.int32, device=x.device)
    combinations = torch.cartesian_prod(ft, bt)  # All (i, j) pairs
    # Only consider pairs where j - i >= 2 and j < T
    combinations = combinations[(combinations[:, 1] - combinations[:, 0] >= 2)]
    fb_pairs = combinations.clone()
    fb_pairs = fb_pairs[combinations[:, 1] < T]
    f_idxs = fb_pairs[:, 0]
    b_idxs = fb_pairs[:, 1]
    nt_idxs = (combinations[:, 0] + 1)

    f_selected = forward_state[:, f_idxs]   # (batch, num_pairs, embed_dim)
    b_selected = backward_state[:, b_idxs]    # (batch, num_pairs, embed_dim)

    single_labels_f = x[:, nt_idxs].unsqueeze(2)  # (batch, num_pairs, 1)
    single_labels_b = x[:, b_idxs].unsqueeze(2)     # (batch, num_pairs, 1)
    single_labels = torch.cat((single_labels_f, single_labels_b), dim=2)  # (batch, num_pairs, 2)

    # Concatenate forward and backward states
    logits = text_head(torch.cat([f_selected, b_selected], dim=2))  # (batch, num_pairs, vocab_size*2)
    fb_numpairs = fb_pairs.shape[0]
    logits = logits.reshape((bs, fb_numpairs, 2, -1))  # (batch, num_pairs, 2, vocab_size)
    logits = logits.reshape((bs * fb_numpairs * 2, -1))
    single_labels = single_labels.reshape((bs * fb_numpairs * 2))
    loss = nn.CrossEntropyLoss()(logits, single_labels)
    return loss

In [116]:
model = BeliefStateTransformer(vocab_size, embed_dim=512).to(device)

# 8. Setting up Training Code

In [117]:
num_epochs = 3
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

model.train()
for epoch in range(num_epochs):
    epoch_loss = 0.0
    # Use tqdm to track progress over batches
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch = batch.to(device)

        # Forward pass: compute forward and backward states.
        f, b = model(batch)

        # Detach f and b for the objective computation, but retain gradients for backpropagation.
        _f = f.detach().requires_grad_()
        _b = b.detach().requires_grad_()

        optimizer.zero_grad()
        loss = belief_state_objective(_f, _b, model.text_head, batch)
        loss.backward()

        # Backpropagate the gradients to update encoders (simulate two-stage gradient update)
        f.backward(_f.grad)
        b.backward(_b.grad)

        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}")

Epoch 1/3:   0%|          | 111/78528 [00:02<33:59, 38.46it/s]


KeyboardInterrupt: 