# Create OpenAI Image Batch Files for Few-Shot Abstract Visual Reasoning Task
This script creates prompts related to the few-shot experiment. We use `GPT-4o` via the OpenAI API to evaluate the model on the abstract visual reasoning task. We create a batch files that contain chunks of the test set with images. The input is the same as given to the meta-learning model, but displayed as an image, with an additional prompt that instructs the model with the respective task. The output should be the predicted output grid. Three few-shot examples are given for the task in this experiment.

## Create Image Batch File
We exploit OpenAI's Batch API to make efficient use of their model and reduce API costs. For this, we first need to create a batch file that contains all the prompts we want to evaluate.

In [None]:
MODEL = "gpt-4o-2024-08-06"

SEED = 1860
DATA_DIR = f"data/split_seed_{SEED}_only_few_shots"
FILE_NAME = f"systematicity_seed_{SEED}"

In [None]:
OUT_DIR = f"{MODEL}/image_batch_files/split_seed_{SEED}_only_few_shots"

### Model Prompt

In [None]:
user_prompt = """### Task Description:
You must solve an abstract visual reasoning task by identifying geometric transformations (e.g., rotation, translation, color changes, etc.) applied to objects within a 10x10 grid.

To infer the correct geometric transformation, you are given a series of **3 pairs of input-output examples**. Each example pair consists of:
- An **input grid**: a 10x10 list of lists (2d array), where each element is an integer (0-9).
- A corresponding **output grid**: a 10x10 list of lists (2d array) that has undergone a transformation based on a specific geometric rule.

For the prediction you need to understand the transformations displayed in the provided examples and apply them to the final input grid.

#### Your Task:
1. **Analyze** the example pairs to infer the transformation rules applied to each input grid.
2. **Identify** how these transformations are applied to generate the output grids.
3. **Apply** the deduced transformations to the final input grid.
4. **Output** the correctly transformed 10x10 grid.

### Output Requirements:
- **Return only the final output grid.**
- Do not include any extra text, explanations, or comments.
- Do not generate any code to solve the task.
- The output must be formatted exactly as:
 `output: [[...]]`
- The output grid must be a 10x10 list of lists containing only integers between 0 and 9 (inclusive).
- Do not include unnecessary line breaks or additional text beyond the specified format.

### Input Format:
You will receive the following data:
1. **Study examples:** A list of 3 few-shot example pairs, formatted as:
  `example input 1: [[...]], example output 1: [[...]], ..., example input 3: [[...]], example output 3: [[...]]`
2. **Final input:** A single 10x10 list of lists on which you must apply the inferred transformation(s).
3. **Image input:** Addtionally, you receive an image that visualizes the 3 few-shot example pairs and the final input query.

Your goal is to determine the correct transformation and return the final output grid.

### Input:
"""

### Get Data

In [None]:
from vmlc.utils.utils import load_jsonl

test_data = load_jsonl(
    file_path=f"{DATA_DIR}/test_{FILE_NAME}.jsonl"
)

### Image Data

In [None]:
import io
import base64
from matplotlib.figure import Figure

def encode_image(image_path: str) -> str:
    """
    Encodes an image file to a base64 string.

    This function reads the image from the specified path in binary mode,
    encodes the image data to base64, and returns the resulting string.

    Args:
        image_path (str): The file path to the image to be encoded.

    Returns:
        str: The base64-encoded string representation of the image.
    """
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")
        
def decode_image(base64_string: str, output_image_path: str) -> None:
    """
    Decodes a base64 string into an image file.

    This function decodes the provided base64-encoded string back into binary data
    and writes the resulting data to the specified output file path.

    Args:
        base64_string (str): The base64-encoded string of the image.
        output_image_path (str): The file path where the decoded image will be saved.
    """
    image_data = base64.b64decode(base64_string)
    with open(output_image_path, "wb") as output_file:
        output_file.write(image_data)

def encode_figure(figure: Figure, fmt: str = "png") -> str:
    """
    Encodes a matplotlib figure as a base64 string without writing it to disk.
    
    The function saves the provided figure into an in-memory buffer in the specified format,
    encodes the resulting image data into a base64 string, and returns that string.
    
    Args:
        figure (Figure): The matplotlib figure to encode.
        fmt (str, optional): The file format for the encoded image (e.g., 'png', 'jpeg').
                             Defaults to "png".
    
    Returns:
        str: A base64-encoded string representing the image of the figure.
    """
    # Create an in-memory bytes buffer
    buffer = io.BytesIO()
    
    # Save the figure to the buffer
    figure.savefig(buffer, format=fmt, bbox_inches="tight")
    
    # Ensure the buffer's pointer is at the beginning of the stream
    buffer.seek(0)
    
    # Read the buffer's content and encode it to base64
    encoded_image = base64.b64encode(buffer.read()).decode("utf-8")
    
    # Clean up the buffer
    buffer.close()
    
    return encoded_image

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors

from typing import Any, Dict, List, Tuple

def get_custom_cmap() -> mcolors.ListedColormap:
    """
    Returns a ListedColormap using a specified color order.

    Returns:
        mcolors.ListedColormap: The custom colormap.
    """
    colors = [
        "black",   # Background (value 0)
        "red",     # value 1
        "orange",  # value 2
        "yellow",  # value 3
        "green",   # value 4
        "blue",    # value 5
        "purple",  # value 6
        "pink",    # value 7
        "cyan",    # value 8
        "grey",    # value 9
        "white"    # Extra if needed
    ]
    return mcolors.ListedColormap(colors)

def plot_grid(ax: plt.Axes, grid: Any, title: str, cmap: mcolors.Colormap,
              vmin: int = 0, vmax: int = 10) -> None:
    """
    Plots a grid on the given axes with a title.

    Args:
        ax (plt.Axes): Axes on which to plot.
        grid (Any): Grid data (convertible to a NumPy array).
        title (str): Title for the subplot.
        cmap (mcolors.Colormap): Colormap to use.
        vmin (int, optional): Minimum data value. Defaults to 0.
        vmax (int, optional): Maximum data value. Defaults to 10.
    """
    grid_array = np.array(grid)
    rows, cols = grid_array.shape
    ax.imshow(grid_array, cmap=cmap, vmin=vmin, vmax=vmax)
    ax.set_title(title, fontsize=9)
    ax.set_xticks(np.arange(cols + 1) - 0.5, minor=True)
    ax.set_yticks(np.arange(rows + 1) - 0.5, minor=True)
    ax.grid(True, which="minor", color="white", linewidth=0.5)
    ax.set_xticks([])
    ax.set_yticks([])

def add_arrow(ax: plt.Axes) -> None:
    """
    Adds an arrow annotation to the axes and hides the axis.

    Args:
        ax (plt.Axes): Axes on which to add the arrow.
    """
    ax.annotate(
        "",
        xy=(1.4, 0.5),
        xytext=(-0.8, 0.5),
        arrowprops=dict(arrowstyle="->", lw=1),
        xycoords="axes fraction",
        textcoords="axes fraction"
    )
    ax.axis("off")

def extract_grid_pairs(study_examples: List[Any]) -> List[Tuple[Any, Any]]:
    """
    Extracts input/output grid pairs from the nested 'study_examples' structure.
    
    Assumes each study example is nested as: example[0] is a list containing
    at least two grids (input and output).

    Args:
        study_examples (List[Any]): The nested study_examples data.

    Returns:
        List[Tuple[Any, Any]]: A list of (input_grid, output_grid) pairs.
    """
    pairs = []
    for input_output_pair in study_examples:
        pairs.append((input_output_pair[0], input_output_pair[1]))
    return pairs

def extract_query_input(queries: List[Any]) -> Any:
    """
    Extracts the input grid for the query from the nested 'queries' structure.
    
    Assumes queries are nested as: queries[0][0][0] is the query input grid.

    Args:
        queries (List[Any]): The nested queries data.

    Returns:
        Any: The query input grid.

    Raises:
        ValueError: If the query grid is not found.
    """
    try:
        return queries[0][0]
    except (IndexError, TypeError):
        raise ValueError("Query input grid not found in the provided data.")

def plot_study_examples(test_episode: Dict[str, Any], verbose: int = 0) -> str:
    """
    Plots a series of input/output grid pairs from study_examples and a query input
    from queries, then returns the plot as a base64-encoded image.

    Args:
        test_episode (Dict[str, Any]): Data containing keys 'study_examples' and 'queries'.

    Returns:
        str: Base64-encoded image string of the resulting plot.
    """
    cmap = get_custom_cmap()

    # Extract grid pairs from the study_examples field
    study_examples = test_episode.get("study_examples", [])
    grids = extract_grid_pairs(study_examples)
    num_examples = len(grids)

    # Extract the query input grid from queries
    queries = test_episode.get("queries", [])
    query_input = extract_query_input(queries)

    # Create figure: one row per grid pair plus one extra row for the query; 3 columns layout
    fig, axes = plt.subplots(num_examples + 1, 3,
                             figsize=(7.68, 2 * (num_examples + 2)),
                             gridspec_kw={'width_ratios': [1, 0.1, 1]})
    fig.suptitle("Study Examples", fontsize=12, fontweight="bold", x=0.02, ha="left", y=0.82)
    plt.subplots_adjust(top=0.78, hspace=0.3)

    # Plot each study example (input/output pair)
    for i in range(num_examples):
        plot_grid(axes[i, 0], grids[i][0], f"Example input {i+1}:", cmap)
        add_arrow(axes[i, 1])
        plot_grid(axes[i, 2], grids[i][1], f"Example output {i+1}:", cmap)

    # Plot query input on the last row (only the input grid)
    query_ax = axes[num_examples, 0]
    pos = query_ax.get_position()
    new_y0 = pos.y0 - 0.01
    new_y1 = pos.y1 - 0.01
    query_ax.set_position([pos.x0, new_y0, pos.width, new_y1 - new_y0])
    plt.text(0, new_y1 + 0.01, "Query", fontsize=12, fontweight="bold",
             ha="left", transform=fig.transFigure)
    plot_grid(query_ax, query_input, "Final input:", cmap)
    add_arrow(axes[num_examples, 1])
    axes[num_examples, 2].axis("off")

    # Optional: print image size info
    fig_width, fig_height = fig.get_size_inches()
    dpi = fig.dpi

    if verbose > 0:
        print(f"Image size: {int(fig_width * dpi)}x{int(fig_height * dpi)} pixels")
        plt.show()
        
    encoded_image = encode_figure(fig)
    plt.close(fig)
    return encoded_image

In [None]:
prompt_example = plot_study_examples(test_data[1], verbose=1)

### Script

In [None]:
from typing import Any, List, Dict, Optional

from vmlc.utils.utils import save_dicts_as_jsonl

def prepare_study_examples(study_examples: List[List[List[List[str]]]]) -> str:
    study_example_str = ""

    for idx, input_output_pair in enumerate(study_examples):
        assert len(input_output_pair) == 2, f"Invalid number of input and output grids! {len(input_output_pair)}"
        input_grid = f"\nexample input {idx + 1}: {input_output_pair[0]}"
        output_grid = f"\nexample output {idx + 1}: {input_output_pair[1]}"

        study_example_str += input_grid + output_grid
    
    return study_example_str


def prepare_batch_files(
    test_data: List[Dict[str, Any]],
    user_prompt: str,
    num_samples_per_batch_file: int,
    model: str,
    out_dir: str,
    few_shot_examples: Optional[List[str]] = None
) -> None:

    curr_idx = 0

    while curr_idx < len(test_data):
        batch_file_content: List[Dict[str, Any]] = []
        curr_samples = test_data[curr_idx:curr_idx + num_samples_per_batch_file]

        for sample_num, sample in enumerate(curr_samples):
            print(f"sample: {curr_idx+sample_num}")
            batch_user_messages: List[Dict[str, str]] = []

            if few_shot_examples is not None:
                batch_user_messages += few_shot_examples

            study_example_str = prepare_study_examples(sample['study_examples'])
            input_grid_str = sample['queries'][0][0]
            
            batch_user_messages += [
                {
                    "role": "user",
                    "content": [
                        { "type": "text", "text": user_prompt + f"Study examples:{study_example_str}\n\n" + f"Final input:\n{input_grid_str}"},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{plot_study_examples(sample)}",
                                "detail": "high"
                            },
                        },
                    ]
                }
            ]

            batch_file_content.append(
                    {
                        "custom_id": f"test_sample_{curr_idx+sample_num}",
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": model,
                            "messages": batch_user_messages,
                            "max_tokens": 1000
                        }
                    }
                )
        
        if few_shot_examples is None:
            file_name = f"batch_file_samples_{curr_idx}-{min(curr_idx + num_samples_per_batch_file - 1, len(test_data) - 1)}.jsonl"
        else:
            file_name = f"batch_file_few_shots_samples_{curr_idx}-{min(curr_idx + num_samples_per_batch_file - 1, len(test_data) - 1)}.jsonl"
        
        save_dicts_as_jsonl(
            data=batch_file_content,
            filepath=f"{out_dir}/{file_name}"
        )

        curr_idx += num_samples_per_batch_file

In [None]:
prepare_batch_files(
    test_data=test_data,
    user_prompt=user_prompt,
    num_samples_per_batch_file=2500,
    model=MODEL,
    out_dir=OUT_DIR,
    few_shot_examples=None
)