In [None]:
%reload_ext autoreload
%autoreload 2
%pylab inline

import os
import pickle
import random
from collections import namedtuple
from itertools import count

import gym
import numpy as np
import pycuber
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T
import tqdm
from gym import spaces
from gym.utils import seeding
from tensorboardX import SummaryWriter
from torch.distributions import Categorical

import cube
import dqn
import solver

In [None]:
path_weights = 'weights/dqn_steps_10_mean_100_gamma_0.9_batch_512_episodes_100000.pkl'

policy = dqn.DQN(n_space=cube.N_SPACE, n_action=cube.N_ACTION)
policy.load_state_dict(pickle.load(open(path_weights, 'rb')))

test_cube = cube.CubeEnv(steps=5)

In [None]:
def test_solver(steps, solver, n_iter, time_limit):
    success_rate = []
    actions_length = []
    solve_times = []
    
    for _ in range(n_iter):
        solving_cube = cube.get_shuffled_cube(steps=steps)
        start_time = time.time()
        is_done, actions, depth, value = \
            solver.solve(solving_cube, time_limit=time_limit)
        solve_time = time.time() - start_time
        
        success_rate.append(is_done)
        if is_done:
            actions_length.append(len(actions))
            solve_times.append(solve_time)
            
    return np.mean(success_rate), actions_length, solve_times

In [None]:
greedy = solver.GreedySolver(policy)
mcts = solver.SimpleMCTSSolver(policy, tau=1.7)
ucb = solver.UCBSolver(policy, value_f, c=3, tau=1.7)

In [None]:
result = {
    'greedy': {},
    'naive': {},
    'ucb': {}}

In [None]:
for steps in range(6, 11):
    print(steps)
    
    success_rate, actions_length, solve_times = \
        test_solver(steps=steps, solver=greedy, n_iter=100, time_limit=30)
    result['greedy'][steps] = {
        'success_rate': success_rate,
        'actions_length': actions_length,
        'solve_times': solve_times}    
    print('greedy', success_rate)
    
    success_rate, actions_length, solve_times = \
        test_solver(steps=steps, solver=mcts, n_iter=100, time_limit=30)
    result['naive'][steps] = {
        'success_rate': success_rate,
        'actions_length': actions_length,
        'solve_times': solve_times}
    print('naive', success_rate)
    
    success_rate, actions_length, solve_times = \
        test_solver(steps=steps, solver=ucb, n_iter=100, time_limit=30)
    result['ucb'][steps] = {
        'success_rate': success_rate,
        'actions_length': actions_length,
        'solve_times': solve_times}
    print('ucb', success_rate)
    
    print('-' * 10)

6
greedy 0.83
naive 0.95
----------
7
greedy 0.63
naive 0.87
----------
8
