In [1]:
import torch

if torch.cuda.is_available():
    print(f"GPU detected: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU detected.")


GPU detected: NVIDIA GeForce RTX 3060


In [2]:
from unsloth import FastLanguageModel
from trl import GRPOConfig, GRPOTrainer
import torch

if torch.cuda.is_available():
    print(f"GPU detected: {torch.cuda.get_device_name(0)}")
else:
    print("No GPU detected.")
import os
os.environ["TORCH_LOGS"] = "recompiles"
os.environ['TORCHDYNAMO_CACHE_SIZE_LIMIT'] = '999999999'

import torch
import torch._dynamo 
torch._dynamo.config.cache_size_limit = 64

from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from pprint import pprint

import re
from datasets import load_dataset, Dataset

# FOR INFERENCES WITH VLLM
from vllm import SamplingParams

# FOR LOGGING

import wandb
wandb.login()

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 07-19 15:29:52 [__init__.py:244] Automatically detected platform cuda.
GPU detected: NVIDIA GeForce RTX 3060


[34m[1mwandb[0m: Currently logged in as: [33mjknguyen3010[0m ([33mjknguyen3010-university-of-buffalo[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
dataset = load_dataset("openai/gsm8k", "main")['train']
print(dataset)


Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})


In [8]:
pprint(dataset[0])

{'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\n'
           'Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and '
           'May.\n'
           '#### 72',
 'question': 'Natalia sold clips to 48 of her friends in April, and then she '
             'sold half as many clips in May. How many clips did Natalia sell '
             'altogether in April and May?'}


# Preprocessing 

In [25]:
# Constant
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""


In [23]:
def extract_hash_answer(text):
    # The issue is that in your test, 'text' is a tuple, not a string.
    # So 'if delim in text' is False, because 'delim' is not in the tuple.
    # You should pass a string, not a tuple, to this function.
    delim = "####"
    if isinstance(text, str) and delim in text:
        return text.split(delim)[1].strip()
    return None
        

text = 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\n'\
    'Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and '\
    'May.\n'\
    '#### 72'

expected = "72"
res = extract_hash_answer(text)
assert res == "72", f'not correct, res = {res}'

# EXTRACT XML ANSWER
def extract_xml_answer(text: str) -> str:
    import re
    match = re.search(r"<answer>\s*(.*?)\s*</answer>", text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return ""

sample_text = "<reasoning>\nSome reasoning here.\n</reasoning>\n<answer>\nThe answer is 42\n</answer>"
expected_answer = "The answer is 42"
extracted = extract_xml_answer(sample_text)
assert extracted == expected_answer, f"extract_xml_answer failed: got {extracted!r}, expected {expected_answer!r}"



In [27]:
gsm8k = dataset.map(lambda x: { # Note that this map function will keep the original features unless overriden
    "prompt": [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": x['question']},
    ],
    'answer': extract_hash_answer(x['answer'])
}) 

pprint(gsm8k[0])

{'answer': '72',
 'prompt': [{'content': '\n'
                        'Respond in the following format:\n'
                        '<reasoning>\n'
                        '...\n'
                        '</reasoning>\n'
                        '<answer>\n'
                        '...\n'
                        '</answer>\n',
             'role': 'system'},
            {'content': 'Natalia sold clips to 48 of her friends in April, and '
                        'then she sold half as many clips in May. How many '
                        'clips did Natalia sell altogether in April and May?',
             'role': 'user'}],
 'question': 'Natalia sold clips to 48 of her friends in April, and then she '
             'sold half as many clips in May. How many clips did Natalia sell '
             'altogether in April and May?'}


# Define reward functions
- correctness 
- integer reward function
- strict format func
- soft format func
- count xml tags <answer>, <reason> , penelize extra tags 

In [37]:
def correct_reward_func(prompts, completions, answer, **kwargs):
    reward = kwargs.get('reward', 2.0)
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_res = [extract_xml_answer(x) for x in responses]
    print(f'#######################\nQuestion\n{q}\nAnswer:\n{answer[0]}\nResponse:\n{responses[0]}\nExtracted:\n{extracted_res}\n')
    return [reward if r == a else 0.0 for r,a in zip(extracted_res, answer)]

# Unit tests for correct_reward_func using only assert

# Test 1: Correct answer
test_prompts = [[
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "6*7?"}
]]
test_completions = [[
    {"role": "assistant", "content": "<reasoning>6*7=42.</reasoning><answer>42</answer>"}
]]
test_answer = ["42"]
test_reward = correct_reward_func(test_prompts, test_completions, test_answer)
print("Test 1:", test_reward)
assert test_reward == [2.0], f"Test 1 failed: got {test_reward}, expected [2.0]"

# Test 2: Incorrect answer
test_completions_wrong = [[
    {"role": "assistant", "content": "<reasoning>6*7=42.</reasoning><answer>41</answer>"}
]]
test_reward_wrong = correct_reward_func(test_prompts, test_completions_wrong, test_answer)
print("Test 2:", test_reward_wrong)
assert test_reward_wrong == [0.0], f"Test 2 failed: got {test_reward_wrong}, expected [0.0]"

# Test 3: Multiple completions, mixed correctness
test_completions_multi = [
    [{"role": "assistant", "content": "<reasoning>6*7=42.</reasoning><answer>42</answer>"}],
    [{"role": "assistant", "content": "<reasoning>6*7=42.</reasoning><answer>41</answer>"}]
]
test_answers_multi = ["42", "42"]
test_reward_multi = correct_reward_func(test_prompts, test_completions_multi, test_answers_multi)
print("Test 3:", test_reward_multi)
assert test_reward_multi == [2.0, 0.0], f"Test 3 failed: got {test_reward_multi}, expected [2.0, 0.0]"

# Test 4: Custom reward value
test_reward_custom = correct_reward_func(test_prompts, test_completions, test_answer, reward=5.0)
print("Test 4:", test_reward_custom)
assert test_reward_custom == [5.0], f"Test 4 failed: got {test_reward_custom}, expected [5.0]"

# Test 5: No <answer> tag in completion
test_completions_no_tag = [[
    {"role": "assistant", "content": "<reasoning>6*7=42.</reasoning>42"}
]]
test_reward_no_tag = correct_reward_func(test_prompts, test_completions_no_tag, test_answer)
print("Test 5:", test_reward_no_tag)
assert test_reward_no_tag == [0.0], f"Test 5 failed: got {test_reward_no_tag}, expected [0.0]"

#######################
Question
6*7?
Answer:
42
Response:
<reasoning>6*7=42.</reasoning><answer>42</answer>
Extracted:
['42']

Test 1: [2.0]
#######################
Question
6*7?
Answer:
42
Response:
<reasoning>6*7=42.</reasoning><answer>41</answer>
Extracted:
['41']

Test 2: [0.0]
#######################
Question
6*7?
Answer:
42
Response:
<reasoning>6*7=42.</reasoning><answer>42</answer>
Extracted:
['42', '41']

Test 3: [2.0, 0.0]
#######################
Question
6*7?
Answer:
42
Response:
<reasoning>6*7=42.</reasoning><answer>42</answer>
Extracted:
['42']

Test 4: [5.0]
#######################
Question
6*7?
Answer:
42
Response:
<reasoning>6*7=42.</reasoning>42
Extracted:
['']

Test 5: [0.0]


In [39]:
# Reward if the model output integer as answer, not float 
def integer_reward_func(completions, **kwargs):
    reward = kwargs.get('reward', 0.5)
    responses = [comp[0]['content'] for comp in completions]
    extracted_res = [extract_xml_answer(x) for x in responses]
    return [reward if r.isdigit() else 0.0 for r in extracted_res]

# Unit tests for integer_reward_func
test_completions_1 = [[{"role": "assistant", "content": "<answer>42</answer>"}]]
assert integer_reward_func(test_completions_1) == [0.5], "Test 1 failed"
print("Test 1: Pass")

test_completions_2 = [[{"role": "assistant", "content": "<answer>abc</answer>"}]]
assert integer_reward_func(test_completions_2) == [0.0], "Test 2 failed"
print("Test 2: Pass")

test_completions_3 = [
    [{"role": "assistant", "content": "<answer>123</answer>"}],
    [{"role": "assistant", "content": "<answer>xyz</answer>"}]
]
assert integer_reward_func(test_completions_3) == [0.5, 0.0], "Test 3 failed"
print("Test 3: Pass")

test_completions_4 = [[{"role": "assistant", "content": "<answer>7</answer>"}]]
assert integer_reward_func(test_completions_4, reward=2.0) == [2.0], "Test 4 failed"
print("Test 4: Pass")

test_completions_5 = [[{"role": "assistant", "content": "<answer></answer>"}]]
assert integer_reward_func(test_completions_5) == [0.0], "Test 5 failed"
print("Test 5: Pass")

Test 1: Pass
Test 2: Pass
Test 3: Pass
Test 4: Pass
Test 5: Pass


In [42]:
def strict_format_reward_func(completions, **kwargs):
    reward = kwargs.get("reward", 0.5)
    pattern = r"<reasoning>\n.*?\n<reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]['content'] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [reward if match else 0.0 for match in matches]

# Unit tests for strict_format_reward_func
test_completions_strict_1 = [[{"role": "assistant", "content": "<reasoning>\nstep1\n<reasoning>\n<answer>\n42\n</answer>\n"}]]
assert strict_format_reward_func(test_completions_strict_1) == [0.5], "Test 1 failed"
print("strict_format_reward_func Test 1: Pass")

test_completions_strict_2 = [[{"role": "assistant", "content": "<reasoning>step1<reasoning><answer>42</answer>"}]]
assert strict_format_reward_func(test_completions_strict_2) == [0.0], "Test 2 failed"
print("strict_format_reward_func Test 2: Pass")

test_completions_strict_3 = [
    [{"role": "assistant", "content": "<reasoning>\nfoo\n<reasoning>\n<answer>\nbar\n</answer>\n"}],
    [{"role": "assistant", "content": "<reasoning>foo<reasoning><answer>bar</answer>"}]
]
assert strict_format_reward_func(test_completions_strict_3) == [0.5, 0.0], "Test 3 failed"
print("strict_format_reward_func Test 3: Pass")

test_completions_strict_4 = [[{"role": "assistant", "content": "<reasoning>\nstep1\n<reasoning>\n<answer>\n42\n</answer>\n"}]]
assert strict_format_reward_func(test_completions_strict_4, reward=2.0) == [2.0], "Test 4 failed"
print("strict_format_reward_func Test 4: Pass")

test_completions_strict_5 = [[{"role": "assistant", "content": ""}]]
assert strict_format_reward_func(test_completions_strict_5) == [0.0], "Test 5 failed"
print("strict_format_reward_func Test 5: Pass")







strict_format_reward_func Test 1: Pass
strict_format_reward_func Test 2: Pass
strict_format_reward_func Test 3: Pass
strict_format_reward_func Test 4: Pass
strict_format_reward_func Test 5: Pass


In [43]:
def soft_format_reward_func(completions, **kwargs):
    reward = kwargs.get("reward", 0.5)
    pattern = r"<reasoning>.*?<reasoning>\s<answer>.*?</answer>$"
    responses = [completion[0]['content'] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [reward if match else 0.0 for match in matches]

# Unit tests for soft_format_reward_func
test_completions_soft_1 = [[{"role": "assistant", "content": "<reasoning>foo<reasoning> <answer>bar</answer>"}]]
assert soft_format_reward_func(test_completions_soft_1) == [0.5], "Test 1 failed"
print("soft_format_reward_func Test 1: Pass")

test_completions_soft_2 = [[{"role": "assistant", "content": "<reasoning>foo<reasoning><answer>bar</answer>"}]]
assert soft_format_reward_func(test_completions_soft_2) == [0.0], "Test 2 failed"
print("soft_format_reward_func Test 2: Pass")

test_completions_soft_3 = [
    [{"role": "assistant", "content": "<reasoning>abc<reasoning> <answer>xyz</answer>"}],
    [{"role": "assistant", "content": "<reasoning>abc<reasoning><answer>xyz</answer>"}]
]
assert soft_format_reward_func(test_completions_soft_3) == [0.5, 0.0], "Test 3 failed"
print("soft_format_reward_func Test 3: Pass")

test_completions_soft_4 = [[{"role": "assistant", "content": "<reasoning>foo<reasoning> <answer>bar</answer>"}]]
assert soft_format_reward_func(test_completions_soft_4, reward=2.0) == [2.0], "Test 4 failed"
print("soft_format_reward_func Test 4: Pass")

test_completions_soft_5 = [[{"role": "assistant", "content": ""}]]
assert soft_format_reward_func(test_completions_soft_5) == [0.0], "Test 5 failed"
print("soft_format_reward_func Test 5: Pass")

soft_format_reward_func Test 1: Pass
soft_format_reward_func Test 2: Pass
soft_format_reward_func Test 3: Pass
soft_format_reward_func Test 4: Pass
soft_format_reward_func Test 5: Pass


In [48]:
#reward counts XML tags, penalize extra content
def count_xml(text):
    count = 0.0

    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")) * 0.001 #extra content after /answer is penalized
    if text.count("\n</answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")) * 0.001 #extra content after /answer is penalized
    return count

def xml_count_reward_func(completions, **kwargs):
    contents = [completion[0]['content'] for completion in completions]
    return [count_xml(c) for c in contents]

# Create a sample text to test count_xml
sample_text = "<reasoning>\nThis is my reasoning.\n</reasoning>\n<answer>\n42\n</answer>\n"
count_result = count_xml(sample_text)
print(f"Sample text:\n{sample_text}\nCount: {count_result}")

# Quick test for xml_count_reward_func
test_completions_xml = [
    [{"role": "assistant", "content": "<reasoning>\nReasoning here.\n</reasoning>\n<answer>\nAnswer here.\n</answer>\n"}],
    [{"role": "assistant", "content": "<reasoning>\nReasoning only.\n</reasoning>\n"}],
    [{"role": "assistant", "content": "<answer>\nJust answer.\n</answer>\n"}],
    [{"role": "assistant", "content": "No tags here."}]
]

xml_counts = xml_count_reward_func(test_completions_xml)
print("xml_count_reward_func test results:", xml_counts)



Sample text:
<reasoning>
This is my reasoning.
</reasoning>
<answer>
42
</answer>

Count: 0.496
xml_count_reward_func test results: [0.496, 0.25, 0.246, 0.0]
