Today, we look at deep reinforcement learning and implement a deep Q-learning agent. As an exercise, you can try implementing an actor-critic agent.

# Deep Q-learning

Implementing a deep Q-learning agent is quite straightforward. We just need to define a neural network model that will predict the values of the Q function. Then, we need to remember the agent's experiences in a replay buffer and use them to fit the network weights after each simulation of the game. For that, we can use the common `fit` method with mean squared error loss to train the network. We just need to compute the target value correctly (according to the Bellmann equations). 

In [3]:
# loosely based on https://keon.io/deep-q-learning/

import random

import tensorflow as tf
import numpy as np
import gym

class DQNAgent:
    
    def __init__(self, num_inputs, num_outputs, batch_size = 32, num_batches = 16):
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.batch_size = batch_size
        self.num_batches = num_batches
        self.eps = 1.0
        self.eps_decay = 0.995
        self.gamma = 0.95
        self.exp_buffer = []
        self.build_model()
    
    # vytvari model Q-site
    def build_model(self):
        self.model = tf.keras.models.Sequential([tf.keras.layers.Dense(24, input_dim=self.num_inputs),
                                                 tf.keras.layers.ReLU(),
                                                 tf.keras.layers.Dense(24, activation=tf.nn.relu),
                                                 tf.keras.layers.ReLU(),
                                                 tf.keras.layers.Dense(self.num_outputs, activation='linear')])
        opt = tf.keras.optimizers.Adam(lr=0.001)
        self.model.compile(optimizer=opt, loss='mse', run_eagerly=True)
        
    # returns agent's action - epsilon greedy when training and the best one otherwise
    def action(self, state, train=False):
        if train and np.random.uniform() < self.eps:
            return np.random.randint(self.num_outputs)
        else: 
            return np.argmax(self.model.predict(state, verbose=False)[0])
        
    # save experience to buffer
    def record_experience(self, exp):
        self.exp_buffer.append(exp)
        if len(self.exp_buffer) > 2000:
            self.exp_buffer = self.exp_buffer[-2000:]
    
    # train based on buffer
    def train(self):
        import pprint
        if (len(self.exp_buffer) <= self.batch_size):
            return
        
        for _ in range(self.num_batches):
            batch = random.sample(self.exp_buffer, self.batch_size)
            #pprint.pprint(batch)
            states = np.array([s for (s, _, _, _, _) in batch])
            next_states = np.array([ns for (_, _, _, ns, _) in batch])
            states = states.reshape((-1, self.num_inputs))
            next_states = next_states.reshape((-1, self.num_inputs))
            pred = self.model.predict(states, verbose=False)
            next_pred = self.model.predict(next_states, verbose=False)
            # compute the target values
            for i, (s, a, r, ns, d) in enumerate(batch):
                pred[i][a] = r
                if not d:
                    pred[i][a] = r + self.gamma*np.amax(next_pred[i])

            self.model.fit(states, pred, epochs=1, verbose=False)
        # decrease epsilon for the epsilon-greedy strategy
        if self.eps > 0.01:
            self.eps = self.eps*self.eps_decay

# create agent (4 inputs, 2 actions)
agent = DQNAgent(4, 2)

env = gym.make("CartPole-v1")
print(env.action_space)

# train the network on 1000 runs of the environment
rewards = []
for i in range(1000):
    obs, _ = env.reset()
    obs = np.reshape(obs, newshape=(1, -1))
    done = False
    terminated = False
    R = 0
    t = 0
    while not (done or terminated):
        old_state = obs
        action = agent.action(obs, train=True)
        obs, r, done, terminated, _ = env.step(action)
        #print(obs, r, done, terminated)
        R += r
        t += 1
        r = r if not (done or terminated) else 10
        obs = np.reshape(obs, newshape=(1, -1))
        #print((old_state, action, r, obs, done or terminated))
        agent.record_experience((old_state, action, r, obs, done or terminated))
    agent.train()
    
    rewards.append(R)
    print(i, R)


Discrete(2)
0 15.0
1 11.0
2 12.0
3 25.0
4 14.0
5 12.0
6 16.0
7 11.0
8 21.0
9 18.0
10 14.0
11 18.0
12 26.0
13 34.0
14 38.0
15 14.0
16 28.0
17 22.0
18 44.0
19 15.0
20 15.0
21 25.0
22 19.0
23 14.0
24 13.0
25 13.0
26 32.0
27 12.0
28 15.0
29 41.0
30 11.0
31 16.0
32 15.0
33 17.0
34 15.0
35 21.0
36 17.0
37 10.0
38 10.0
39 17.0
40 31.0
41 27.0
42 11.0
43 20.0
44 10.0
45 20.0
46 26.0
47 19.0
48 19.0
49 14.0
50 24.0
51 10.0
52 9.0
53 21.0
54 16.0
55 23.0
56 13.0
57 34.0
58 44.0
59 10.0
60 9.0
61 48.0
62 17.0
63 27.0
64 29.0
65 22.0
66 42.0
67 30.0
68 31.0
69 69.0
70 17.0
71 19.0
72 30.0
73 39.0
74 16.0
75 19.0
76 26.0
77 49.0
78 59.0
79 34.0
80 19.0
81 19.0
82 57.0
83 83.0
84 46.0
85 11.0
86 40.0
87 44.0
88 34.0
89 19.0
90 41.0
91 31.0
92 51.0
93 24.0
94 44.0
95 57.0
96 24.0
97 56.0
98 28.0
99 39.0
100 25.0
101 13.0
102 15.0
103 30.0
104 31.0
105 36.0
106 29.0
107 41.0
108 15.0
109 43.0
110 38.0
111 64.0
112 33.0
113 15.0
114 41.0
115 34.0
116 52.0
117 15.0
118 39.0
119 59.0
120 30.0
121 44.0
12

KeyboardInterrupt: 

Let us test, how well the agent solves the problem.

In [5]:
env = gym.make("CartPole-v1")
print(env.action_space)

obs, _ = env.reset()
obs = np.reshape(obs, newshape=(1, -1))
done = False
R = 0
t = 0
while not (done or terminated):
    old_state = obs
    action = agent.action(obs, train=False)
    obs, r, done, terminated, _ = env.step(action)
    obs = np.reshape(obs, newshape=(1, -1))
    R += r
    t += 1
        
print(R)

Discrete(2)
187.0


In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(12,8))
plt.plot(rewards)
plt.ylabel('Reward')
plt.xlabel('Episode')
plt.show()

# Exercise (not a homework)

Choose one of the problems in OpenAI gym with continuous actions and try to solve it with the actor-critic approach.