In [73]:

import re
import ast
from dataclasses import dataclass, field

from datasets import load_dataset

from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config


def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    for content, sol in zip(contents, solution):
        try:
            # Regular expression to extract content between <answer> and </answer>
            pattern = r"<answer>(.*?)</answer>"

            # Find all matches
            matches = re.findall(pattern, content, re.DOTALL)[0]
            print(matches)
            reward = 1.0 if sol in matches else 0.0
        except Exception:  # if it fails for any reason, return 0.0
            reward = 0.0
        rewards.append(reward)
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return rewards


def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


reward_funcs_registry = {
    "accuracy": accuracy_reward,
    "format": format_reward,
}

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)


In [30]:
 # Load the dataset
dataset = load_dataset('lordspline/arc-agi')

# Format into conversation
def make_conversation(example):
    examples = ''
    ex_num = 1
    for ex in example['train']:
        ex_in = "\n".join(" ".join(map(str, row)) for row in ex['input'])
        ex_out = "\n".join(" ".join(map(str, row)) for row in ex['output'])
        examples += f'Example {ex_num}: \n\nInput:\n{ex_in}\nOutput:\n{ex_out}\n\n'
        ex_num += 1
    test_in = "\n".join(" ".join(map(str, row)) for row in example["test"][0]["input"])

    question = f'Find the common rule that maps an input grid to an output grid, given the examples below.\n{examples} Below is a test input grid. Predict the corresponding output grid by applying the rule you found. Your final answer should just be the text output grid itself. \n\nInput:\n{test_in}\n'
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": question},
        ],
    }

def make_solution(example):
    solution = "\n".join(" ".join(map(str, row)) for row in example["test"][0]["output"])
    return {
        "solution": f'{solution}'
    }

dataset = dataset.map(make_conversation)
dataset = dataset.map(make_solution)
dataset = dataset.remove_columns("train")
dataset = dataset.remove_columns("test")


Map: 100%|██████████| 400/400 [00:00<00:00, 519.05 examples/s] 
Map: 100%|██████████| 400/400 [00:00<00:00, 580.16 examples/s]
Map: 100%|██████████| 5/5 [00:00<00:00, 88.66 examples/s]
Map: 100%|██████████| 400/400 [00:00<00:00, 1492.57 examples/s]
Map: 100%|██████████| 400/400 [00:00<00:00, 1004.49 examples/s]
Map: 100%|██████████| 5/5 [00:00<00:00, 86.09 examples/s]


In [61]:
print(dataset['training'][0]['prompt'][1]['content'])
print(dataset['training'][0]['solution'])

Find the common rule that maps an input grid to an output grid, given the examples below.
Example 1: 

Input:
0 0 5
0 5 0
5 0 0
Output:
3 3 3
4 4 4
2 2 2

Example 2: 

Input:
0 0 5
0 0 5
0 0 5
Output:
3 3 3
3 3 3
3 3 3

Example 3: 

Input:
5 0 0
0 5 0
5 0 0
Output:
2 2 2
4 4 4
2 2 2

Example 4: 

Input:
0 5 0
0 0 5
0 5 0
Output:
4 4 4
3 3 3
4 4 4

 Below is a test input grid. Predict the corresponding output grid by applying the rule you found. Your final answer should just be the text output grid itself. 

Input:
0 0 5
5 0 0
0 5 0

3 3 3
2 2 2
4 4 4


In [78]:
answer = [[{'role': 'assistant', 'content': '''<think> 1, 2, 3, 4, 5 </think><answer> 
3 3 3
2 2 2
4 4 2 </answer>'''
}]]

print(accuracy_reward(answer, [dataset['training'][0]['solution']]))

print(format_reward(answer))

 
3 3 3
2 2 2
4 4 2 
[0.0]
[1.0]
