In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import gym
import matplotlib.pyplot as plt

In [2]:
# 相关超参数
BATCH_SIZE = 32
LR = 0.01
EPSILON = 0.9
GAMMA = 0.9
TARGET_REPLACE_ITER = 100
MEMORY_CAPACITY = 2000
env = gym.make('CartPole-v0').unwrapped
N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape[0]

In [3]:
# 定义Net
class Net(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        self.fc1 = tf.keras.layers.Dense(units=50,
                                        kernel_initializer=tf.keras.initializers.random_normal(mean=0.0,stddev=0.01),
                                        input_shape=(N_STATES,), activation='relu')
        self.fc2 = tf.keras.layers.Dense(units=N_ACTIONS,
                                        kernel_initializer=tf.keras.initializers.random_normal(mean=0.0,stddev=0.01))
    def call(self, state):
        y = self.fc1(state)
        actions_value = self.fc2(y)
        return actions_value

In [6]:
# 定义DQN类
class Agent_DQN(object):
    def __init__(self):
        self.eval_net, self.target_net = Net(), Net()
        self.step = 0
        self.memory_counter = 0
        self.memory = np.zeros(shape=(MEMORY_CAPACITY, N_STATES * 2 + 2))
        self.optimizer = tf.keras.optimizers.Adam((LR))
        self.loss = tf.keras.losses.MSE
        self.eval_net.compile(optimizer=self.optimizer,
                         loss=self.loss)
        self.target_net.compile(optimizer=self.optimizer,
                         loss=self.loss)
        
        self.target_net(tf.ones((1,N_STATES)))
        self.eval_net(tf.ones((1,N_STATES)))
        self.eval_net.set_weights(self.target_net.get_weights())
       
        
    
    def choice_action(self, state):
        '''
        state: array([1,2,3,4])
        '''
        # 如果小于epsilon则选择最优的策略
        state = np.expand_dims(state, axis=0)
        if np.random.rand() < EPSILON:
            return np.argmax(np.array(self.eval_net(state)))
        else:
            return np.random.randint(2)
    
    def store_transition(self, state, action, reward, next_state):
        transition = np.hstack([state, action, reward, next_state])
        index = self.memory_counter % MEMORY_CAPACITY
        self.memory[index] = transition
        self.memory_counter += 1
    
    def learn(self):
        if self.step % 100 == 0:
            self.target_net.set_weights(self.eval_net.get_weights())
        self.step += 1
        # 抽取数据
        index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
        b_memory = self.memory[index, : ]
        b_state = b_memory[:, :N_STATES]
        b_action = b_memory[:, N_STATES: N_STATES+1]
        b_reward = b_memory[:, N_STATES+1: N_STATES+2]
        b_next_state = b_memory[:, N_STATES+2:]
        
        q_target = self.target_net(b_next_state)
        U = tf.constant(b_reward, dtype=tf.float32) + GAMMA * tf.reduce_max(q_target, axis=1, keepdims=True)
        self.eval_net.fit(b_state, U)

In [None]:
Agent = Agent_DQN()
episode_rewards = []
for i in range(500):
    if i % 10 == 0:
        print('<<<<<<<<<Episode: %s' % i)
    # 重置环境
    state = env.reset()                                                     
    # 初始化该循环对应的episode的总奖励
    episode_reward_sum = 0                                              

    while True:
        #env.render()
        action = Agent.choice_action(state)
        next_state, reward, done, _ = env.step(action)
        
        # 修改奖励 (不修改也可以，修改奖励只是为了更快地得到训练好的摆杆)
        x, x_dot, theta, theta_dot = next_state
        r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
        r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
        new_reward = r1 + r2
        # 存储样本
        Agent.store_transition(state, action, reward, next_state)
        # 逐步加上一个episode内每个step的reward
        episode_reward_sum += reward
        # 更新状态
        state = next_state                                                
        # 如果累计的transition数量超过了记忆库的固定容量2000
        if Agent.memory_counter > 2000:
            # 开始学习 (抽取记忆，即32个transition，并对评估网络参数进行更新，并在开始学习后每隔100次将评估网络的参数赋给目标网络)
            Agent.learn()
        if done:       # 如果done为True
            # round()方法返回episode_reward_sum的小数点四舍五入到2个数字
            print('episode%s---reward_sum: %s' % (i, round(episode_reward_sum, 2)))
            episode_rewards.append(episode_reward_sum)
            break                                             # 该episode结束

<<<<<<<<<Episode: 0
episode0---reward_sum: 9.0
episode1---reward_sum: 8.0
episode2---reward_sum: 10.0
episode3---reward_sum: 9.0
episode4---reward_sum: 8.0
episode5---reward_sum: 10.0
episode6---reward_sum: 8.0
episode7---reward_sum: 9.0
episode8---reward_sum: 9.0
episode9---reward_sum: 9.0
<<<<<<<<<Episode: 10
episode10---reward_sum: 9.0
episode11---reward_sum: 10.0
episode12---reward_sum: 18.0
episode13---reward_sum: 16.0
episode14---reward_sum: 8.0
episode15---reward_sum: 11.0
episode16---reward_sum: 9.0
episode17---reward_sum: 10.0
episode18---reward_sum: 9.0
episode19---reward_sum: 12.0
<<<<<<<<<Episode: 20
episode20---reward_sum: 9.0
episode21---reward_sum: 8.0
episode22---reward_sum: 9.0
episode23---reward_sum: 9.0
episode24---reward_sum: 8.0
episode25---reward_sum: 10.0
episode26---reward_sum: 13.0
episode27---reward_sum: 9.0
episode28---reward_sum: 9.0
episode29---reward_sum: 8.0
<<<<<<<<<Episode: 30
episode30---reward_sum: 9.0
episode31---reward_sum: 9.0
episode32---reward_su

episode190---reward_sum: 37.0
episode191---reward_sum: 27.0
episode192---reward_sum: 10.0
episode193---reward_sum: 11.0
episode194---reward_sum: 11.0
episode195---reward_sum: 15.0
episode196---reward_sum: 14.0
episode197---reward_sum: 14.0


episode198---reward_sum: 10.0
episode199---reward_sum: 16.0
<<<<<<<<<Episode: 200
episode200---reward_sum: 14.0
episode201---reward_sum: 12.0
episode202---reward_sum: 9.0
episode203---reward_sum: 9.0
episode204---reward_sum: 10.0
episode205---reward_sum: 11.0
episode206---reward_sum: 14.0


episode207---reward_sum: 15.0
episode208---reward_sum: 16.0
episode209---reward_sum: 16.0
<<<<<<<<<Episode: 210
episode210---reward_sum: 10.0
episode211---reward_sum: 13.0
episode212---reward_sum: 12.0
episode213---reward_sum: 11.0
episode214---reward_sum: 13.0
episode215---reward_sum: 12.0


episode216---reward_sum: 12.0
episode217---reward_sum: 9.0
episode218---reward_sum: 10.0
episode219---reward_sum: 10.0
<<<<<<<<<Episode: 220
episode220---reward_sum: 13.0
episode221---reward_sum: 16.0
episode222---reward_sum: 18.0
episode223---reward_sum: 14.0
episode224---reward_sum: 13.0


episode225---reward_sum: 15.0
episode226---reward_sum: 10.0
episode227---reward_sum: 13.0
episode228---reward_sum: 28.0
episode229---reward_sum: 26.0
<<<<<<<<<Episode: 230
episode230---reward_sum: 14.0
episode231---reward_sum: 14.0


episode232---reward_sum: 14.0
episode233---reward_sum: 19.0
episode234---reward_sum: 14.0
episode235---reward_sum: 16.0
episode236---reward_sum: 13.0
episode237---reward_sum: 17.0
episode238---reward_sum: 16.0
episode239---reward_sum: 15.0
<<<<<<<<<Episode: 240


episode240---reward_sum: 20.0
episode241---reward_sum: 13.0
episode242---reward_sum: 20.0
episode243---reward_sum: 20.0
episode244---reward_sum: 12.0
episode245---reward_sum: 14.0
episode246---reward_sum: 16.0


episode247---reward_sum: 15.0
episode248---reward_sum: 17.0
episode249---reward_sum: 21.0
<<<<<<<<<Episode: 250
episode250---reward_sum: 17.0
episode251---reward_sum: 15.0
episode252---reward_sum: 22.0


episode253---reward_sum: 20.0
episode254---reward_sum: 19.0
episode255---reward_sum: 20.0
episode256---reward_sum: 26.0
episode257---reward_sum: 23.0
episode258---reward_sum: 21.0


episode259---reward_sum: 15.0
<<<<<<<<<Episode: 260
episode260---reward_sum: 14.0
episode261---reward_sum: 15.0
episode262---reward_sum: 22.0
episode263---reward_sum: 15.0
episode264---reward_sum: 22.0


episode265---reward_sum: 17.0
episode266---reward_sum: 18.0
episode267---reward_sum: 22.0
episode268---reward_sum: 18.0
episode269---reward_sum: 14.0
<<<<<<<<<Episode: 270
episode270---reward_sum: 13.0
episode271---reward_sum: 20.0


episode272---reward_sum: 23.0
episode273---reward_sum: 18.0
episode274---reward_sum: 21.0
episode275---reward_sum: 21.0
episode276---reward_sum: 30.0
episode277---reward_sum: 17.0


episode278---reward_sum: 26.0
episode279---reward_sum: 18.0
<<<<<<<<<Episode: 280
episode280---reward_sum: 15.0
episode281---reward_sum: 27.0
episode282---reward_sum: 22.0


episode283---reward_sum: 26.0
episode284---reward_sum: 31.0
episode285---reward_sum: 16.0
episode286---reward_sum: 15.0
episode287---reward_sum: 25.0
episode288---reward_sum: 16.0


episode289---reward_sum: 15.0
<<<<<<<<<Episode: 290
episode290---reward_sum: 16.0
episode291---reward_sum: 13.0
episode292---reward_sum: 8.0
episode293---reward_sum: 14.0
episode294---reward_sum: 19.0
episode295---reward_sum: 18.0
episode296---reward_sum: 17.0


episode297---reward_sum: 9.0
episode298---reward_sum: 14.0
episode299---reward_sum: 15.0
<<<<<<<<<Episode: 300
episode300---reward_sum: 22.0
episode301---reward_sum: 9.0
episode302---reward_sum: 14.0
episode303---reward_sum: 10.0
episode304---reward_sum: 14.0
episode305---reward_sum: 11.0


episode306---reward_sum: 9.0
episode307---reward_sum: 16.0
episode308---reward_sum: 10.0
episode309---reward_sum: 9.0
<<<<<<<<<Episode: 310
episode310---reward_sum: 14.0
episode311---reward_sum: 9.0
episode312---reward_sum: 18.0
episode313---reward_sum: 17.0
episode314---reward_sum: 10.0


episode315---reward_sum: 15.0
episode316---reward_sum: 9.0
episode317---reward_sum: 21.0
episode318---reward_sum: 9.0
episode319---reward_sum: 8.0
<<<<<<<<<Episode: 320
episode320---reward_sum: 21.0
episode321---reward_sum: 21.0
episode322---reward_sum: 10.0
episode323---reward_sum: 10.0


episode324---reward_sum: 20.0
episode325---reward_sum: 33.0
episode326---reward_sum: 9.0
episode327---reward_sum: 12.0
episode328---reward_sum: 11.0
episode329---reward_sum: 14.0
<<<<<<<<<Episode: 330
episode330---reward_sum: 14.0


episode331---reward_sum: 12.0
episode332---reward_sum: 11.0
episode333---reward_sum: 23.0
episode334---reward_sum: 21.0
episode335---reward_sum: 12.0
episode336---reward_sum: 11.0
episode337---reward_sum: 14.0
episode338---reward_sum: 9.0


episode339---reward_sum: 10.0
<<<<<<<<<Episode: 340
episode340---reward_sum: 10.0
episode341---reward_sum: 20.0
episode342---reward_sum: 39.0
episode343---reward_sum: 15.0
episode344---reward_sum: 11.0


episode345---reward_sum: 29.0
episode346---reward_sum: 20.0
episode347---reward_sum: 20.0
episode348---reward_sum: 10.0
episode349---reward_sum: 15.0
<<<<<<<<<Episode: 350
episode350---reward_sum: 10.0
episode351---reward_sum: 10.0
episode352---reward_sum: 12.0


episode353---reward_sum: 11.0
episode354---reward_sum: 9.0
episode355---reward_sum: 21.0
episode356---reward_sum: 15.0
episode357---reward_sum: 31.0
episode358---reward_sum: 22.0


episode359---reward_sum: 12.0
<<<<<<<<<Episode: 360
episode360---reward_sum: 15.0
episode361---reward_sum: 23.0
episode362---reward_sum: 15.0
episode363---reward_sum: 48.0


episode364---reward_sum: 14.0
episode365---reward_sum: 18.0
episode366---reward_sum: 13.0
episode367---reward_sum: 13.0
episode368---reward_sum: 13.0
episode369---reward_sum: 35.0
<<<<<<<<<Episode: 370


episode370---reward_sum: 18.0
episode371---reward_sum: 37.0
episode372---reward_sum: 12.0
episode373---reward_sum: 32.0
episode374---reward_sum: 16.0


episode375---reward_sum: 19.0
episode376---reward_sum: 23.0
episode377---reward_sum: 30.0
episode378---reward_sum: 20.0
episode379---reward_sum: 30.0
<<<<<<<<<Episode: 380


episode380---reward_sum: 17.0
episode381---reward_sum: 27.0
episode382---reward_sum: 34.0
episode383---reward_sum: 34.0


episode384---reward_sum: 16.0
episode385---reward_sum: 18.0
episode386---reward_sum: 42.0
episode387---reward_sum: 25.0


episode388---reward_sum: 28.0
episode389---reward_sum: 41.0
<<<<<<<<<Episode: 390
episode390---reward_sum: 16.0
episode391---reward_sum: 15.0


episode392---reward_sum: 40.0
episode393---reward_sum: 18.0
episode394---reward_sum: 17.0
episode395---reward_sum: 16.0
episode396---reward_sum: 40.0
episode397---reward_sum: 22.0


episode398---reward_sum: 16.0
episode399---reward_sum: 61.0
<<<<<<<<<Episode: 400


episode400---reward_sum: 41.0
episode401---reward_sum: 19.0
episode402---reward_sum: 11.0
episode403---reward_sum: 37.0
episode404---reward_sum: 25.0
episode405---reward_sum: 15.0


episode406---reward_sum: 21.0
episode407---reward_sum: 34.0
episode408---reward_sum: 30.0
episode409---reward_sum: 22.0
<<<<<<<<<Episode: 410


episode410---reward_sum: 16.0
episode411---reward_sum: 13.0
episode412---reward_sum: 37.0
episode413---reward_sum: 24.0
episode414---reward_sum: 10.0
episode415---reward_sum: 31.0


episode416---reward_sum: 19.0
episode417---reward_sum: 53.0
episode418---reward_sum: 40.0


episode419---reward_sum: 35.0
<<<<<<<<<Episode: 420
episode420---reward_sum: 37.0
episode421---reward_sum: 21.0


episode422---reward_sum: 69.0
episode423---reward_sum: 26.0
episode424---reward_sum: 51.0


episode425---reward_sum: 29.0
episode426---reward_sum: 14.0
episode427---reward_sum: 38.0
episode428---reward_sum: 19.0


episode429---reward_sum: 33.0
<<<<<<<<<Episode: 430
episode430---reward_sum: 39.0
episode431---reward_sum: 31.0


episode432---reward_sum: 44.0
episode433---reward_sum: 28.0
episode434---reward_sum: 20.0
episode435---reward_sum: 44.0


episode436---reward_sum: 31.0
episode437---reward_sum: 45.0
episode438---reward_sum: 17.0
episode439---reward_sum: 14.0
<<<<<<<<<Episode: 440
episode440---reward_sum: 25.0


episode441---reward_sum: 23.0
episode442---reward_sum: 21.0
episode443---reward_sum: 37.0
episode444---reward_sum: 29.0


episode445---reward_sum: 32.0
episode446---reward_sum: 31.0
episode447---reward_sum: 42.0


episode448---reward_sum: 24.0
episode449---reward_sum: 32.0
<<<<<<<<<Episode: 450
episode450---reward_sum: 14.0
episode451---reward_sum: 24.0
episode452---reward_sum: 32.0


episode453---reward_sum: 36.0
episode454---reward_sum: 41.0
episode455---reward_sum: 43.0


episode456---reward_sum: 20.0
episode457---reward_sum: 20.0
episode458---reward_sum: 17.0
episode459---reward_sum: 22.0
<<<<<<<<<Episode: 460
episode460---reward_sum: 19.0
episode461---reward_sum: 18.0


episode462---reward_sum: 13.0
episode463---reward_sum: 35.0
episode464---reward_sum: 31.0
episode465---reward_sum: 16.0


episode466---reward_sum: 39.0
episode467---reward_sum: 26.0
episode468---reward_sum: 13.0


In [None]:
fig, axes = plt.subplots(1,1,figsize=(8,6))
axes.plot(episode_rewards)
axes.set_title('episode_rewards')