# Create OpenAI Batch Files for Systematicity Abstract Visual Reasoning Task
This script creates prompts related to the systematicity experiment. We use `GPT-4o` via the OpenAI API to evaluate the model on the abstract visual reasoning task. We create a batch files that contain chunks of the test set. The input is the same as given to the meta-learning model, with an additional prompt that instructs the model with the respective task. The output should be the predicted output grid.

## Create Batch File
We exploit OpenAI's Batch API to make efficient use of their model and reduce API costs. For this, we first need to create a batch file that contains all the prompts we want to evaluate.

In [None]:
MODEL = "gpt-4o-2024-08-06"

SEED = 1860
DATA_DIR = f"data/split_seed_{SEED}"
FILE_NAME = f"systematicity_seed_{SEED}"

In [None]:
from pathlib import Path

# Output paths
CURR_FILE_PATH = Path.cwd().resolve()
OUT_DIR = f"{MODEL}/batch_files/split_seed_{SEED}"

### Model Prompt

In [None]:
user_prompt = """### Task Description:
You must solve an abstract visual reasoning task by identifying geometric transformations (e.g., rotation, translation, color changes, etc.) applied to objects within a 10x10 grid.

To infer the correct geometric transformation, you are given a series of **12 pairs of input-output examples**. Each example pair consists of:
- An **input grid**: a 10x10 list of lists (2d array), where each element is an integer (0-9).
- A corresponding **output grid**: a 10x10 list of lists (2d array) that has undergone a transformation based on a specific geometric rule.

The first 6 example pairs demonstrate primitive transformations based on the object's color, shape, or the presence of an additional object.
For instance, objects of a certain color within the 10x10 input grid might undergo a translation, while objects of a certain shape (distinct numerical pattern) are being rotated.

The latter 6 example pairs involve **composite transformations**, meaning multiple transformations are applied simultaneously.
For instance, for objects that have the appropriate color **and** shape, both a translation and rotation are applied simultaneously.

For the final prediction you need to understand and further combine the transformations displayed in the provided examples and apply them to the final input grid.

#### Your Task:
1. **Analyze** the example pairs to infer the transformation rules applied to each input grid.
2. **Identify** how these transformations might combine to generate the output grids.
3. **Apply** the deduced transformations to the final input grid.
4. **Output** the correctly transformed 10x10 grid.

### Output Requirements:
- **Return only the final output grid.**
- Do not include any extra text, explanations, or comments.
- The output must be formatted exactly as:
 `output: [[...]]`
- The output grid must be a 10x10 list of lists containing only integers between 0 and 9 (inclusive).
- Do not include unnecessary line breaks or additional text beyond the specified format.

### Input Format:
You will receive the following data:
1. **Study examples:** A list of 12 study example pairs, formatted as:
  `example input 1: [[...]], example output 1: [[...]], ..., example input 12: [[...]], example output 12: [[...]]`
2. **Final input:** A single 10x10 list of lists on which you must apply the inferred transformation(s).

Your goal is to determine the correct transformation and return the final output grid.

### Input:
"""
user_prompt

### Get Data

In [None]:
from vmlc.utils.utils import load_jsonl

test_data = load_jsonl(
    file_path=f"{DATA_DIR}/test_{FILE_NAME}.jsonl"
)

### Script

In [None]:
from typing import Any, List, Dict, Optional

from vmlc.utils.utils import save_dicts_as_jsonl

def prepare_study_examples(study_examples: List[List[List[List[str]]]]) -> str:
    study_example_str = ""

    for idx, input_output_pair in enumerate(study_examples):
        assert len(input_output_pair) == 2, f"Invalid number of input and output grids! {len(input_output_pair)}"
        input_grid = f"\nexample input {idx + 1}: {input_output_pair[0]}"
        output_grid = f"\nexample output {idx + 1}: {input_output_pair[1]}"

        study_example_str += input_grid + output_grid
    
    return study_example_str


def prepare_batch_files(
    test_data: List[Dict[str, Any]],
    user_prompt: str,
    num_samples_per_batch_file: int,
    model: str,
    out_dir: str,
    few_shot_examples: Optional[List[str]] = None
) -> None:

    curr_idx = 0

    while curr_idx < len(test_data):
        batch_file_content: List[Dict[str, Any]] = []
        curr_samples = test_data[curr_idx:curr_idx + num_samples_per_batch_file]

        for sample_num, sample in enumerate(curr_samples):
            batch_user_messages: List[Dict[str, str]] = []

            if few_shot_examples is not None:
                batch_user_messages += few_shot_examples

            study_example_str = prepare_study_examples(sample['study_examples'])
            input_grid_str = sample['queries'][0][0]
            
            batch_user_messages += [
                {
                    "role": "user",
                    "content": user_prompt + f"Study examples:{study_example_str}\n\n" + f"Final input:\n{input_grid_str}"
                }
            ]
    
            if "o3-mini" in MODEL or "o1" in MODEL:
                batch_file_content.append(
                    {
                        "custom_id": f"test_sample_{curr_idx+sample_num}",
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": model,
                            "messages": batch_user_messages,
                            "max_completion_tokens": 26000,
                            "reasoning_effort": "low",
                        }
                    }
                )
            else:
                batch_file_content.append(
                    {
                        "custom_id": f"test_sample_{curr_idx+sample_num}",
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": model,
                            "messages": batch_user_messages,
                            "max_tokens": 1000
                        }
                    }
                )
        
        if few_shot_examples is None:
            file_name = f"batch_file_samples_{curr_idx}-{min(curr_idx + num_samples_per_batch_file - 1, len(test_data) - 1)}.jsonl"
        else:
            file_name = f"batch_file_few_shots_samples_{curr_idx}-{min(curr_idx + num_samples_per_batch_file - 1, len(test_data) - 1)}.jsonl"
        
        save_dicts_as_jsonl(
            data=batch_file_content,
            filepath=f"{out_dir}/{file_name}"
        )

        curr_idx += num_samples_per_batch_file

In [None]:
prepare_batch_files(
    test_data=test_data,
    user_prompt=user_prompt,
    num_samples_per_batch_file=2500,
    model=MODEL,
    out_dir=OUT_DIR,
    few_shot_examples=None
)
