In [None]:
!pip install unsloth vllm
!pip install --upgrade pillow
!pip install llama-cpp-python
!pip install langchain-community

In [None]:
import torch
import re
import pandas as pd
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
from datasets import load_dataset, Dataset
from difflib import SequenceMatcher

PatchFastRL("GRPO", FastLanguageModel)

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = 512,
    max_lora_rank = 64,
    load_in_4bit = True,
    fast_inference = True,
    gpu_memory_utilization = 0.8)

model = FastLanguageModel.get_peft_model(
    model,
    r = 64,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 64,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407)


In [None]:
# Load a dataset from Hugging Face (example: Riddles QA)
df_riddles_qa = load_dataset("D4ve-R/riddles-qa")

# Save the train split (or any available split) to a CSV file
df_riddles_qa["train"].to_csv("data/riddles/riddles_qa.csv")

# Load a dataset from Hugging Face (example: Riddles QA)
df_riddle_sense = load_dataset("INK-USC/riddle_sense")

# Function to map answerKey to actual answer text
def get_correct_answer(example):
    answer_index = example["choices"]["label"].index(example["answerKey"])  # Find index of correct answer
    return example["choices"]["text"][answer_index]  # Get corresponding answer text

# Create DataFrame with riddle and correct answer
df_riddle_sense = pd.DataFrame({
    "riddle": df_riddle_sense["train"]["question"],
    "answer": [get_correct_answer(ex) for ex in df_riddle_sense["train"]]
})

# Save the train split (or any available split) to a CSV file
df_riddle_sense.to_csv("data/riddles/riddle_sense.csv", index=False)

In [None]:
df_riddles_qa = pd.read_csv("data/riddles/riddles_qa.csv")
df_riddle_sense = pd.read_csv("data/riddles/riddle_sense.csv")
df_riddles = pd.read_csv("data/riddles/Riddles.csv", names=["riddle", "answer", "hint"])

In [None]:
display(df_riddles_qa.head())
display(df_riddle_sense.head())
display(df_riddles.head())

In [None]:
df_riddles_qa = df_riddles_qa.drop_duplicates(subset=["answer", "riddle"])
# df_riddle_sense = df_riddle_sense.drop_duplicates(subset=["answer", "riddle"])
df_riddles = df_riddles.drop_duplicates(subset=["answer", "riddle"])

dataset = pd.concat([
    df_riddles_qa.set_index(["answer", "riddle"]),
    # df_riddle_sense.set_index(["answer", "riddle"]),
    df_riddles.set_index(["answer", "riddle"])
], axis=1).reset_index()

dataset.drop("hint", axis=1, inplace=True)

In [None]:
# Long answer or reasoning should be filtered out
dataset = dataset[dataset["answer"].str.split().str.len() <= 4]
dataset.reset_index(inplace=True)

In [None]:
display(dataset)
dataset = Dataset.from_pandas(dataset)

In [None]:
# Load and prep dataset
SYSTEM_PROMPT = """
Solve the riddle and respond in the following format. Note the clue is an important keyword or phrase from the riddle:
<clue>  
...
</clue>    
<reason>  
...
</reason>  
<answer>  
...
</answer>  
"""

def extract_xml_tag(text: str, tag: str) -> str:
    answer = text.split(f"<{tag}>")[-1]
    answer = answer.split(f"</{tag}>")[0]
    return answer.strip()

# uncomment middle messages for 1-shot prompting
def get_riddles() -> Dataset:
    # Load CSVs
    df_riddles_qa = pd.read_csv("data/riddles/riddles_qa.csv")
    #df_riddle_sense = pd.read_csv("data/riddles/riddle_sense.csv")
    df_riddles = pd.read_csv("data/riddles/Riddles.csv", names=["riddle", "answer", "hint"])

    # Drop duplicates
    for df in [df_riddles_qa, df_riddle_sense, df_riddles]:
        df.drop_duplicates(subset=["answer", "riddle"], inplace=True)

    # Merge data
    dataset = pd.concat([
        df_riddles_qa.set_index(["answer", "riddle"]),
        #df_riddle_sense.set_index(["answer", "riddle"]),
        df_riddles.set_index(["answer", "riddle"])
    ], axis=1).reset_index()

    # Drop hint column safely
    dataset.drop(columns=["hint"], errors="ignore", inplace=True)

    # Keep only short answers (≤ 4 words)
    dataset = dataset[dataset["answer"].str.split().str.len() <= 4]

    # Drop missing or empty answers
    dataset = dataset.dropna(subset=["answer"])
    dataset = dataset[dataset["answer"].str.strip() != ""]

    # Convert to structured format
    records = dataset.apply(lambda row: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': row['riddle']}
        ],
        'answer': row['answer']
    }, axis=1).to_list()

    return Dataset.from_list(records)

dataset=get_riddles()

## Reward Functions

### Correct Answer
Rewards if the answer and preferred answer are similar.

In [None]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'].strip().lower() for completion in completions]
    extracted_responses = [extract_xml_tag(r, "answer").strip().lower() for r in responses]
    correct_answers = [a.strip().lower() for a in answer]
    q = prompts[0][-1]['content']

    def is_similar(r, a):
        return r in a or a in r

    print('-'*20, f"\nRiddle:\n{q}", f"\nAnswer:\n{answer[0]}\n", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}\n")

    return [2.0 if is_similar(r, a) else 0.0 for r, a in zip(extracted_responses, correct_answers)]

### Strict Correct XML Format
Rewards a perfectly formatted response with proper XML tag order and structure

In [None]:
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<clue>\n.*?\n</clue>\n<reason>\n.*?\n</reason>\n<answer>\n.*?\n</answer>$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) and r.endswith("</answer>") and r.startswith("<clue>") for r in responses]
    return [0.5 if match else 0.0 for match in matches]

### Soft Correct XML Format
Rewards a response as long as the XML is parsable. Does not whitespace.

In [None]:
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<clue>\s*.*?\s*</clue>\s*<reason>\s*.*?\s*</reason>\s*<answer>\s*.*?\s*</answer>$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

### XML Count
Rewards if the correct number of XML tags are present in any order

In [None]:
def count_xml(text) -> float:
    count = 0.0
    if text.count("<clue>\n") == 1:
        count += 0.125
    if text.count("\n</clue>\n") == 1:
        count += 0.125
    if text.count("<reason>\n") == 1:
        count += 0.125
    if text.count("\n</reason>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
    if text.count("\n</answer>") == 1:
        count += 0.125
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

### Answer Brevity
Rewards short answers, less than 4 words. All answers in the dataset are 4 words or less

In [None]:
def short_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'].strip() for completion in completions]
    extracted_responses = [extract_xml_tag(r, "answer").strip() for r in responses]
    
    def shortness_score(r):
        words = r.split()
        if len(words) <= 4:
            return 0.5
        return 0.0
    
    return [shortness_score(r) for r in extracted_responses]

### Clue Extraction
Rewards a clue that is a substring from the original riddle. (Must be robust against model responding with individual letters)

In [None]:
def compute_reward(clue, riddle):
    clue_length = len(clue.split())
    riddle_words = set(riddle.split())

    if clue_length < 1 or clue_length > 5:
        return 0.0  

    return 0.5 if clue in riddle else 0.0


def clue_extraction_reward_func(prompts, completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'].strip() for completion in completions]
    extracted_clues = [extract_xml_tag(r, "clue").strip() for r in responses]
    riddle = prompts[0][-1]['content']
    
    return [compute_reward(clue, riddle) for clue in extracted_clues]


In [None]:
import wandb

# Initialize WandB
wandb.init(project="riddle-llm-training", name="grpo_run_5", config={
    "learning_rate": 5e-6,
    "num_train_epochs": 10,
    "batch_size": 1,
    "max_steps": 50000,
    "gradient_accumulation_steps": 1,
})

In [None]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    num_train_epochs = 10, # Set to 1 for a full training run
    max_steps = 10000,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "wandb", # Can use Weights & Biases
    output_dir = "outputs",
)

# Log training arguments to WandB
wandb.config.update(vars(training_args))

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        short_reward_func,
        clue_extraction_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()