# Beyond Human Data: Scaling Self-Training for Problem-Solving with Language Models
(Link: https://arxiv.org/pdf/2312.06585) \\
In this noteboook, we study key ideas presented in the ResTEM paper. Specifically, we focus on implementing one iteration of the Expectation Maximization Algorithm presented in the paper. The Expectation Maximization algorithm uses self-training to iteratively refine a language model’s ability to solve problems beyond human-annotated data. The key steps in the algorithm involve:

Expectation (E-step): The model generates multiple candidate solutions for a given problem and assigns confidence scores to them based on its current knowledge.
Maximization (M-step): A selection mechanism, such as majority voting or probabilistic reweighting, determines the most plausible solutions. These solutions are then used to fine-tune the model, improving its performance in subsequent iterations.
In this notebook, we will:

* Implement one full iteration of the ResTEM self-training loop.
* Generate candidate solutions using a base language model.
* Assign confidence scores to each generated solution.
* Select the most plausible solutions to update the model.

This approach enables language models to bootstrap their own learning process, extending their capabilities beyond the limitations of human-labeled datasets.

Please note that running this notebook require colab pro subscription.

First, let us install dependencies:

In [None]:
!pip install --upgrade pip
!pip install -q git+https://github.com/google-deepmind/gemma.git
!pip install rich
!pip install openai

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


Now, let us import necessary packages

In [None]:
import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = ''
os.environ["KAGGLE_KEY"] = ''

from rich import print
import os
import enum
import re
import string

import chex
import jax
import jax.numpy as jnp
import optax

import tensorflow as tf
import tensorflow_datasets as tfds

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
# Import OpenAI client for answer verification.
from openai import OpenAI

Now, we perform the following steps:

1. Download and Set Up the Gemma Model

2. Implementing the Gemma Tokenizer
A custom GemmaTokenizer class is defined to handle tokenization using SentencePieceProcessor. This class provides methods for:
  * Tokenizing input text into token IDs.
  * Converting token IDs back into text.
  * Tokenizing input text as a TensorFlow operation for integration into TF pipelines.
3. Answer Verification Using OpenAI API
This function evaluates model-generated answers against ground-truth answers using OpenAI's GPT-4. It checks correctness and returns a binary verification result.

4. Data Extraction and Utility Functions
Functions in this section handle:
  * Extracting math questions and answers from a text file.
  * Preparing inputs for model training by generating attention masks and positional encodings.
5. Forward Pass and Loss Computation \
Defines a function for performing a forward pass through the Transformer model and computing the negative log-likelihood (NLL) loss for next-token prediction.

6. Input Processing for Model Training \
A function to tokenize questions and answers, build input masks, and prepare attention masks for the model. It ensures proper formatting before feeding data into the neural network.

In [None]:
# Set GEMMA variant and download model from KaggleHub.
GEMMA_VARIANT = '2b-it'  # or '2b'
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/gemma/flax/{GEMMA_VARIANT}')
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT)
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'tokenizer.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)


# ----------------------------
# Define Gemma Tokenizer class
# ----------------------------
class GemmaTokenizer:
    def __init__(self, spm_processor: spm.SentencePieceProcessor):
        self._spm_processor = spm_processor

    @property
    def pad_id(self) -> int:
        """Fast access to the pad ID."""
        return self._spm_processor.pad_id()

    def tokenize(self, example: str, prefix: str = '', suffix: str = '', add_eos: bool = True) -> jax.Array:
        """Tokenize the input string."""
        int_list = [self._spm_processor.bos_id()]
        int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))
        if add_eos:
            int_list.append(self._spm_processor.eos_id())
        return jnp.array(int_list, dtype=jnp.int32)

    def tokenize_tf_op(self, str_tensor: tf.Tensor, prefix: str = '', suffix: str = '', add_eos: bool = True) -> tf.Tensor:
        """A TensorFlow operator for tokenization."""
        encoded = tf.numpy_function(
            self.tokenize,
            [str_tensor, prefix, suffix, add_eos],
            tf.int32)
        encoded.set_shape([None])
        return encoded

    def to_string(self, tokens: jax.Array) -> str:
        """Convert an array of tokens back to a string."""
        return self._spm_processor.EncodeIds(tokens.tolist())


# ----------------------------
# Verification Function Using OpenAI API
# ----------------------------
def verify_answers(questions_dict, model_output):
    """
    Verifies if the model output matches the correct answers.

    Args:
        questions_dict (list): A list of dictionaries containing 'question' and 'answer'.
        model_output (str): The model-generated answer as a string.

    Returns:
        dict: A dictionary with 1 if the answer is correct, 0 otherwise.
    """
    verification_results = {}
    client = OpenAI(
        api_key=""
    )
    for entry in questions_dict:
        question = entry["question"]
        correct_answer = entry["answer"]
        model_answer = model_output
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "You are a math verifier that checks if an answer is correct."},
                {"role": "user", "content": f"Question: {question}\nCorrect Answer: {correct_answer}\nModel Answer: {model_answer}\nIs the model answer correct? Respond with only '1' for correct or '0' for incorrect."}
            ]
        )
        print("Verification response:", response)
        print("\nVerification verdict:", response.choices[0].message.content)
        verification_results[question] = int(response.choices[0].message.content)
    return verification_results


# ----------------------------
# Data and Utility Functions
# ----------------------------
def extract_math_questions(file_path):
    """
    Extract math questions and answers from a text file.
    Assumes every question is immediately followed by its answer.
    """
    questions_list = []
    with open(file_path, 'r') as file:
        lines = file.readlines()
        for i in range(0, len(lines) - 1, 2):
            question = lines[i].strip()
            answer = lines[i + 1].strip()
            questions_list.append({"question": question, "answer": answer})
    return questions_list


def forward_and_loss_fn(params, *, model: transformer_lib.Transformer,
                        input_tokens: jax.Array,  # Shape [B, L]
                        input_mask: jax.Array,    # Shape [B, L]
                        positions: jax.Array,     # Shape [B, L]
                        attention_mask: jax.Array # Shape [B, L, L]
                        ) -> jax.Array:
    """
    Forward pass and negative log-likelihood loss for next-token prediction.
    """
    logits, _ = model.apply(params, input_tokens, positions, None, attention_mask)
    logits = logits[0, :-1]  # Exclude the last time step.
    target_tokens = input_tokens[0, 1:]
    target_mask = input_mask[0, 1:]
    one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])
    one_hot = one_hot * target_mask.astype(one_hot.dtype)[..., None]
    norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)
    loss = -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor
    return loss


def get_attention_mask_and_positions(example: jax.Array, pad_id: int) -> tuple[jax.Array, jax.Array]:
    """
    Build positional encodings and a causal attention mask for the given tokens.
    """
    pad_mask = example != pad_id
    current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
    attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
    return current_token_position, attention_mask


def prepare_input(question: str, answer: str, tokenizer: GemmaTokenizer):
    """
    Concatenate question and answer, tokenize, and build input masks and positions.
    """
    text = question + "\n" + answer
    text = text[:15]
    tokens = tokenizer.tokenize(text)
    input_mask = (tokens != tokenizer.pad_id).astype(jnp.int32)
    input_tokens = tokens[None, :]
    input_mask = input_mask[None, :]
    positions, attn_mask = get_attention_mask_and_positions(input_tokens, tokenizer.pad_id)
    return input_tokens, input_mask, positions, attn_mask



Download Math Dataset and inspect a sample of data

In [None]:

# ----------------------------
# Download MATH Dataset and Setup
# ----------------------------
!git clone https://github.com/deepmind/mathematics_dataset
!pip install --upgrade mathematics_dataset/
!python -m mathematics_dataset.generate --filter=linear_1d
!python -m mathematics_dataset.generate_to_file --output_dir "./math_dataset_output"

# Set the path to your training file (update if needed)
file_path = "./math_dataset_output/train-easy/algebra__linear_1d.txt"
questions = extract_math_questions(file_path)
print("Loaded questions:", questions)



fatal: destination path 'mathematics_dataset' already exists and is not an empty directory.
Processing ./mathematics_dataset
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: mathematics_dataset
  Building wheel for mathematics_dataset (setup.py) ... [?25l[?25hdone
  Created wheel for mathematics_dataset: filename=mathematics_dataset-1.0.1-py3-none-any.whl size=93932 sha256=a1af7ab17ed4d72a8d1228ca58aa88d032814a66018f7101e0c0f9699cc6f19b
  Stored in directory: /tmp/pip-ephem-wheel-cache-7u7_riir/wheels/b8/02/8d/7bfff952dfa10b7814bb55538d72f0af0ce01b92a7ccce8f00
Successfully built mathematics_dataset
Installing collected packages: mathematics_dataset
  Attempting uninstall: mathematics_dataset
    Found existing installation: mathematics_dataset 1.0.1
    Uninstalling mathematics_dataset-1.0.1:
      Successfully uninstalled mathematics_dataset-1.0.1
Successfully installed mathematics_dataset-1.0.1
[1mtrain/algebra__linear_1d[0m
 Solve 18*d

In [None]:
questions = extract_math_questions(file_path)
for q in questions:
  print("Question:",q['question'])
  print("Correct Answer:", q["answer"])

#  Load Model, Tokenizer, and Create Sampler

In [None]:
# ----------------------------
# Load Model, Tokenizer, and Create Sampler
# ----------------------------
params = params_lib.load_and_format_params(CKPT_PATH)
config_2b = transformer_lib.TransformerConfig.from_params(params, cache_size=30)
model_2b = transformer_lib.Transformer(config=config_2b)

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
tokenizer = GemmaTokenizer(vocab)
sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params = params['transformer'],
)

  sampler = sampler_lib.Sampler(


# Generate Outputs from the model for questions in the MATH dataset.

In [None]:

model_outputs = []
for q in questions:
  model_outputs.append(sampler(
      [q['question']],
      total_generation_steps=200,
      ).text)
  print("Question:", q)
  print("Model Output:")
  print(sampler(
      [q['question']],
      total_generation_steps=200,
      ).text[0])


# IMPLEMENTATION OF REST-EM ALGORITHM

The ReST-EM Training Loop implements one iteration of the Expectation-Maximization (EM) algorithm in the ResTEM framework for self-training language models. The loop consists of two main phases:

1. Expectation Step (E-Step):Generating and Evaluating Samples \
For each question in the dataset, multiple candidate answers are generated using the language model.
Each generated answer is evaluated using an external verification system (e.g., OpenAI's GPT-4) to determine if it is correct.
Only the correctly verified answers (reward = 1) are retained for further training.

2. Maximization Step (M-Step): Fine-Tuning the Model \
The model is updated using reward-weighted negative log-likelihood loss.
Gradients are computed for each verified (rewarded) sample and accumulated.
The model parameters are updated using stochastic gradient descent (SGD) with the computed gradients. If no correct samples are found, the update step is skipped for that iteration.


In [None]:
# ----------------------------
# ReST-EM Training Loop
# ----------------------------

num_iterations = 1
num_samples_per_question = 3  # number of generated samples per question in the E-step
learning_rate = 1e-4
optimizer = optax.sgd(learning_rate)
opt_state = optimizer.init(params)

print("\nStarting ReST-EM training loop...")
for iteration in range(num_iterations):
    print(f"\n=== ReST-EM Iteration {iteration} ===")
    D_i = []  # Will store tuples (question, generated_answer, reward)
    # E-step: Generate dataset D_i by sampling outputs and verifying each one.
    for entry in questions:
        q_text = entry["question"]
        correct_answer = entry["answer"]
        generated_samples = []
        for _ in range(num_samples_per_question):
            # Generate one sample.
            output_text = sampler([q_text], total_generation_steps=200).text[0]
            # Use the OpenAI-based verifier.
            verif_result = verify_answers([{"question": q_text, "answer": correct_answer}], output_text)
            r = verif_result[q_text]
            print(f"Verification for question: {q_text}\nGenerated Answer: {output_text}\nReward: {r}")
            generated_samples.append((q_text, output_text, r))
        # Keep only samples with reward 1.
        for sample in generated_samples:
            if sample[2] == 1:
                D_i.append(sample)

    print(f"Iteration {iteration}: Generated {len(D_i)} correct samples.")
    if len(D_i) == 0:
        print("No correct samples generated; skipping improvement step this iteration.")
        continue

    # M-step: Fine-tune the model using the reward-weighted negative log-likelihood loss.
    def loss_for_sample(p, question_text, answer_text):
        input_tokens, input_mask, positions, attn_mask = prepare_input(question_text, answer_text, tokenizer)
        return forward_and_loss_fn(params={'params': p['transformer']}, model=model_2b, input_tokens=input_tokens,
                                   input_mask=input_mask, positions=positions, attention_mask=attn_mask)

    grads_sum = None
    loss_sum = 0.0
    sample_count = 0

    for (q_text, a_text, reward) in D_i:
        if reward == 0:
            continue
        #loss_val, grads = jax.value_and_grad(loss_for_sample)(params['transformer'], q_text, a_text)
        loss_val, grads = jax.value_and_grad(lambda p: reward * loss_for_sample(p, q_text, a_text))(params)
        if grads_sum is None:
            grads_sum = grads
        else:
            grads_sum = jax.tree_map(lambda a, b: a + b, grads_sum, grads)
        loss_sum += loss_val
        sample_count += 1
    del grads

    if sample_count > 0:
        avg_grads = jax.tree_map(lambda x: x / sample_count, grads_sum)
        updates, opt_state = optimizer.update(avg_grads, opt_state)
        params = optax.apply_updates(params, updates)
        avg_loss = loss_sum / sample_count
        print(f"Iteration {iteration}: Updated model with average loss: {avg_loss}")
    else:
        print("No samples with reward 1 found for gradient update in this iteration.")

print("\nReST-EM training loop complete.")


  grads_sum = jax.tree_map(lambda a, b: a + b, grads_sum, grads)
  grads_sum = jax.tree_map(lambda a, b: a + b, grads_sum, grads)
  grads_sum = jax.tree_map(lambda a, b: a + b, grads_sum, grads)
  grads_sum = jax.tree_map(lambda a, b: a + b, grads_sum, grads)
  grads_sum = jax.tree_map(lambda a, b: a + b, grads_sum, grads)
  avg_grads = jax.tree_map(lambda x: x / sample_count, grads_sum)


# Inference: Generate answers for the training questions using the updated model.

In [None]:
# ----------------------------
# Inference: Generate answers for the training questions using the updated model.
# ----------------------------
print("\nFinal model generation on training questions:")
final_outputs = []
for entry in questions:
    q = entry["question"]
    out = sampler([q], total_generation_steps=200).text[0]
    final_outputs.append(out)
    print("Q:", q)
    print("A:", out)
    print("------------------")

# Optionally, collect a final set of training data by verifying the model outputs.
training_data = []
i = 0
for output in final_outputs:
    quest = [questions[i]]
    verif = verify_answers(quest, output)
    if list(verif.values())[0] == 1:
        training_data.append(questions[i])
    i += 1

print("Training Data (Correct Answers):")
print(training_data)


That's it !