<a href="https://colab.research.google.com/github/nam-2pir/docs/blob/main/colabs/GRPO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://withpi.ai"><img src="https://withpi.ai/logoFullBlack.svg" width="240"></a>

<a href="https://code.withpi.ai"><font size="4">Documentation</font></a>

<a href="https://play.withpi.ai"><font size="4">Technique Catalog</font></a>

# Reinforcement Learning GRPO

This is the companion to the RL playground

Description: Train models to more deeply learn patterns from your data.

## Install and initialize SDK

You'll need a WITHPI_API_KEY from https://play.withpi.ai.  Add it to your notebook secrets (the key symbol) on the left.

Run the cells below to install packages and load the SDK

In [1]:
%%capture
%pip install withpi httpx datasets tqdm

In [64]:
import os
import json
from pathlib import Path
from collections import defaultdict

import datasets
import httpx
from google.colab import files, userdata
from withpi import PiClient
from withpi.types import Contract

from rich.console import Console
from rich.table import Table
from rich.live import Live

console = Console()


# Load the notebook secret into the environment so the Pi Client can access it.
os.environ["WITHPI_API_KEY"] = userdata.get("WITHPI_API_KEY")
client = PiClient()


def print_contract(contract: Contract):
    """print_contract pretty-prints a contract"""
    for dimension in contract.dimensions:
        print(dimension.label)
        for sub_dimension in dimension.sub_dimensions:
            print(f"\t{sub_dimension.description}")


def print_scores(pi_scores):
    """print_scores pretty-prints a Pi Score response as a table."""
    for dimension_name, dimension_scores in pi_scores.dimension_scores.items():
        print(f"{dimension_name}: {dimension_scores.total_score}")
        for (
            subdimension_name,
            subdimension_score,
        ) in dimension_scores.subdimension_scores.items():
            print(f"\t{subdimension_name}: {subdimension_score}")
        print("\n")
    print("---------------------")
    print(f"Total score: {pi_scores.total_score}")


def save_file(filename: str, model: str):
    """save_file offers to download the model with the given filename"""
    Path(filename).write_text(model)
    files.download(filename)


def load_contract(url: str) -> Contract:
    """load_contract pulls a Contract JSON blob locally with validation."""
    resp = httpx.get(url)
    return Contract.model_validate_json(resp.content)


def generate_table(
    job_id: str, training_data: dict, is_done: bool, additional_columns: dict[str, str]
):
    """Generate a training progress table dynamically."""
    table = Table(title=f"Training Status for {job_id}")

    # Define columns
    table.add_column("Step", justify="right", style="cyan")
    table.add_column("Epoch", justify="right", style="cyan")
    table.add_column("Learning Rate", justify="right", style="cyan")
    table.add_column("Train Loss", justify="right", style="magenta")
    table.add_column("Eval Loss", justify="right", style="green")
    for header in additional_columns.keys():
        table.add_column(header, justify="right", style="black")

    def format_num(num: float | None, digits: int = 4) -> str:
        if num is None:
            return "X"
        return format(num, f".{digits}f")

    for step, data in training_data.items():
        additional_columns_data = [
            format_num(data.get(column_name, None))
            for column_name in additional_columns.values()
        ]
        table.add_row(
            str(step),
            format_num(data.get("epoch", None)),
            format_num(data.get("learning_rate", None), digits=10),
            format_num(data.get("loss", None)),
            format_num(data.get("eval_loss", None)),
            *additional_columns_data,
        )

    if not is_done:
        table.add_row("...", "", "", "", "", "")

    return table


def stream_response(job_id: str, method, additional_columns: dict[str, str]):
    """stream_response streams messages from the provided method

    method should be a Pi client object with `retrieve` and `stream_messages`
    endpoints.  This is primarily for convenience."""

    training_data = defaultdict(dict)
    is_log_console = False

    while True:
        response = method.retrieve(job_id=job_id)
        if (response.state != "QUEUED") and (response.state != "RUNNING"):
            if response.state == "DONE" and not is_log_console:
                for line in response.detailed_status:
                    try:
                        data_dict = json.loads(line)
                        training_data[data_dict["step"]].update(data_dict)
                    except Exception:
                        pass
                console.print(
                    generate_table(
                        job_id,
                        training_data,
                        is_done=True,
                        additional_columns=additional_columns,
                    )
                )
            return response

        with method.with_streaming_response.stream_messages(
            job_id=job_id, timeout=None
        ) as response:
            with Live(auto_refresh=True, console=console, refresh_per_second=4) as live:
                is_done = False
                for line in response.iter_lines():
                    if line == "DONE":
                        is_done = True
                    try:
                        data_dict = json.loads(line)
                        training_data[data_dict["step"]].update(data_dict)
                    except Exception:
                        pass
                    live.update(
                        generate_table(
                            job_id,
                            training_data,
                            is_done,
                            additional_columns=additional_columns,
                        )
                    )
                    is_log_console = True

# Load a contract and dataset

We have a pre-existing contract you can play with.


In [51]:
import datasets

tldr_contract = load_contract(
    "https://raw.githubusercontent.com/withpi/cookbook-withpi/refs/heads/main/contracts/tldr.json"
)

num_examples = 1200
tldr_data = datasets.load_dataset("withpi/tldr")["train"].select(range(num_examples))

print(tldr_data)

Dataset({
    features: ['prompt', 'completion'],
    num_rows: 1200
})


## Kick off the job

The GRPO job internally performs a 90/10 train-test split, which is why the loader is not splitting the input data.

This process takes a while, please be patient as a cloud GPU is aquired, fine tuning is performed, and a result is returned.

In [None]:
status = client.model.rl.grpo.start_job(
    contract=tldr_contract,
    examples=[{"llm_input": row["prompt"]} for row in tldr_data],
    model="LLAMA_3.2_1B",
    num_train_epochs=5,
)
print(status)

In [None]:
response = stream_response(
    # status.job_id,
    "rl_grpo_jobs:96b4d59e-16c1-4dee-beb3-f393ec9b854d",
    client.model.rl.grpo,
    additional_columns={
        "Train Pi Reward": "rewards/pi_reward_func",
        "Train Std Reward": "reward_std",
        "Eval Pi Reward": "eval_rewards/pi_reward_func",
        "Eval Std Reward": "eval_reward_std",
        "Train KL": "kl",
        "Eval KL": "eval_kl",
        "Train Completion Length": "completion_length",
        "Eval Completion Length": "eval_completion_length",
    },
)
print("GRPO model = {}".format(response.trained_models[0].model_dump_json(indent=2)))
repo_name = response.trained_models[0].hf_model_name

Output()