## DQNをねずみ学習問題に適用してみる

In [1]:
#coding:utf-8
######skinner_DQN.py##########
import numpy as np
import chainer
from chainer import functions as F
from chainer import links as L
import chainerrl

  from ._conv import register_converters as _register_converters


In [2]:
class QFunction(chainer.Chain):
    def __init__(self, obs_size, n_actions, n_hidden_channels=2):
        super(QFunction, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(obs_size, n_hidden_channels)
            self.l2 = L.Linear(n_hidden_channels, n_hidden_channels)
            self.l3 = L.Linear(n_hidden_channels, n_actions)
            
    def __call__(self, x, test=False):
        h1 = F.tanh(self.l1(x))
        h2 = F.tanh(self.l2(h1))
        y = chainerrl.action_value.DiscreteActionValue(self.l3(h2))
        return y
    
def random_action():
    return np.random.choice([0, 1])

def step(state, action):
    reward = 0
    if state == 0:
        if action == 0:
            state = 1
        else:
            state = 0
            
    else:
        if action == 0:
            state = 0
        else:
            state = 1
            reward += 1
            
    return np.array([state]), reward

gamma = 0.9
alpha = 0.5
max_number_of_steps = 5    #1試行のステップ数
num_episodes = 50

q_func = QFunction(1, 2)
optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)
explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(start_epsilon=1.0, 
                                        end_epsilon=0.1, decay_steps=num_episodes,
                                        random_action_func=random_action)
replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10**6)
phi = lambda x: x.astype(np.float32, copy=False)
agent = chainerrl.agents.DQN(
                            q_func, optimizer, replay_buffer, 
                            gamma, explorer, replay_start_size=500,
                            update_interval=1, target_update_interval=100,
                            phi=phi)
#agent.load("agent")

for episode in range(num_episodes):
    state = np.array([0])
    R = 0
    reward = 0
    done = True
    
    for t in range(max_number_of_steps):
        action = agent.act_and_train(state, reward)
        next_state, reward = step(state, action)
        print(state, action, reward)
        R += reward    #報酬を追加
        state = next_state
    agent.stop_episode_and_train(state, reward, done)
    
    print("episode: {}    total reward: {}".format(episode+1, R))
agent.save("agent")

[0] 0 0
[1] 0 0
[0] 1 0
[0] 0 0
[1] 0 0
episode: 1    total reward: 0
[0] 1 0
[0] 0 0
[1] 1 1
[1] 1 1
[1] 0 0
episode: 2    total reward: 2
[0] 0 0
[1] 1 1
[1] 1 1
[1] 1 1
[1] 0 0
episode: 3    total reward: 3
[0] 0 0
[1] 0 0
[0] 1 0
[0] 1 0
[0] 0 0
episode: 4    total reward: 0
[0] 1 0
[0] 0 0
[1] 1 1
[1] 1 1
[1] 1 1
episode: 5    total reward: 3
[0] 1 0
[0] 0 0
[1] 0 0
[0] 0 0
[1] 1 1
episode: 6    total reward: 1
[0] 0 0
[1] 0 0
[0] 0 0
[1] 1 1
[1] 1 1
episode: 7    total reward: 2
[0] 0 0
[1] 1 1
[1] 1 1
[1] 1 1
[1] 1 1
episode: 8    total reward: 4
[0] 0 0
[1] 1 1
[1] 1 1
[1] 1 1
[1] 1 1
episode: 9    total reward: 4
[0] 1 0
[0] 0 0
[1] 0 0
[0] 0 0
[1] 1 1
episode: 10    total reward: 1
[0] 0 0
[1] 0 0
[0] 0 0
[1] 1 1
[1] 1 1
episode: 11    total reward: 2
[0] 0 0
[1] 1 1
[1] 1 1
[1] 1 1
[1] 1 1
episode: 12    total reward: 4
[0] 0 0
[1] 1 1
[1] 1 1
[1] 1 1
[1] 1 1
episode: 13    total reward: 4
[0] 0 0
[1] 1 1
[1] 1 1
[1] 1 1
[1] 1 1
episode: 14    total reward: 4
[0] 0 0
[1] 1 1

## OpenAI Gymによる倒立振子

In [1]:
#import myenv
import numpy as np
import gym    #倒立振子の実行環境
from gym import wrappers    #Gymの画像保存
import time
import chainer
import chainer.functions as F
import chainer.links as L
import chainerrl

  from ._conv import register_converters as _register_converters


In [2]:
#Q関数の定義
class QFunction(chainer.Chain):
    def __init__(self, bos_size, n_actions, n_hidden_channels=2):
        super().__init__()
        with self.init_scope():
            self.l0 = L.Linear(bos_size, n_hidden_channels)
            self.l1 = L.Linear(n_hidden_channels, n_hidden_channels)
            self.l2 = L.Linear(n_hidden_channels, n_actions)
            
    def __call__(self, x, test=False):
        h = F.tanh(self.l0(x))
        h = F.tanh(self.l1(h))
        return chainerrl.action_value.DiscreteActionValue(self.l2(h))
    
env = gym.make("CartPole-v0")

gamma = 0.9
alpha = 0.5
max_number_of_steps = 200
num_episodes = 300


q_func = QFunction(env.observation_space.shape[0], env.action_space.n)
optimizer = chainer.optimizers.Adam(eps=1e-2)
optimizer.setup(q_func)

explorer = chainerrl.explorers.LinearDecayEpsilonGreedy(start_epsilon=1.0, 
                                end_epsilon=0.1, decay_steps=num_episodes, 
                                random_action_func=env.action_space.sample)

replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=10**6)
phi = lambda x: x.astype(np.float32, copy=False)

agent = chainerrl.agents.DQN(
                    q_func, optimizer, replay_buffer, gamma, explorer, 
                    replay_start_size=500, update_interval=1, target_update_interval=100,
                    phi=phi)


for episode in range(num_episodes):
    observation = env.reset()
    done = False
    reward = 0
    R = 0
    
    for t in range(max_number_of_steps):
        if episode % 100 == 0:
            env.render()
        action = agent.act_and_train(observation, reward)
        observation, reward, done, info = env.step(action)
        R += reward
        if done:
            break
    agent.stop_episode_and_train(observation, reward, done)
    if episode % 10 == 0:
        print("episode:{}  R:{}  statistics:{}".format(episode, R, agent.get_statistics()))

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
break eposode: 0
episode:0  R:15.0  statistics:[('average_q', 0.0034965859030964223), ('average_loss', 0)]
break eposode: 1
break eposode: 2
break eposode: 3
break eposode: 4
break eposode: 5
break eposode: 6
break eposode: 7
break eposode: 8
break eposode: 9
break eposode: 10
episode:10  R:12.0  statistics:[('average_q', 0.019593022549087534), ('average_loss', 0)]
break eposode: 11
break eposode: 12
break eposode: 13
break eposode: 14
break eposode: 15
break eposode: 16
break eposode: 17
break eposode: 18
break eposode: 19
break eposode: 20
episode:20  R:12.0  statistics:[('average_q', 0.033206120577753705), ('average_loss', 0)]
break eposode: 21
break eposode: 22
break eposode: 23
break eposode: 24
break eposode: 25
break eposode: 26
break eposode: 27
break eposode: 28
break eposode: 29
break eposode: 30
episode:30  R:11.0  statistics:[('average_q', 0.04849225172857148), ('aver

break eposode: 283
break eposode: 284
break eposode: 285
break eposode: 286
break eposode: 287
break eposode: 288
break eposode: 289
break eposode: 290
episode:290  R:10.0  statistics:[('average_q', 5.4270701639145), ('average_loss', 0.7646095346116722)]
break eposode: 291
break eposode: 292
break eposode: 293
break eposode: 294
break eposode: 295
break eposode: 296
break eposode: 297
break eposode: 298
break eposode: 299


In [4]:
envids = [spec.id for spec in gym.envs.registry.all()]
envids

['Copy-v0',
 'RepeatCopy-v0',
 'ReversedAddition-v0',
 'ReversedAddition3-v0',
 'DuplicatedInput-v0',
 'Reverse-v0',
 'CartPole-v0',
 'CartPole-v1',
 'MountainCar-v0',
 'MountainCarContinuous-v0',
 'Pendulum-v0',
 'Acrobot-v1',
 'LunarLander-v2',
 'LunarLanderContinuous-v2',
 'BipedalWalker-v2',
 'BipedalWalkerHardcore-v2',
 'CarRacing-v0',
 'Blackjack-v0',
 'KellyCoinflip-v0',
 'KellyCoinflipGeneralized-v0',
 'FrozenLake-v0',
 'FrozenLake8x8-v0',
 'CliffWalking-v0',
 'NChain-v0',
 'Roulette-v0',
 'Taxi-v2',
 'GuessingGame-v0',
 'HotterColder-v0',
 'Reacher-v2',
 'Pusher-v2',
 'Thrower-v2',
 'Striker-v2',
 'InvertedPendulum-v2',
 'InvertedDoublePendulum-v2',
 'HalfCheetah-v2',
 'Hopper-v2',
 'Swimmer-v2',
 'Walker2d-v2',
 'Ant-v2',
 'Humanoid-v2',
 'HumanoidStandup-v2',
 'FetchSlide-v1',
 'FetchPickAndPlace-v1',
 'FetchReach-v1',
 'FetchPush-v1',
 'HandReach-v0',
 'HandManipulateBlockRotateZ-v0',
 'HandManipulateBlockRotateParallel-v0',
 'HandManipulateBlockRotateXYZ-v0',
 'HandManipul