In [2]:
import torch
import numpy as np

from math import factorial, log2, log10, ceil
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"

import json
from tqdm import tqdm

In [46]:
# pancake
def get_move(n, k):
    p = np.arange(n)
    p[:k] = p[:k][::-1]
    return p

def get_moves(n):
    return np.stack([get_move(n, k) for k in range(2, n+1)]).repeat(2, axis=0)

def get_target(n):
    return np.arange(n)

In [47]:
j_shift = 44

In [48]:
print(f"p   | {'puzzle':10s} | #actions   | # unique elements | # elements |")
for j in range(10):
    n = 5*j+10

    actions = get_moves(n).tolist()
    names = []
    for i in range(2, n+1):
        names.append(str(i))
        names.append(str(i)+"'")

    generators = {'actions':actions, 'names':names}
    solution_state = get_target(n)

    with open(f'../generators/p{j+j_shift:03d}.json', 'w') as f:
        json.dump(generators, f)

    print(f"{j+j_shift:03d} | pancake_{str(n):2s} | {str(len(actions)):3s}        | {str(n):17s} | {str(n):10s} |")

p   | puzzle     | #actions   | # unique elements | # elements |
044 | pancake_10 | 18         | 10                | 10         |
045 | pancake_15 | 28         | 15                | 15         |
046 | pancake_20 | 38         | 20                | 20         |
047 | pancake_25 | 48         | 25                | 25         |
048 | pancake_30 | 58         | 30                | 30         |
049 | pancake_35 | 68         | 35                | 35         |
050 | pancake_40 | 78         | 40                | 40         |
051 | pancake_45 | 88         | 45                | 45         |
052 | pancake_50 | 98         | 50                | 50         |
053 | pancake_55 | 108        | 55                | 55         |


In [49]:
def generate_inverse_moves(moves):
    """Generate the inverse moves for a given list of moves."""
    inverse_moves = [0] * len(moves)
    for i, move in enumerate(moves):
        if "'" in move:  # It's an a_j'
            inverse_moves[i] = moves.index(move.replace("'", ""))
        else:  # It's an a_j
            inverse_moves[i] = moves.index(move + "'")
    return inverse_moves

def random_step(states, last_moves):
    """Perform a random step while avoiding inverse moves."""
    possible_moves = torch.ones((states.size(0), all_moves.size(0)), dtype=torch.bool, device=states.device)
    possible_moves[torch.arange(states.size(0), device=states.device), inverse_moves[last_moves]] = False
    next_moves = torch.multinomial(possible_moves.float(), 1).squeeze()
    new_states = torch.gather(states, 1, all_moves[next_moves])
    return new_states, next_moves

In [None]:
N = 100
for j in range(10):
    n = 5*j+10
    k = 0

    solution_state = np.arange(n)

    with open(f'../generators/p{j+j_shift:03d}.json', 'r') as f:
        all_moves, move_names = json.load(f).values()
        all_moves = torch.tensor(all_moves, dtype=torch.int64)

    num_elements = len(np.unique(solution_state))
    b = max(3, ceil(log2(log2(num_elements))))
    dtype = {3:torch.int8, 4:torch.int16, 5:torch.int32, 6:torch.int64}[b]

    if num_elements >= 2**(2**b-1):
        shift = 2**(2**b-1)
    else:
        shift = 0

    solution_state = torch.tensor(np.array(solution_state) - shift, dtype=dtype)

    inverse_moves = torch.tensor(generate_inverse_moves(move_names), dtype=torch.int64)

    last_moves = torch.full((N,), -1, dtype=torch.int64)
    rnd_states = solution_state[None].expand(N, -1)

    for _ in range(10_000):
        rnd_states, last_moves = random_step(rnd_states, last_moves)

    mask = torch.randint(0, 2, (N,), dtype=torch.bool)
    rnd_states[mask], _ = random_step(rnd_states[mask], last_moves[mask])

    torch.save(rnd_states, f"../datasets/p{j+j_shift:03d}-t{k:03d}-rnd.pt")
    torch.save(solution_state, f"../targets/p{j+j_shift:03d}-t{k:03d}.pt")

    print(f"{j+j_shift:03d}.{k:03d}", end='\r')

045.000

In [51]:
rnd_states

tensor([[20, 43, 44,  ..., 15, 14, 24],
        [32,  1,  0,  ...,  2, 33, 20],
        [12, 21, 37,  ..., 25,  0, 45],
        ...,
        [35, 45, 39,  ..., 24, 17, 16],
        [25, 52,  0,  ..., 35,  2, 49],
        [ 5,  4, 49,  ...,  7,  8, 22]], dtype=torch.int8)