Goal is to write GRPO from scratch and train on maths dataset

In [1]:
%pip install tf-keras # for some reason, Hugging Face cannot work without it
%pip install flash-attn # FlashAttention2
%pip install wandb # Weights and Biases
%pip install 'accelerate>=0.26.0'
%pip install transformers # Hugging Face Transformers API
%pip install datasets # Hugging Face Datasets API

Collecting tf-keras
  Downloading tf_keras-2.19.0-py3-none-any.whl.metadata (1.8 kB)
Collecting tensorflow<2.20,>=2.19 (from tf-keras)
  Downloading tensorflow-2.19.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting absl-py>=1.0.0 (from tensorflow<2.20,>=2.19->tf-keras)
  Downloading absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting astunparse>=1.6.0 (from tensorflow<2.20,>=2.19->tf-keras)
  Downloading astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow<2.20,>=2.19->tf-keras)
  Downloading flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow<2.20,>=2.19->tf-keras)
  Downloading gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow<2.20,>=2.19->tf-keras)
  Downloading google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow<2.20,>

In [2]:
# Import necessary libraries
# Basic Python libraries for various operations
import random
import copy
import re
import os
import numpy as np
import wandb

# PyTorch and related libraries for deep learning
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

# Hugging Face libraries for transformer models
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def set_random_seed(seed: int = 42):
    # Set the seed for Python's built-in random module
    random.seed(seed)
    # Set the seed for NumPy
    np.random.seed(seed)
    # Set the seed for PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Ensure deterministic behavior in cuDNN (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set random seed for reproducibility
set_random_seed(42)

In [4]:
# Set environment variables for Weights & Biases (wandb) logging
os.environ["WANDB_API_KEY"] = "3217ad321534a113446df07670c404af4d018fd0"
os.environ["WANDB_PROJECT"] = "GRPO-Qwen-1.5-Instruct-Multi-GPU"

In [5]:
# The model generates output in an XML-like format with <reasoning> and <answer> tags.
# Two functions are provided:
# 1. extract_answer_from_model_output: Extracts content within <answer> tags.
# 2. extract_answer_from_dataset: Extracts the expected answer from the GSM8K dataset using "####" as a delimiter.

SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

def extract_answer_from_model_output(text):
    # Extract the answer from the model's output
    parts = text.split("<answer>")
    if len(parts) < 2:
        return None
    last_part = parts[-1]

    if "</answer>" not in last_part:
        return None
    answer = last_part.split("</answer>")[0].strip()
    return None if answer == "..." else answer

def extract_answer_from_dataset(text):
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

In [6]:
# Prepare the dataset for training
def prepare_dataset(split="train"):
    # Load the GSM8K dataset
    data = load_dataset('openai/gsm8k', 'main')[split]
    formatted_data = []
    for example in data:
        prompt_str = build_prompt([
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example["question"]},
        ])
        formatted_example = {
            "prompt": prompt_str,
            "answer": extract_answer_from_dataset(example["answer"]),
        }
        formatted_data.append(formatted_example)
    return formatted_data


def build_prompt(messages):
    # Build the prompt for the model
    return "\n".join([m["content"].strip() for m in messages])

In [7]:
def extract_last_number(text):
    text = text.replace('$', '').replace('%', '')
    pattern = r'(?:^|\s|=)\s*(-?\d*\.?\d+)\s*$'
    match = re.search(pattern, text)
    return float(match.group(1)) if match else None


def extract_single_number(text):
    numbers = re.findall(r'-?\d*\.?\d+', text)
    return float(numbers[0]) if len(numbers) == 1 else None

"""
Explanation:
       1. Sets the model to evaluation mode.
       2. For each example in the evaluation set:
          - Encodes the prompt and generates a response using the model.
          - Extracts the predicted answer from the generated response.
          - Compares the predicted answer with the expected answer using multiple methods:
            a. Exact string matching
            b. Single number extraction and comparison
            c. Last number extraction and comparison
          - Prints detailed information about each example.
       3. Calculates and returns the overall accuracy.
       4. Returns the model to training mode.
"""
def evaluate_model(model, tokenizer, eval_examples, device):
    model.eval() # Set the model to evaluation mode
    correct = 0 # Initialize the number of correct predictions
    total = len(eval_examples) # Get the total number of examples
    print("\n" + "="*50)
    print("Evaluation on {} examples".format(total))
    print("="*50)

    for example in eval_examples:
        # Get the full prompt and expected answer
        full_prompt = example["prompt"]
        expected = example["answer"]

        # Tokenize and generate response
        inputs = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=512, # Maximum number of tokens to generate
                temperature=0.7, # Sampling temperature
                num_return_sequences=1, # Number of sequences to generate
                pad_token_id=tokenizer.pad_token_id, # Padding token ID
                eos_token_id=tokenizer.eos_token_id, # End-of-sequence token ID
                forced_eos_token_id=tokenizer.eos_token_id, # Force the model to stop at the end-of-sequence token
                early_stopping=False, # Do not stop early
            )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)

        try:
            # Extract the answer and check for correctness
            predicted = extract_answer_from_model_output(response)

            # Try different methods for comparison
            if predicted == expected:
                is_correct = True
            else:
                # Try single number matching
                pred_num = extract_single_number(str(predicted))
                exp_num = extract_single_number(str(expected))
                if pred_num is not None and exp_num is not None and pred_num == exp_num:
                    is_correct = True
                else:
                    # Try last number matching
                    pred_num = extract_last_number(str(predicted))
                    exp_num = extract_last_number(str(expected))
                    is_correct = (pred_num is not None and exp_num is not None and pred_num == exp_num)

            if is_correct:
                correct += 1

            # Print detailed information about the example
            print("\nPrompt:")
            print(full_prompt)
            print("\nExpected Answer:")
            print(expected)
            print("\nPredicted Answer:")
            print(predicted)
            print("\nFull Generated Response:")
            print(response)
            print("\nCorrect:", "✓" if is_correct else "✗")
            print("-"*50)

        except Exception as e:
            print("\nFailed to parse model output for example:")
            print(full_prompt)
            print("Error:", e)
            print("-"*50)

    # Calculate the overall accuracy
    accuracy = (correct / total) * 100
    print(f"\nAccuracy: {correct}/{total} ({accuracy:.2f}%)")
    print("="*50)

    model.train() # Set the model back to training mode
    return accuracy

In [8]:
"""
correctness_reward:
This function assigns rewards based on whether the generated answer is correct. It compares the extracted answer from the model output with the expected answer, 
using both exact string matching and numeric equivalence checks. A exact match earns a higher reward (2.0), while a match based on numeric equivalence receives a smaller reward (1.5).
       1. Extracts the content from each completion.
       2. Extracts the answer portion from each response using extract_answer_from_model_output.
       3. Assigns rewards based on matching criteria:
          - 2.0 points for an exact match
          - 1.5 points for numeric equivalence (when values match but format differs)
          - 0.0 points for incorrect answers
       4. Tracks completion lengths for analysis.

format_reward:
This function encourages the model to adhere to the desired XML-like output format. It provides a small reward for the presence of the <reasoning>, </reasoning>, <answer>, and </answer> tags in the generated text.
We use a relatively value of 0.05 for each of the four pieces because the model is already capable of using these tags from previous supervised finetuning step, so we give this small reward 
so that it doesn't forget to do that because of the RL updates.
       1. Extracts the content from each completion.
       2. Evaluates format compliance by checking for required XML tags:
          - 0.2 points for each tag present (<reasoning>, </reasoning>, <answer>, </answer>)
          - Maximum score of 0.8 for perfect format compliance
       3. Stores and returns the format compliance scores.

combined_reward:
       1. Calculates separate rewards for correctness and format compliance.
       2. Combines the rewards with the following weights:
          - Correctness score range: 0.0 to 2.0
          - Format score range: 0.0 to 0.8
          - Total possible range: 0.0 to 2.8
       3. Returns the combined reward for each example.
"""

def correctness_reward(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    extracted = [extract_answer_from_model_output(response) for response in responses]
    rewards = []

    for r, a in zip(extracted, answer):
        if r == a:
            rewards.append(2.0)
        else:
            r_num = extract_single_number(str(r))
            a_num = extract_single_number(str(a))
            if r_num is not None and a_num is not None and r_num == a_num:
                rewards.append(1.5)
            else:
                rewards.append(0.0)
    completion_lengths = [len(response.split()) for response in responses]
    return rewards

def format_reward(completions, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    rewards = []
    format_scores = []
    for response in responses:
        score = 0.0
        if "<reasoning>" in response:
            score += 0.2
        if "</reasoning>" in response:
            score += 0.2
        if "<answer>" in response:
            score += 0.2
        if "</answer>" in response:
            score += 0.2
        rewards.append(score)
        format_scores.append(score)
    return rewards
    
def combined_reward(prompts, completions, answer):
    correctness_rewards = correctness_reward(prompts, completions, answer)
    format_rewards = format_reward(completions)
    
    combined_rewards = []
    for c_reward, f_reward in zip(correctness_rewards, format_rewards):
        combined_rewards.append(c_reward + f_reward)
    
    return combined_rewards

In [9]:
# TODO: GRPO Code
# we implement all the building blocks of the GRPO algorithm from scratch. The implementation assumes that the machine running the code has at least 2 GPUs. 
# We use PyTorch's DataParallel API to distribute the policy model across the GPU cores, one copy of the model per GPU core. The batch is split between the GPU cores.

def selective_log_softmax(logits, input_ids):
    """
    Computes log probabilities for specific tokens in the vocabulary.

    Args:
        logits (torch.Tensor): The raw logits output from the model.
        input_ids (torch.Tensor): The token IDs for which we want the log probabilities.

    Returns:
        torch.Tensor: Log probabilities of the selected tokens.

    Explanation:
        1. Applies log softmax to convert logits to log probabilities over the vocabulary.
        2. Uses gather to extract only the log probabilities corresponding to the input_ids.
        3. Removes the extra dimension to match the original shape of input_ids.
    """
    log_probs = nn.functional.log_softmax(logits, dim=-1)
    return log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def compute_log_probs(model, input_ids, attention_mask, logits_to_keep):
    """
    Computes the log probabilities for a batch of tokens.

    Args:
        model: The language model.
        input_ids (torch.Tensor): Token IDs for input sequences.
        attention_mask (torch.Tensor): Attention mask for input sequences.
        logits_to_keep (int): Number of tokens to keep from the end of the sequence.

    Returns:
        torch.Tensor: Log probabilities of the selected tokens.

    Explanation:
        1. Gets logits from the model for the input sequence.
        2. Selects logits for all tokens except the last one (as we predict next tokens).
        3. Selects only the last 'logits_to_keep' tokens from both logits and input_ids.
        4. Computes log probabilities for these tokens using selective_log_softmax.
    """
    logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[:, :-1, :]
    input_ids = input_ids[:, -logits_to_keep:]
    logits = logits[:, -logits_to_keep:, :]
    return selective_log_softmax(logits, input_ids)

def create_completion_mask(completion_ids, eos_token_id):
    """
    Creates a mask for completion tokens that excludes tokens after the EOS token.

    Args:
        completion_ids (torch.Tensor): Token IDs of the generated completions.
        eos_token_id (int): The ID of the end-of-sequence token.

    Returns:
        torch.Tensor: A binary mask with 1s for valid tokens and 0s after the EOS token.

    Explanation:
        1. Identifies positions where EOS tokens occur in each sequence.
        2. Finds the index of the first EOS token in each sequence.
        3. Creates a mask where positions before and including the first EOS are 1, others are 0.
        4. If no EOS token is found in a sequence, all positions are set to 1.
    """
    is_eos = completion_ids == eos_token_id
    eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device)
    mask_exists = is_eos.any(dim=1)
    eos_idx[mask_exists] = is_eos.int().argmax(dim=1)[mask_exists]
    sequence_indices = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1)
    return (sequence_indices <= eos_idx.unsqueeze(1)).int()

def generate_completions(model, tokenizer, prompts, num_generations=4, max_completion_length=32):
    """
    Generates multiple completions for each prompt.

    Args:
        model: The language model.
        tokenizer: The tokenizer for encoding and decoding text.
        prompts (list): List of text prompts.
        num_generations (int): Number of completions to generate per prompt.
        max_completion_length (int): Maximum number of tokens to generate.

    Returns:
        tuple: Containing prompt IDs, prompt mask, completion IDs, and completion mask.

    Explanation:
        1. Encodes the prompts and moves them to the appropriate device.
        2. Repeats each prompt num_generations times to generate multiple completions.
        3. Generates completions using the model with specified parameters.
        4. Extracts the completion IDs (excluding the prompt tokens).
        5. Creates a mask for the completions using create_completion_mask.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
    prompt_ids = inputs["input_ids"].to(device)
    prompt_mask = inputs["attention_mask"].to(device)
    print(f"Input batch size: {prompt_ids.size(0)}, Device before model: {prompt_ids.device}")
    prompt_length = prompt_ids.size(1)
    prompt_ids = prompt_ids.repeat_interleave(num_generations, dim=0)
    prompt_mask = prompt_mask.repeat_interleave(num_generations, dim=0)
    outputs = model.generate(
        prompt_ids,
        attention_mask=prompt_mask,
        max_new_tokens=max_completion_length,
        do_sample=True,
        temperature=1.0,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        early_stopping=False
    )
    print(f"Output batch size: {outputs.size(0)}, Device after model: {outputs.device}")
    completion_ids = outputs[:, prompt_length:]
    completion_mask = create_completion_mask(completion_ids, tokenizer.eos_token_id)
    return prompt_ids, prompt_mask, completion_ids, completion_mask

def generate_rollout_data(model, ref_model, tokenizer, batch_samples, num_generations, max_completion_length):
    """
    Generates data for GRPO rollouts including completions and log probabilities.

    Args:
        model: The policy model being trained.
        ref_model: The reference model for KL divergence calculation.
        tokenizer: The tokenizer for encoding and decoding text.
        batch_samples (list): Batch of training samples.
        num_generations (int): Number of completions to generate per sample.
        max_completion_length (int): Maximum completion length.

    Returns:
        dict: Dictionary containing all data needed for GRPO updates.

    Explanation:
        1. Extracts prompts and expected answers from the batch samples.
        2. Generates completions using the current policy model.
        3. Combines prompt and completion tokens.
        4. Computes log probabilities from both the policy model and reference model.
        5. Formats completions for reward calculation.
        6. Repeats prompts and answers to match the number of generated completions.
        7. Returns all data needed for GRPO loss calculation.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    prompts = [sample["prompt"] if isinstance(sample, dict) else sample[0] for sample in batch_samples]
    answers = [sample["answer"] if isinstance(sample, dict) else sample[1] for sample in batch_samples]
    with torch.no_grad():
        prompt_ids, prompt_mask, completion_ids, completion_mask = generate_completions(
            model, tokenizer, prompts, num_generations, max_completion_length
        )
        input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
        logits_to_keep = completion_ids.size(1)
        old_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep)
        ref_log_probs = compute_log_probs(ref_model, input_ids, attention_mask, logits_to_keep)
    formatted_completions = [[{'content': tokenizer.decode(ids, skip_special_tokens=True)}] for ids in completion_ids]
    repeated_prompts = [p for p in prompts for _ in range(num_generations)]
    repeated_answers = [a for a in answers for _ in range(num_generations)]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "completion_mask": completion_mask,
        "old_log_probs": old_log_probs,
        "ref_log_probs": ref_log_probs,
        "formatted_completions": formatted_completions,
        "repeated_prompts": repeated_prompts,
        "repeated_answers": repeated_answers,
        "logits_to_keep": logits_to_keep,
        "batch_size": len(prompts),
        "num_generations": num_generations
    }

def grpo_loss(model, ref_model, rollout_data, tokenizer, reward_function, beta=0.01, epsilon=0.2):
    """
    Computes the GRPO loss for updating the policy model.

    Args:
        model: The policy model being trained.
        ref_model: The reference model for KL divergence calculation.
        rollout_data (dict): Data generated by generate_rollout_data.
        tokenizer: The tokenizer for encoding and decoding text.
        reward_function: Function that calculates rewards for completions.
        beta (float): KL penalty coefficient.
        epsilon (float): Clipping parameter for PPO.

    Returns:
        torch.Tensor: The GRPO loss to be minimized.

    Explanation:
        1. Computes current token log probabilities using the policy model.
        2. Calculates the probability ratio between current and old policies.
        3. Computes rewards using the provided reward_function.
        4. Calculates advantages by standardizing rewards within each prompt.
        5. Computes the PPO surrogate objective with clipping.
        6. Calculates the KL divergence between reference and policy models.
        7. Combines surrogate loss and KL penalty.
        8. Averages the loss across all tokens and batches.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    input_ids = rollout_data["input_ids"]
    attention_mask = rollout_data["attention_mask"]
    completion_mask = rollout_data["completion_mask"]
    logits_to_keep = rollout_data["logits_to_keep"]
    old_log_probs = rollout_data["old_log_probs"]
    ref_log_probs = rollout_data["ref_log_probs"]
    token_log_probs = compute_log_probs(model, input_ids, attention_mask, logits_to_keep)
    ratio = torch.exp(token_log_probs - old_log_probs)
    rewards = torch.tensor(
        reward_function(prompts=rollout_data["repeated_prompts"], completions=rollout_data["formatted_completions"], answer=rollout_data["repeated_answers"]),
        dtype=torch.float32,
        device=device
    )
    #print(f"Rewards: {rewards}")  # Debug rewards
    batch_size = rollout_data["batch_size"]
    num_generations = rollout_data["num_generations"]
    rewards = rewards.view(batch_size, num_generations)
    avg_reward = rewards.mean().item()
    print("Average Reward:", avg_reward)
    mean_rewards = rewards.mean(dim=1).repeat_interleave(num_generations)
    std_rewards = rewards.std(dim=1).repeat_interleave(num_generations)
    advantages = ((rewards.view(-1) - mean_rewards) / (std_rewards + 1e-4)).unsqueeze(1)
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
    surrogate_loss = torch.min(surr1, surr2)
    kl = torch.exp(ref_log_probs - token_log_probs) - (ref_log_probs - token_log_probs) - 1
    per_token_loss = surrogate_loss - beta * kl
    loss = -((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
    return loss, avg_reward

def train_with_grpo(model, tokenizer, train_data, num_iterations=1, num_steps=500, batch_size=4,
                              num_generations=4, max_completion_length=128, beta=0.1,
                              learning_rate=5e-6, mu=3, epsilon=0.2, reward_function=None, device_ids=None):
    """
    This function is your original working code (train_with_grpo_static)
    with an added outer loop for iterative GRPO updates per the pseudocode.

    Args:
        model: The language model to train.
        tokenizer: The tokenizer for encoding and decoding text.
        train_data (list): Training dataset.
        num_iterations (int): Number of outer iterations (reference model updates).
        num_steps (int): Number of batch updates per iteration.
        batch_size (int): Number of prompts per batch.
        num_generations (int): Number of completions per prompt.
        max_completion_length (int): Maximum token length for completions.
        beta (float): KL penalty coefficient.
        learning_rate (float): Learning rate for optimizer.
        mu (int): Number of policy updates per batch.
        epsilon (float): PPO clipping parameter.
        reward_function: Function that calculates rewards for completions.
        device_ids (list): List of GPU device IDs for DataParallel.

    Returns:
        The trained model.

    Explanation:
        1. For each outer iteration:
           - Creates a reference model as a deep copy of the current policy model.
           - Reinitializes the optimizer for the policy model.
           - For each training step:
             a. Samples a batch of examples from the training data.
             b. Generates rollout data including completions and log probabilities.
             c. For mu iterations:
                i. Computes the GRPO loss.
                ii. Updates the policy model using gradient descent.
           - Monitors GPU memory usage and prints progress information.
    """
    assert device_ids is not None and len(device_ids) > 1, "This code needs at least 2 GPU cores to run!"

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Wrap model with DataParallel if multiple GPUs are available.

    model = nn.DataParallel(model, device_ids=device_ids)
    print(f"Model wrapped with DataParallel across GPUs: {device_ids}")

    # Outer loop: iterative GRPO updates.
    for iteration in range(num_iterations):
        print(f"\nIteration {iteration+1}/{num_iterations}")

        # Create a reference model (deep copy) and set it to eval mode.
        ref_model = copy.deepcopy(model.module)
        ref_model.eval()
        for param in ref_model.parameters():
            param.requires_grad = False
        print("Reference model created.")

        # Reinitialize the optimizer for this iteration.
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        model.train()

        # Inner loop: your original training steps.
        for step in range(num_steps):
            batch_samples = random.sample(train_data, batch_size)
            with torch.no_grad():
                rollout_data = generate_rollout_data(
                    model.module,
                    ref_model,
                    tokenizer,
                    batch_samples,
                    num_generations,
                    max_completion_length
                )
            for grpo_iter in range(mu):
                loss, avg_reward = grpo_loss(
                    model.module,
                    ref_model,
                    rollout_data,
                    tokenizer,
                    reward_function,
                    beta=beta,
                    epsilon=epsilon
                )
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
                optimizer.step()
                # Log to wandb
                wandb.log({
                    "loss": loss.item(),
                    "average_reward": avg_reward,
                    "iteration": iteration + 1,
                    "step": step + 1,
                    "grpo_iter": grpo_iter + 1
                })
                print(f"Iteration {iteration+1}/{num_iterations}, Step {step+1}/{num_steps}, "
                      f"GRPO iter {grpo_iter+1}/{mu}, loss: {loss.item():.4f}")
                #for i in range(torch.cuda.device_count()):
                #    print(f"GPU {i} Usage: {torch.cuda.memory_allocated(i) / 1024**2:.2f} MiB, "
                #          f"Utilization: {torch.cuda.utilization(i)}%")
                # Uncomment to see the GPU utilization stats
    return model.module

In [None]:
"""
TODO:
- Many transformer models store past activations (attention KV cache) to speed up inference.
  This reduces memory usage at the cost of recomputing previous activations.
- If an input does not require gradients, it won’t contribute to the learning process.
- Enables gradient checkpointing to save memory.
  Normally, activations are stored for all layers during the forward pass to compute gradients in the backward pass.
  With gradient checkpointing:
   - Only some activations are saved.
   - Other activations are recomputed on-demand during backpropagation.
   - This significantly reduces memory consumption but slightly increases computation time.
"""

# optimize model memory
def optimize_model_memory(model):
    model.train() # Set the model to training mode
    model.config.use_cache = False # Disable KV caching

    # ensure that input requires gradients
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads() # Enable input gradients
    else:
        def make_input_requires_grad(module, input, output):
            output.requires_grad = True # Set input gradients to True
        model.get_input_embeddings().register_forward_hook(make_input_requires_grad) # Register forward hook

    model.gradient_checkpointing_enable() # Enable gradient checkpointing

    return model

In [None]:
"""
The device is determined (GPU if available, otherwise CPU).
The pre-trained Qwen2.5-1.5B-Instruct model and tokenizer are loaded. The tokenizer's pad token is set to the eos_token.
A small subset of the dataset is reserved for evaluation to provide a baseline.
The model is optimized for memory efficiency by enabling gradient checkpointing and disabling KV caching.
Step 1: The model is evaluated before fine-tuning to establish a baseline accuracy.
Step 2: Reinforcement learning fine-tuning is performed using the train_with_grpo function with our defined reward functions (format_reward and correctness_reward, combined into combined_reward). The model is trained using a multi-GPU.
Step 3: The final, fine-tuned model and tokenizer are saved to disk.
"""

# Model Execution
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using primary device: {device}")

model_name = "Qwen/Qwen2.5-1.5B-Instruct"
output_dir = "ms_math_solver_model"

print("Downloading Model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16, # Use bfloat16 for memory efficiency
    device_map="auto" # Automatically distribute model across available GPUs
)
print("Model Downloaded.")

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") # Load tokenizer, set padding side to left
tokenizer.pad_token = tokenizer.eos_token # Set the pad token to the EOS token
model.config.pad_token_id = tokenizer.eos_token_id # Set the model's token ID to the EOS token ID
model.config.eos_token_id = tokenizer.eos_token_id # Set the model's end-of-sequence token ID

num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")
device_ids = list(range(num_gpus)) if num_gpus > 1 else None

all_data = prepare_dataset("train") # Load the training dataset
random.shuffle(all_data) # Shuffle the dataset for randomness
size_of_eval_set = 30 # Size of the evaluation set
eval_data = all_data[:size_of_eval_set] # Reserve a subset for evaluation
train_data = all_data[size_of_eval_set:] # Use the remaining data for training

print("/nInitial Evaluation before the finetuning...")  
pre_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device) # Evaluate the model before fine-tuning
print(f"Initial Evaluation Accuracy before GRPO: {pre_grpo_accuracy:.2f}%")

model = optimize_model_memory(model) # Optimize the model for memory efficiency

print("\nFine-tuning the model using GRPO...")
training_config = {
    'num_iterations': 1,
    'num_steps': 500,
    'batch_size': 7,
    'num_generations': 12,
    'max_completion_length': 400,
    'beta': 0.4, # KL penalty coefficient
    'learning_rate': 5e-6, 
    'mu': 1, # Number of policy updates per batch
    'epsilon': 0.1 # PPO clipping parameter
}

# initialize wandb
wandb.init(project=os.environ["WANDB_PROJECT"], reinit=True)
print("Wandb initialized.")

model = train_with_grpo(
    model=model, 
    tokenizer=tokenizer,
    train_data=train_data,
    reward_function=combined_reward,
    device_ids=device_ids,
    **training_config 
)

wandb.finish()
print("Training completed.")

print("\nEvaluation after the finetuning...")
post_grpo_accuracy = evaluate_model(model, tokenizer, eval_data, device) # Evaluate the model after fine-tuning
print(f"Final Evaluation Accuracy after GRPO: {post_grpo_accuracy:.2f}%")

print("\nSaving the final model and tokenizer...")
model.save_pretrained("ms_math_solver_model") # Save the final model
tokenizer.save_pretrained("ms_math_solver_model") # Save the final tokenizer