In [2]:
"""
GRPO training pipeline for Q-agent: generate JSON (topic, question, choices, answer, explanation).
Uses a 3-row sanity dataset; replace with your full dataset for real training.
"""
import json
import re
import os
from datasets import Dataset
from unsloth import FastLanguageModel
import torch
from trl import GRPOConfig, GRPOTrainer

# ---------------------------------------------------------------------------
# 1. Sanity dataset: 3 rows in your JSON format
# ---------------------------------------------------------------------------
RAW_DATA = [
    {
        "topic": "Alphanumeric Series",
        "question": "What comes next in the series: A1, B2, C3, D4, ?",
        "choices": [
            "A) E5",
            "B) F6",
            "C) D5",
            "D) E4",
        ],
        "answer": "A",
        "explanation": "The pattern is one letter and one number each increasing by 1. So next is E5.",
    },
    {
        "topic": "Number Series",
        "question": "Next number in sequence: 2, 6, 12, 20, 30, ?",
        "choices": [
            "A) 40",
            "B) 42",
            "C) 44",
            "D) 36",
        ],
        "answer": "B",
        "explanation": "Differences are 4, 6, 8, 10 (even numbers). Next difference is 12, so 30 + 12 = 42.",
    },
    {
        "topic": "Logical Reasoning",
        "question": "If all roses are flowers and some flowers fade quickly, which must be true?",
        "choices": [
            "A) All roses fade quickly",
            "B) Some roses may fade quickly",
            "C) No roses fade quickly",
            "D) All flowers are roses",
        ],
        "answer": "B",
        "explanation": "Some flowers fade quickly; roses are a subset of flowers. So some roses may be in that subset.",
    },
]

SYSTEM_PROMPT = """You are a Q-agent. Given a topic, output exactly one JSON object with no other text.
Use this format only:
{"topic": "<Topic>", "question": "<full question>", "choices": ["A) ...", "B) ...", "C) ...", "D) ..."], "answer": "<A or B or C or D>", "explanation": "<brief explanation, at most 100 words>"}"""


def build_grpo_dataset(raw_rows):
    """Convert list of JSON-format items to GRPO dataset with chat prompts."""
    prompts = []
    answers = []
    for item in raw_rows:
        topic = item["topic"]
        user_content = f"Generate a single multiple-choice question in the required JSON format for the topic: {topic}."
        prompts.append([
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_content},
        ])
        answers.append(item["answer"].strip().upper())
    return Dataset.from_dict({"prompt": prompts, "answer": answers})


dataset = build_grpo_dataset(RAW_DATA)

# ---------------------------------------------------------------------------
# 2. Model and LoRA
# ---------------------------------------------------------------------------
lora_rank = 8
# Use a small model for pipeline sanity; switch to 14B for real training.
model_path = os.environ.get("GRPO_MODEL_PATH", "/workspace/AAIPL/hf_models/Qwen2.5-14B-Instruct")
max_seq_length = 1024

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_path,
    max_seq_length=max_seq_length,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank * 2,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

# ---------------------------------------------------------------------------
# 3. Lengths and GRPO config
# ---------------------------------------------------------------------------
max_prompt_length = 256
max_completion_length = max_seq_length - max_prompt_length

training_args = GRPOConfig(
    temperature=1.0,
    learning_rate=5e-5,
    weight_decay=0.001,
    warmup_ratio=0.1,
    lr_scheduler_type="linear",
    optim="adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    num_generations=2,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_completion_length,
    max_steps=10,
    save_steps=10,
    report_to="none",
    output_dir="outputs",
)

# ---------------------------------------------------------------------------
# 4. Reward functions (signature: prompts, completions, answer, **kwargs)
#    Completions are list of [{"role":"assistant","content": "..."}]; we use content.
# ---------------------------------------------------------------------------

def get_content(completion):
    if isinstance(completion, str):
        return completion
    if isinstance(completion, (list, tuple)) and completion and isinstance(completion[0], dict):
        return completion[0].get("content", "")
    return ""


def reward_valid_json(completions, **kwargs):
    scores = []
    for c in completions:
        text = get_content(c)
        try:
            # Allow JSON inside markdown code blocks
            if "```" in text:
                text = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
                text = text.group(1).strip() if text else text
            data = json.loads(text.strip())
            required = {"topic", "question", "choices", "answer", "explanation"}
            if required.issubset(data.keys()) and isinstance(data.get("choices"), list) and len(data["choices"]) == 4:
                scores.append(1.0)
            else:
                scores.append(0.0)
        except (json.JSONDecodeError, TypeError, AttributeError):
            scores.append(0.0)
    return scores


def reward_answer_format(completions, **kwargs):
    scores = []
    for c in completions:
        text = get_content(c)
        try:
            if "```" in text:
                m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
                text = m.group(1).strip() if m else ""
            data = json.loads(text.strip())
            ans = (data.get("answer") or "").strip().upper()
            scores.append(1.0 if ans in ("A", "B", "C", "D") else 0.0)
        except (json.JSONDecodeError, TypeError, AttributeError):
            scores.append(0.0)
    return scores


def reward_explanation_length(completions, **kwargs):
    scores = []
    for c in completions:
        text = get_content(c)
        try:
            if "```" in text:
                m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
                text = m.group(1).strip() if m else ""
            data = json.loads(text.strip())
            expl = (data.get("explanation") or "")
            n = len(expl.split())
            scores.append(1.0 if 0 < n <= 100 else (0.5 if n <= 100 else 0.0))
        except (json.JSONDecodeError, TypeError, AttributeError):
            scores.append(0.0)
    return scores


def reward_answer_correctness(prompts, completions, answer, **kwargs):
    """Reward when generated 'answer' matches the reference (dataset) answer."""
    if answer is None:
        return [0.0] * len(completions)
    refs = answer if isinstance(answer, (list, tuple)) else [answer]
    if len(refs) < len(completions):
        refs = (refs * ((len(completions) // len(refs)) + 1))[:len(completions)]
    scores = []
    for c, ref in zip(completions, refs):
        text = get_content(c)
        ref = (ref or "").strip().upper()
        try:
            if "```" in text:
                m = re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
                text = m.group(1).strip() if m else ""
            data = json.loads(text.strip())
            ans = (data.get("answer") or "").strip().upper()
            scores.append(2.0 if ans == ref else 0.0)
        except (json.JSONDecodeError, TypeError, AttributeError):
            scores.append(0.0)
    return scores


# ---------------------------------------------------------------------------
# 5. Trainer and train
# ---------------------------------------------------------------------------
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        reward_valid_json,
        reward_answer_format,
        reward_explanation_length,
        reward_answer_correctness,
    ],
    args=training_args,
    train_dataset=dataset,
)

if __name__ == "__main__":
    trainer.train()
    model.save_pretrained(os.path.join(training_args.output_dir, "final"))


Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now.
==((====))==  Unsloth 2025.10.9: Fast Qwen2 patching. Transformers: 4.56.2. vLLM: 0.11.1rc3.dev39+gf417746ad.rocm700.
   \\   /|    . Num GPUs = 1. Max memory: 255.688 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0a0+git1c57644. ROCm Toolkit: 7.0.51831-a3e329ad8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = True]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Unsloth 2025.10.9 patched 48 layers with 48 QKV layers, 48 O layers and 48 MLP layers.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.


Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 2


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 3 | Num Epochs = 4 | Total steps = 10
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2
 "-____-"     Trainable parameters = 34,406,400 of 14,804,440,064 (0.23% trained)


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / reward_valid_json / mean,rewards / reward_valid_json / std,rewards / reward_answer_format / mean,rewards / reward_answer_format / std,rewards / reward_explanation_length / mean,rewards / reward_explanation_length / std,rewards / reward_answer_correctness / mean,rewards / reward_answer_correctness / std
1,0.0,3.0,0.0,131.5,124.0,139.0,0.0,131.5,124.0,139.0,0,0,0,0,0,0.001366,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
2,0.0,3.0,0.0,125.0,111.0,139.0,0.0,125.0,111.0,139.0,No Log,No Log,No Log,No Log,No Log,0.001617,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
3,0.0,4.0,1.414214,128.5,118.0,139.0,0.0,128.5,118.0,139.0,No Log,No Log,No Log,No Log,No Log,0.00154,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.414214
4,0.0,4.0,1.414214,123.0,120.0,126.0,0.0,123.0,120.0,126.0,No Log,No Log,No Log,No Log,No Log,0.001311,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.414214
5,0.0,3.0,0.0,133.0,122.0,144.0,0.0,133.0,122.0,144.0,No Log,No Log,No Log,No Log,No Log,0.003699,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
6,0.0,4.0,1.414214,142.5,137.0,148.0,0.0,142.5,137.0,148.0,No Log,No Log,No Log,No Log,No Log,0.003174,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.414214
7,0.0,4.0,1.414214,133.0,119.0,147.0,0.0,133.0,119.0,147.0,No Log,No Log,No Log,No Log,No Log,0.011095,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.414214
8,0.0,4.0,1.414214,107.0,101.0,113.0,0.0,107.0,101.0,113.0,No Log,No Log,No Log,No Log,No Log,0.007963,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.414214
9,0.0,3.0,0.0,146.5,139.0,154.0,0.0,146.5,139.0,154.0,No Log,No Log,No Log,No Log,No Log,0.008826,1.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
10,0.0,4.0,1.414214,135.0,119.0,151.0,0.0,135.0,119.0,151.0,No Log,No Log,No Log,No Log,No Log,0.033648,1.0,0.0,1.0,0.0,1.0,0.0,1.0,1.414214


Unsloth: Will smartly offload gradients to save VRAM!
