In [1]:
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl
import gym_rubiks
import numpy as np

In [2]:
env = gym_rubiks.make("rubiks-2x2-5-v0")
print('observation space:', env.observation_space)
print('action space:', env.action_space)

obs = env.reset()
env.render()
#print('initial observation:', obs)

action = env.action_space.sample()
obs, r, done, info = env.step(action)
#print('next observation:', obs)
#print('reward:', r)
#print('done:', done)
#print('info:', info)

observation space: Box(6, 2, 2)
action space: Discrete(9)


In [3]:
class QFunction(chainer.Chain):

    def __init__(self, obs_size, n_actions, n_hidden_channels=9):
        super().__init__()
        with self.init_scope():
            self.conv_layers = chainer.ChainList(
                L.ConvolutionND(2, 6, 2000,2, stride=1))
            self.l0 = L.Linear(2000,256)
            self.l1 = L.Linear(256, n_hidden_channels)

    def __call__(self, x, test=False):
        """
        Args:
            x (ndarray or chainer.Variable): An observation
            test (bool): a flag indicating whether it is in test mode
        """
        h = x
        for l in self.conv_layers:
            h = F.relu(l(h))
        h = F.relu(self.l0(h))
        h = F.relu(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(h)

obs_size = 24
print(obs_size)
n_actions = env.action_space.n
q_func = chainerrl.q_functions.FCStateQFunctionWithDiscreteAction(
    obs_size, n_actions,
    n_hidden_layers=2, n_hidden_channels=2000)
#q_func = chainerrl.q_functions.DuelingDQN(n_actions,6)
#q_func = QFunction(obs_size,n_actions)

24


In [4]:
q_func.to_gpu(0)

<chainerrl.q_functions.state_q_functions.FCStateQFunctionWithDiscreteAction at 0x14ff7013668>

In [5]:
# Use Adam to optimize q_func. eps=1e-2 is for stability.
optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)

In [6]:

# Set the discount factor that discounts future rewards.
gamma = 0.95

# Use epsilon-greedy for exploration
explorer = chainerrl.explorers.ConstantEpsilonGreedy(
    epsilon=0.3, random_action_func=env.action_space.sample)

# DQN uses Experience Replay.
# Specify a replay buffer and its capacity.
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10 ** 6)

# Since observations from CartPole-v0 is numpy.float64 while
# Chainer only accepts numpy.float32 by default, specify
# a converter as a feature extractor function phi.
phi = lambda x: x.astype(np.float32, copy=False)

# Now create an agent that will interact with the environment.
agent = chainerrl.agents.DoubleDQN(
    q_func, optimizer, replay_buffer, gamma, explorer,
    replay_start_size=500, update_interval=1,
    target_update_interval=100, phi=phi)

In [7]:

n_episodes = 4000
max_episode_len = 10
n_done = 0
for i in range(1, n_episodes + 1):
    obs = env.reset()
    reward = 0
    done = False
    R = 0  # return (sum of rewards)
    t = 0  # time step
    while not done and t < max_episode_len:
        # Uncomment to watch the behaviour
        # env.render()
        action = agent.act_and_train(obs, reward)
        obs, reward, done, _ = env.step(action)
        R += reward
        t += 1
        n_done += done
#         if done:
#             print('solved on:',i)
#             print('earned :', R)
    if i % 100 == 0:
        print('episode:', i,
              'R:', R,
              'statistics:', agent.get_statistics())
        print('# solved:',n_done)
    agent.stop_episode_and_train(obs, reward, done)
print('Finished.')
print(n_done)

episode: 100 R: -10.0 statistics: [('average_q', -0.09931802910729833), ('average_loss', 0.2986903550690484)]
# solved: 3
episode: 200 R: -10.0 statistics: [('average_q', -4.670691649217793), ('average_loss', 0.45458787703651327)]
# solved: 10
episode: 300 R: -10.0 statistics: [('average_q', -8.402759642871336), ('average_loss', 1.0090146785682381)]
# solved: 25
episode: 400 R: -10.0 statistics: [('average_q', -9.358780449533208), ('average_loss', 1.8668534571750335)]
# solved: 45
episode: 500 R: -10.0 statistics: [('average_q', -5.623351658065536), ('average_loss', 2.9940957960784873)]
# solved: 70
episode: 600 R: -10.0 statistics: [('average_q', 12.361717644956313), ('average_loss', 4.879309392489103)]
# solved: 90
episode: 700 R: 93.0 statistics: [('average_q', 37.63880908081049), ('average_loss', 5.099146608848921)]
# solved: 133
episode: 800 R: 92.0 statistics: [('average_q', 53.696247997463004), ('average_loss', 3.3553696429220947)]
# solved: 180
episode: 900 R: 95.0 statistics: 

In [8]:
for i in range(100):
    obs = env.reset()
    done = False
    R = 0
    t = 0
    while not done and t < 20:
        env.render()
        action = agent.act(obs)
        obs, r, done, _ = env.step(action)
        R += r
        t += 1
    print('test episode:', i, 'R:', R)
    agent.stop_episode()

test episode: 0 R: 98.0
test episode: 1 R: 97.0
test episode: 2 R: 99.0
test episode: 3 R: 99.0
test episode: 4 R: 100.0
test episode: 5 R: 97.0
test episode: 6 R: 99.0
test episode: 7 R: 97.0
test episode: 8 R: 96.0
test episode: 9 R: 97.0
test episode: 10 R: -20.0
test episode: 11 R: 97.0
test episode: 12 R: 99.0
test episode: 13 R: 98.0
test episode: 14 R: 96.0
test episode: 15 R: 98.0
test episode: 16 R: -20.0
test episode: 17 R: 98.0
test episode: 18 R: 96.0
test episode: 19 R: 98.0
test episode: 20 R: 98.0
test episode: 21 R: 99.0
test episode: 22 R: 98.0
test episode: 23 R: 98.0
test episode: 24 R: 97.0
test episode: 25 R: 96.0
test episode: 26 R: 100.0
test episode: 27 R: 98.0
test episode: 28 R: 98.0
test episode: 29 R: 98.0
test episode: 30 R: 98.0
test episode: 31 R: 99.0
test episode: 32 R: 98.0
test episode: 33 R: 98.0
test episode: 34 R: -20.0
test episode: 35 R: 95.0
test episode: 36 R: 99.0
test episode: 37 R: 98.0
test episode: 38 R: 97.0
test episode: 39 R: 99.0
test 

In [9]:
agent.save('agent5')

In [10]:
agent.load('agent')

In [11]:
for i in range(100):
    obs = env.reset()
    done = False
    R = 0
    t = 0
    while not done and t < 20:
        env.render()
        action = agent.act(obs)
        obs, r, done, _ = env.step(action)
        R += r
        t += 1
    print('test episode:', i, 'R:', R)
    agent.stop_episode()

test episode: 0 R: 98.0
test episode: 1 R: 97.0
test episode: 2 R: 98.0
test episode: 3 R: 99.0
test episode: 4 R: 100.0
test episode: 5 R: 99.0
test episode: 6 R: 97.0
test episode: 7 R: 98.0
test episode: 8 R: 99.0
test episode: 9 R: 99.0
test episode: 10 R: 98.0
test episode: 11 R: 97.0
test episode: 12 R: 95.0
test episode: 13 R: 99.0
test episode: 14 R: 96.0
test episode: 15 R: 99.0
test episode: 16 R: 96.0
test episode: 17 R: 97.0
test episode: 18 R: 98.0
test episode: 19 R: 98.0
test episode: 20 R: 98.0
test episode: 21 R: 98.0
test episode: 22 R: 98.0
test episode: 23 R: 97.0
test episode: 24 R: -20.0
test episode: 25 R: 100.0
test episode: 26 R: 98.0
test episode: 27 R: 98.0
test episode: 28 R: 98.0
test episode: 29 R: 97.0
test episode: 30 R: 97.0
test episode: 31 R: 98.0
test episode: 32 R: -20.0
test episode: 33 R: 98.0
test episode: 34 R: 98.0
test episode: 35 R: 99.0
test episode: 36 R: 97.0
test episode: 37 R: 97.0
test episode: 38 R: 96.0
test episode: 39 R: 97.0
test e