# Search with base models

## Goal

Can we solve ARC tasks using base models with access to a DSL?

## Imports

In [None]:
import os
from arc25.utils import get_least_used_gpu_index
from arc25.logging import configure_logging, log_execution_time

configure_logging()
os.environ['CUDA_VISIBLE_DEVICES'] = str(get_least_used_gpu_index())

In [None]:
import importlib
import inspect
import json

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

from arc25.training_tasks import *
from arc25.encoders import create_grid_encoder
from arc25.prompting import pretty_print_prompt, Template

## Prompt

https://github.com/flowersteam/SOAR/blob/main/soar/prompt.py

In [None]:
def extract_footprint(module_name: str, show_types: bool = False) -> str:
    """
    Load a module by name, then return a newline-separated list of all
    top-level functions in it, in the form:

      def func_name(arg1, arg2) -> return

    If show_types=True, annotations are included; otherwise only names.
    """
    mod = importlib.import_module(module_name)
    footprints = []

    for name, fn in inspect.getmembers(mod, inspect.isfunction):
        # skip imports from elsewhere
        if fn.__module__ != module_name or name.startswith("_"):
            continue

        sig = inspect.signature(fn)
        if not show_types:
            # strip type info
            params = [p.name for p in sig.parameters.values()]
            sig_text = f"({', '.join(params)})"
        else:
            sig_text = str(sig)

        footprints.append(f"- {name}{sig_text}")

    return "\n".join(footprints)

print(extract_footprint('arc25.BARC_dsl', show_types=True))

In [None]:
with open('/mnt/hdd0/Kaggle/arc25/data/arc-prize-2024/arc-agi_training_challenges.json', 'r') as f:
    training_challenges = json.load(f)

def get_task(task_name):
    if task_name in training_challenges:
        task_data = training_challenges[task_name]
        inputs = [Img(sample['input']) for sample in task_data['train']]
        outputs = [Img(sample['output']) for sample in task_data['train']]
        return Task(inputs=inputs, outputs=outputs, code='', name=task_name)
    raise ValueError(f"Task {task_name} not found in training challenges.")

In [None]:
system_prompt = """You are an advanced AI assistant specialized in solving Abstract Reasoning Corpus (ARC-AGI) tasks."""


prompt_template = Template("""You are tasked with solving a transformation problem from the Abstraction and Reasoning Challenge (ARC).
Implement the transformation rules as a Python function.
You should only write the implemented the transformation in code.
You must write code in triple backticks (```python and then ```). You must write a function called `transform` which takes a single argument, the input grid as `list[list[int]]`, and returns the transformed grid (also as `list[list[int]]`).

## Key Priors:

- **Objectness**: Consider the grid as containing objects (groups of connected cells) rather than just individual pixels.
- **Goal-Directed**: The transformation should achieve a specific goal, such as creating symmetry or changing the color of specific objects.
- **Numbers & Counting**: Keep track of the number of objects, sizes, and their relative positions.
- **Geometry & Topology**: Use spatial relationships such as adjacency, enclosure, or symmetry.

Carefully analyze the examples and find the underlying transformation logic.

## Domain Specific Primitive Functions

You can use the already implemented following functions to manipulate the grid:

{{ dsl }}

## Examples

Below are several input-output examples that illustrate the transformation.
Your function should generalize the pattern from these examples to solve any input following the same logic.

{% for sample in train_samples %}
### Example {{ loop.index }}

#### Input

{{ sample.input }}

#### Output

{{ sample.output }}
{% endfor %}
""")


def create_prompt_from_task(task, grid_encoder, tokenizer):
    train_samples = [{'input': grid_encoder.to_text(grid), 'output': grid_encoder.to_text(output)} for grid, output in zip(task.inputs, task.outputs)]
    render_kwargs = dict(train_samples=train_samples, dsl=extract_footprint('arc25.BARC_dsl', show_types=True))
    messages = [{"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt_template.render(**render_kwargs)}]
    prompt = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
    return prompt

In [None]:
@log_execution_time
def load_model(base_model_path, lora_path=None):
    logging.info(f"Loading model from {base_model_path} and LoRA from {lora_path}")
    torch.cuda.empty_cache()

    bnb_config = BitsAndBytesConfig(
            load_in_4bit= True,
            bnb_4bit_quant_type= "nf4",
            bnb_4bit_compute_dtype= torch.float16,
            bnb_4bit_use_double_quant= True,
            llm_int8_enable_fp32_cpu_offload= True,
            llm_int8_skip_modules=['gate', 'lm_head'],
    )


    model = AutoModelForCausalLM.from_pretrained(
        base_model_path, torch_dtype="auto", device_map="auto", quantization_config=bnb_config)
    if lora_path is None:
        tokenizer = AutoTokenizer.from_pretrained(base_model_path)
        return model, tokenizer

    tokenizer = AutoTokenizer.from_pretrained(lora_path)
    model = PeftModel.from_pretrained(model, lora_path, is_trainable=True)
    return model, tokenizer

## First steps

In [None]:
model, tokenizer = load_model("/home/gbarbadillo/models/Qwen2.5-Coder-0.5B-Instruct")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("/home/gbarbadillo/models/Qwen2.5-Coder-0.5B-Instruct")
len(tokenizer.tokenize(extract_footprint('arc25.BARC_dsl', show_types=True)))
prompt = create_prompt_from_task(get_task('0b148d64'), grid_encoder = create_grid_encoder('GridShapeEncoder(RowNumberEncoder(MinimalGridEncoder()))'), tokenizer=tokenizer)
pretty_print_prompt(prompt, default_color='white')

In [None]:
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512,
    do_sample=True,
    temperature=1.0,
    top_p=0.90,
    num_return_sequences=1,
)
generated_ids = generated_ids[:, len(model_inputs.input_ids[0]):]
predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(predictions[0])

## TODO

- [x] Create a prompt with the available DSL functions and the training ARC task
- [ ] Verify the effect of caching
- [ ] Update the library to be able to select which DSL to use when executing code
- [ ] Try to solve some easy task
- [ ] Create a refine prompt
- [ ] Make a more complex tree search