In [None]:
from shared.rubiks.env import RubiksCubeEnv
from shared.dqn.model import DQN
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Initialize environment and seeds
env = RubiksCubeEnv(size=3, scramble_moves=1)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)
env.reset(seed=seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    
device

In [None]:
# We will use this function to determine the number of scramble moves for each episode
# We want to increase the number of scramble moves as the training progresses
def get_scramble_moves(episode, max_episodes, min_scrambles=1, max_scrambles=10):
    """Linearly increase scramble moves over episodes"""
    progress = episode / max_episodes
    return int(min_scrambles + progress * (max_scrambles - min_scrambles))

# DQN w No Replay Buffer or Target Net

In [None]:
GAMMA = 0.99
EPS_START = 0.9 # Initial exploration rate, 1.0 means 100% exploration, 0.0 means 100% exploitation
EPS_END = 0.01
EPS_DECAY = 0.001
LEARNING_RATE = 3e-4

# Reset the environment
env.reset(seed=seed)
input_dim = env.observation_space.shape[0]  # 6 * size^2
output_dim = env.action_space.n  # 12 (6 faces * 2 rotations)

# Reset the model weights
policy_net = DQN(input_dim, output_dim)  # Initialize with random weights and biases
optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
loss_fn = nn.SmoothL1Loss()
epsilon = EPS_START

num_episodes = 2000
max_steps = 10

# Reward tracking
reward_list = []

for episode in range(num_episodes):
    # Reset the game state for each episode
    state, info = env.reset()
    state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) # [1, input_dim]
    total_reward = 0
    
    for step in range(max_steps):
        # Select action (e-greedy)
        if random.random() < epsilon: # explore
            action_tensor = env.action_space.sample()
        else: # exploit
            with torch.no_grad():
                q_values = policy_net(state_tensor) # [1, output_dim], output_dim would be 12 as there are 12 actions
            action_tensor = torch.max(q_values, dim=1).indices.item()  # Get the index of the action with the highest Q-value
            
        # Perform action in the environment
        observation, reward, terminated, truncated, _ = env.step(action_tensor)
        next_state_tensor = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        reward_tensor = torch.tensor([reward], device=device)
        action_tensor = torch.tensor([[action_tensor]], device=device)
        done = terminated or truncated
        total_reward += reward
        
        # Get predicted Q-value for the current state-action pair
        state_action_values = policy_net(state_tensor)
        state_action_values = state_action_values.gather(1, action_tensor).squeeze(1)
        
        # Calculate target Q-value
        with torch.no_grad():
            # NOTE: We aren't using a target network here, just the policy_net itself. This code leads to the 'chasing a moving target' problem.
            next_state_values = policy_net(next_state_tensor)  # [1, output_dim]
        next_state_values = torch.max(next_state_values, dim=1).values  # Find the max q-value on the action dimension
        expected_state_action_values = reward_tensor if done else reward_tensor + GAMMA * next_state_values
    
        # Calculate loss
        loss = loss_fn(state_action_values, expected_state_action_values) # input, target
        
        optimizer.zero_grad() # Clear gradients
        loss.backward() # Backpropagation, this computes gradients
        optimizer.step() # Update weights based on gradients
        
        # Move to the next state
        state_tensor = next_state_tensor
        
        if done:
            break
        
    # Decay epsilon (less exploration over time)
    epsilon = max(EPS_END, epsilon - EPS_DECAY)
    
    reward_list.append(total_reward)
    
    print(f"Episode {episode+1}, Total Reward: {total_reward:.1f}, Epsilon: {epsilon:.3f}")

In [None]:
# Plot rewards moving avg over time
window_size = 50  # adjust this as needed
reward_array = np.array(reward_list)
moving_avg = np.convolve(reward_array, np.ones(window_size)/window_size, mode='valid')

plt.figure(figsize=(10, 5))
plt.plot(reward_list, color='lightgray', label='Raw Reward')
plt.plot(range(window_size - 1, len(reward_list)), moving_avg, color='blue', label=f'{window_size}-Step Moving Avg')
plt.xlabel('Step')
plt.ylabel('Reward')
plt.title('Reward per Step with Moving Average')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# DQN w Replay Buffer & No Target Net

In [None]:
from shared.dqn.memory import ReplayMemory, Transition

GAMMA = 0.99
EPS_START = 0.9 # Initial exploration rate, 1.0 means 100% exploration, 0.0 means 100% exploitation
EPS_END = 0.01
EPS_DECAY = 0.0001
LEARNING_RATE = 3e-4

# Reset the environment
env.reset(seed=seed)
input_dim = env.observation_space.shape[0]  # 6 * size^2
output_dim = env.action_space.n  # 12 (6 faces * 2 rotations)

# Reset the model weights
policy_net = DQN(input_dim, output_dim)  # Initialize with random weights and biases
optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
loss_fn = nn.SmoothL1Loss()
epsilon = EPS_START

num_episodes = 4000
max_steps = 30

memory = ReplayMemory(10000)  # Initialize replay memory with a capacity of 10,000
BATCH_SIZE = 32  # Size of the batch for training

# Reward tracking
reward_list = []

for episode in range(num_episodes):
    # Reset the game state for each episode
    scramble_moves = get_scramble_moves(episode, num_episodes)
    state, info = env.reset(options={'scramble_moves': scramble_moves})
    state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) # [1, input_dim]
    total_reward = 0
    
    for step in range(max_steps):
        # Select action (e-greedy)
        if random.random() < epsilon: # explore
            action_tensor = torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long) # [1, 1] tensor for action
        else: # exploit
            with torch.no_grad():
                q_values = policy_net(state_tensor) # [1, output_dim], output_dim would be 12 as there are 12 actions
            action_tensor = torch.max(q_values, dim=1).indices.view(1, 1)  # Get the index of the action with the highest Q-value
            
        # Perform action in the environment
        observation, reward, terminated, truncated, _ = env.step(action_tensor.item())
        next_state_tensor = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        reward_tensor = torch.tensor([reward], device=device) # [1] tensor for reward
        terminated_batch = torch.tensor([terminated], device=device, dtype=torch.bool) # [1] tensor for termination status
        done = terminated or truncated
        total_reward += reward
        
        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
        
        # Store transition in replay memory
        memory.push(state_tensor, action_tensor, next_state_tensor, reward_tensor, terminated_batch)
        
        if len(memory) < BATCH_SIZE:
            break
        # Sample a batch from memory
        transitions = memory.sample(BATCH_SIZE)
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
        # detailed explanation). This converts batch-array of Transitions
        # to Transition of batch-arrays.
        batch = Transition(*zip(*transitions))
        
        
        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state would've been the one after which simulation ended)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                                    if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        state_action_values = policy_net(state_batch).gather(1, action_batch)
        
        # Calculate target Q-value
        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        with torch.no_grad():
            # NOTE: We aren't using a target network here, just the policy_net itself. This code leads to the 'chasing a moving target' problem.
            next_state_values[non_final_mask] = policy_net(non_final_next_states).max(1).values
        
        # Compute the expected Q values
        expected_state_action_values = (reward_batch + GAMMA * next_state_values).unsqueeze(1)
    
        # Calculate loss
        loss = loss_fn(state_action_values, expected_state_action_values) # input, target
        
        optimizer.zero_grad() # Clear gradients
        loss.backward() # Backpropagation, this computes gradients
        optimizer.step() # Update weights based on gradients
        
        # Move to the next state
        state_tensor = next_state_tensor
        
        if done:
            break
        
    # Decay epsilon (less exploration over time)
    epsilon = max(EPS_END, epsilon - EPS_DECAY)
    
    reward_list.append(total_reward)

    print(f"Episode {episode+1}, Scramble Count: {scramble_moves}, Total Reward: {total_reward:.1f}, Epsilon: {epsilon:.3f}")

In [None]:
# Plot rewards moving avg over time
import matplotlib.pyplot as plt
import numpy as np

window_size = 50  # adjust this as needed
reward_array = np.array(reward_list)
moving_avg = np.convolve(reward_array, np.ones(window_size)/window_size, mode='valid')

plt.figure(figsize=(10, 5))
plt.plot(reward_list, color='lightgray', label='Raw Reward')
plt.plot(range(window_size - 1, len(reward_list)), moving_avg, color='blue', label=f'{window_size}-Step Moving Avg')
plt.xlabel('Step')
plt.ylabel('Reward')
plt.title('Reward per Step with Moving Average')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# DQN w Replay Buffer & Target Net

In [None]:
from shared.dqn.memory import ReplayMemory, Transition
GAMMA = 0.99
EPS_START = 0.9 # Initial exploration rate, 1.0 means 100% exploration, 0.0 means 100% exploitation
EPS_END = 0.01
EPS_DECAY = 0.001
LEARNING_RATE = 3e-4
TARGET_UPDATE_FREQUENCY = 100  # Update target network every 100 episodes

# Reset the environment
env.reset(seed=seed)
input_dim = env.observation_space.shape[0]  # 6 * size^2
output_dim = env.action_space.n  # 12 (6 faces * 2 rotations)

# Initialize both policy and target networks
policy_net = DQN(input_dim, output_dim).to(device)  # Main network for action selection and training
target_net = DQN(input_dim, output_dim).to(device)   # Target network for stable Q-value targets

# Initialize target network with same weights as policy network
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()  # Set target network to evaluation mode

optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
loss_fn = nn.SmoothL1Loss()
epsilon = EPS_START

num_episodes = 2000

memory = ReplayMemory(10000)  # Initialize replay memory with a capacity of 10,000
BATCH_SIZE = 32  # Size of the batch for training

# Reward tracking
reward_list = []

for scramble_moves in range(1, 11):
    epsilon = EPS_START  # Reset epsilon for each scramble level
    for episode in range(num_episodes):
        # Reset the game state for each episode
        state, info = env.reset(options={'scramble_moves': scramble_moves})
        state_tensor = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) # [1, input_dim]
        total_reward = 0
        
        for step in range(scramble_moves * 2):
            # Select action (e-greedy)
            if random.random() < epsilon: # explore
                action_tensor = torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long) # [1, 1] tensor for action
            else: # exploit
                with torch.no_grad():
                    q_values = policy_net(state_tensor) # [1, output_dim], output_dim would be 12 as there are 12 actions
                action_tensor = torch.max(q_values, dim=1).indices.view(1, 1)  # Get the index of the action with the highest Q-value
                
            # Perform action in the environment
            observation, reward, terminated, truncated, _ = env.step(action_tensor.item())
            next_state_tensor = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
            reward_tensor = torch.tensor([reward], device=device) # [1] tensor for reward
            terminated_batch = torch.tensor([terminated], device=device, dtype=torch.bool) # [1] tensor for termination status
            done = terminated or truncated
            total_reward += reward
            
            if terminated:
                next_state = None
            else:
                next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
            
            # Store transition in replay memory
            memory.push(state_tensor, action_tensor, next_state_tensor, reward_tensor, terminated_batch)
            
            if len(memory) < BATCH_SIZE:
                break
            # Sample a batch from memory
            transitions = memory.sample(BATCH_SIZE)
            # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
            # detailed explanation). This converts batch-array of Transitions
            # to Transition of batch-arrays.
            batch = Transition(*zip(*transitions))
            
            
            # Compute a mask of non-final states and concatenate the batch elements
            # (a final state would've been the one after which simulation ended)
            non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                batch.next_state)), device=device, dtype=torch.bool)
            non_final_next_states = torch.cat([s for s in batch.next_state
                                                        if s is not None])
            state_batch = torch.cat(batch.state)
            action_batch = torch.cat(batch.action)
            reward_batch = torch.cat(batch.reward)
            
            # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
            # columns of actions taken. These are the actions which would've been taken
            # for each batch state according to policy_net
            state_action_values = policy_net(state_batch).gather(1, action_batch)
            
            # Calculate target Q-value
            next_state_values = torch.zeros(BATCH_SIZE, device=device)
            with torch.no_grad():
                # Use target network to compute next state values for stability
                next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values
            
            # Compute the expected Q values
            expected_state_action_values = (reward_batch + GAMMA * next_state_values).unsqueeze(1)
        
            # Calculate loss
            loss = loss_fn(state_action_values, expected_state_action_values) # input, target
            
            optimizer.zero_grad() # Clear gradients
            loss.backward() # Backpropagation, this computes gradients
            optimizer.step() # Update weights based on gradients
            
            # Move to the next state
            state_tensor = next_state_tensor
            
            if done:
                break
                
        # Update target network periodically
        if episode % TARGET_UPDATE_FREQUENCY == 0:
            target_net.load_state_dict(policy_net.state_dict())
            print(f"Target network updated at episode {episode}")
        
        # Decay epsilon (less exploration over time)
        epsilon = max(EPS_END, epsilon - EPS_DECAY)
        
        reward_list.append(total_reward)

        print(f"Scramble Count: {scramble_moves}, Episode {episode+1}, Total Reward: {total_reward:.1f}, Epsilon: {epsilon:.3f}")

In [None]:
# Plot rewards moving avg over time
import matplotlib.pyplot as plt
import numpy as np

window_size = 50  # adjust this as needed
reward_array = np.array(reward_list)
moving_avg = np.convolve(reward_array, np.ones(window_size)/window_size, mode='valid')

plt.figure(figsize=(10, 5))
plt.plot(reward_list, color='lightgray', label='Raw Reward')
plt.plot(range(window_size - 1, len(reward_list)), moving_avg, color='blue', label=f'{window_size}-Step Moving Avg')
plt.xlabel('Step')
plt.ylabel('Reward')
plt.title('Reward per Step with Moving Average')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Save Current Model

In [None]:
import os

# Where to save the trained model
ARTIFACT_PATH = "artifacts/dqn_rubik_v1.pt"

os.makedirs(os.path.dirname(ARTIFACT_PATH), exist_ok=True)
torch.save({
    "model_state": policy_net.state_dict(),
    "meta": {
        "input_dim": input_dim,
        "output_dim": output_dim,
        "faces_order": ['U', 'D', 'L', 'R', 'F', 'B'],  # matches cube environment
    },
}, ARTIFACT_PATH)
print(f"Saved trained model to {ARTIFACT_PATH}")
