# Evaluate Model on Abstract Visual Reasoning Task
We use `GPT-4o` via the OpenAI API to evaluate the model on the abstract visual reasoning task. We create a batch files that contain chunks of the test set. The input is the same as given to the meta-learning model, with an additional prompt that instructs the model with the respective task. The output should be the predicted output grid.

### Evaluate Batch Files
The following code evaluates the error types in the batch files. The batch files are stored in the `batch_files` directory. The code reads the batch files, sends the statements to the GPT-4o model, and evaluates the results. The results are stored in the `gpt-4o` directory.

### Parameters

In [None]:
IMAGE_INPUT = True
ONLY_FEW_SHOTS = False

In [None]:
SEED = 1860
MODEL = "gpt-4o-2024-08-06"

FILE_NAME = f"systematicity_seed_{SEED}"
BACTHFOLDER = "image_batch_files" if IMAGE_INPUT else "batch_files"
DATA_DIR = f"{MODEL}/{BACTHFOLDER}/split_seed_{SEED}"

if ONLY_FEW_SHOTS:
    DATA_DIR += "_only_few_shots"
DATA_DIR

In [None]:
from pathlib import Path

# Paths
CURR_FILE_PATH = Path.cwd().resolve()
IMG_SPEC = "with_images" if IMAGE_INPUT else "text_only"
FEW_SHOT_SPEC = "only_few_shots" if ONLY_FEW_SHOTS else "vanilla"
OUT_DIR = str(CURR_FILE_PATH.parent / "experimental_results" / MODEL / IMG_SPEC / FEW_SHOT_SPEC)
OUT_DIR

### Keys

In [None]:
import os
import openai
from dotenv import load_dotenv

load_dotenv()

# Retrieve the API key from environment variable
api_key = os.getenv("OPENAI_API_VMLC")

# Check if the API key is retrieved successfully
if not api_key:
    raise ValueError("API key not found. Ensure the OPENAI_API_KEY environment variable is set correctly.")

# Set the OpenAI API key
openai.api_key = api_key

### Script
Scripts to evaluate the error types in the models' responses.

#### Check Status of Batch Files

In [None]:
from openai import OpenAI

client = OpenAI(
    api_key=api_key
)

def get_batch_status(num_jobs: int = 12) -> list[str]:
    batches = openai.batches.list(limit=10)
    batch_id_list: list[str] = []

    for batch in batches:
        batch_id = batch.id
        status_object = client.batches.retrieve(batch_id)
        status = status_object.status
        if len(batch_id_list) < num_jobs:
            print(f"id: {batch_id} - status: {status}")
            batch_id_list.append(batch_id)
        else:
            break
    return batch_id_list

batch_ids = get_batch_status()

#### Create batch file

In [None]:
from pathlib import Path

def create_batch(
    batch_file_path: Path,
    description: str
) -> None:
    batch_input_file = client.files.create(
        file=open(batch_file_path, "rb"),
        purpose="batch"
    )

    batch_input_file_id = batch_input_file.id

    client.batches.create(
        input_file_id=batch_input_file_id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
        metadata={
            "description": description
        }
    )

In [None]:
from openai import OpenAI

client = OpenAI(
    api_key=api_key
)


in_file_name = "batch_file_samples_0-2499.jsonl"
batch_file_path = f"{DATA_DIR}/{in_file_name}"

if ONLY_FEW_SHOTS:
    description = "3 few-shot study examples, 1 input query, "
else:
    description = "12 systematic study examples, 1 input query, "

description += "with image data." if IMAGE_INPUT else "text only."

create_batch(
    batch_file_path=batch_file_path,
    description=description
)

#### Retrieve Data

In [None]:
import json
from openai import OpenAI

from vmlc.utils.utils import save_dicts_as_jsonl


def get_batch_status(num_jobs: int = 12) -> list[str]:
    batches = openai.batches.list(limit=10)
    batch_id_list: list[str] = []

    for batch in batches:
        batch_id = batch.id
        status_object = client.batches.retrieve(batch_id)
        status = status_object.status
        if len(batch_id_list) < num_jobs:
            print(f"id: {batch_id} - status: {status}")
            batch_id_list.append(batch_id)
        else:
            break
    return batch_id_list


def read_file_content(
    client: OpenAI,
    file_id: str
) -> list[dict]:
    http_content = client.files.content(file_id)
    jsonl_str_content = http_content.read().decode("utf-8").strip().split("\n")
    jsonl_content = [json.loads(str_content) for str_content in jsonl_str_content if str_content]

    return jsonl_content


def extract_relevant_content(
    in_content: list[dict],
    out_content: list[dict]
) -> list[dict]:
    assert len(in_content) == len(out_content), f"Length mismatch between input file and output file!"
    relevant_content: list[dict] = []

    for (in_row, out_row) in zip(in_content, out_content):
        custom_id = out_row["custom_id"]
        status_code = out_row["response"]["status_code"]
        error = out_row["error"]
        evaluator_response = out_row["response"]["body"]["choices"][0]["message"]["content"]
        problem_and_response = [in_row["body"]["messages"][i]["content"] for i in range(len(in_row["body"]["messages"]))]

        relevant_content.append(
            {
                "custom_id": custom_id,
                "status_code": status_code,
                "error": error,
                "evaluator_input": problem_and_response, 
                "evaluator_response": evaluator_response,
            }
        )
    
    return relevant_content


def retrieve_batch_response(
    client: OpenAI,
    batch_id: str
) -> None:
    batch = client.batches.retrieve(batch_id)
    status = batch.status
    print(f"Batch ID: {batch_id} - Status: {status}")

    if status == "completed":
        print(f"Saving batch {batch_id}...")
        file_name = f"batch_id_{batch_id}.jsonl"

        in_content = read_file_content(
            client=client,
            file_id=batch.input_file_id
        )

        out_content = read_file_content(
            client=client,
            file_id=batch.output_file_id
        )
        # save entire output
        save_dicts_as_jsonl(data=out_content, filepath=f"{OUT_DIR}/openai_output/{file_name}")

        # save relevant parts
        relevant_content = extract_relevant_content(
            in_content=in_content,
            out_content=out_content
        )
        save_dicts_as_jsonl(data=relevant_content, filepath=f"{OUT_DIR}/{file_name}")
    else:
        print(f"Batch {batch_id} not completed yet. Please try again later.")

In [None]:
last_batch_ids = get_batch_status(num_jobs=3)
last_batch_ids

In [None]:
client = OpenAI(
    api_key=api_key
)

for batch_id in last_batch_ids:
    try:
        retrieve_batch_response(
            client=client,
            batch_id=batch_id
        )
    except Exception as e:
        print(f"An error occurred: {e}")
