In [None]:
import json
import numpy as np
#np.random.seed(42)

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

task_name = "vehicles"

# Add images here, i.e. pictures of vehicles with "vehicles": ["car", "truck", "motorcycle", "jeep", "tractor"]
objects = {"plants": ["apple", "pineapple", "banana", "avocado", "ginger", "kiwi", "pear", "watermelon"]}
all_objects = objects[task_name]

background_colors = ["darkkhaki", "navajowhite", "gold"]
lines = []
file_num = 0
for _ in range(5):
    objects = list(np.random.choice(all_objects, np.random.randint(4,len(all_objects)+1)))
    rows = 2*np.random.randint(2,5)
    min_cols = int(4*len(objects)/(rows-2)+2)+1
    cols = 4*np.random.randint(min_cols//2, min_cols//2+4)
    assert(cols % 2 + rows % 2 == 0)
    fig, ax = plt.subplots(figsize=(cols/2, rows/2))

    correct = set(np.random.choice(objects, np.random.randint(1,len(objects)))) # No trick questions, there is always at least one object visible.
    incorrect = set(objects) - correct

    print(f"Correct: {correct}")

    spaces = (cols/2 - 2) * (rows - 2) / 4
    try:
        assert spaces - len(objects) > 0, "There is not enough space to fit all the objects"
    except Exception as e:
        print(f"Skipping due to error: {e}")
    optional_holes = int(max(0, min(spaces - len(objects), 4)))
    if optional_holes > 0:
        optional_holes = np.random.randint(0,optional_holes)

    def rand_position():
        return np.random.randint(0,1+cols/4-2), np.random.randint(0,1+rows/2-2)
    
    object_positions = {}
    try: # Terrible and hacky, refactor immediately
        for object in objects:
            position = rand_position()
            counter = 0
            while position in object_positions.values():
                position = rand_position()
                counter += 1
                if counter > 100:
                    raise Exception
            object_positions[object] = position
    except:
        continue

    hole_positions = []
    try: # Terrible and hacky, refactor immediately
        for k in range(optional_holes):
            position = rand_position()
            counter = 0
            while position in list(object_positions.values()) + hole_positions:
                position = rand_position()
                counter += 1
                if counter > 100:
                    raise Exception
            hole_positions.append(position)
    except:
        continue
    assert(len(object_positions.values()) == len(set(object_positions.values()))) # No duplicate positions

    # Define the grid size and colors
    grid_color = np.random.choice(background_colors, 1)[0]
    border_color = 'black'

    def add_small_patch(i,j):
        rect = patches.Rectangle((i, j), 1, 1, linewidth=1, edgecolor=border_color, facecolor=grid_color)
        ax.add_patch(rect)

    
    # Tile edge cells
    for i in range(cols):
        add_small_patch(i,0)
        add_small_patch(i,rows-1)
    for j in range(rows):
        add_small_patch(0,j)
        add_small_patch(cols-1,j)
        add_small_patch(cols//2-1,j)
        add_small_patch(cols//2,j)
    
    # Tile middle cells
    for r in range((rows-2)//2):
        for c in range((cols//2-2)//2):
            col = int(1 + cols//2 + 2*c)
            row = 1 + 2*r
            
            for i in [0,1]:
                for j in [0,1]:
                    add_small_patch(col+i,row+j)
            
            skip = False
            for object, position in object_positions.items():
                if position == (c,r) and object in correct:
                    skip = True
            if skip:
                continue

            if (c, r) in hole_positions:
                continue

            col = int(cols//2 - 3 - 2*c)
            row = 1 + 2*r

            for i in [0,1]:
                for j in [0,1]:
                    add_small_patch(col+i,row+j)      

    thick_linewidth = 8
    rect = patches.Rectangle((0, 0), cols, rows, linewidth=thick_linewidth, edgecolor=border_color, facecolor="#00000000")
    ax.add_patch(rect)
    rect = patches.Rectangle((0, 0), cols/2, rows, linewidth=thick_linewidth / 2, edgecolor=border_color, facecolor="#00000000")
    ax.add_patch(rect)

    for object, position in object_positions.items():
        col, row = position
        col = 1 +cols/2 + 2*col
        row = 1 + 2*row
        rect = patches.Rectangle((col, row), 2, 2, linewidth=1, edgecolor=border_color, facecolor=grid_color)
        ax.add_patch(rect)

        img = mpimg.imread(f'assets/{object}.png')
        ax.imshow(img, extent=[col, col+2, row, row+2], aspect='auto', zorder=2)

    ax.set_xlim(0, cols)
    ax.set_ylim(0, rows)
    ax.axis('off')

    plt.gca().set_aspect('equal')
    filename = f'puzzles/{task_name}-{file_num}.png'
    plt.savefig(filename, bbox_inches="tight")
    puzzle = {"id": f"{task_name}-{file_num}", "correct": list(correct), "incorrect": list(incorrect), "image": filename}
    lines.append(json.dumps(puzzle) + "\n")
    file_num += 1

with open(f"puzzles/{task_name}.jsonl", "w") as fp:
    fp.writelines(lines)
