In [1]:
import json
import os
from functools import partial

import jax
from jax import numpy as jnp
from tqdm import tqdm

In [2]:
folder = "data/training"

In [3]:
file_paths = [
    os.path.join(folder, f)
    for f in os.listdir(folder) 
    if f.endswith('.json')
]

puzzles = []
for file_path in file_paths:
    puzzle = json.load(open(file_path, 'r'))
    puzzle['puzzle_id'] = file_path.split('/')[-1].split('.')[0]
    puzzles.append(puzzle)

len(puzzles), puzzles[0]

(400,
 {'train': [{'input': [[0, 0, 5], [0, 5, 0], [5, 0, 0]],
    'output': [[3, 3, 3], [4, 4, 4], [2, 2, 2]]},
   {'input': [[0, 0, 5], [0, 0, 5], [0, 0, 5]],
    'output': [[3, 3, 3], [3, 3, 3], [3, 3, 3]]},
   {'input': [[5, 0, 0], [0, 5, 0], [5, 0, 0]],
    'output': [[2, 2, 2], [4, 4, 4], [2, 2, 2]]},
   {'input': [[0, 5, 0], [0, 0, 5], [0, 5, 0]],
    'output': [[4, 4, 4], [3, 3, 3], [4, 4, 4]]}],
  'test': [{'input': [[0, 0, 5], [5, 0, 0], [0, 5, 0]],
    'output': [[3, 3, 3], [2, 2, 2], [4, 4, 4]]}],
  'puzzle_id': 'a85d4709'})

In [4]:
puzzles_flat = []
for puzzle in puzzles:
    for pair in puzzle['train']:
        puzzles_flat.append({
            "puzzle_id": puzzle['puzzle_id'],
            "x": jnp.asarray(pair['input']),
            "y": jnp.asarray(pair['output'])
        })

len(puzzles_flat), puzzles_flat[0]

(1302,
 {'puzzle_id': 'a85d4709', 'x': Array([[0, 0, 5],
         [0, 5, 0],
         [5, 0, 0]], dtype=int32), 'y': Array([[3, 3, 3],
         [4, 4, 4],
         [2, 2, 2]], dtype=int32)})

In [5]:
def d8_aug(puzzle, op_idx):
    ops = [
        lambda x: x,
        partial(jnp.rot90, k=1),
        partial(jnp.rot90, k=2),
        partial(jnp.rot90, k=3),
        jnp.fliplr,
        jnp.flipud,
        jnp.transpose,
        lambda x: jnp.transpose(jnp.rot90(x, k=1)),
    ]
    return {
        **puzzle,
        "x": ops[op_idx](puzzle["x"]), 
        "y": ops[op_idx](puzzle["y"]),
        "d8_aug": op_idx
    }

def colour_aug(puzzle, colours):
    return {
        **puzzle,
        "x": colours[puzzle["x"]],
        "y": colours[puzzle["x"]],
        "colour_aug": colours
    }

In [6]:
key = jax.random.key(0)
puzzles_aug = []
n_augs = 1000
for puzzle in tqdm(puzzles_flat):
    # no augs
    base = puzzle.copy() 
    base["d8_aug"] = -1
    base["colour_aug"] = jnp.arange(10)
    puzzles_aug.append(base)
    puzzles_meta = {(base['puzzle_id'], base['d8_aug'], str(base["colour_aug"]))}

    # keep trying augs until unique n_augs
    current_augs = 0
    while current_augs < n_augs:
        key, op_key, colour_key = jax.random.split(key, 3)
        op_idx = jax.random.randint(op_key, (), 0, 8).item()
        colours = jax.random.permutation(colour_key, jnp.arange(10))
        aug_puzzle = colour_aug(d8_aug(puzzle.copy(), op_idx), colours)
        aug_meta = (aug_puzzle['puzzle_id'], aug_puzzle['d8_aug'], str(aug_puzzle["colour_aug"]))
        if aug_meta not in puzzles_meta:
            puzzles_aug.append(aug_puzzle)
            puzzles_meta.add(aug_meta)
            current_augs += 1

100%|██████████| 1302/1302 [38:35<00:00,  1.78s/it]   


In [8]:
output_path = f"data/training_n_augs={n_augs}.jsonl"
with open(output_path, 'w') as file:
    for aug_puzzle in tqdm(puzzles_aug):
        out_puzzle = {
            **aug_puzzle,
            "x": aug_puzzle["x"].tolist(),
            "y": aug_puzzle["y"].tolist(),
            "colour_aug": aug_puzzle["colour_aug"].tolist()
        }
        json.dump(out_puzzle, file)
        file.write("\n")

100%|██████████| 1303302/1303302 [03:17<00:00, 6588.35it/s] 
