In [1]:
import datasets
from textwrap import dedent
import os

In [2]:
ds = datasets.load_dataset('Asap7772/hint_sol_gen_omnimath_41_mini', split='train')
ds = ds.train_test_split(test_size=0.1, seed=42)
ds

DatasetDict({
    train: Dataset({
        features: ['problem', 'answer', 'nohint_responses', 'hint_responses', 'nohint_correct', 'hint_correct', 'source', 'difficulty', 'model', 'all_hints'],
        num_rows: 727
    })
    test: Dataset({
        features: ['problem', 'answer', 'nohint_responses', 'hint_responses', 'nohint_correct', 'hint_correct', 'source', 'difficulty', 'model', 'all_hints'],
        num_rows: 81
    })
})

In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

In [4]:
# SFT for hint generation
def map_hint_gen(example):
    query_str = f"""
# Task
Generate a set of 5 hints for the following problem. Write the hints within a set of xml tags <note> and </note>.

# Problem
{example['problem']}

# Hints
    """
    prompt = dedent(query_str).strip()
    chat = [{'role': 'user', 'content': prompt}]
    example['query'] = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    example['completion'] = example['all_hints'].strip()
    return example

ds_hint_gen = ds.map(map_hint_gen, batched=False, num_proc=os.cpu_count(), remove_columns=ds['train'].column_names)
ds_hint_gen

Map (num_proc=24):   0%|          | 0/727 [00:00<?, ? examples/s]

Map (num_proc=24):   0%|          | 0/81 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['query', 'completion'],
        num_rows: 727
    })
    test: Dataset({
        features: ['query', 'completion'],
        num_rows: 81
    })
})

In [5]:
print(ds_hint_gen['train'][0]['query'])
print(ds_hint_gen['train'][0]['completion'])

<|im_start|>user
# Task
Generate a set of 5 hints for the following problem. Write the hints within a set of xml tags <note> and </note>.

# Problem
Find all pairs of integers $a,b$ for which there exists a polynomial $P(x) \in \mathbb{Z}[X]$ such that product $(x^2+ax+b)\cdot P(x)$ is a polynomial of a form \[ x^n+c_{n-1}x^{n-1}+\cdots+c_1x+c_0  \] where each of $c_0,c_1,\ldots,c_{n-1}$ is equal to $1$ or $-1$.

# Hints<|im_end|>
<|im_start|>assistant

<notes>
  <note>
    <description>When multiplying two polynomials to obtain a target with specific leading and constant terms, equate the highest- and lowest-degree coefficients first.  This immediately forces the leading coefficients of the factors to multiply to the target’s leading coefficient, and similarly for constant terms.  It provides quick divisibility or sign constraints on those factor coefficients.</description>
    <example>Suppose f(x)=u x+p and g(x)=v x+q produce a product with leading coefficient 1 and constant term ±1

In [6]:
ds_hint_gen['train'].to_parquet('/home/anikait.singh/rl_behaviors_verl_stable/hint_gen_sft/train.parquet')
ds_hint_gen['test'].to_parquet('/home/anikait.singh/rl_behaviors_verl_stable/hint_gen_sft/test.parquet')

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

245450

In [12]:
import random
per_no_hint = 0.2

math_prefix = "Solve the following math problem. Give your final answer as \\boxed{}."

def map_fn_hint_cond_sol(example):
    curr_hint = example['all_hints']
    prompt_no_hint = f"{math_prefix}\n{example['problem']}"
    prompt_hint = f"{math_prefix}\n{example['problem']}\n{curr_hint}"
    remaining_sols_no_hint = [no_hint_sol for no_hint_sol, no_hint_sol_correct in zip(example['nohint_responses'], example['nohint_correct']) if no_hint_sol_correct]
    remaining_sols_hint = [hint_sol for hint_sol, hint_sol_correct in zip(example['hint_responses'], example['hint_correct']) if hint_sol_correct]
    if random.random() < per_no_hint:
        if len(remaining_sols_no_hint) > 0:
            prompt = prompt_no_hint
            example['completion'] = remaining_sols_no_hint[0]
        else:
            prompt = prompt_hint
            example['completion'] = None
    else:
        if len(remaining_sols_hint) > 0:
            prompt = prompt_hint
            example['completion'] = remaining_sols_hint[0]
        else:
            prompt = prompt_no_hint
            example['completion'] = None
            
    chat = [{'role': 'user', 'content': prompt}]
    example['query'] = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    return example

ds_hint_cond_sol = ds.map(map_fn_hint_cond_sol, batched=False, num_proc=os.cpu_count(), remove_columns=ds['train'].column_names)
def filter_fn_hint_cond_sol(example):
    return example['completion'] is not None

ds_hint_cond_sol = ds_hint_cond_sol.filter(filter_fn_hint_cond_sol, num_proc=os.cpu_count())

ds_hint_cond_sol['train'].to_parquet('/home/anikait.singh/rl_behaviors_verl_stable/hint_cond_sol_gen_sft/train.parquet')
ds_hint_cond_sol['test'].to_parquet('/home/anikait.singh/rl_behaviors_verl_stable/hint_cond_sol_gen_sft/test.parquet')


Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

229156

In [11]:
print(ds_hint_cond_sol['train']['query'][0])
print(ds_hint_cond_sol['train']['completion'][0])

<|im_start|>user
Solve the following math problem. Give your final answer as \boxed{}.
If $3+\triangle=5$ and $\triangle+\square=7$, what is the value of $\triangle+\Delta+\Delta+\square+\square$?
<notes>
  <note>
    <description>Translate stated relationships into algebraic equations by assigning each unknown to a variable and expressing each “sum equals” or “difference equals” statement as an equation. This converts a word or symbol puzzle into a form amenable to algebraic manipulation.</description>
    <example>Suppose you’re told “a plus b equals c.” Introduce variables x and y for the unknowns and write x + y = c. If another statement is “x minus d equals e,” write x − d = e. This formalization allows systematic solving.</example>
  </note>
  <note>
    <description>Isolate a variable in a simple linear equation by applying inverse operations. For an equation of the form x + A = B or x − A = B, subtract or add A to both sides to solve for x.</description>
    <example>Given x + 