## Augmentating the ARC Training Data

### Dataset preparation

In [2]:
import os
import json

def read_json_files(directory):
    json_data = []

    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            file_path = os.path.join(directory, filename)

            with open(file_path, 'r') as file:
                try:
                    data = json.load(file)
                    json_data.append(data)
                except json.JSONDecodeError as e:
                    print(f"Error reading {filename}: {e}")

    return json_data

directory = 'arc_data/training'
all_json_data = read_json_files(directory)

In [3]:
def transform_input(data):
    transformed_data = {}

    # Iterate through each train and test case
    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = case['input']
            output_matrix = case['output']
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            # Add the output matrix to the transformed data
            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data

In [4]:
def extract_after_output(text):
    index = text.find('###Output:\n')
    if index != -1:
        return text[index + len('###Output:\n'):]
    else:
        return text

def extract_before_output(text):
    index = text.find('###Output:\n')
    if index != -1:
        return text[:index]
    else:
        return text

In [5]:
DEFAULT_PROMPT = "We are playing a game which involves transforming a 2D input grid of digits into an output grid of digits. Every below pair of grids contains the same transformation. Each Input grid is followed by an Output grid which applies the same transformation as previous Input/Output pairs. Given the provided examples, output the correct grid for the last input."

def generate_train_prompt(data_point):
    train = data_point['train']
    test = data_point['test']
    correct_output = extract_after_output(test).strip()
    text = f'{DEFAULT_PROMPT}\n{train}\n{test}'
    return {'text': text, 'labels': correct_output}

In [6]:
def flip_2d_list(matrix, flip_type):
    if flip_type == 'horizontal':
        # Flip each row
        return [row[::-1] for row in matrix]
    elif flip_type == 'vertical':
        # Reverse the order of rows
        return matrix[::-1]
    else:
        raise ValueError("Invalid flip type. Use 'horizontal' or 'vertical'.")

def rotate_matrix_90_degrees(matrix):
    return [list(row) for row in zip(*matrix[::-1])]

def rotate_matrix_270_degrees(matrix):
    return [list(row) for row in zip(*matrix)][::-1]


In [7]:
def transform_input_horizontal(data):
    transformed_data = {}

    # Iterate through each train and test case
    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = flip_2d_list(case['input'], 'horizontal')
            output_matrix = flip_2d_list(case['output'], 'horizontal')
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            # Add the output matrix to the transformed data
            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data


In [8]:
def transform_input_vertical(data):
    transformed_data = {}

    # Iterate through each train and test case
    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = flip_2d_list(case['input'], 'vertical')
            output_matrix = flip_2d_list(case['output'], 'vertical')
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            # Add the output matrix to the transformed data
            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data


In [9]:
def transform_input_270(data):
    transformed_data = {}

    # Iterate through each train and test case
    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = rotate_matrix_270_degrees(case['input'])
            output_matrix = rotate_matrix_270_degrees(case['output'])
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            # Add the output matrix to the transformed data
            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data

In [10]:
def transform_input_90(data):
    transformed_data = {}

    # Iterate through each train and test case
    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = rotate_matrix_90_degrees(case['input'])
            output_matrix = rotate_matrix_90_degrees(case['output'])
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            # Add the output matrix to the transformed data
            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data


In [13]:
from itertools import permutations

In [34]:
augemented_data = []

for i, element in enumerate(all_json_data):
    train_permutation = permutations(all_json_data[i]['train'])
    train_permutation = list(train_permutation)
    for j in range(len(train_permutation)):
        element['train'] = train_permutation[j]
        augemented_data.append(generate_train_prompt(transform_input(element)))
        augemented_data.append(generate_train_prompt(transform_input_90(element)))
        augemented_data.append(generate_train_prompt(transform_input_270(element)))
        augemented_data.append(generate_train_prompt(transform_input_vertical(element)))
        augemented_data.append(generate_train_prompt(transform_input_horizontal(element)))

In [35]:
num_tasks = len(augemented_data)

In [36]:
print(f'Number of tasks in the new augmented dataset: {len(augemented_data)}')

Number of tasks in the new augmented dataset: 18668610


### Save the augmented dataset

In [37]:
with open(f"arc_augmented_{str(num_tasks)}_tasks.json", "w") as file:
    json.dump(augemented_data, file, indent=4)