# Search with base models

## Goal

Can we solve ARC tasks using base models with access to a DSL?

## Imports

In [None]:
import os
import logging
from arc25.utils import get_least_used_gpu_index
from arc25.logging import configure_logging, log_execution_time

configure_logging()
os.environ['CUDA_VISIBLE_DEVICES'] = str(get_least_used_gpu_index())

# Add VLLM specific environment variables to avoid common issues
os.environ['VLLM_USE_MODELSCOPE'] = 'False'
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

In [None]:
import time
import importlib
import inspect
import json
import gc

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
from vllm import LLM, SamplingParams

from arc25.training_tasks import *
from arc25.encoders import create_grid_encoder
from arc25.prompting import pretty_print_prompt, Template

## Code

### Prompt

https://github.com/flowersteam/SOAR/blob/main/soar/prompt.py

In [None]:
def extract_footprint(module_name: str, show_types: bool = False) -> str:
    """
    Load a module by name, then return a newline-separated list of all
    top-level functions in it, in the form:

      def func_name(arg1, arg2) -> return

    If show_types=True, annotations are included; otherwise only names.
    """
    mod = importlib.import_module(module_name)
    footprints = []

    for name, fn in inspect.getmembers(mod, inspect.isfunction):
        # skip imports from elsewhere
        if fn.__module__ != module_name or name.startswith("_"):
            continue

        sig = inspect.signature(fn)
        if not show_types:
            # strip type info
            params = [p.name for p in sig.parameters.values()]
            sig_text = f"({', '.join(params)})"
        else:
            sig_text = str(sig)

        footprints.append(f"- dsl.{name}{sig_text}")

    return "\n".join(footprints)

print(extract_footprint('arc25.BARC_dsl', show_types=True))

In [None]:
with open('/mnt/hdd0/Kaggle/arc25/data/arc-prize-2024/arc-agi_training_challenges.json', 'r') as f:
    training_challenges = json.load(f)

def get_task(task_name):
    if task_name in training_challenges:
        task_data = training_challenges[task_name]
        inputs = [Img(sample['input']) for sample in task_data['train']]
        outputs = [Img(sample['output']) for sample in task_data['train']]
        return Task(inputs=inputs, outputs=outputs, code='', name=task_name)
    raise ValueError(f"Task {task_name} not found in training challenges.")

In [None]:
system_prompt = """You are an advanced AI assistant specialized in solving Abstract Reasoning Corpus (ARC-AGI) tasks."""


prompt_template = Template(
"""You are tasked with solving a transformation problem from the Abstraction and Reasoning Challenge (ARC).
Implement the transformation rules as a Python function.
You should only write the implemented the transformation in code.
You must write code in triple backticks (```python and then ```). You must write a function called `transform` which takes a single argument, the input grid as `list[list[int]]`, and returns the transformed grid (also as `list[list[int]]`).

## Key Priors:

- **Objectness**: Consider the grid as containing objects (groups of connected cells) rather than just individual pixels.
- **Goal-Directed**: The transformation should achieve a specific goal, such as creating symmetry or changing the color of specific objects.
- **Numbers & Counting**: Keep track of the number of objects, sizes, and their relative positions.
- **Geometry & Topology**: Use spatial relationships such as adjacency, enclosure, or symmetry.

Carefully analyze the examples and find the underlying transformation logic.

## Domain Specific Primitive Functions

You can use the already implemented following functions to manipulate the grid:

{{ dsl }}

The dsl has been already imported, so just simply call the functions as needed. F.e. dsl.foo()
Do not import the dsl again, just use it directly.

## Examples

Below are several input-output examples that illustrate the transformation.
Your function should generalize the pattern from these examples to solve any input following the same logic.

{% for sample in train_samples %}
### Example {{ loop.index }}

#### Input

{{ sample.input }}

#### Output

{{ sample.output }}
{% endfor %}
""")


def create_prompt_from_task(task, grid_encoder, tokenizer):
    train_samples = [{'input': grid_encoder.to_text(grid), 'output': grid_encoder.to_text(output)} for grid, output in zip(task.inputs, task.outputs)]
    render_kwargs = dict(train_samples=train_samples, dsl=extract_footprint('arc25.BARC_dsl', show_types=True))
    messages = [{"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt_template.render(**render_kwargs)}]
    prompt = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=True)
    return prompt

### Model

In [None]:
@log_execution_time
def load_model(model_path, use_4bit_quantization=False):
    logging.info(f"Loading model from {model_path}")
    cleanup_gpu()
    llm = LLM(
        model=model_path,
        gpu_memory_utilization=0.9,  # Use less GPU memory
        # max_model_len=4096,  # Limit context length
        trust_remote_code=True,
        dtype="bfloat16",  # Use float16 to save memory
        tensor_parallel_size=1,  # Single GPU
        quantization="bitsandbytes" if use_4bit_quantization else None,
        enable_prefix_caching=True, # Seems that it is true by default, but let's be explicit
        max_model_len=32000, # otherwise the 14B model will fail with "context length exceeded" error
    )
    if model_path.endswith('.gguf'):
        tokenizer_path = os.path.join(os.path.dirname(model_path), 'tokenizer')
    else:
        tokenizer_path = model_path
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    return llm, tokenizer


def cleanup_gpu():
    """Clean up GPU memory before loading VLLM"""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

### Code

In [None]:
def parse_python_code(text):
    # Extract Python code from the text
    if '```python' not in text:
        return ''
    code = text.split('```python')[1]
    if not '```' in code:
        return ''

    code = code.split('```')[0].strip()
    code = '\n'.join(line for line in code.split('\n') if not 'import dsl' in line and not 'from dsl' in line)

    return code.split('```')[0].strip()

## First steps

In [None]:
model_path = "/home/gbarbadillo/models/Qwen2.5-Coder-3B-Instruct"
# model_path = '/home/gbarbadillo/models/Qwen2.5-Coder-14B-Instruct-GGUF/qwen2.5-coder-14b-instruct-q4_k_m.gguf'
llm, tokenizer = load_model(model_path, use_4bit_quantization=False)

In [None]:
task_ids = list(training_challenges.keys())[:10]
grid_encoder = create_grid_encoder('GridShapeEncoder(RowNumberEncoder(MinimalGridEncoder()))')
prompts = [create_prompt_from_task(get_task(task_id), grid_encoder=grid_encoder, tokenizer=tokenizer) for task_id in task_ids]

In [None]:
sampling_params = SamplingParams(n=64, temperature=1.0, top_p=0.95, max_tokens=2048)
t0 = time.time()
outputs = llm.generate(prompts, sampling_params)
total_tokens = sum(sum(len(_output.token_ids) for _output in output.outputs) for output in outputs)
t1 = time.time()
print(f"Total tokens generated: {total_tokens}")
print(f"Time taken: {t1 - t0:.2f} seconds")
print(f"Average time per task: {(t1 - t0) / len(outputs):.2f} seconds")
print(f"Average tokens per task: {total_tokens / len(outputs) / sampling_params.n:.2f} tokens")
print(f"Average tokens per second: {total_tokens / (t1 - t0):.2f} tokens/second")

In [None]:
import arc25.BARC_dsl as dsl

predicted_code = {key: [] for key in task_ids}
predicted_outputs = {key: [] for key in task_ids}
for task_id, responses in zip(task_ids, outputs):
    task = get_task(task_id)
    for i, output in enumerate(responses.outputs):
        code = parse_python_code(output.text)
        if code:
            predicted_code[task_id].append(code)
            try:
                task_predicted_outputs = safe_code_execution(code, task.inputs, func_name='transform', dsl=dsl)
                predicted_outputs[task_id].append(task_predicted_outputs)
            except Exception as e:
                logging.error(f"Error executing code for task {task_id}, response {i}: {e}")

In [None]:
total = 0
for task_id, task_predicted_outputs in predicted_outputs.items():
    total += len(task_predicted_outputs)
    print(f"Task {task_id} valid predicted outputs: {len(task_predicted_outputs)}/{sampling_params.n}")
print(f"Total valid predicted outputs: {total}/{len(task_ids) * sampling_params.n} ({total/len(task_ids)/sampling_params.n:.1%})")

In [None]:
total = 0
for task_id, task_code_predictions in predicted_code.items():
    n_functions_using_dsl = sum(1 for code in task_code_predictions if 'dsl.' in code)
    total += n_functions_using_dsl
    print(f"Task {task_id} use of dsl functions: {n_functions_using_dsl}/{sampling_params.n}")
print(f"Total use of dsl functions: {total}/{len(task_ids) * sampling_params.n} ({total/len(task_ids)/sampling_params.n:.1%})")

In [None]:
print(predicted_code['025d127b'][0])

## TODO

- [x] Create a prompt with the available DSL functions and the training ARC task
- [x] Fix VLLM initialization issues with proper memory management
- [x] Verify the effect of caching
- [x] Generate some code that can be used to test the new BARC dsl
- [x] Update the library to be able to select which DSL to use when executing code
- [x] Verify that I can execute the code generated with the BARC dsl
- [ ] Try to solve some easy task with independent sampling
  - [ ] How frequently is the dsl used?
  - [ ] Influence of the model
- [ ] Create a refine prompt
- [ ] Make a more complex tree search