In [1]:
import os

from eval_utils import display_prompt, display_responses
from load_torchtune_ds import load_gutenberg_dataset

from vllm import LLM, SamplingParams
from torch.utils.data import DataLoader
from torchtune.dev.grpo.data import padded_collate_rl
from torchtune import config
from torchtune.config._utils import _get_component_from_path

from omegaconf import DictConfig

from IPython.display import display, HTML

  from .autonotebook import tqdm as notebook_tqdm


INFO 03-22 21:41:26 __init__.py:190] Automatically detected platform cuda.


2025-03-22 21:41:26,471	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
world_size = 1
rank = 1
batch_size = 2
grpo_size = 2

### NOTE: this is what we are replacing.
# cfg_dataset = DictConfig({
#     '_component_': 'torchtune.dev.grpo.gsm8k.gsm8k_dataset',
#     'partition': '3-5/100' 
# })

root_path = os.path.expanduser('~/dev/nebius-experiments/projects/torchtune/trained_models/')

cfg_tokenizer = DictConfig({
    '_component_': 'torchtune.models.llama3.llama3_tokenizer',
    'path': os.path.join(root_path, 'Llama3_3_70B_GRPOd_gsm8k_default_reward/original/tokenizer.model'),
    'max_seq_len': 'null'
})
collate_fn = 'torchtune.dev.grpo.data.padded_collate_rl'

tokenizer = config.instantiate(cfg_tokenizer)
collate_fn = _get_component_from_path(collate_fn)

In [3]:
# Create the dataset for historical context reasoning
data_path = os.path.join(os.getcwd(), "gutenberg_dataset")
if not os.path.exists(data_path):
    raise ValueError("Did you run the download.py script?")

dataset = load_gutenberg_dataset(tokenizer, data_path=data_path)

Loaded 144 passages from 144 files


Saving the dataset (1/1 shards): 100%|██████████| 144/144 [00:00<00:00, 33506.03 examples/s]


In [4]:
# Create a dataloader
dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda batch: padded_collate_rl(
        batch,
        padding_idx=tokenizer.pad_id,
        ignore_idx=-100,  # CROSS_ENTROPY_IGNORE_IDX
    ),
)

In [5]:
batch = next(dataloader._get_iterator())
batch

{'tokens': tensor([[128000,     32,  21765,  ...,  22103,     25,    220],
         [128000,     32,  21765,  ..., 128004, 128004, 128004]]),
 'answers': ['enlightenment (1725)', 'enlightenment (1725)']}

In [6]:
batch = next(dataloader._get_iterator())
tokens = batch["tokens"]         # tokenized prompts
answers = batch["answers"]       # untokenized answers
tokens = tokens                  # [batch_size x num_tokens_per_prompt]
tokens_ls = tokens.tolist()
out = []
_prompts = []
_answers = []
for i in range(tokens.shape[0]):
    prompt = tokenizer.decode(tokens_ls[i])
    _prompts.extend([prompt] * grpo_size) 
    answer = answers[i]
    _answers.extend([answer] * grpo_size)
    # display(HTML(display_prompt(
    #     prompt, 
    #     answer, 
    #     tokenizer
    # )))

In [7]:
from xml.etree import ElementTree as ET
from typing import Tuple
import torch
from torchtune.modules.transforms.tokenizers import ModelTokenizer
import re

def extract_tags(text: str) -> dict[str, list[str]]:
    """
    Parse XML-like tags from text. Returns a dictionary with keys 'think', 'answer_era', and 'answer_date'.
    The values are lists of strings, with each string being the content of a tag.
    """
    xml_string = f"<root>{text}</root>"
    try:
        root = ET.fromstring(xml_string)
        return {
            "think": [elem.text if elem.text is not None else "" for elem in root.findall("think")],
            "answer_era": [elem.text if elem.text is not None else "" for elem in root.findall("answer_era")],
            "answer_date": [elem.text if elem.text is not None else "" for elem in root.findall("answer_date")]
        }
    except ET.ParseError:
        return {"think": [], "answer_era": [], "answer_date": []}

def shaped_correctness_reward(answer: str, completion: str) -> tuple[float, float]:
    """
    Reward function for verifiable rewards with shaping for era and date identification.

    Args:
        answer (str): Ground-truth answer in the format "era (date)"
        completion (str): Model's completion
    Returns:
        reward: (float) A shaped reward for format and correctness
        success: (float) A binary measure of success (1 if fully successful, 0 otherwise)
    """
    reward = 0.0
    success = 0.0
    
    # Parse the ground truth era and date
    gt_match = re.match(r'([a-z]+)\s*\((\d+)\)', answer.lower())
    if gt_match:
        gt_era = gt_match.group(1)
        gt_date = gt_match.group(2)
    else:
        # Fallback if parsing fails
        gt_era = answer
        gt_date = ""
    
    # Extract tags from completion
    tags = extract_tags(completion)
    
    # Format rewards - consistent with the original example
    if len(tags["think"]) == 1:
        reward += 5.0  # Reward for having a thinking section
    
    if len(tags["answer_era"]) == 1:
        reward += 2.5  # Reward for having an era answer tag
    
    if len(tags["answer_date"]) == 1:
        reward += 2.5  # Reward for having a date answer tag
    
    # Correctness rewards for era
    if tags["answer_era"] and any(gt_era == attempt.lower() for attempt in tags["answer_era"]):
        # One of the answer_era tags has the exact right era
        reward += 20.0
    elif tags["answer_era"] and any(gt_era in attempt.lower() for attempt in tags["answer_era"]):
        # One of the answer_era tags contains the right era as a substring
        reward += 10.0
    
    # Correctness rewards for date
    if gt_date and tags["answer_date"]:
        try:
            gt_year = int(gt_date)
            for attempt in tags["answer_date"]:
                if attempt.isdigit():
                    attempt_year = int(attempt)
                    year_diff = abs(gt_year - attempt_year)
                    
                    if year_diff == 0:
                        # Exact date match
                        reward += 20.0
                        break
                    elif year_diff <= 20:
                        # Within 20 years
                        reward += 15.0
                        break
                    elif year_diff <= 50:
                        # Within 50 years
                        reward += 10.0
                        break
        except ValueError:
            pass  # Ignore non-numeric date values
    
    # Full success - both era and date are correct
    if tags["answer_era"] and tags["answer_date"] and \
       len(tags["answer_era"]) > 0 and tags["answer_era"][-1].lower() == gt_era and \
       len(tags["answer_date"]) > 0 and tags["answer_date"][-1].isdigit() and abs(int(tags["answer_date"][-1]) - int(gt_date)) <= 20:
        reward = 100.0
        success = 1.0
    
    return reward, success

def batch_shaped_correctness_reward(
    tokenizer: ModelTokenizer, completions: torch.Tensor, answers: list[str]
) -> Tuple[torch.Tensor]:
    """Utility function to apply the shaped reward function to a GRPO-style batch of completions."""

    batch_size, grpo_size, *_ = completions.shape
    rewards = torch.zeros(batch_size, grpo_size, dtype=torch.float32)
    successes = torch.zeros(batch_size, grpo_size, dtype=torch.float32)
    # completions :: [B, G, L]
    for b in range(batch_size):
        for g in range(grpo_size):
            text_completion = tokenizer.decode(
                completions[b, g].tolist()
            )  # skips special tokens, stops at eos
            reward, success = shaped_correctness_reward(
                answer=answers[b], completion=text_completion
            )
            rewards[b, g] = reward
            successes[b, g] = success

    return rewards, successes

### Load model

In [8]:
path = '/tmp/Llama-3.2-3B-Instruct/'
llm = LLM(
    model=path, 
    task="generate", 
    trust_remote_code=True,
    # tensor_parallel_size=1,
    dtype='bfloat16'
)

INFO 03-22 21:41:31 config.py:1556] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 03-22 21:41:31 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.2) with config: model='/tmp/Llama-3.2-3B-Instruct/', speculative_config=None, tokenizer='/tmp/Llama-3.2-3B-Instruct/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=/tmp/Llama-3.2-3B-Instruct/, num_scheduler_steps=1, multi_step_stream_outputs=True,

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:00<00:00,  1.20it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.86it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.72it/s]



INFO 03-22 21:41:34 model_runner.py:1115] Loading model weights took 6.0160 GB
INFO 03-22 21:41:35 worker.py:267] Memory profiling takes 0.50 seconds
INFO 03-22 21:41:35 worker.py:267] the current vLLM instance can use total_gpu_memory (79.10GiB) x gpu_memory_utilization (0.90) = 71.19GiB
INFO 03-22 21:41:35 worker.py:267] model weights take 6.02GiB; non_torch_memory takes 0.15GiB; PyTorch activation peak memory takes 1.21GiB; the rest of the memory reserved for KV Cache is 63.81GiB.
INFO 03-22 21:41:35 executor_base.py:110] # CUDA blocks: 37339, # CPU blocks: 2340
INFO 03-22 21:41:35 executor_base.py:115] Maximum concurrency for 131072 tokens per request: 4.56x
INFO 03-22 21:41:37 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_u

Capturing CUDA graph shapes: 100%|██████████| 35/35 [00:10<00:00,  3.44it/s]

INFO 03-22 21:41:47 model_runner.py:1562] Graph capturing finished in 10 secs, took 0.29 GiB
INFO 03-22 21:41:47 llm_engine.py:431] init engine (profile, create kv cache, warmup model) took 13.29 seconds





In [9]:
sampling_params = SamplingParams(
    temperature=0.8, 
    top_p=0.95,
    max_tokens=512
)
output = llm.generate(_prompts, sampling_params)

Processed prompts: 100%|██████████| 4/4 [00:02<00:00,  1.50it/s, est. speed input: 843.85 toks/s, output: 726.39 toks/s]


In [10]:
stop_token_ids = [
    128001,
    128009,
    128008
]
pad_id = 128004
max_tokens = 512

data = []
for o in output:
    out_tokens = list(o.outputs[0].token_ids)
    if len(out_tokens) < max_tokens:
        out_tokens += [pad_id] * (max_tokens - len(out_tokens))
    data.append(out_tokens)
responses=torch.tensor(data, dtype=torch.int32).reshape(batch_size, grpo_size, max_tokens)

In [11]:
responses.shape # [batch_size, grpo_size, generation_max_tokens]

torch.Size([2, 2, 512])

In [12]:
rewards, successes = batch_shaped_correctness_reward(
    tokenizer=tokenizer, 
    completions=responses, 
    answers=_answers
)

In [13]:
rewards

tensor([[100., 100.],
        [  0.,   0.]])

In [14]:
successes

tensor([[1., 1.],
        [0., 0.]])

In [15]:
advantages = (rewards - rewards.mean(1, keepdim=True)) / (
    rewards.std(1, keepdim=True) + 1e-4
)

In [16]:
advantages

tensor([[0., 0.],
        [0., 0.]])

In [None]:
display(HTML(
    display_responses(
        responses,
        tokenizer, 
        grpo_size, 
        advantages=advantages, 
        rewards=rewards, 
        successes=successes
    )
))

IndexError: index 2 is out of bounds for dimension 0 with size 2