# Rubik's Cube Reinforcement Learning Solver

This notebook demonstrates how to build a reinforcement learning agent to solve a simplified Rubik's Cube.

## Overview
- Create a Rubik's Cube environment
- Implement Deep Q-Network (DQN)
- Train the agent
- Visualize the solution

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque, namedtuple
import matplotlib.pyplot as plt
from copy import deepcopy

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Rubik's Cube Environment

We'll implement a simplified 2x2x2 cube (Pocket Cube) for easier learning.

In [None]:
class RubiksCube2x2:
    """
    2x2x2 Rubik's Cube (Pocket Cube) environment.
    
    Faces: 0=Front, 1=Back, 2=Up, 3=Down, 4=Left, 5=Right
    Colors: 0-5 representing each face color
    """
    
    # Actions: each face can be rotated clockwise or counter-clockwise
    ACTIONS = ['F', "F'", 'B', "B'", 'U', "U'", 'D', "D'", 'L', "L'", 'R', "R'"]
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Reset to solved state."""
        # Each face is a 2x2 array with the same color
        self.state = np.array([np.full((2, 2), i) for i in range(6)])
        return self.get_state()
    
    def get_state(self):
        """Return flattened state for neural network."""
        return self.state.flatten().astype(np.float32) / 5.0
    
    def is_solved(self):
        """Check if cube is solved."""
        for face in self.state:
            if not np.all(face == face[0, 0]):
                return False
        return True
    
    def _rotate_face(self, face_idx, clockwise=True):
        """Rotate a face 90 degrees."""
        if clockwise:
            self.state[face_idx] = np.rot90(self.state[face_idx], -1)
        else:
            self.state[face_idx] = np.rot90(self.state[face_idx], 1)
    
    def move(self, action_idx):
        """Execute a move."""
        action = self.ACTIONS[action_idx]
        clockwise = "'" not in action
        face = action[0]
        
        if face == 'F':
            self._move_front(clockwise)
        elif face == 'B':
            self._move_back(clockwise)
        elif face == 'U':
            self._move_up(clockwise)
        elif face == 'D':
            self._move_down(clockwise)
        elif face == 'L':
            self._move_left(clockwise)
        elif face == 'R':
            self._move_right(clockwise)
    
    def _move_front(self, clockwise=True):
        """Rotate front face."""
        self._rotate_face(0, clockwise)
        temp = self.state[2][1, :].copy()
        if clockwise:
            self.state[2][1, :] = self.state[4][:, 1][::-1]
            self.state[4][:, 1] = self.state[3][0, :]
            self.state[3][0, :] = self.state[5][:, 0][::-1]
            self.state[5][:, 0] = temp
        else:
            self.state[2][1, :] = self.state[5][:, 0]
            self.state[5][:, 0] = self.state[3][0, :][::-1]
            self.state[3][0, :] = self.state[4][:, 1]
            self.state[4][:, 1] = temp[::-1]
    
    def _move_back(self, clockwise=True):
        """Rotate back face."""
        self._rotate_face(1, clockwise)
        temp = self.state[2][0, :].copy()
        if clockwise:
            self.state[2][0, :] = self.state[5][:, 1]
            self.state[5][:, 1] = self.state[3][1, :][::-1]
            self.state[3][1, :] = self.state[4][:, 0]
            self.state[4][:, 0] = temp[::-1]
        else:
            self.state[2][0, :] = self.state[4][:, 0][::-1]
            self.state[4][:, 0] = self.state[3][1, :]
            self.state[3][1, :] = self.state[5][:, 1][::-1]
            self.state[5][:, 1] = temp
    
    def _move_up(self, clockwise=True):
        """Rotate up face."""
        self._rotate_face(2, clockwise)
        temp = self.state[0][0, :].copy()
        if clockwise:
            self.state[0][0, :] = self.state[5][0, :]
            self.state[5][0, :] = self.state[1][0, :]
            self.state[1][0, :] = self.state[4][0, :]
            self.state[4][0, :] = temp
        else:
            self.state[0][0, :] = self.state[4][0, :]
            self.state[4][0, :] = self.state[1][0, :]
            self.state[1][0, :] = self.state[5][0, :]
            self.state[5][0, :] = temp
    
    def _move_down(self, clockwise=True):
        """Rotate down face."""
        self._rotate_face(3, clockwise)
        temp = self.state[0][1, :].copy()
        if clockwise:
            self.state[0][1, :] = self.state[4][1, :]
            self.state[4][1, :] = self.state[1][1, :]
            self.state[1][1, :] = self.state[5][1, :]
            self.state[5][1, :] = temp
        else:
            self.state[0][1, :] = self.state[5][1, :]
            self.state[5][1, :] = self.state[1][1, :]
            self.state[1][1, :] = self.state[4][1, :]
            self.state[4][1, :] = temp
    
    def _move_left(self, clockwise=True):
        """Rotate left face."""
        self._rotate_face(4, clockwise)
        temp = self.state[0][:, 0].copy()
        if clockwise:
            self.state[0][:, 0] = self.state[2][:, 0]
            self.state[2][:, 0] = self.state[1][:, 1][::-1]
            self.state[1][:, 1] = self.state[3][:, 0][::-1]
            self.state[3][:, 0] = temp
        else:
            self.state[0][:, 0] = self.state[3][:, 0]
            self.state[3][:, 0] = self.state[1][:, 1][::-1]
            self.state[1][:, 1] = self.state[2][:, 0][::-1]
            self.state[2][:, 0] = temp
    
    def _move_right(self, clockwise=True):
        """Rotate right face."""
        self._rotate_face(5, clockwise)
        temp = self.state[0][:, 1].copy()
        if clockwise:
            self.state[0][:, 1] = self.state[3][:, 1]
            self.state[3][:, 1] = self.state[1][:, 0][::-1]
            self.state[1][:, 0] = self.state[2][:, 1][::-1]
            self.state[2][:, 1] = temp
        else:
            self.state[0][:, 1] = self.state[2][:, 1]
            self.state[2][:, 1] = self.state[1][:, 0][::-1]
            self.state[1][:, 0] = self.state[3][:, 1][::-1]
            self.state[3][:, 1] = temp
    
    def scramble(self, num_moves=10):
        """Scramble the cube with random moves."""
        moves = []
        for _ in range(num_moves):
            action = random.randint(0, len(self.ACTIONS) - 1)
            self.move(action)
            moves.append(self.ACTIONS[action])
        return moves
    
    def visualize(self):
        """Print the cube state."""
        face_names = ['Front', 'Back', 'Up', 'Down', 'Left', 'Right']
        colors = ['G', 'B', 'W', 'Y', 'O', 'R']  # Green, Blue, White, Yellow, Orange, Red
        
        for i, name in enumerate(face_names):
            print(f"{name}:")
            for row in self.state[i]:
                print(' '.join(colors[int(c)] for c in row))
            print()

# Test the cube
cube = RubiksCube2x2()
print("Solved cube:")
cube.visualize()
print(f"Is solved: {cube.is_solved()}")

moves = cube.scramble(5)
print(f"\nAfter scrambling with: {moves}")
cube.visualize()
print(f"Is solved: {cube.is_solved()}")

## 2. Deep Q-Network (DQN)

In [None]:
class DQN(nn.Module):
    """Deep Q-Network for Rubik's Cube."""
    
    def __init__(self, state_size, action_size, hidden_sizes=[256, 256, 128]):
        super().__init__()
        
        layers = []
        input_size = state_size
        
        for hidden_size in hidden_sizes:
            layers.extend([
                nn.Linear(input_size, hidden_size),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_size)
            ])
            input_size = hidden_size
        
        layers.append(nn.Linear(input_size, action_size))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

# Test the network
state_size = 24  # 6 faces * 4 squares
action_size = 12  # 12 possible moves
model = DQN(state_size, action_size).to(device)
print(model)
print(f'\nParameters: {sum(p.numel() for p in model.parameters()):,}')

## 3. Replay Buffer

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

class ReplayBuffer:
    """Experience replay buffer."""
    
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, *args):
        self.buffer.append(Transition(*args))
    
    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)
        return Transition(*zip(*transitions))
    
    def __len__(self):
        return len(self.buffer)

## 4. DQN Agent

In [None]:
class DQNAgent:
    """DQN Agent for solving Rubik's Cube."""
    
    def __init__(
        self,
        state_size,
        action_size,
        lr=1e-4,
        gamma=0.99,
        epsilon_start=1.0,
        epsilon_end=0.01,
        epsilon_decay=0.995,
        buffer_size=100000,
        batch_size=64,
        target_update=10
    ):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size
        self.target_update = target_update
        
        # Networks
        self.policy_net = DQN(state_size, action_size).to(device)
        self.target_net = DQN(state_size, action_size).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        # Optimizer
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        
        # Replay buffer
        self.memory = ReplayBuffer(buffer_size)
        
        self.steps = 0
    
    def select_action(self, state):
        """Select action using epsilon-greedy policy."""
        if random.random() < self.epsilon:
            return random.randint(0, self.action_size - 1)
        
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
            self.policy_net.eval()
            q_values = self.policy_net(state_tensor)
            self.policy_net.train()
            return q_values.argmax().item()
    
    def train_step(self):
        """Perform one training step."""
        if len(self.memory) < self.batch_size:
            return 0
        
        # Sample batch
        batch = self.memory.sample(self.batch_size)
        
        states = torch.FloatTensor(np.array(batch.state)).to(device)
        actions = torch.LongTensor(batch.action).unsqueeze(1).to(device)
        next_states = torch.FloatTensor(np.array(batch.next_state)).to(device)
        rewards = torch.FloatTensor(batch.reward).to(device)
        dones = torch.FloatTensor(batch.done).to(device)
        
        # Current Q values
        current_q = self.policy_net(states).gather(1, actions).squeeze()
        
        # Target Q values
        with torch.no_grad():
            next_q = self.target_net(next_states).max(1)[0]
            target_q = rewards + (1 - dones) * self.gamma * next_q
        
        # Loss
        loss = F.smooth_l1_loss(current_q, target_q)
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 1.0)
        self.optimizer.step()
        
        # Update target network
        self.steps += 1
        if self.steps % self.target_update == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())
        
        return loss.item()
    
    def decay_epsilon(self):
        """Decay exploration rate."""
        self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)

## 5. Training Loop

In [None]:
def train_agent(agent, num_episodes=1000, max_steps=50, scramble_depth=3):
    """Train the DQN agent."""
    cube = RubiksCube2x2()
    
    history = {
        'episode_rewards': [],
        'episode_lengths': [],
        'solved': [],
        'losses': []
    }
    
    for episode in range(num_episodes):
        # Reset and scramble
        cube.reset()
        cube.scramble(scramble_depth)
        state = cube.get_state()
        
        episode_reward = 0
        episode_loss = 0
        
        for step in range(max_steps):
            # Select and execute action
            action = agent.select_action(state)
            cube.move(action)
            next_state = cube.get_state()
            
            # Calculate reward
            if cube.is_solved():
                reward = 100
                done = True
            else:
                reward = -1  # Small penalty for each move
                done = False
            
            # Store transition
            agent.memory.push(state, action, next_state, reward, done)
            
            # Train
            loss = agent.train_step()
            episode_loss += loss
            episode_reward += reward
            
            state = next_state
            
            if done:
                break
        
        # Decay epsilon
        agent.decay_epsilon()
        
        # Record history
        history['episode_rewards'].append(episode_reward)
        history['episode_lengths'].append(step + 1)
        history['solved'].append(cube.is_solved())
        history['losses'].append(episode_loss / (step + 1) if step > 0 else 0)
        
        # Print progress
        if (episode + 1) % 100 == 0:
            recent_solved = sum(history['solved'][-100:])
            avg_reward = np.mean(history['episode_rewards'][-100:])
            print(f'Episode {episode + 1}/{num_episodes} - '
                  f'Solved: {recent_solved}/100 - '
                  f'Avg Reward: {avg_reward:.2f} - '
                  f'Epsilon: {agent.epsilon:.3f}')
    
    return history

In [None]:
# Create agent and train
state_size = 24
action_size = 12

agent = DQNAgent(
    state_size=state_size,
    action_size=action_size,
    lr=1e-3,
    gamma=0.99,
    epsilon_start=1.0,
    epsilon_end=0.1,
    epsilon_decay=0.995,
    buffer_size=50000,
    batch_size=64,
    target_update=100
)

# Train (you may need to increase episodes for better results)
history = train_agent(agent, num_episodes=500, max_steps=20, scramble_depth=3)

## 6. Visualize Training Progress

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Episode rewards
window = 50
rewards_smooth = np.convolve(history['episode_rewards'], np.ones(window)/window, mode='valid')
axes[0, 0].plot(rewards_smooth)
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Reward')
axes[0, 0].set_title('Episode Rewards (Smoothed)')

# Solve rate
solved_smooth = np.convolve([int(s) for s in history['solved']], np.ones(window)/window, mode='valid')
axes[0, 1].plot(solved_smooth * 100)
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Solve Rate (%)')
axes[0, 1].set_title('Solve Rate (Smoothed)')

# Episode lengths
lengths_smooth = np.convolve(history['episode_lengths'], np.ones(window)/window, mode='valid')
axes[1, 0].plot(lengths_smooth)
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('Steps')
axes[1, 0].set_title('Episode Lengths (Smoothed)')

# Loss
losses_smooth = np.convolve(history['losses'], np.ones(window)/window, mode='valid')
axes[1, 1].plot(losses_smooth)
axes[1, 1].set_xlabel('Episode')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].set_title('Training Loss (Smoothed)')

plt.tight_layout()
plt.show()

## 7. Test the Trained Agent

In [None]:
def test_agent(agent, num_tests=10, scramble_depth=3, max_steps=20):
    """Test the trained agent."""
    cube = RubiksCube2x2()
    results = []
    
    for test in range(num_tests):
        cube.reset()
        scramble_moves = cube.scramble(scramble_depth)
        print(f"\nTest {test + 1}: Scrambled with {scramble_moves}")
        
        state = cube.get_state()
        solution = []
        
        for step in range(max_steps):
            action = agent.select_action(state)
            cube.move(action)
            solution.append(cube.ACTIONS[action])
            state = cube.get_state()
            
            if cube.is_solved():
                print(f"  Solved in {step + 1} moves: {solution}")
                results.append(True)
                break
        else:
            print(f"  Failed to solve in {max_steps} moves")
            results.append(False)
    
    print(f"\nSuccess rate: {sum(results)}/{num_tests} ({100*sum(results)/num_tests:.1f}%)")

# Set agent to evaluation mode
agent.epsilon = 0  # No exploration during testing
test_agent(agent, num_tests=10, scramble_depth=3)

## 8. Save Model

In [None]:
# Save the trained model
torch.save({
    'policy_net': agent.policy_net.state_dict(),
    'target_net': agent.target_net.state_dict(),
    'optimizer': agent.optimizer.state_dict(),
}, 'rubiks_cube_dqn.pth')
print('Model saved to rubiks_cube_dqn.pth')

## Next Steps

1. **Increase Training**: Train for more episodes with deeper scrambles
2. **Curriculum Learning**: Start with simple scrambles and increase difficulty
3. **Better Reward Shaping**: Add intermediate rewards for partial solves
4. **Double DQN**: Use double Q-learning for more stable training
5. **Dueling DQN**: Separate value and advantage streams
6. **3x3x3 Cube**: Scale up to standard Rubik's Cube
7. **AlphaZero Approach**: Combine MCTS with neural networks