In [1]:
import torch
import numpy as np
# from cube_env.env import RubiksCubeEnv
from cube_env.cube_env import RubiksCubeEnv
from agent.agent import DQNAgent
from cube_env.phases import count_solved_edges_first_layer
import polars as pl
import plotly.express as px
import plotly.graph_objects as go
from cube.utils import d_action_turn

In [2]:
model_path = "models/model_edges_first_layer.pt"
obs_size = 54
n_actions = 18

agent = DQNAgent(
    obs_dim=obs_size,
    n_actions=n_actions,
    lr=1e-3,
    gamma=0.99,
    batch_size=64,
    buffer_capacity=10_000
)

agent.policy_net.load_state_dict(torch.load(model_path))
agent.policy_net.eval()

env = RubiksCubeEnv()

num_eval_episodes = 3
max_steps = 100

obs, _ = env.reset()

l_episodes = []
l_l_actions = []

for episode in range(1, num_eval_episodes+1):
    l_actions = []
    obs, _ = env.reset()
    done = False
    steps = 0
    total_reward = 0
    while not done and steps < max_steps:
        action = agent.select_action(obs, epsilon=0.00)
        l_actions.append(action)
        obs, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        total_reward += reward
        steps += 1

    solved = count_solved_edges_first_layer(env.cube) == 4
    # print(f"Solved edges: {count_solved_edges_first_layer(env.cube)} | Steps: {steps} | Total reward: {total_reward:.2f}")
    # print(l_actions)

    if done == False:
        print("PANIK")

    env.render()
    l_episodes.append(episode)
    l_l_actions.append(l_actions)

env.close()

In [76]:
l_l_actions

[[3, 0, 15, 5, 14, 6, 1, 7, 11],
 [12, 17, 8, 13, 0, 7, 5, 1, 6, 11],
 [3, 8, 1, 13, 9, 14],
 [16, 11, 13, 4, 12, 0, 13, 11],
 [0, 11, 2, 14, 6, 1, 7, 3, 6],
 [17, 5, 15, 5, 17, 3, 17, 14, 16, 9, 14, 1, 5],
 [10, 1, 5, 1, 4, 12, 3, 17, 3, 17, 14, 9, 14, 1, 8],
 [0, 4, 15, 12, 11, 1, 5],
 [8, 10, 5, 11, 6, 11, 0, 5],
 [6, 12, 17, 1, 11, 1, 14, 1, 5],
 [2, 6, 13, 4, 13, 11],
 [17, 3, 1, 8, 10, 12, 1, 5, 13],
 [1, 9, 4, 7, 0, 5, 9, 12],
 [0, 13, 17, 8, 0, 5, 14, 9, 3, 0, 4, 14],
 [17, 8, 0, 5, 11, 12],
 [9, 6, 11, 1, 10, 14, 7, 3, 6],
 [13, 11, 1, 8],
 [15, 6, 2, 14, 8, 1, 11, 3, 7, 4],
 [10, 5, 8, 1, 12, 4],
 [16, 7, 16, 5, 13, 10],
 [15, 3, 1, 11, 13, 2, 14, 0, 5],
 [10, 16, 13, 3, 0, 11],
 [12, 16, 2, 17, 4, 1, 10, 8, 3],
 [15, 10, 7, 1, 5, 13, 10],
 [0, 11, 7, 5, 12, 3, 1, 8, 2, 5],
 [6, 10, 7, 5, 12, 0, 5, 1, 8],
 [17, 0, 13, 4, 16, 10, 13, 11, 3],
 [12, 3, 15, 9, 12, 3, 7, 4],
 [10, 6, 3, 6, 0, 14],
 [1, 9, 7, 12, 4, 16, 12, 0, 5, 1, 8],
 [4, 6, 16, 0, 9, 5, 0, 8, 9],
 [15, 7, 0, 14

In [97]:
pl_data = pl.DataFrame({
    "episodes": l_episodes,
    "actions": l_l_actions,
})\
    .with_columns(
        pl.col("actions").list.len().alias("moves_count"),
    )\
    .with_columns(
        pl.col("moves_count").mean().alias("mean"),
        pl.col("moves_count").std().alias("std"),
    )

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=pl_data["episodes"],
        y=pl_data["moves_count"],
        name="Moves count"
    )
)

fig.add_trace(
    go.Scatter(
        x=pl_data["episodes"],
        y=pl_data["mean"],
        name="Mean moves count",
        line=dict(dash='dash')
    )
)

fig.update_layout(
    height=650,
    title={
        'text': f'Moves count for solving edges of the first layer in inference',
        'x': 0.5,
        'y': 0.94,
    },
    xaxis=dict(title='Episode'),
    yaxis=dict(title='Moves count'),
)

fig.show()