<a href="https://colab.research.google.com/github/lambroz/Rubik-Cube/blob/main/CubeGym.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load libraries

In [1]:
!sudo apt-get update
!pip install rubik-cube
!pip install gym==0.21.0
!pip install stable_baselines3
!pip install sb3-contrib

0% [Working]            Get:1 http://security.ubuntu.com/ubuntu bionic-security InRelease [88.7 kB]
0% [Connecting to archive.ubuntu.com (185.125.190.36)] [1 InRelease 14.2 kB/88.                                                                               Ign:2 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  InRelease
0% [Connecting to archive.ubuntu.com (185.125.190.36)] [1 InRelease 25.8 kB/88.                                                                               Get:3 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease [3,626 B]
0% [Connecting to archive.ubuntu.com (185.125.190.36)] [1 InRelease 28.7 kB/88.0% [Connecting to archive.ubuntu.com (185.125.190.36)] [1 InRelease 40.2 kB/88.                                                                               Hit:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  InRelease
0% [Connecting to archive.ubuntu.com (185.125.190.3

In [2]:
import gym
import math
import time
import numpy as np
from random import choices
from rubik.cube import Cube
from google.colab import drive
from stable_baselines3 import PPO, DQN
from sb3_contrib import QRDQN
from stable_baselines3.common.env_checker import check_env

In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


# Set up environment

In [4]:
class VCube(Cube):
    
    COLORS = ['B', 'G', 'O', 'R', 'W', 'Y']
    MOVES = ['F', 'B', 'R', 'L', 'D', 'U'] 

    def __init__(self, cube_str='OOOOOOOOOYYYWWWGGGBBBYYYWWWGGGBBBYYYWWWGGGBBBRRRRRRRRR', random_shuffles=0):
        super().__init__(cube_str)
        if random_shuffles:
          self.shuffle(random_shuffles) 
        
    def cube_key(self):
        """
        Creates a signature key that identifies the cube.
        """
        # Starts from the flat_str representation.
        s = self.flat_str()
        # Removes the letters in the center that do not influence the cube state.
        s_no_centers = s[:4] + s[5:22] + s[23:25] + s[26:28] + s[29:31] + s[32:49] + s[50:]
        # One-hot encoding of the colours.
        key = np.array([1 if self.COLORS[i] == l else 0 for l in s_no_centers for i in range(6)], dtype=np.int32)
        return key

    @classmethod
    def decode_cube_key(cls, key):
      """
      Decodes signature key and returns cube string.
      """
      cube_str = ''
      for k, b in enumerate(key):
        if b == 1:
          cube_str += cls.COLORS[k % 6]
        # Add center pieces
        if len(cube_str) == 4:
          cube_str += 'O'
        elif len(cube_str) == 22:
          cube_str += 'Y'
        elif len(cube_str) == 25:
          cube_str += 'W'
        elif len(cube_str) == 28:
          cube_str += 'G'
        elif len(cube_str) == 31:
          cube_str += 'B'
        elif len(cube_str) == 49:
          cube_str += 'R'
      return cube_str 

    def shuffle(self, k=20):
        """
        Shuffle the cube with n moves.
        """
        seq = ' '.join(choices(self.MOVES, k=k))
        self.sequence(seq)
        
    def evaluate_cube(self):
        """
        Count pieces in the correct positions.
        """
        solved_cube = 'OOOOOOOOOYYYWWWGGGBBBYYYWWWGGGBBBYYYWWWGGGBBBRRRRRRRRR'
        s = self.flat_str()
        n_matching_pieces = sum([1 if s[i] == solved_cube[i] else 0 for i in range(54)])
        return n_matching_pieces

In [5]:
solved_cube = VCube('O' * 9 + ('Y' * 3 + 'W' * 3 + 'G' * 3 + 'B' * 3) * 3 + 'R' * 9)
#solved_cube.shuffle()
print(solved_cube)

    OOO
    OOO
    OOO
YYY WWW GGG BBB
YYY WWW GGG BBB
YYY WWW GGG BBB
    RRR
    RRR
    RRR


In [6]:
print(solved_cube.flat_str())

key = solved_cube.cube_key()
print(len(key))

decoded_key = VCube.decode_cube_key(key)
print(decoded_key)

OOOOOOOOOYYYWWWGGGBBBYYYWWWGGGBBBYYYWWWGGGBBBRRRRRRRRR
288
OOOOOOOOOYYYWWWGGGBBBYYYWWWGGGBBBYYYWWWGGGBBBRRRRRRRRR


In [11]:
class RubikCubeEnv(gym.Env):
    """Custom Environment that follows gym interface"""
    metadata = {"render.modes": ["human"]}

    def __init__(self):
        super(RubikCubeEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(len(VCube.MOVES))
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=(288,), dtype=np.int32)
        self.cube = None
        self.n_moves = None

    def step(self, action):
      # Play
      move = self.cube.MOVES[action]
      self.cube.sequence(move)
      self.n_moves += 1
      # Observe cube and compute rewards
      observation = self.cube.cube_key() 
      reward = math.pow(self.cube.evaluate_cube(), 2) / 1e7
      if self.cube.is_solved():
          reward += 1
          done = True
      elif self.n_moves == 1000:
          done = True
      else:
          done = False
      return observation, reward, done, {}

    def reset(self):
        # Reset counter
        self.n_moves = 0
        # Reshuffle cube and observe it
        random_shuffles = round(np.random.normal(loc=5, scale=7, size=None)) + 2
        self.cube = VCube(random_shuffles=random_shuffles)
        observation = self.cube.cube_key()
        return observation  

    def render(self, mode="human"):
        print(self.cube)

    def close(self):
        pass

    def get_n_moves(self):
      return self.n_moves

# Test the environment

In [12]:
env = RubikCubeEnv()
check_env(env)
env.reset()

array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
       0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 1], dtype=int32)

# Create the model

In [13]:
algo = 'QRDQN-1024-1024-128-v4'
models_dir = f'drive/MyDrive/models/{algo}'
logdir = 'drive/MyDrive/models/logs'

In [14]:
#model = QRDQN('MlpPolicy', env, verbose=1, tensorboard_log=logdir, policy_kwargs={'net_arch' : [1024, 1024, 128]})
#model = DQN('MlpPolicy', env, verbose=1, tensorboard_log=logdir, policy_kwargs={'net_arch' : [512, 512]})
#model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=logdir)
#model = PPO.load('drive/MyDrive/models/PPO/6000000', env=env)
model = DQN.load('drive/MyDrive/models/DQN512/4000000', env=env)
#model = QRDQN.load('drive/MyDrive/models/QRDQN-512-512-512-v2/2000000', env=env)
#model = QRDQN.load('drive/MyDrive/models/QRDQN-1024-1024-128-v4/6000000', env=env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [None]:
TIMESTEPS = 1000000
for i in range(1, 20):
    model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False, tb_log_name=algo)
    model.save(f'{models_dir}/{TIMESTEPS*i}')

In [22]:
obs = env.reset()
env.render()
done = False
while not done:
  action, _states = model.predict(obs)
  obs, rewards, done, info = env.step(action)
  counter = env.get_n_moves()
  if counter % 100 == 1:
    print('---------------------')
    env.render()
    print(f'counter: {counter}')
    print(f'reward: {rewards}')

    BRW
    OOO
    WWW
GWR GGG OBY RBO
GYY OWR GGY RBO
RYY OWR GGO YYY
    WRB
    WRB
    BBB
---------------------
    YYO
    OOO
    WWW
WWR GGG OBB YRR
RYY OWR GGB YBB
BYY OWR GGB YOO
    WRB
    WRB
    GGR
counter: 1
reward: 3.61e-05
---------------------
    WOW
    WOW
    BRB
YYG RWR YGG OBO
YYG RWR YGG OBO
YYG RWR YGG OBO
    WOW
    BRB
    BRB
counter: 101
reward: 4.84e-05
---------------------
    OOR
    OOR
    OOR
YYY WWB GGG WBB
YYY WWB GGG WBB
YYY WWB GGG WBB
    RRO
    RRO
    RRO
counter: 201
reward: 0.0001764
---------------------
    OOR
    OOR
    OOR
YYY WWB GGG WBB
YYY WWB GGG WBB
YYY WWB GGG WBB
    RRO
    RRO
    RRO
counter: 301
reward: 0.0001764
---------------------
    RRY
    OOY
    OOY
WBB YYR WWW OGG
YYR WWW OGG WBB
YYO BBB RGG WBB
    GGG
    RRO
    RRO
counter: 401
reward: 5.29e-05
---------------------
    OOR
    OOR
    OOR
YYY WWB GGG WBB
YYY WWB GGG WBB
YYY WWB GGG WBB
    RRO
    RRO
    RRO
counter: 501
reward: 0.0001764
---------------