In [79]:
import torch
import numpy as np

from math import factorial, log2, log10, ceil
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"

import json
from tqdm import tqdm

In [3]:
def state2im(state):
    M = np.zeros((2*n-1,2*n-1), dtype=int)
    M[n-1:, n-1:] = np.concatenate(([0], state[:n**2-1])).reshape(n,n)
    M[np.arange(n-1)] = M[np.arange(n-1)+n]
    M[:, np.arange(n-1)] = M[:, np.arange(n-1)+n]
    idx_hor, idx_ver = n-1-state[n*n-1], n-1-state[n*(n+1)-1]
    return M[idx_ver:idx_ver+n, idx_hor:idx_hor+n]

In [34]:
def get_target(n):
    V0 = np.concatenate((
        np.arange(n**2-1)+1, 
        np.roll(np.arange(n)[::-1], 1), 
        np.roll(np.arange(n)[::-1], 1)
    ))
    return V0

def get_moves(n):
    idx = np.arange(n)

    mr = np.arange(n**2).reshape(n,n)-1
    mr[0,1:] = np.roll(mr[0,1:], -1)
    for j in range(1, n):
        mr[j,:] = np.roll(mr[j,:], -1)
    mr = mr.reshape(-1)[1:]

    md = np.arange(n**2).reshape(n,n)-1
    md[1:,0] = np.roll(md[1:,0], -1)
    for j in range(1, n):
        md[:,j] = np.roll(md[:,j], -1)
    md = md.reshape(-1)[1:]

    gr = np.concatenate((mr, np.roll(idx, 1)+n*n-1, idx+n*(n+1)-1))
    gd = np.concatenate((md, idx+n*n-1, np.roll(idx, 1)+n*(n+1)-1))

    gl = np.argsort(gr)
    gu = np.argsort(gd)
    
    return np.stack((gr, gl, gd, gu))

## Test

In [88]:
n = 3

V0 = get_target(n)
(gr, gl, gd, gu) = get_moves(n)

print(V0)
print(gr)
print(gd)

[1 2 3 4 5 6 7 8 0 2 1 0 2 1]
[ 1  0  3  4  2  6  7  5 10  8  9 11 12 13]
[ 3  4  5  6  7  2  0  1  8  9 10 13 11 12]


In [90]:
state = V0.copy()
print("initial state:\n", state2im(state), sep='')
state = state[gd]
print("step down:\n", state2im(state), sep='')
state = state[gr]
print("step right:\n", state2im(state), sep='')

initial state:
[[0 1 2]
 [3 4 5]
 [6 7 8]]
step down:
[[3 1 2]
 [0 4 5]
 [6 7 8]]
step right:
[[3 1 2]
 [4 0 5]
 [6 7 8]]


In [91]:
state = V0.copy()
print("initial state:\n", state2im(state), sep='')
state = state[gr]
print("step right:\n", state2im(state), sep='')
state = state[gd]
print("step down:\n", state2im(state), sep='')

initial state:
[[0 1 2]
 [3 4 5]
 [6 7 8]]
step right:
[[1 0 2]
 [3 4 5]
 [6 7 8]]
step down:
[[1 4 2]
 [3 0 5]
 [6 7 8]]


## Save generators

In [71]:
j_shift = 26 # n=3:26, n=4:27, n=5:28, n=6:29

In [69]:
print(f"p   | {'puzzle':12s} | #actions | # unique elements | # elements |")
for j in range(8):
    n = j+3

    actions = get_moves(n).tolist()
    names = ["r", "r'", "d", "d'"]

    generators = {'actions':actions, 'names':names}
    solution_state = get_target(n)

    with open(f'../generators/p{j+j_shift:03d}.json', 'w') as f:
        json.dump(generators, f)

    print(f"{j+j_shift:03d} | puzzle_{str(n**2-1):5s} | {len(actions)}        | {str(n*(2+n)-1):17s} | {str(n*n-1):10s} |")

p   | puzzle       | #actions | # unique elements | # elements |
026 | puzzle_8     | 4        | 14                | 8          |
027 | puzzle_15    | 4        | 23                | 15         |
028 | puzzle_24    | 4        | 34                | 24         |
029 | puzzle_35    | 4        | 47                | 35         |
030 | puzzle_48    | 4        | 62                | 48         |
031 | puzzle_63    | 4        | 79                | 63         |
032 | puzzle_80    | 4        | 98                | 80         |
033 | puzzle_99    | 4        | 119               | 99         |


## Save datasets

In [70]:
def generate_inverse_moves(moves):
    """Generate the inverse moves for a given list of moves."""
    inverse_moves = [0] * len(moves)
    for i, move in enumerate(moves):
        if "'" in move:  # It's an a_j'
            inverse_moves[i] = moves.index(move.replace("'", ""))
        else:  # It's an a_j
            inverse_moves[i] = moves.index(move + "'")
    return inverse_moves

def random_step(states, last_moves):
    """Perform a random step while avoiding inverse moves."""
    possible_moves = torch.ones((states.size(0), all_moves.size(0)), dtype=torch.bool, device=states.device)
    possible_moves[torch.arange(states.size(0), device=states.device), inverse_moves[last_moves]] = False
    next_moves = torch.multinomial(possible_moves.float(), 1).squeeze()
    new_states = torch.gather(states, 1, all_moves[next_moves])
    return new_states, next_moves

In [85]:
N = 100
for j in range(8):
    n = j+3
    k = 0

    solution_state = get_target(n)

    with open(f'../generators/p{j+j_shift:03d}.json', 'r') as f:
        all_moves, move_names = json.load(f).values()
        all_moves = torch.tensor(all_moves, dtype=torch.int64)

    num_elements = len(np.unique(solution_state))
    b = max(3, ceil(log2(log2(num_elements))))
    dtype = {3:torch.int8, 4:torch.int16, 5:torch.int32, 6:torch.int64}[b]

    if num_elements >= 2**(2**b-1):
        shift = 2**(2**b-1)
    else:
        shift = 0

    solution_state = torch.tensor(np.array(solution_state) - shift, dtype=dtype)

    inverse_moves = torch.tensor(generate_inverse_moves(move_names), dtype=torch.int64)

    last_moves = torch.full((N,), -1, dtype=torch.int64)
    rnd_states = solution_state[None].expand(N, -1)

    for _ in range(10_000):
        rnd_states, last_moves = random_step(rnd_states, last_moves)

    mask = torch.randint(0, 2, (N,), dtype=torch.bool)
    rnd_states[mask], _ = random_step(rnd_states[mask], last_moves[mask])

    torch.save(rnd_states, f"../datasets/p{j+j_shift:03d}-t{k:03d}-rnd.pt")
    torch.save(solution_state, f"../targets/p{j+j_shift:03d}-t{k:03d}.pt")

    print(f"{j+j_shift:03d}.{k:03d}", end='\r')

033.000

In [86]:
# Klein-Cook