pip install keras-rl

In [1]:
import numpy as np
import gym

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import EpsGreedyQPolicy
from rl.memory import SequentialMemory
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D

Using TensorFlow backend.


In [2]:
ENV_NAME = 'CartPole-v0'

# Get the environment and extract the number of actions available in the Cartpole problem
env = gym.make(ENV_NAME)
np.random.seed(123)
env.seed(123)
nb_actions = env.action_space.n

In [3]:
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(16))             #试用不同的神经网络模型和激活函数
model.add(Dropout(0.2))
model.add(Activation('relu'))
model.add(Dense(nb_actions,activation="linear"))
print(model.summary())

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 4)                 0         
_________________________________________________________________
dense_1 (Dense)              (None, 16)                80        
_________________________________________________________________
dropout_1 (Dropout)          (None, 16)                0         
_________________________________________________________________
activation_1 (Activation)    (None, 16)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 34        
Total params: 114
Trainable params: 114
Non-trainable params: 0
_______________________

In [4]:
policy = EpsGreedyQPolicy()
memory = SequentialMemory(limit=50000, window_length=1)
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

# Okay, now it's time to learn something! We visualize the training here for show, but this slows down training quite a lot. 
dqn.fit(env, nb_steps=5000, visualize=True, verbose=2)

Training for 5000 steps ...
    9/5000: episode: 1, duration: 1.495s, episode steps: 9, steps per second: 6, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.162 [-2.882, 1.749], loss: --, mean_absolute_error: --, mean_q: --
Instructions for updating:
Use tf.cast instead.




   20/5000: episode: 2, duration: 0.543s, episode steps: 11, steps per second: 20, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.111 [-3.316, 2.196], loss: 0.583898, mean_absolute_error: 0.640469, mean_q: 0.659885
   29/5000: episode: 3, duration: 0.200s, episode steps: 9, steps per second: 45, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.889 [0.000, 1.000], mean observation: -0.111 [-2.226, 1.423], loss: 0.471900, mean_absolute_error: 0.607929, mean_q: 0.846932




   39/5000: episode: 4, duration: 0.167s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.900 [0.000, 1.000], mean observation: -0.152 [-2.770, 1.738], loss: 0.418853, mean_absolute_error: 0.587467, mean_q: 0.919964




   47/5000: episode: 5, duration: 0.133s, episode steps: 8, steps per second: 60, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.138 [-2.563, 1.604], loss: 0.455315, mean_absolute_error: 0.644158, mean_q: 1.144103
   56/5000: episode: 6, duration: 0.154s, episode steps: 9, steps per second: 59, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.124 [-2.762, 1.768], loss: 0.393256, mean_absolute_error: 0.609623, mean_q: 1.247866
   65/5000: episode: 7, duration: 0.164s, episode steps: 9, steps per second: 55, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.139 [-2.798, 1.804], loss: 0.415989, mean_absolute_error: 0.624854, mean_q: 1.288174
   74/5000: episode: 8, duration: 0.148s, episode steps: 9, steps per second: 61, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.00

  326/5000: episode: 34, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.138 [-3.067, 2.005], loss: 0.735744, mean_absolute_error: 1.057123, mean_q: 2.766982
  337/5000: episode: 35, duration: 0.183s, episode steps: 11, steps per second: 60, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.909 [0.000, 1.000], mean observation: -0.134 [-2.740, 1.737], loss: 0.703948, mean_absolute_error: 1.018812, mean_q: 2.842488
  346/5000: episode: 36, duration: 0.149s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.159 [-2.825, 1.740], loss: 0.645382, mean_absolute_error: 1.067219, mean_q: 2.903301
  355/5000: episode: 37, duration: 0.150s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean acti

  613/5000: episode: 63, duration: 0.149s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.889 [0.000, 1.000], mean observation: -0.143 [-2.358, 1.416], loss: 0.698736, mean_absolute_error: 1.827471, mean_q: 4.018883
  622/5000: episode: 64, duration: 0.150s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.889 [0.000, 1.000], mean observation: -0.143 [-2.482, 1.607], loss: 0.695523, mean_absolute_error: 1.841189, mean_q: 4.044074
  630/5000: episode: 65, duration: 0.133s, episode steps: 8, steps per second: 60, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.156 [-2.523, 1.552], loss: 0.686945, mean_absolute_error: 1.832919, mean_q: 4.084052
  640/5000: episode: 66, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action

  913/5000: episode: 93, duration: 0.163s, episode steps: 10, steps per second: 61, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.900 [0.000, 1.000], mean observation: -0.139 [-2.584, 1.566], loss: 0.587097, mean_absolute_error: 2.232606, mean_q: 5.033912
  922/5000: episode: 94, duration: 0.148s, episode steps: 9, steps per second: 61, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.889 [0.000, 1.000], mean observation: -0.147 [-2.460, 1.543], loss: 0.586012, mean_absolute_error: 2.240338, mean_q: 4.916302
  933/5000: episode: 95, duration: 0.187s, episode steps: 11, steps per second: 59, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.818 [0.000, 1.000], mean observation: -0.104 [-2.586, 1.792], loss: 0.725483, mean_absolute_error: 2.330693, mean_q: 4.930007
  943/5000: episode: 96, duration: 0.164s, episode steps: 10, steps per second: 61, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean ac

 1193/5000: episode: 122, duration: 0.168s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.800 [0.000, 1.000], mean observation: -0.111 [-2.358, 1.549], loss: 0.547462, mean_absolute_error: 2.622486, mean_q: 5.223836
 1201/5000: episode: 123, duration: 0.131s, episode steps: 8, steps per second: 61, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 1.000 [1.000, 1.000], mean observation: -0.163 [-2.528, 1.529], loss: 0.575894, mean_absolute_error: 2.691753, mean_q: 5.361913
 1212/5000: episode: 124, duration: 0.186s, episode steps: 11, steps per second: 59, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.818 [0.000, 1.000], mean observation: -0.129 [-2.622, 1.757], loss: 0.714374, mean_absolute_error: 2.717479, mean_q: 5.299494
 1221/5000: episode: 125, duration: 0.146s, episode steps: 9, steps per second: 62, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean 

 1498/5000: episode: 152, duration: 0.218s, episode steps: 13, steps per second: 60, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.692 [0.000, 1.000], mean observation: -0.121 [-1.823, 1.127], loss: 0.460617, mean_absolute_error: 3.026842, mean_q: 5.570987
 1508/5000: episode: 153, duration: 0.163s, episode steps: 10, steps per second: 61, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.700 [0.000, 1.000], mean observation: -0.115 [-1.830, 1.181], loss: 0.534764, mean_absolute_error: 3.110224, mean_q: 5.641264
 1520/5000: episode: 154, duration: 0.202s, episode steps: 12, steps per second: 60, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.583 [0.000, 1.000], mean observation: -0.110 [-1.846, 1.150], loss: 0.551166, mean_absolute_error: 3.052565, mean_q: 5.493619
 1532/5000: episode: 155, duration: 0.198s, episode steps: 12, steps per second: 61, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], m

 1840/5000: episode: 180, duration: 0.181s, episode steps: 11, steps per second: 61, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.000 [0.000, 0.000], mean observation: 0.137 [-2.131, 3.256], loss: 2.584592, mean_absolute_error: 4.045169, mean_q: 6.979732
 1853/5000: episode: 181, duration: 0.219s, episode steps: 13, steps per second: 59, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.231 [0.000, 1.000], mean observation: 0.108 [-1.538, 2.459], loss: 1.705484, mean_absolute_error: 3.749029, mean_q: 6.553955
 1863/5000: episode: 182, duration: 0.167s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.300 [0.000, 1.000], mean observation: 0.106 [-1.402, 2.079], loss: 1.447718, mean_absolute_error: 3.761361, mean_q: 6.691996
 1876/5000: episode: 183, duration: 0.214s, episode steps: 13, steps per second: 61, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean

 2129/5000: episode: 208, duration: 0.184s, episode steps: 11, steps per second: 60, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.273 [0.000, 1.000], mean observation: 0.122 [-1.178, 1.804], loss: 2.636891, mean_absolute_error: 4.345182, mean_q: 7.435962
 2144/5000: episode: 209, duration: 0.248s, episode steps: 15, steps per second: 60, episode reward: 15.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.333 [0.000, 1.000], mean observation: 0.097 [-1.127, 1.761], loss: 2.595536, mean_absolute_error: 4.461460, mean_q: 7.670774
 2153/5000: episode: 210, duration: 0.150s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.333 [0.000, 1.000], mean observation: 0.114 [-1.028, 1.595], loss: 3.124922, mean_absolute_error: 4.428094, mean_q: 7.572355
 2164/5000: episode: 211, duration: 0.182s, episode steps: 11, steps per second: 60, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean a

 2433/5000: episode: 237, duration: 0.151s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.111 [0.000, 1.000], mean observation: 0.138 [-1.346, 2.255], loss: 2.821190, mean_absolute_error: 4.655288, mean_q: 7.974102
 2441/5000: episode: 238, duration: 0.132s, episode steps: 8, steps per second: 61, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.000 [0.000, 0.000], mean observation: 0.135 [-1.612, 2.564], loss: 2.457121, mean_absolute_error: 4.761731, mean_q: 8.267947
 2450/5000: episode: 239, duration: 0.149s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.000 [0.000, 0.000], mean observation: 0.134 [-1.792, 2.771], loss: 2.469788, mean_absolute_error: 4.654816, mean_q: 8.069149
 2460/5000: episode: 240, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean actio

 2720/5000: episode: 266, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.100 [0.000, 1.000], mean observation: 0.128 [-1.563, 2.495], loss: 2.900705, mean_absolute_error: 5.003161, mean_q: 8.594482
 2730/5000: episode: 267, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.000 [0.000, 0.000], mean observation: 0.110 [-1.969, 2.990], loss: 3.246304, mean_absolute_error: 5.035436, mean_q: 8.610533
 2739/5000: episode: 268, duration: 0.150s, episode steps: 9, steps per second: 60, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.000 [0.000, 0.000], mean observation: 0.143 [-1.769, 2.833], loss: 3.005422, mean_absolute_error: 4.917148, mean_q: 8.451508
 2749/5000: episode: 269, duration: 0.165s, episode steps: 10, steps per second: 61, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean a

 3012/5000: episode: 295, duration: 0.234s, episode steps: 14, steps per second: 60, episode reward: 14.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.429 [0.000, 1.000], mean observation: 0.081 [-1.023, 1.457], loss: 3.111619, mean_absolute_error: 4.998004, mean_q: 8.424567
 3027/5000: episode: 296, duration: 0.252s, episode steps: 15, steps per second: 60, episode reward: 15.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.400 [0.000, 1.000], mean observation: 0.083 [-0.799, 1.327], loss: 2.879693, mean_absolute_error: 4.878584, mean_q: 8.192976
 3036/5000: episode: 297, duration: 0.146s, episode steps: 9, steps per second: 62, episode reward: 9.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.333 [0.000, 1.000], mean observation: 0.123 [-0.797, 1.435], loss: 2.765136, mean_absolute_error: 4.941566, mean_q: 8.411473
 3047/5000: episode: 298, duration: 0.183s, episode steps: 11, steps per second: 60, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean a

 3317/5000: episode: 325, duration: 0.132s, episode steps: 8, steps per second: 60, episode reward: 8.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.125 [0.000, 1.000], mean observation: 0.143 [-1.405, 2.206], loss: 2.413092, mean_absolute_error: 4.678598, mean_q: 7.973662
 3330/5000: episode: 326, duration: 0.216s, episode steps: 13, steps per second: 60, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.385 [0.000, 1.000], mean observation: 0.108 [-1.122, 1.757], loss: 2.092863, mean_absolute_error: 4.794716, mean_q: 8.233026
 3340/5000: episode: 327, duration: 0.167s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.300 [0.000, 1.000], mean observation: 0.145 [-0.959, 1.703], loss: 2.652409, mean_absolute_error: 4.745459, mean_q: 7.933897
 3354/5000: episode: 328, duration: 0.233s, episode steps: 14, steps per second: 60, episode reward: 14.000, mean reward: 1.000 [1.000, 1.000], mean a

 3637/5000: episode: 354, duration: 0.199s, episode steps: 12, steps per second: 60, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.417 [0.000, 1.000], mean observation: 0.114 [-0.740, 1.263], loss: 2.099151, mean_absolute_error: 4.555119, mean_q: 7.753989
 3650/5000: episode: 355, duration: 0.217s, episode steps: 13, steps per second: 60, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.385 [0.000, 1.000], mean observation: 0.085 [-0.998, 1.441], loss: 2.280572, mean_absolute_error: 4.712335, mean_q: 8.016583
 3664/5000: episode: 356, duration: 0.234s, episode steps: 14, steps per second: 60, episode reward: 14.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.357 [0.000, 1.000], mean observation: 0.073 [-0.960, 1.489], loss: 1.921081, mean_absolute_error: 4.693148, mean_q: 8.107881
 3674/5000: episode: 357, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean

 3977/5000: episode: 383, duration: 0.282s, episode steps: 17, steps per second: 60, episode reward: 17.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.471 [0.000, 1.000], mean observation: 0.059 [-0.980, 1.357], loss: 2.265724, mean_absolute_error: 4.618610, mean_q: 7.787598
 3990/5000: episode: 384, duration: 0.217s, episode steps: 13, steps per second: 60, episode reward: 13.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.462 [0.000, 1.000], mean observation: 0.089 [-0.831, 1.220], loss: 2.099079, mean_absolute_error: 4.671297, mean_q: 8.013646
 4009/5000: episode: 385, duration: 0.316s, episode steps: 19, steps per second: 60, episode reward: 19.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.474 [0.000, 1.000], mean observation: 0.068 [-0.780, 1.161], loss: 2.486734, mean_absolute_error: 4.690415, mean_q: 7.916396
 4021/5000: episode: 386, duration: 0.199s, episode steps: 12, steps per second: 60, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean

 4592/5000: episode: 412, duration: 0.200s, episode steps: 12, steps per second: 60, episode reward: 12.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.417 [0.000, 1.000], mean observation: 0.122 [-0.764, 1.415], loss: 2.491603, mean_absolute_error: 4.755912, mean_q: 8.060043
 4602/5000: episode: 413, duration: 0.166s, episode steps: 10, steps per second: 60, episode reward: 10.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.400 [0.000, 1.000], mean observation: 0.128 [-0.954, 1.579], loss: 1.840471, mean_absolute_error: 4.822448, mean_q: 8.403212
 4613/5000: episode: 414, duration: 0.183s, episode steps: 11, steps per second: 60, episode reward: 11.000, mean reward: 1.000 [1.000, 1.000], mean action: 0.455 [0.000, 1.000], mean observation: 0.132 [-0.740, 1.247], loss: 2.209652, mean_absolute_error: 4.782045, mean_q: 8.209985
 4630/5000: episode: 415, duration: 0.282s, episode steps: 17, steps per second: 60, episode reward: 17.000, mean reward: 1.000 [1.000, 1.000], mean

<keras.callbacks.History at 0x17761d7fba8>

In [5]:
dqn.test(env, nb_episodes=5, visualize=True)

Testing for 5 episodes ...
Episode 1: reward: 20.000, steps: 20
Episode 2: reward: 18.000, steps: 18
Episode 3: reward: 15.000, steps: 15
Episode 4: reward: 17.000, steps: 17
Episode 5: reward: 17.000, steps: 17


<keras.callbacks.History at 0x17667c14e80>