In [12]:
class TrainConfig:
    max_depth = 30                          # God's Number
    batch_size_per_depth = 1000             # number of scrambles per batch
    num_steps = 10000                       # number of batches
    learning_rate = 1e-3
    INTERVAL_PLOT, INTERVAL_SAVE = 100, 1000
    ENABLE_FP16 = False                     # Set this to True if you want to train the model faster

class SearchConfig:
    beam_width = 2**11.                      # This controls the trade-off between time and optimality
    max_depth = TrainConfig.max_depth * 2   # Any number above God's Number will do
    ENABLE_FP16 = False                     # Set this to True if you want to solve faster

In [13]:
import os
import time
import random
import pickle
import numpy as np
from copy import deepcopy
from contextlib import nullcontext
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
import matplotlib.colors as Colors
from cycler import cycler; plt.rcParams["axes.prop_cycle"] = cycler(color=["#000000", "#2180FE", "#EB4275"])
from IPython.display import clear_output

import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from numpy import random

print(f'device: {device}')
print(f'os.cpu_count(): {os.cpu_count()}')

device: cpu
os.cpu_count(): 8


In [14]:
"""
Colors:
         0 0 0 0
         0 0 0 0
         0 0 0 0
         0 0 0 0

1 1 1 1  2 2 2 2  3 3 3 3  4 4 4 4
1 1 1 1  2 2 2 2  3 3 3 3  4 4 4 4
1 1 1 1  2 2 2 2  3 3 3 3  4 4 4 4
1 1 1 1  2 2 2 2  3 3 3 3  4 4 4 4

         5 5 5 5
         5 5 5 5
         5 5 5 5
         5 5 5 5

             ** 16 17 **
             12 ** ** 13
              8 ** **  9
             **  0  1 **

** 12  8 **  **  0  1 **  **  9 13 **  ** 17 16 **
19 ** **  2   2 ** **  3   3 ** ** 18  18 ** ** 19
21 ** **  4   4 ** **  5   5 ** ** 20  20 ** ** 21
** 14 10 **  **  6  7 **  ** 11 15 **  ** 23 22 **

             **  6  7 **
             10 ** ** 11
             14 ** ** 15
             ** 22 23 **


0 = white
1 = orange
2 = green
3 = red
4 = blue
5 = yellow



Indices (starting with 16n):

              0  1  2  3
              4  5  6  7
              8  9 10 11
             12 13 14 15

16 17 18 19  32 33 34 35  48 49 50 51  64 65 66 67
20 21 22 23  36 37 38 39  52 53 54 55  68 69 70 71
24 25 26 27  40 41 42 43  56 57 58 59  72 73 74 75
28 29 30 31  44 45 46 47  60 61 62 63  76 77 78 79

             80 81 82 83
             84 85 86 87
             88 89 90 91
             92 93 94 95


Moves:
1U 1D 1L 1R 1B 1F 1U' 1D' 1L' 1R' 1B' 1F'
2U 2D 2L 2R 2B 2F 2U' 2D' 2L' 2R' 2B' 2F'

Corner indices:

             ** ** ** **
             **  0  1 **
             **  2  3 **
             ** ** ** **

** ** ** **  ** ** ** **  ** ** ** **  ** ** ** **
**  4  5 **  **  8  9 **  ** 12 13 **  ** 16 17 **
**  6  7 **  ** 10 11 **  ** 14 15 **  ** 18 19 **
** ** ** **  ** ** ** **  ** ** ** **  ** ** ** **

             ** ** ** **
             ** 20 21 **
             ** 22 23 **
             ** ** ** **

Edge Indices: (front to back -> left to right -> top to bottom)

             ** 16 17 **
             12 ** ** 13
              8 ** **  9
             **  0  1 **

** 12  8 **  **  0  1 **  **  9 13 **  ** 17 16 **
19 ** **  2   2 ** **  3   3 ** ** 18  18 ** ** 19
21 ** **  4   4 ** **  5   5 ** ** 20  20 ** ** 21
** 14 10 **  **  6  7 **  ** 11 15 **  ** 23 22 **

             **  6  7 **
             10 ** ** 11
             14 ** ** 15
             ** 22 23 **

Corner Indices: (front to back -> left to right -> top to bottom)

              4 ** **  5
             ** ** ** **
             ** ** ** **
              0 ** **  1

 4 ** **  0   0 ** **  1   1 ** **  5   5 ** **  4
** ** ** **  ** ** ** **  ** ** ** **  ** ** ** **
** ** ** **  ** ** ** **  ** ** ** **  ** ** ** **
 6 ** **  2   2 ** **  3   3 ** **  7   7 ** **  6

              2 ** **  3
             ** ** ** **
             ** ** ** **
              6 ** **  7

"""

"\nColors:\n         0 0 0 0\n         0 0 0 0\n         0 0 0 0\n         0 0 0 0\n\n1 1 1 1  2 2 2 2  3 3 3 3  4 4 4 4\n1 1 1 1  2 2 2 2  3 3 3 3  4 4 4 4\n1 1 1 1  2 2 2 2  3 3 3 3  4 4 4 4\n1 1 1 1  2 2 2 2  3 3 3 3  4 4 4 4\n\n         5 5 5 5\n         5 5 5 5\n         5 5 5 5\n         5 5 5 5\n\n             ** 16 17 **\n             12 ** ** 13\n              8 ** **  9\n             **  0  1 **\n\n** 12  8 **  **  0  1 **  **  9 13 **  ** 17 16 **\n19 ** **  2   2 ** **  3   3 ** ** 18  18 ** ** 19\n21 ** **  4   4 ** **  5   5 ** ** 20  20 ** ** 21\n** 14 10 **  **  6  7 **  ** 11 15 **  ** 23 22 **\n\n             **  6  7 **\n             10 ** ** 11\n             14 ** ** 15\n             ** 22 23 **\n\n\n0 = white\n1 = orange\n2 = green\n3 = red\n4 = blue\n5 = yellow\n\n\n\nIndices (starting with 16n):\n\n              0  1  2  3\n              4  5  6  7\n              8  9 10 11\n             12 13 14 15\n\n16 17 18 19  32 33 34 35  48 49 50 51  64 65 66 67\n20 21 22 

In [15]:
class Cube4:

    def __init__(self):
        self.DTYPE = np.int64
        self.reset()
        self.goal = np.arange(0, 16 * 6, dtype=self.DTYPE) // 16

        faces = ["U", "D", "L", "R", "B", "F"]
        degrees = ["", "'"]
        widths = ["1", "2"]
        degrees_inference = degrees[::-1]
        self.moves = [f"{w}{f}{n}" for w in widths for f in faces for n in degrees]
        self.moves_inference = [f"{w}{f}{n}" for w in widths for f in faces for n in degrees_inference]

        self.pairing = {
            "R": "L",
            "L": "R",
            "F": "B",
            "B": "F",
            "U": "D",
            "D": "U",
        }
        self.rotations = {
            'x': "1R 2R 1L' 2L'",
            'y': "1U 2U 1D' 2D'",
            'z': "1F 2F 1B' 2B'"
        }
        self.rotation_scrambles = [i + " " + j for i in ["", "1L 2L 1R' 2R'", "1L 2L 1R' 2R' 1L 2L 1R' 2R'", "1L' 2L' 1R 2R", "1U 2U 1D' 2D'", "1U' 2U' 1D 2D"] for j in ["", "1F 2F 1B' 2B'", "1F 2F 1B' 2B' 1F 2F 1B' 2B'", "1F' 2F' 1B 2B"]]

        # Prohibit obviously redundant moves.
        self.moves_available_after = {
            m: [v for v in self.moves if v[1] != m[1]] + [m]
            for m in self.moves
        } # self-cancelling moves on the same face

        # [OPTIMIZATION] slicing by move string (e.g., R', U, F) => indices (e.g., 2, 6, 1)
        self.moves_ix = [self.moves.index(m) for m in self.moves]
        self.moves_ix_available_after = {
            self.moves.index(m): [self.moves.index(m) for m in available_moves]
            for m, available_moves in self.moves_available_after.items()
        }

        self.moves_ix_inference = [self.moves.index(m) for m in self.moves_inference]
        self.pairing_ix = {
            0: 1,
            1: 0,
            2: 3,
            3: 2,
            4: 5,
            5: 4,
        } # Points to the opposite face index

        # Vectorize the sticker group replacement operations
        self.__vectorize_moves()

    def __str__(self):
        """Returns a string representation of the cube."""
        a = ""
        for i in range(0, 4):
            a += "         "
            for j in range(0, 4):
                a += str(self.state[4*i + j]) + " "
            a += "\n"
        a += "\n"
        for b in range(0, 4):
            for i in range(1, 5):
                for j in range(0, 4):
                    a += str(self.state[16*i + 4*b + j]) + " "
                a += " "
            a += "\n"
        a += "\n"
        for i in range(0, 4):
            a += "         "
            for j in range(0, 4):
                a += str(self.state[80 + 4*i + j]) + " "
            a += "\n"


        return a

    def reset(self, train=False):
      """Resets the cube to the solved state. If train mode is on, then solved states are defined as reduced 3x3 states. """
      self.state = np.arange(0, 16 * 6, dtype=self.DTYPE) // 16
      if train:
        self.scramble_corners()
        self.scramble_edges(paired=True)
        self.rotate_randomly()

    def is_solved(self):
        """Checks if the cube is in the solved state."""
        return self.are_centers_solved() and self.are_edges_solved() and self.paired_edge_parity() == 0 and self.permutation_parity() == 0

    def are_centers_solved(self):
        """Checks if center pieces are matching each other on every side."""
        center_indices = np.array([5, 6, 9, 10, 21, 22, 25, 26, 37, 38, 41, 42, 53, 54, 57, 58, 69, 70, 73, 74, 85, 86, 89, 90])
        return np.all([np.all(self.state[center_indices[4*i:4*i+4]] == self.state[center_indices[4*i]]) for i in range(6)])

    def are_edges_solved(self):
        """Checks if edge pieces are matching each other in each slot."""
        edge_indices = np.array([[13, 33], [14, 34], [23, 36], [39, 52], [27, 40], [43, 56], [45, 81], [46, 82], [8, 18], [11, 49], [30, 84], [61, 87], [4, 17], [7, 50], [29, 88], [62, 91], [66, 1], [65, 2], [68, 55], [71, 20], [72, 59], [75, 24], [78, 93], [77, 94]])
        edge_pairs = np.array([[0, 1], [2, 4], [3, 5], [6, 7], [8, 12], [9, 13], [10, 14], [11, 15], [16, 17], [18, 20], [19, 21], [22, 23]])
        return np.all([self.state[edge_indices[pair[0]][0]] == self.state[edge_indices[pair[1]][0]] and self.state[edge_indices[pair[0]][1]] == self.state[edge_indices[pair[1]][1]] for pair in edge_pairs])

    def scramble_centers(self):
        """Scramble the center pieces."""
        indices = np.array([5, 6, 9, 10, 21, 22, 25, 26, 37, 38, 41, 42, 53, 54, 57, 58, 69, 70, 73, 74, 85, 86, 89, 90])
        colors = np.array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5])

        np.random.shuffle(colors)
        self.state[indices] = colors

    def scramble_edges(self, paired=False):
        """Scramble the edge pieces. If paired mode is on, edge pairs are kept intact and scrambled together, maintaining edge and permutation parity."""
        indices = np.array([[13, 33], [14, 34], [23, 36], [39, 52], [27, 40], [43, 56], [45, 81], [46, 82], [8, 18], [11, 49], [30, 84], [61, 87], [4, 17], [7, 50], [29, 88], [62, 91], [66, 1], [65, 2], [68, 55], [71, 20], [72, 59], [75, 24], [78, 93], [77, 94]])
        colors = np.array([[0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 2], [1, 4], [1, 5], [2, 0], [2, 1], [2, 3], [2, 5], [3, 0], [3, 2], [3, 4], [3, 5], [4, 0], [4, 1], [4, 3], [4, 5], [5, 1], [5, 2], [5, 3], [5, 4]])
        edge_pairs = np.array([[0, 1], [2, 4], [3, 5], [6, 7], [8, 12], [9, 13], [10, 14], [11, 15], [16, 17], [18, 20], [19, 21], [22, 23]])


        if paired:
          while True:
              indices2 = np.array([[13, 33], [23, 36], [39, 52], [45, 81], [8, 18], [11, 49], [30, 84], [61, 87], [66, 1], [68, 55], [71, 20], [78, 93]])
              colors2 = np.array([[0, 1], [0, 2], [0, 3], [0, 4], [1, 2], [1, 4], [1, 5], [2, 3], [2, 5], [3, 4], [3, 5], [4, 5]])
              [np.random.shuffle(i) for i in colors2]
              np.random.shuffle(colors2)
              self.state[indices2] = colors2
              for pair in edge_pairs:
                self.state[indices[pair[1]]] = self.state[indices[pair[0]]]

              if self.paired_edge_parity() == 1:
                rand_pair = edge_pairs[np.random.randint(12)]
                temp = self.state[indices[rand_pair[0]][0]]
                self.state[indices[rand_pair[0]][0]] = self.state[indices[rand_pair[0]][1]]
                self.state[indices[rand_pair[1]][0]] = self.state[indices[rand_pair[1]][1]]
                self.state[indices[rand_pair[0]][1]] = temp
                self.state[indices[rand_pair[1]][1]] = temp

              if self.permutation_parity() == 0:
                break
        else:
          [np.random.shuffle(i) for i in colors]
          np.random.shuffle(colors)
          self.state[indices] = colors


    def scramble_corners(self):
        """Scramble the corner pieces, maintaining the corner parity invariant."""
        #ccw order, white/yellow first. By ordering in ccw, the number of cw turns to orient each corner (the parity) is equal to the index of the white/yellow face in the colors array
        indices = np.array([[12, 19, 32], [15, 35, 48], [80, 44, 31], [83, 60, 47],
                            [0, 67, 16], [3, 51, 64], [92, 28, 79], [95, 76, 63]])
        colors = np.array([[0, 1, 2], [0, 2, 3], [5, 2, 1], [5, 3, 2],
                          [0, 4, 1], [0, 3, 4], [5, 1, 4], [5, 4, 3]])

        #randomly rotates each corner
        for i in range(8):
          colors[i] = np.roll(colors[i], np.random.randint(3))

        # permutes the corners
        np.random.shuffle(colors)

        #sums the indexes of the white/yellow faces in the colors array
        parity = sum([np.where(i%5==0)[0][0] for i in colors]) % 3 # i%5=0 is satisfied when i = 0,5

        #correct parity by rotating a random edge
        rand_edge = np.random.randint(8)
        colors[rand_edge] = np.roll(colors[rand_edge], parity*2%3) # parity*2%3 maps to {0:0, 1:2, 2:1}

        # parity = sum([np.where(i%5==0)[0][0] for i in colors]) % 3     <- the new parity should always equal 0

        self.state[indices] = colors

    def rotate_randomly(self):
        """Randomly rotates the cube."""
        i = np.random.randint(len(self.rotation_scrambles))
        self.apply_scramble(self.rotation_scrambles[i])

    def corner_parity(self):
        """Computes the corner parity of the cube."""
        indices = np.array([[12, 19, 32], [15, 35, 48], [80, 44, 31], [83, 60, 47],
                            [0, 67, 16], [3, 51, 64], [92, 28, 79], [95, 76, 63]])
        colors = self.state[indices]

        return sum([np.where(i%5==0)[0][0] for i in colors]) % 3


    def paired_edge_parity(self):
        """Computes the edge parity of the cube."""
        assert self.are_edges_solved() == True

        indices = np.array([[13, 33], [14, 34], [23, 36], [39, 52], [27, 40], [43, 56], [45, 81], [46, 82], [8, 18], [11, 49], [30, 84], [61, 87], [4, 17], [7, 50], [29, 88], [62, 91], [66, 1], [65, 2], [68, 55], [71, 20], [72, 59], [75, 24], [78, 93], [77, 94]])
        edge_pairs = np.array([[0, 1], [2, 4], [3, 5], [6, 7], [8, 12], [9, 13], [10, 14], [11, 15], [16, 17], [18, 20], [19, 21], [22, 23]])

        parity = 0

        for i in [0, 3, 4, 5, 6, 7, 8, 11]:
            edge = indices[edge_pairs[i][0]]
            if edge[0] <= 15 or edge[0] >= 80:
              X = edge[0]
              YZ = edge[1]
            else:
              X = edge[1]
              YZ = edge[0]

            if self.state[X] == 0 or self.state[X] == 5:
              parity += 1
            elif (self.state[X] == 1 or self.state[X] == 3) and (self.state[YZ] == 2 or self.state[YZ] == 4):
              parity += 1
        for i in [1, 2, 9, 10]:
            edge = indices[edge_pairs[i][0]]

            if (edge[0] >= 32 and edge[0] <= 47) or (edge[0] >= 64 and edge[0] <= 79):
              Y = edge[0]
              Z = edge[1]
            else:
              Y = edge[1]
              Z = edge[0]
            if self.state[Z] == 0 or self.state[Z] == 5:
              parity += 1
            elif self.state[Y] == 2 or self.state[Y] == 4:
              parity += 1

        return parity % 2

    def reset_rotation(self):
        """Resets the cube's rotation to the default orientation (white on top, green at front)"""
        rotations = []

        if self.state[5] == 2:
            [self.apply_scramble(self.rotations['x']) for _ in range(3)]
            rotations.append("x'")
        if self.state[21] == 2:
            [self.apply_scramble(self.rotations['y']) for _ in range(3)]
            rotations.append("y'")
        if self.state[53] == 2:
            self.apply_scramble(self.rotations['y'])
            rotations.append("y")
        if self.state[69] == 2:
            [self.apply_scramble(self.rotations['y']) for _ in range(2)]
            rotations.append("y2")
        if self.state[85] == 2:
            self.apply_scramble(self.rotations['x'])
            rotations.append("x")

        if self.state[21] == 0:
            self.apply_scramble(self.rotations['z'])
            rotations.append("z")
        if self.state[53] == 0:
            [self.apply_scramble(self.rotations['z']) for _ in range(3)]
            rotations.append("z'")
        if self.state[85] == 0:
            [self.apply_scramble(self.rotations['z']) for _ in range(2)]
            rotations.append("z2")


        return rotations

    def permutation_parity(self):
        """Computes the permutation parity of the cube."""
        temp_cube = Cube4()
        temp_cube.state = self.state
        temp_cube.reset_rotation()

        parity = 0

        indices = np.array([[12, 19, 32], [15, 35, 48], [80, 44, 31], [83, 60, 47],
                            [0, 67, 16], [3, 51, 64], [92, 28, 79], [95, 76, 63]])

        corner_index_from_colors = {
          (0, 1, 2): 0,
          (0, 2, 3): 1,
          (1, 2, 5): 2,
          (2, 3, 5): 3,
          (0, 1, 4): 4,
          (0, 3, 4): 5,
          (1, 4, 5): 6,
          (3, 4, 5): 7
        }

        g = []

        for i in range(8):
            piece_colors = tuple(sorted([temp_cube.state[indices[i][j]] for j in range(3)]))
            g.append(corner_index_from_colors[piece_colors])

        v = [False for _ in range(8)]
        stack = [i for i in reversed(range(8))]

        while stack:
          node = stack.pop()
          v[node] = True
          if not v[g[node]]:
            parity += 1
            stack.append(g[node])

        edge_pair_index_from_colors = {
          (0, 2): 0,
          (1, 2): 1,
          (2, 3): 2,
          (2, 5): 3,
          (0, 1): 4,
          (0, 3): 5,
          (1, 5): 6,
          (3, 5): 7,
          (0, 4): 8,
          (3, 4): 9,
          (1, 4): 10,
          (4, 5): 11,
        }

        indices = np.array([[13, 33], [14, 34], [23, 36], [39, 52], [27, 40], [43, 56], [45, 81], [46, 82], [8, 18], [11, 49], [30, 84], [61, 87], [4, 17], [7, 50], [29, 88], [62, 91], [66, 1], [65, 2], [68, 55], [71, 20], [72, 59], [75, 24], [78, 93], [77, 94]])
        edge_pairs = np.array([[0, 1], [2, 4], [3, 5], [6, 7], [8, 12], [9, 13], [10, 14], [11, 15], [16, 17], [18, 20], [19, 21], [22, 23]])

        g = []

        for i in range(12):
            piece_colors = tuple(sorted([temp_cube.state[indices[edge_pairs[i][0]][j]] for j in range(2)]))
            g.append(edge_pair_index_from_colors[piece_colors])

        v = [False for _ in range(12)]
        stack = [i for i in reversed(range(12))]

        while stack:
          node = stack.pop()
          v[node] = True
          if not v[g[node]]:
            parity += 1
            stack.append(g[node])

        return parity % 2

    def finger(self, move):
        """Applies a single move on the cube state using move string."""
        if move[0] in self.rotations:
            if move[-1] == "'":
                for _ in range(3):
                    self.apply_scramble(self.rotations[move[0]])
            else:
                self.apply_scramble(self.rotations[move])
            return
        self.state[self.sticker_target[move]] = self.state[self.sticker_source[move]]

    def finger_ix(self, ix):
        """The same `finger` method **but using indices of moves for faster execution"""
        self.state[self.sticker_target_ix[ix]] = self.state[self.sticker_source_ix[ix]]

    def apply_scramble(self, scramble):
        """Applies a sequence of moves (scramble) to the cube state."""
        if isinstance(scramble, str):
            scramble = scramble.split()

        scramble2 = []
        for i in range(len(scramble)):
            a = ""
            if scramble[i][0].isdigit() or scramble[i][0] in self.rotations:
              a += scramble[i]
              scramble2.append(a)
              continue
            if "w" in scramble[i]:
                a += "2"
            else:
                a += "1"
            a += scramble[i][0]
            if scramble[i][-1] == "'":
                a += "'"
            elif scramble[i][-1] == "2":
                a += "2"
            scramble2.append(a)
            if a[0] == "2":
                scramble2.append("1" + a[1:])
        for m in scramble2:
            if m[-1]=='2':
                for _ in range(2):
                    self.finger(m[:-1])
            else:
                    self.finger(m)

    def scrambler(self, scramble_length):
        """
        Generates a random scramble of given length and returns the cube state and scramble moves as a generator.
        Please note that index-based implementations (faster) follow commented lexical logics.
        """
        while True:
            # Reset the cube state, scramble, and return cube state and scramble moves
            self.reset()
            scramble = []

            for i in range(scramble_length):
                if i:
                    last_move = scramble[-1]
                    if i > 1:   # [3rd~ moves]
                        while True:
                            # move = random.choice(self.moves_available_after[last_move])
                            move = random.choice(self.moves_ix_available_after[last_move])

                            if scramble[-2] == last_move == move:
                                # Three subsequent moves on the same face, which could be one
                                continue
                            # elif (
                            #     scramble[-2][0] == move[0] and len(scramble[-2] + move) == 3
                            #     and last_move[0] == self.pairing[move[0]]
                            # ):
                            # elif (
                                # scramble[-2]//2 == move//2 and scramble[-2]%2 != move%2
                                # and last_move//2 == self.pairing_ix[move//2]
                            # ):
                                # Two mutually canceling moves sandwiching an opposite face move
                                # continue
                            else:
                                break
                    else:       # [2nd move]
                        # move = random.choice(self.moves_available_after[last_move])
                        move = random.choice(self.moves_ix_available_after[last_move])
                else:           # [1st move]
                    # move = random.choice(self.moves)
                    move = random.choice(self.moves_ix)

                # self.finger(move)
                self.finger_ix(move)
                scramble.append(move)

                yield self.state, move


    def __vectorize_moves(self):
        """
        Vectorizes the sticker group replacement operations for faster computation.
        This method defines ```self.sticker_target``` and ```self.sticker_source``` to manage sticker colors (target is replaced by source).
        They define indices of target and source stickers so that the moves can be vectorized.
        """
        self.sticker_target, self.sticker_source = dict(), dict()

        self.sticker_replacement = {
            # Sticker A is replaced by another sticker at index B -> A:B
            '1U':{0: 12, 1: 8, 2: 4, 3: 0, 4: 13, 5: 9, 6: 5, 7: 1, 8: 14, 9: 10, 10: 6, 11: 2, 12: 15, 13: 11, 14: 7, 15: 3, 16: 32, 17: 33, 18: 34, 19: 35, 32: 48, 33: 49, 34: 50, 35: 51, 48: 64, 49: 65, 50: 66, 51: 67, 64: 16, 65: 17, 66: 18, 67: 19},
            '1D':{83: 80, 87: 81, 91: 82, 95: 83, 82: 84, 86: 85, 90: 86, 94: 87, 81: 88, 85: 89, 89: 90, 93: 91, 80: 92, 84: 93, 88: 94, 92: 95, 44: 28, 45: 29, 46: 30, 47: 31, 60: 44, 61: 45, 62: 46, 63: 47, 76: 60, 77: 61, 78: 62, 79: 63, 28: 76, 29: 77, 30: 78, 31: 79},
            '1L':{16: 28, 17: 24, 18: 20, 19: 16, 20: 29, 21: 25, 22: 21, 23: 17, 24: 30, 25: 26, 26: 22, 27: 18, 28: 31, 29: 27, 30: 23, 31: 19, 0: 79, 4: 75, 8: 71, 12: 67, 32: 0, 36: 4, 40: 8, 44: 12, 80: 32, 84: 36, 88: 40, 92: 44, 67: 92, 71: 88, 75: 84, 79: 80},
            '1R':{48: 60, 49: 56, 50: 52, 51: 48, 52: 61, 53: 57, 54: 53, 55: 49, 56: 62, 57: 58, 58: 54, 59: 50, 60: 63, 61: 59, 62: 55, 63: 51, 3: 35, 7: 39, 11: 43, 15: 47, 35: 83, 39: 87, 43: 91, 47: 95, 83: 76, 87: 72, 91: 68, 95: 64, 64: 15, 68: 11, 72: 7, 76: 3},
            '1B':{64: 76, 65: 72, 66: 68, 67: 64, 68: 77, 69: 73, 70: 69, 71: 65, 72: 78, 73: 74, 74: 70, 75: 66, 76: 79, 77: 75, 78: 71, 79: 67, 0: 51, 1: 55, 2: 59, 3: 63, 51: 95, 55: 94, 59: 93, 63: 92, 92: 16, 93: 20, 94: 24, 95: 28, 16: 3, 20: 2, 24: 1, 28: 0},
            '1F':{32: 44, 33: 40, 34: 36, 35: 32, 36: 45, 37: 41, 38: 37, 39: 33, 40: 46, 41: 42, 42: 38, 43: 34, 44: 47, 45: 43, 46: 39, 47: 35, 12: 31, 13: 27, 14: 23, 15: 19, 48: 12, 52: 13, 56: 14, 60: 15,  80: 60, 81: 56, 82: 52, 83: 48, 19: 80, 23: 81, 27: 82, 31: 83},
            '2U':{20: 36, 21: 37, 22: 38, 23: 39, 36: 52, 37: 53, 38: 54, 39: 55, 52: 68, 53: 69, 54: 70, 55: 71, 68: 20, 69: 21, 70: 22, 71: 23} | {a: a for a in range(0, 16)},
            '2D':{24: 72, 25: 73, 26: 74, 27: 75, 40: 24, 41: 25, 42: 26, 43: 27, 56: 40, 57: 41, 58: 42, 59: 43, 72: 56, 73: 57, 74: 58, 75: 59} | {a: a for a in range(0, 16)},
            '2L':{1: 78, 5: 74, 9: 70, 13: 66, 33: 1, 37: 5, 41: 9, 45: 13, 81: 33, 85: 37, 89: 41, 93: 45, 66: 93, 70: 89, 74: 85, 78: 81} | {a: a for a in range(16, 32)},
            '2R':{2: 34, 6: 38, 10: 42, 14: 46, 34: 82, 38: 86, 42: 90, 46: 94, 82: 77, 86: 73, 90: 69, 94: 65, 65: 14, 69: 10, 73: 6, 77: 2} | {a: a for a in range(16, 32)},
            '2B':{4: 50, 5: 54, 6: 58, 7: 62, 50: 91, 54: 90, 58: 89, 62: 88, 88: 17, 89: 21, 90: 25, 91: 29, 17: 7, 21: 6, 25: 5, 29: 4} | {a: a for a in range(32, 48)},
            '2F':{8: 30, 9: 26, 10: 22, 11: 18, 18: 84, 22: 85, 26: 86, 30: 87, 84: 61, 85: 57, 86: 53, 87: 49, 49: 8, 53: 9, 57: 10, 61: 11} | {a: a for a in range(32, 48)}
        }
        for m in self.moves:
            if len(m) == 2:
                assert m in self.sticker_replacement
            else:
                if m[-1] == "'":
                    self.sticker_replacement[m] = {
                        v: k for k, v in self.sticker_replacement[m[:2]].items()
                    }
                elif m[-1] == "2":
                    self.sticker_replacement[m] = {
                        k: self.sticker_replacement[m[:2]][v]
                        for k, v in self.sticker_replacement[m[:2]].items()
                    }
                else:
                    raise

            self.sticker_target[m] = list(self.sticker_replacement[m].keys())
            self.sticker_source[m] = list(self.sticker_replacement[m].values())

            for i, idx in enumerate(self.sticker_target[m]):
                assert self.sticker_replacement[m][idx] == self.sticker_source[m][i]

        # For index slicing
        self.sticker_target_ix = np.array([np.array(self.sticker_target[m]) for m in self.moves])
        self.sticker_source_ix = np.array([np.array(self.sticker_source[m]) for m in self.moves])


In [16]:
env = Cube4()

In [17]:
class LinearBlock(nn.Module):
    """
    Linear layer with ReLU and BatchNorm
    """
    def __init__(self, input_prev, embed_dim):
        super(LinearBlock, self).__init__()
        self.fc = nn.Linear(input_prev, embed_dim)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(embed_dim)

    def forward(self, inputs):
        x = inputs
        x = self.fc(x)
        x = self.relu(x)
        x = self.bn(x)
        return x

class ResidualBlock(nn.Module):
    """
    Residual block with two linear layers
    """
    def __init__(self, embed_dim):
        super(ResidualBlock, self).__init__()
        self.layers = nn.ModuleList([
            LinearBlock(embed_dim, embed_dim),
            LinearBlock(embed_dim, embed_dim)
        ])

    def forward(self, inputs):
        x = inputs
        for layer in self.layers:
            x = layer(x)
        x += inputs # skip-connection
        return x

class Model(nn.Module):
    """
    Fixed architecture following DeepCubeA.
    """
    def __init__(self, input_dim=576, output_dim=len(env.moves)):
        super(Model, self).__init__()
        self.input_dim = input_dim
        self.embedding = LinearBlock(input_dim, 5000)
        self.layers = nn.ModuleList([
            LinearBlock(5000,1000),
            ResidualBlock(1000),
            ResidualBlock(1000),
            ResidualBlock(1000),
            ResidualBlock(1000)
        ])
        self.output = nn.Linear(1000, output_dim)

    def forward(self, inputs):
        # int indices => float one-hot vectors
        x = nn.functional.one_hot(inputs, num_classes=6).to(torch.float)
        x = x.reshape(-1, self.input_dim)
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        logits = self.output(x)
        return logits

model = Model()
#model.load_state_dict(torch.load('/content/drive/MyDrive/no_parity_reduced_15000steps.pth', map_location=device))
model.to(device)

Model(
  (embedding): LinearBlock(
    (fc): Linear(in_features=576, out_features=5000, bias=True)
    (relu): ReLU()
    (bn): BatchNorm1d(5000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layers): ModuleList(
    (0): LinearBlock(
      (fc): Linear(in_features=5000, out_features=1000, bias=True)
      (relu): ReLU()
      (bn): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1-4): 4 x ResidualBlock(
      (layers): ModuleList(
        (0-1): 2 x LinearBlock(
          (fc): Linear(in_features=1000, out_features=1000, bias=True)
          (relu): ReLU()
          (bn): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
  )
  (output): Linear(in_features=1000, out_features=24, bias=True)
)

In [21]:
class ScrambleGenerator(torch.utils.data.Dataset):
    def __init__(
            self,
            num_workers=os.cpu_count(),
            max_depth=TrainConfig.max_depth,
            total_samples=TrainConfig.num_steps*TrainConfig.batch_size_per_depth
        ):
        self.num_workers = num_workers
        self.max_depth = max_depth
        self.envs = [Cube4() for _ in range(num_workers)]
        [env.reset(train=True) for env in self.envs]
        self.generators = [env.scrambler(self.max_depth) for env in self.envs]

        self.total_samples = total_samples

    def __len__(self):
        return self.total_samples

    def __getitem__(self, i):
        ''' generate one scramble, consisting of `self.max_depth` data points '''
        worker_idx = i % self.num_workers
        X = np.zeros((self.max_depth, 96), dtype=int)
        y = np.zeros((self.max_depth,), dtype=int)
        for j in range(self.max_depth):
            state, last_move = next(self.generators[worker_idx])
            X[j, :] = state
            y[j] = last_move
        return X, y

dataloader = torch.utils.data.DataLoader(
    ScrambleGenerator(),
    num_workers=os.cpu_count() if torch.cuda.is_available() else 0, # can't multiprocess with CPU I guess? Not entirely sure, but if I'm not using CUDA then num_workers must be 0 to work
    batch_size=TrainConfig.batch_size_per_depth
)

In [22]:
# models/cube4.pth was trained for 15,000 steps

def plot_loss_curve(h, iteration):
    _, ax = plt.subplots(1, 1)
    ax.plot(h)
    ax.set_xlabel("Steps")
    ax.set_ylabel("Cross-entropy loss")
    ax.set_xscale("log")

    plt.show()

def train(model, dataloader):
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=TrainConfig.learning_rate)
    g = iter(dataloader)
    h = []
    ctx = torch.cuda.amp.autocast(dtype=torch.float16) if TrainConfig.ENABLE_FP16 else nullcontext()

    for i in trange(1, TrainConfig.num_steps + 1):
        batch_x, batch_y = next(g)
        batch_x, batch_y = batch_x.reshape(-1, 96).to(device), batch_y.reshape(-1).to(device)

        with ctx:
            pred_y = model(batch_x)
            loss = loss_fn(pred_y, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        h.append(loss.item())
        if TrainConfig.INTERVAL_PLOT and i % TrainConfig.INTERVAL_PLOT == 0:
            clear_output()
            plot_loss_curve(h, model.state_dict())
        if TrainConfig.INTERVAL_SAVE and i % TrainConfig.INTERVAL_SAVE == 0:
            torch.save(model.state_dict(), f"/content/drive/MyDrive/parity_test_{i}steps.pth")
            print("Model saved.")
    print(f"Trained on data equivalent to {TrainConfig.batch_size_per_depth * TrainConfig.num_steps} solves.")
    return model

model = train(model, dataloader)

  0%|          | 11/10000 [00:53<13:36:02,  4.90s/it]


KeyboardInterrupt: 

In [28]:
model = Model(input_dim=576, output_dim=len(env.moves))
model.load_state_dict(torch.load('../efficientcube/models/cube4.pth', map_location=device, weights_only=True))
model.to(device)

Model(
  (embedding): LinearBlock(
    (fc): Linear(in_features=576, out_features=5000, bias=True)
    (relu): ReLU()
    (bn): BatchNorm1d(5000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layers): ModuleList(
    (0): LinearBlock(
      (fc): Linear(in_features=5000, out_features=1000, bias=True)
      (relu): ReLU()
      (bn): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1-4): 4 x ResidualBlock(
      (layers): ModuleList(
        (0-1): 2 x LinearBlock(
          (fc): Linear(in_features=1000, out_features=1000, bias=True)
          (relu): ReLU()
          (bn): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
  )
  (output): Linear(in_features=1000, out_features=24, bias=True)
)

In [34]:
@torch.no_grad()
def beam_search(
        env,
        model,
        beam_width=SearchConfig.beam_width,
        max_depth=SearchConfig.max_depth,
        skip_redundant_moves=True,
    ):
    """
    Best-first search algorithm.
    Input:
        env: A scrambled instance of the given environment.
        model: PyTorch model used to predict the next move(s).
        beam_width: Number of top solutions to return per depth.
        max_depth: Maximum depth of the search tree.
        skip_redundant_moves: If True, skip redundant moves.
    Output:
        if solved successfully:
            True, {'solutions':solution path, "num_nodes_generated":number of nodes expanded, "times":time taken to solve}
        else:
            False, None
    """
    model.eval()
    with torch.cuda.amp.autocast(dtype=torch.float16) if SearchConfig.ENABLE_FP16 else nullcontext():
        # metrics
        num_nodes_generated, time_0 = 0, time.time()
        candidates = [
            {"state":deepcopy(env.state), "path":[], "value":1.}
        ] # list of dictionaries

        for depth in range(max_depth+1):
            # TWO things at a time for every candidate: 1. check if solved & 2. add to batch_x
            batch_x = np.zeros((len(candidates), env.state.shape[-1]), dtype=np.int64)
            for i,c in enumerate(candidates):
                c_path, env.state = c["path"], c["state"]
                if c_path:
                    env.finger_ix(c_path[-1])
                    num_nodes_generated += 1
                    if env.is_solved():
                        # Revert: array of indices => array of notations
                        c_path = [str(env.moves[i]) for i in c_path]
                        return True, {'solutions':c_path, "num_nodes_generated":num_nodes_generated, "times":time.time()-time_0}
                batch_x[i, :] = env.state

            # after checking the nodes expanded at the deepest
            if depth==max_depth:
                print("Solution not found.")
                return False, None

            # make predictions with the trained DNN
            batch_x = torch.from_numpy(batch_x).to(device)
            batch_p = model(batch_x)
            batch_p = torch.nn.functional.softmax(batch_p, dim=-1)
            batch_p = batch_p.detach().cpu().numpy()

            # loop over candidates
            candidates_next_depth = []  # storage for the depth-level candidates storing (path, value, index).
            for i, c in enumerate(candidates):
                c_path = c["path"]
                value_distribution = batch_p[i, :] # output logits for the given state
                value_distribution *= c["value"] # multiply the cumulative probability so far of the expanded path

                for m, value in zip(env.moves_ix_inference, value_distribution): # iterate over all possible moves.
                    # predicted value to expand the path with the given move.

                    if c_path and skip_redundant_moves:
                        if m not in env.moves_ix_available_after[c_path[-1]]:
                            # Two mutually canceling moves
                            continue
                        elif len(c_path) > 1:
                            # if c_path[-2] == c_path[-1] == m:
                            if c_path[-2] == c_path[-1] == m:
                                # Three subsequent moves that could be one
                                continue
                            # elif (
                            #     c_path[-2][0] == m[0] and len(c_path[-2] + m) == 3
                            #     and c_path[-1][0] == env.pairing[m[0]]
                            # ):
                            #elif (
                                #c_path[-2]//2 == m//2 and c_path[-2]%2 != m%2
                                #and c_path[-1]//2 == env.pairing_ix[m//2]
                            #):
                                # Two mutually canceling moves sandwiching an opposite face move
                                #continue

                    # add to the next-depth candidates unless 'continue'd.
                    candidates_next_depth.append({
                        'state':deepcopy(c['state']),
                        "path": c_path+[m],
                        "value":value,
                    })

            # sort potential paths by expected values and renew as 'candidates'
            candidates = sorted(candidates_next_depth, key=lambda item: -item['value'])
            # if the number of candidates exceed that of beam width 'beam_width'
            candidates = candidates[:int(beam_width)]


In [35]:
result_ours = {
    "solutions":[],
    "num_nodes_generated":[],
    "times":[]
}
test_scrambles = ["L2 F B' L F2 D' B' R F' U L2 D R2 B2 L2 D R2 U2 L2 Rw2 F Fw2 D B2 U2 Rw2 D B U2 Rw2 Uw2 B2 R' U Rw' U2 R' Fw' Uw' Fw2 R' F Uw2 B' U"]
for scramble in tqdm(test_scrambles, position=0):
    # reset and scramble
    env.reset()
    env.apply_scramble(scramble)

    #env = scramble
    # solve
    success, result = beam_search(env, model)
    if success:
        for k in result_ours.keys():
            result_ours[k].append(result[k])
    else:
        result_ours["solutions"].append(None)
    #print(env)

result_ours['solution_lengths'] = [len(e) for e in result_ours['solutions'] if e]
result_ours['solution_lengths_count'] = {
    i: result_ours["solution_lengths"].count(i)
    for i in range(min(result_ours["solution_lengths"]), max(result_ours["solution_lengths"]))
}
f"Successfully solved {len(result_ours['times'])} cases out of {len(result_ours['solutions'])}"

100%|██████████| 1/1 [00:06<00:00,  6.22s/it]


'Successfully solved 1 cases out of 1'

In [36]:
print(result_ours)
print(env)
result_ours

{'solutions': [["1F'", '2L', '1U', '2F', '2R', "1F'", "2U'", "2D'", '2B', "1U'", '2R', "2F'", "1B'", "1R'", "2U'", '2D', '1R', '2B', "1R'", '2D', '1L', "1F'", "1R'", "1U'", '1F', '1D', '1R', "2D'"]], 'num_nodes_generated': [51740], 'times': [6.2072529792785645], 'solution_lengths': [28], 'solution_lengths_count': {}}
         1 3 3 1 
         0 0 0 3 
         0 0 0 3 
         5 1 1 1 

4 4 4 4  3 4 4 5  2 4 4 4  5 2 2 0  
1 1 1 0  1 2 2 2  0 3 3 5  3 4 4 5  
1 1 1 0  1 2 2 2  0 3 3 5  3 4 4 5  
3 2 2 3  5 2 2 3  0 4 4 1  0 0 0 0  

         2 5 5 4 
         1 5 5 5 
         1 5 5 5 
         2 3 3 2 



{'solutions': [["1F'",
   '2L',
   '1U',
   '2F',
   '2R',
   "1F'",
   "2U'",
   "2D'",
   '2B',
   "1U'",
   '2R',
   "2F'",
   "1B'",
   "1R'",
   "2U'",
   '2D',
   '1R',
   '2B',
   "1R'",
   '2D',
   '1L',
   "1F'",
   "1R'",
   "1U'",
   '1F',
   '1D',
   '1R',
   "2D'"]],
 'num_nodes_generated': [51740],
 'times': [6.2072529792785645],
 'solution_lengths': [28],
 'solution_lengths_count': {}}